"""Base classes for FastMCP prompts."""

from __future__ import annotations as _annotations

import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import Any

import pydantic_core
from mcp.types import ContentBlock, Icon, PromptMessage, Role, TextContent
from mcp.types import Prompt as MCPPrompt
from mcp.types import PromptArgument as MCPPromptArgument
from pydantic import Field, TypeAdapter

from fastmcp.exceptions import PromptError
from fastmcp.server.dependencies import get_context
from fastmcp.utilities.components import FastMCPComponent
from fastmcp.utilities.json_schema import compress_schema
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.types import (
    FastMCPBaseModel,
    find_kwarg_by_type,
    get_cached_typeadapter,
)

logger = get_logger(__name__)


def Message(
    content: str | ContentBlock, role: Role | None = None, **kwargs: Any
) -> PromptMessage:
    """A user-friendly constructor for PromptMessage."""
    if isinstance(content, str):
        content = TextContent(type="text", text=content)
    if role is None:
        role = "user"
    return PromptMessage(content=content, role=role, **kwargs)


message_validator = TypeAdapter[PromptMessage](PromptMessage)

SyncPromptResult = (
    str
    | PromptMessage
    | dict[str, Any]
    | Sequence[str | PromptMessage | dict[str, Any]]
)
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]


class PromptArgument(FastMCPBaseModel):
    """An argument that can be passed to a prompt."""

    name: str = Field(description="Name of the argument")
    description: str | None = Field(
        default=None, description="Description of what the argument does"
    )
    required: bool = Field(
        default=False, description="Whether the argument is required"
    )


