"""Error handling middleware for consistent error responses and tracking."""

import asyncio
import logging
import traceback
from collections.abc import Callable
from typing import Any

import anyio
from mcp import McpError
from mcp.types import ErrorData

from fastmcp.exceptions import NotFoundError

from .middleware import CallNext, Middleware, MiddlewareContext


class ErrorHandlingMiddleware(Middleware):
    """Middleware that provides consistent error handling and logging.

    Catches exceptions, logs them appropriately, and converts them to
    proper MCP error responses. Also tracks error patterns for monitoring.

    Example:
        ```python
        from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware
        import logging

        # Configure logging to see error details
        logging.basicConfig(level=logging.ERROR)

        mcp = FastMCP("MyServer")
        mcp.add_middleware(ErrorHandlingMiddleware())
        ```
    """

    def __init__(
        self,
        logger: logging.Logger | None = None,
        include_traceback: bool = False,
        error_callback: Callable[[Exception, MiddlewareContext], None] | None = None,
        transform_errors: bool = True,
    ):
        """Initialize error handling middleware.

        Args:
            logger: Logger instance for error logging. If None, uses 'fastmcp.errors'
            include_traceback: Whether to include full traceback in error logs
            error_callback: Optional callback function called for each error
            transform_errors: Whether to transform non-MCP errors to McpError
        """
        self.logger = logger or logging.getLogger("fastmcp.errors")
        self.include_traceback = include_traceback
        self.error_callback = error_callback
        self.transform_errors = transform_errors
        self.error_counts = {}

    def _log_error(self, error: Exception, context: MiddlewareContext) -> None:
        """Log error with appropriate detail level."""
        error_type = type(error).__name__
        method = context.method or "unknown"

        # Track error counts
        error_key = f"{error_type}:{method}"
        self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1

        base_message = f"Error in {method}: {error_type}: {error!s}"

        if self.include_traceback:
            self.logger.error(f"{base_message}\n{traceback.format_exc()}")
        else:
            self.logger.error(base_message)

        # Call custom error callback if provided
        if self.error_callback:
            try:
                self.error_callback(error, context)
            except Exception as callback_error:
                self.logger.error(f"Error in error callback: {callback_error}")

    def _transform_error(self, error: Exception) -> Exception:
        """Transform non-MCP errors to proper MCP errors."""
        if isinstance(error, McpError):
            return error

        if not self.transform_errors:
            return error

        # Map common exceptions to appropriate MCP error codes
        error_type = type(error)

        if error_type in (ValueError, TypeError):
            return McpError(
                ErrorData(code=-32602, message=f"Invalid params: {error!s}")
            )
        elif error_type in (FileNotFoundError, KeyError, NotFoundError):
            return McpError(
                ErrorData(code=-32001, message=f"Resource not found: {error!s}")
            )
        elif error_type is PermissionError:
            return McpError(
                ErrorData(code=-32000, message=f"Permission denied: {error!s}")
            )
        # asyncio.TimeoutError is a subclass of TimeoutError in Python 3.10, alias in 3.11+
        elif error_type in (TimeoutError, asyncio.TimeoutError):
            return McpError(
                ErrorData(code=-32000, message=f"Request timeout: {error!s}")
            )
        else:
            return McpError(
                ErrorData(code=-32603, message=f"Internal error: {error!s}")
            )

    async def on_message(self, context: MiddlewareContext, call_next: CallNext) -> Any:
        """Handle errors for all messages."""
        try:
            return await call_next(context)
        except Exception as error:
            self._log_error(error, context)

            # Transform and re-raise
            transformed_error = self._transform_error(error)
            raise transformed_error from error

    def get_error_stats(self) -> dict[str, int]:
        """Get error statistics for monitoring."""
        return self.error_counts.copy()


class RetryMiddleware(Middleware):
    """Middleware that implements automatic retry logic for failed requests.

    Retries requests that fail with transient errors, using exponential
    backoff to avoid overwhelming the server or external dependencies.

    Example:
        ```python
        from fastmcp.server.middleware.error_handling import RetryMiddleware

        # Retry up to 3 times with exponential backoff
        retry_middleware = RetryMiddleware(
            max_retries=3,
            retry_exceptions=(ConnectionError, TimeoutError)
        )

        mcp = FastMCP("MyServer")
        mcp.add_middleware(retry_middleware)
        ```
    """

    def __init__(
        self,
        max_retries: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        backoff_multiplier: float = 2.0,
        retry_exceptions: tuple[type[Exception], ...] = (ConnectionError, TimeoutError),
        logger: logging.Logger | None = None,
    ):
        """Initialize retry middleware.

        Args:
            max_retries: Maximum number of retry attempts
            base_delay: Initial delay between retries in seconds
            max_delay: Maximum delay between retries in seconds
            backoff_multiplier: Multiplier for exponential backoff
            retry_exceptions: Tuple of exception types that should trigger retries
            logger: Logger for retry attempts
        """
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.backoff_multiplier = backoff_multiplier
        self.retry_exceptions = retry_exceptions
        self.logger = logger or logging.getLogger("fastmcp.retry")

    def _should_retry(self, error: Exception) -> bool:
        """Determine if an error should trigger a retry."""
        return isinstance(error, self.retry_exceptions)

    def _calculate_delay(self, attempt: int) -> float:
        """Calculate delay for the given attempt number."""
        delay = self.base_delay * (self.backoff_multiplier**attempt)
        return min(delay, self.max_delay)

    async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
        """Implement retry logic for requests."""
        last_error = None

        for attempt in range(self.max_retries + 1):
            try:
                return await call_next(context)
            except Exception as error:
                last_error = error

                # Don't retry on the last attempt or if it's not a retryable error
                if attempt == self.max_retries or not self._should_retry(error):
                    break

                delay = self._calculate_delay(attempt)
                self.logger.warning(
                    f"Request {context.method} failed (attempt {attempt + 1}/{self.max_retries + 1}): "
                    f"{type(error).__name__}: {error!s}. Retrying in {delay:.1f}s..."
                )

                await anyio.sleep(delay)

        # Re-raise the last error if all retries failed
        if last_error:
            raise last_error
