from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Generic, Literal

from mcp.server.elicitation import (
    CancelledElicitation,
    DeclinedElicitation,
)
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import TypeVar

from fastmcp.utilities.json_schema import compress_schema
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.types import get_cached_typeadapter

__all__ = [
    "AcceptedElicitation",
    "CancelledElicitation",
    "DeclinedElicitation",
    "ScalarElicitationType",
    "get_elicitation_schema",
]

logger = get_logger(__name__)

T = TypeVar("T", default=Any)


class ElicitationJsonSchema(GenerateJsonSchema):
    """Custom JSON schema generator for MCP elicitation that always inlines enums.

    MCP elicitation requires inline enum schemas without $ref/$defs references.
    This generator ensures enums are always generated inline for compatibility.
    Optionally adds enumNames for better UI display when available.
    """

    def generate_inner(self, schema: core_schema.CoreSchema) -> JsonSchemaValue:
        """Override to prevent ref generation for enums."""
        # For enum schemas, bypass the ref mechanism entirely
        if schema["type"] == "enum":
            # Directly call our custom enum_schema without going through handler
            # This prevents the ref/defs mechanism from being invoked
            return self.enum_schema(schema)
        # For all other types, use the default implementation
        return super().generate_inner(schema)

    def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
        """Generate inline enum schema with optional enumNames for better UI.

        If enum members have a _display_name_ attribute or custom __str__,
        we'll include enumNames for better UI representation.
        """
        # Get the base schema from parent
        result = super().enum_schema(schema)

        # Try to add enumNames if the enum has display-friendly names
        enum_cls = schema.get("cls")
        if enum_cls:
            members = schema.get("members", [])
            enum_names = []
            has_custom_names = False

            for member in members:
                # Check if member has a custom display name attribute
                if hasattr(member, "_display_name_"):
                    enum_names.append(member._display_name_)
                    has_custom_names = True
                # Or use the member name with better formatting
                else:
                    # Convert SNAKE_CASE to Title Case for display
                    display_name = member.name.replace("_", " ").title()
                    enum_names.append(display_name)
                    if display_name != member.value:
                        has_custom_names = True

            # Only add enumNames if they differ from the values
            if has_custom_names:
                result["enumNames"] = enum_names

        return result


# we can't use the low-level AcceptedElicitation because it only works with BaseModels
class AcceptedElicitation(BaseModel, Generic[T]):
    """Result when user accepts the elicitation."""

    action: Literal["accept"] = "accept"
    data: T


@dataclass
class ScalarElicitationType(Generic[T]):
    value: T


def get_elicitation_schema(response_type: type[T]) -> dict[str, Any]:
    """Get the schema for an elicitation response.

    Args:
        response_type: The type of the response
    """

    # Use custom schema generator that inlines enums for MCP compatibility
    schema = get_cached_typeadapter(response_type).json_schema(
        schema_generator=ElicitationJsonSchema
    )
    schema = compress_schema(schema)

    # Validate the schema to ensure it follows MCP elicitation requirements
    validate_elicitation_json_schema(schema)

    return schema


def validate_elicitation_json_schema(schema: dict[str, Any]) -> None:
    """Validate that a JSON schema follows MCP elicitation requirements.

    This ensures the schema is compatible with MCP elicitation requirements:
    - Must be an object schema
    - Must only contain primitive field types (string, number, integer, boolean)
    - Must be flat (no nested objects or arrays of objects)
    - Allows const fields (for Literal types) and enum fields (for Enum types)
    - Only primitive types and their nullable variants are allowed

    Args:
        schema: The JSON schema to validate

    Raises:
        TypeError: If the schema doesn't meet MCP elicitation requirements
    """
    ALLOWED_TYPES = {"string", "number", "integer", "boolean"}

    # Check that the schema is an object
    if schema.get("type") != "object":
        raise TypeError(
            f"Elicitation schema must be an object schema, got type '{schema.get('type')}'. "
            "Elicitation schemas are limited to flat objects with primitive properties only."
        )

    properties = schema.get("properties", {})

    for prop_name, prop_schema in properties.items():
        prop_type = prop_schema.get("type")

        # Handle nullable types
        if isinstance(prop_type, list):
            if "null" in prop_type:
                prop_type = [t for t in prop_type if t != "null"]
                if len(prop_type) == 1:
                    prop_type = prop_type[0]
        elif prop_schema.get("nullable", False):
            continue  # Nullable with no other type is fine

        # Handle const fields (Literal types)
        if "const" in prop_schema:
            continue  # const fields are allowed regardless of type

        # Handle enum fields (Enum types)
        if "enum" in prop_schema:
            continue  # enum fields are allowed regardless of type

        # Handle references to definitions (like Enum types)
        if "$ref" in prop_schema:
            # Get the referenced definition
            ref_path = prop_schema["$ref"]
            if ref_path.startswith("#/$defs/"):
                def_name = ref_path[8:]  # Remove "#/$defs/" prefix
                ref_def = schema.get("$defs", {}).get(def_name, {})
                # If the referenced definition has an enum, it's allowed
                if "enum" in ref_def:
                    continue
                # If the referenced definition has a type that's allowed, it's allowed
                ref_type = ref_def.get("type")
                if ref_type in ALLOWED_TYPES:
                    continue
            # If we can't determine what the ref points to, reject it for safety
            raise TypeError(
                f"Elicitation schema field '{prop_name}' contains a reference '{ref_path}' "
                "that could not be validated. Only references to enum types or primitive types are allowed."
            )

        # Handle union types (oneOf/anyOf)
        if "oneOf" in prop_schema or "anyOf" in prop_schema:
            union_schemas = prop_schema.get("oneOf", []) + prop_schema.get("anyOf", [])
            for union_schema in union_schemas:
                # Allow const and enum in unions
                if "const" in union_schema or "enum" in union_schema:
                    continue
                union_type = union_schema.get("type")
                if union_type not in ALLOWED_TYPES:
                    raise TypeError(
                        f"Elicitation schema field '{prop_name}' has union type '{union_type}' which is not "
                        f"a primitive type. Only {ALLOWED_TYPES} are allowed in elicitation schemas."
                    )
            continue

        # Check if it's a primitive type
        if prop_type not in ALLOWED_TYPES:
            raise TypeError(
                f"Elicitation schema field '{prop_name}' has type '{prop_type}' which is not "
                f"a primitive type. Only {ALLOWED_TYPES} are allowed in elicitation schemas."
            )

        # Check for nested objects or arrays of objects (not allowed)
        if prop_type == "object":
            raise TypeError(
                f"Elicitation schema field '{prop_name}' is an object, but nested objects are not allowed. "
                "Elicitation schemas must be flat objects with primitive properties only."
            )

        if prop_type == "array":
            items_schema = prop_schema.get("items", {})
            if items_schema.get("type") == "object":
                raise TypeError(
                    f"Elicitation schema field '{prop_name}' is an array of objects, but arrays of objects are not allowed. "
                    "Elicitation schemas must be flat objects with primitive properties only."
                )