class Prompt(FastMCPComponent):
    """A prompt template that can be rendered with parameters."""

    arguments: list[PromptArgument] | None = Field(
        default=None, description="Arguments that can be passed to the prompt"
    )

    def enable(self) -> None:
        super().enable()
        try:
            context = get_context()
            context._queue_prompt_list_changed()  # type: ignore[private-use]
        except RuntimeError:
            pass  # No context available

    def disable(self) -> None:
        super().disable()
        try:
            context = get_context()
            context._queue_prompt_list_changed()  # type: ignore[private-use]
        except RuntimeError:
            pass  # No context available

    def to_mcp_prompt(
        self,
        *,
        include_fastmcp_meta: bool | None = None,
        **overrides: Any,
    ) -> MCPPrompt:
        """Convert the prompt to an MCP prompt."""
        arguments = [
            MCPPromptArgument(
                name=arg.name,
                description=arg.description,
                required=arg.required,
            )
            for arg in self.arguments or []
        ]

        return MCPPrompt(
            name=overrides.get("name", self.name),
            description=overrides.get("description", self.description),
            arguments=arguments,
            title=overrides.get("title", self.title),
            icons=overrides.get("icons", self.icons),
            _meta=overrides.get(
                "_meta", self.get_meta(include_fastmcp_meta=include_fastmcp_meta)
            ),
        )

    @staticmethod
    def from_function(
        fn: Callable[..., PromptResult | Awaitable[PromptResult]],
        name: str | None = None,
        title: str | None = None,
        description: str | None = None,
        icons: list[Icon] | None = None,
        tags: set[str] | None = None,
        enabled: bool | None = None,
        meta: dict[str, Any] | None = None,
    ) -> FunctionPrompt:
        """Create a Prompt from a function.

        The function can return:
        - A string (converted to a message)
        - A Message object
        - A dict (converted to a message)
        - A sequence of any of the above
        """
        return FunctionPrompt.from_function(
            fn=fn,
            name=name,
            title=title,
            description=description,
            icons=icons,
            tags=tags,
            enabled=enabled,
            meta=meta,
        )

    async def render(
        self,
        arguments: dict[str, Any] | None = None,
    ) -> list[PromptMessage]:
        """Render the prompt with arguments.

        This method is not implemented in the base Prompt class and must be
        implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement render()")


class FunctionPrompt(Prompt):
    """A prompt that is a function."""

    fn: Callable[..., PromptResult | Awaitable[PromptResult]]

    @classmethod
    def from_function(
        cls,
        fn: Callable[..., PromptResult | Awaitable[PromptResult]],
        name: str | None = None,
        title: str | None = None,
        description: str | None = None,
        icons: list[Icon] | None = None,
        tags: set[str] | None = None,
        enabled: bool | None = None,
        meta: dict[str, Any] | None = None,
    ) -> FunctionPrompt:
        """Create a Prompt from a function.

        The function can return:
        - A string (converted to a message)
        - A Message object
        - A dict (converted to a message)
        - A sequence of any of the above
        """
        from fastmcp.server.context import Context

        func_name = name or getattr(fn, "__name__", None) or fn.__class__.__name__

        if func_name == "<lambda>":
            raise ValueError("You must provide a name for lambda functions")
            # Reject functions with *args or **kwargs
        sig = inspect.signature(fn)
        for param in sig.parameters.values():
            if param.kind == inspect.Parameter.VAR_POSITIONAL:
                raise ValueError("Functions with *args are not supported as prompts")
            if param.kind == inspect.Parameter.VAR_KEYWORD:
                raise ValueError("Functions with **kwargs are not supported as prompts")

        description = description or inspect.getdoc(fn)

        # if the fn is a callable class, we need to get the __call__ method from here out
        if not inspect.isroutine(fn):
            fn = fn.__call__
        # if the fn is a staticmethod, we need to work with the underlying function
        if isinstance(fn, staticmethod):
            fn = fn.__func__

        type_adapter = get_cached_typeadapter(fn)
        parameters = type_adapter.json_schema()

        # Auto-detect context parameter if not provided

        context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context)
        prune_params = [context_kwarg] if context_kwarg else None

        parameters = compress_schema(parameters, prune_params=prune_params)

        # Convert parameters to PromptArguments
        arguments: list[PromptArgument] = []
        if "properties" in parameters:
            for param_name, param in parameters["properties"].items():
                arg_description = param.get("description")

                # For non-string parameters, append JSON schema info to help users
                # understand the expected format when passing as strings (MCP requirement)
                if param_name in sig.parameters:
                    sig_param = sig.parameters[param_name]
                    if (
                        sig_param.annotation != inspect.Parameter.empty
                        and sig_param.annotation is not str
                        and param_name != context_kwarg
                    ):
                        # Get the JSON schema for this specific parameter type
                        try:
                            param_adapter = get_cached_typeadapter(sig_param.annotation)
                            param_schema = param_adapter.json_schema()

                            # Create compact schema representation
                            schema_str = json.dumps(param_schema, separators=(",", ":"))

                            # Append schema info to description
                            schema_note = f"Provide as a JSON string matching the following schema: {schema_str}"
                            if arg_description:
                                arg_description = f"{arg_description}\n\n{schema_note}"
                            else:
                                arg_description = schema_note
                        except Exception:
                            # If schema generation fails, skip enhancement
                            pass

                arguments.append(
                    PromptArgument(
                        name=param_name,
                        description=arg_description,
                        required=param_name in parameters.get("required", []),
                    )
                )

        return cls(
            name=func_name,
            title=title,
            description=description,
            icons=icons,
            arguments=arguments,
            tags=tags or set(),
            enabled=enabled if enabled is not None else True,
            fn=fn,
            meta=meta,
        )

    def _convert_string_arguments(self, kwargs: dict[str, Any]) -> dict[str, Any]:
        """Convert string arguments to expected types based on function signature."""
        from fastmcp.server.context import Context

        sig = inspect.signature(self.fn)
        converted_kwargs = {}

        # Find context parameter name if any
        context_param_name = find_kwarg_by_type(self.fn, kwarg_type=Context)

        for param_name, param_value in kwargs.items():
            if param_name in sig.parameters:
                param = sig.parameters[param_name]

                # Skip Context parameters - they're handled separately
                if param_name == context_param_name:
                    converted_kwargs[param_name] = param_value
                    continue

                # If parameter has no annotation or annotation is str, pass as-is
                if (
                    param.annotation == inspect.Parameter.empty
                    or param.annotation is str
                ) or not isinstance(param_value, str):
                    converted_kwargs[param_name] = param_value
                else:
                    # Try to convert string argument using type adapter
                    try:
                        adapter = get_cached_typeadapter(param.annotation)
                        # Try JSON parsing first for complex types
                        try:
                            converted_kwargs[param_name] = adapter.validate_json(
                                param_value
                            )
                        except (ValueError, TypeError, pydantic_core.ValidationError):
                            # Fallback to direct validation
                            converted_kwargs[param_name] = adapter.validate_python(
                                param_value
                            )
                    except (ValueError, TypeError, pydantic_core.ValidationError) as e:
                        # If conversion fails, provide informative error
                        raise PromptError(
                            f"Could not convert argument '{param_name}' with value '{param_value}' "
                            f"to expected type {param.annotation}. Error: {e}"
                        ) from e
            else:
                # Parameter not in function signature, pass as-is
                converted_kwargs[param_name] = param_value

        return converted_kwargs

    async def render(
        self,
        arguments: dict[str, Any] | None = None,
    ) -> list[PromptMessage]:
        """Render the prompt with arguments."""
        from fastmcp.server.context import Context

        # Validate required arguments
        if self.arguments:
            required = {arg.name for arg in self.arguments if arg.required}
            provided = set(arguments or {})
            missing = required - provided
            if missing:
                raise ValueError(f"Missing required arguments: {missing}")

        try:
            # Prepare arguments with context
            kwargs = arguments.copy() if arguments else {}
            context_kwarg = find_kwarg_by_type(self.fn, kwarg_type=Context)
            if context_kwarg and context_kwarg not in kwargs:
                kwargs[context_kwarg] = get_context()

            # Convert string arguments to expected types when needed
            kwargs = self._convert_string_arguments(kwargs)

            # Call function and check if result is a coroutine
            result = self.fn(**kwargs)
            if inspect.isawaitable(result):
                result = await result

            # Validate messages
            if not isinstance(result, list | tuple):
                result = [result]

            # Convert result to messages
            messages: list[PromptMessage] = []
            for msg in result:
                try:
                    if isinstance(msg, PromptMessage):
                        messages.append(msg)
                    elif isinstance(msg, str):
                        messages.append(
                            PromptMessage(
                                role="user",
                                content=TextContent(type="text", text=msg),
                            )
                        )
                    else:
                        content = pydantic_core.to_json(msg, fallback=str).decode()
                        messages.append(
                            PromptMessage(
                                role="user",
                                content=TextContent(type="text", text=content),
                            )
                        )
                except Exception as e:
                    raise PromptError(
                        "Could not convert prompt result to message."
                    ) from e

            return messages
        except Exception as e:
            logger.exception(f"Error rendering prompt {self.name}")
            raise PromptError(f"Error rendering prompt {self.name}.") from e
