"""Dataset management for pydantic evals.

This module provides functionality for creating, loading, saving, and evaluating datasets of test cases.
Each case must have inputs, and can optionally have a name, expected output, metadata, and case-specific evaluators.

Datasets can be loaded from and saved to YAML or JSON files, and can be evaluated against
a task function to produce an evaluation report.
"""

from __future__ import annotations as _annotations

import functools
import inspect
import sys
import time
import traceback
import warnings
from collections.abc import Awaitable, Callable, Mapping, Sequence
from contextlib import AsyncExitStack, nullcontext
from contextvars import ContextVar
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast

import anyio
import logfire_api
import yaml
from anyio import to_thread
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_serializer
from pydantic._internal import _typing_extra
from pydantic_core import to_json
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler
from rich.progress import Progress
from typing_extensions import NotRequired, Self, TypedDict, TypeVar

from pydantic_evals._utils import get_event_loop

from ._utils import get_unwrapped_function_name, logfire_span, task_group_gather
from .evaluators import EvaluationResult, Evaluator
from .evaluators._run_evaluator import run_evaluator
from .evaluators.common import DEFAULT_EVALUATORS
from .evaluators.context import EvaluatorContext
from .evaluators.evaluator import EvaluatorFailure
from .evaluators.spec import EvaluatorSpec
from .otel import SpanTree
from .otel._context_subtree import context_subtree
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure

if TYPE_CHECKING:
    from pydantic_ai.retries import RetryConfig

if sys.version_info < (3, 11):
    from exceptiongroup import ExceptionGroup  # pragma: lax no cover
else:
    ExceptionGroup = ExceptionGroup  # pragma: lax no cover

__all__ = (
    'Case',
    'Dataset',
    'increment_eval_metric',
    'set_eval_attribute',
)

_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')

InputsT = TypeVar('InputsT', default=Any)
"""Generic type for the inputs to the task being evaluated."""
OutputT = TypeVar('OutputT', default=Any)
"""Generic type for the expected output of the task being evaluated."""
MetadataT = TypeVar('MetadataT', default=Any)
"""Generic type for the metadata associated with the task being evaluated."""

DEFAULT_DATASET_PATH = './test_cases.yaml'
"""Default path for saving/loading datasets."""
DEFAULT_SCHEMA_PATH_TEMPLATE = './{stem}_schema.json'
"""Default template for schema file paths, where {stem} is replaced with the dataset filename stem."""
_YAML_SCHEMA_LINE_PREFIX = '# yaml-language-server: $schema='


_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)


class _CaseModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid'):
    """Internal model for a case, used for serialization/deserialization."""

    name: str | None = None
    inputs: InputsT
    metadata: MetadataT | None = None
    expected_output: OutputT | None = None
    evaluators: list[EvaluatorSpec] = Field(default_factory=list)


class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid'):
    """Internal model for a dataset, used for serialization/deserialization."""

    # $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context
    json_schema_path: str | None = Field(default=None, alias='$schema')
    name: str | None = None
    cases: list[_CaseModel[InputsT, OutputT, MetadataT]]
    evaluators: list[EvaluatorSpec] = Field(default_factory=list)


@dataclass(init=False)
class Case(Generic[InputsT, OutputT, MetadataT]):
    """A single row of a [`Dataset`][pydantic_evals.Dataset].

    Each case represents a single test scenario with inputs to test. A case may optionally specify a name, expected
    outputs to compare against, and arbitrary metadata.

    Cases can also have their own specific evaluators which are run in addition to dataset-level evaluators.

    Example:
    ```python
    from pydantic_evals import Case

    case = Case(
        name='Simple addition',
        inputs={'a': 1, 'b': 2},
        expected_output=3,
        metadata={'description': 'Tests basic addition'},
    )
    ```
    """

    name: str | None
    """Name of the case. This is used to identify the case in the report and can be used to filter cases."""
    inputs: InputsT
    """Inputs to the task. This is the input to the task that will be evaluated."""
    metadata: MetadataT | None = None
    """Metadata to be used in the evaluation.

    This can be used to provide additional information about the case to the evaluators.
    """
    expected_output: OutputT | None = None
    """Expected output of the task. This is the expected output of the task that will be evaluated."""
    evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = field(default_factory=list)
    """Evaluators to be used just on this case."""

    def __init__(
        self,
        *,
        name: str | None = None,
        inputs: InputsT,
        metadata: MetadataT | None = None,
        expected_output: OutputT | None = None,
        evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (),
    ):
        """Initialize a new test case.

        Args:
            name: Optional name for the case. If not provided, a generic name will be assigned when added to a dataset.
            inputs: The inputs to the task being evaluated.
            metadata: Optional metadata for the case, which can be used by evaluators.
            expected_output: Optional expected output of the task, used for comparison in evaluators.
            evaluators: Tuple of evaluators specific to this case. These are in addition to any
                dataset-level evaluators.

        """
        # Note: `evaluators` must be a tuple instead of Sequence due to misbehavior with pyright's generic parameter
        # inference if it has type `Sequence`
        self.name = name
        self.inputs = inputs
        self.metadata = metadata
        self.expected_output = expected_output
        self.evaluators = list(evaluators)


class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
    """A dataset of test [cases][pydantic_evals.Case].

    Datasets allow you to organize a collection of test cases and evaluate them against a task function.
    They can be loaded from and saved to YAML or JSON files, and can have dataset-level evaluators that
    apply to all cases.

    Example:
    ```python
    # Create a dataset with two test cases
    from dataclasses import dataclass

    from pydantic_evals import Case, Dataset
    from pydantic_evals.evaluators import Evaluator, EvaluatorContext


    @dataclass
    class ExactMatch(Evaluator):
        def evaluate(self, ctx: EvaluatorContext) -> bool:
            return ctx.output == ctx.expected_output

    dataset = Dataset(
        cases=[
            Case(name='test1', inputs={'text': 'Hello'}, expected_output='HELLO'),
            Case(name='test2', inputs={'text': 'World'}, expected_output='WORLD'),
        ],
        evaluators=[ExactMatch()],
    )

    # Evaluate the dataset against a task function
    async def uppercase(inputs: dict) -> str:
        return inputs['text'].upper()

    async def main():
        report = await dataset.evaluate(uppercase)
        report.print()
    '''
       Evaluation Summary: uppercase
    ┏━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┓
    ┃ Case ID  ┃ Assertions ┃ Duration ┃
    ┡━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━┩
    │ test1    │ ✔          │     10ms │
    ├──────────┼────────────┼──────────┤
    │ test2    │ ✔          │     10ms │
    ├──────────┼────────────┼──────────┤
    │ Averages │ 100.0% ✔   │     10ms │
    └──────────┴────────────┴──────────┘
    '''
    ```
    """

    name: str | None = None
    """Optional name of the dataset."""
    cases: list[Case[InputsT, OutputT, MetadataT]]
    """List of test cases in the dataset."""
    evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = []
    """List of evaluators to be used on all cases in the dataset."""

    def __init__(
        self,
        *,
        name: str | None = None,
        cases: Sequence[Case[InputsT, OutputT, MetadataT]],
        evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (),
    ):
        """Initialize a new dataset with test cases and optional evaluators.

        Args:
            name: Optional name for the dataset.
            cases: Sequence of test cases to include in the dataset.
            evaluators: Optional sequence of evaluators to apply to all cases in the dataset.
        """
        case_names = set[str]()
        for case in cases:
            if case.name is None:
                continue
            if case.name in case_names:
                raise ValueError(f'Duplicate case name: {case.name!r}')
            case_names.add(case.name)

        super().__init__(
            name=name,
            cases=cases,
            evaluators=list(evaluators),
        )

    # TODO in v2: Make everything not required keyword-only
    async def evaluate(
        self,
        task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
        name: str | None = None,
        max_concurrency: int | None = None,
        progress: bool = True,
        retry_task: RetryConfig | None = None,
        retry_evaluators: RetryConfig | None = None,
        *,
        task_name: str | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> EvaluationReport[InputsT, OutputT, MetadataT]:
        """Evaluates the test cases in the dataset using the given task.

        This method runs the task on each case in the dataset, applies evaluators,
        and collects results into a report. Cases are run concurrently, limited by `max_concurrency` if specified.

        Args:
            task: The task to evaluate. This should be a callable that takes the inputs of the case
                and returns the output.
            name: The name of the experiment being run, this is used to identify the experiment in the report.
                If omitted, the task_name will be used; if that is not specified, the name of the task function is used.
            max_concurrency: The maximum number of concurrent evaluations of the task to allow.
                If None, all cases will be evaluated concurrently.
            progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
            retry_task: Optional retry configuration for the task execution.
            retry_evaluators: Optional retry configuration for evaluator execution.
            task_name: Optional override to the name of the task being executed, otherwise the name of the task
                function will be used.
            metadata: Optional dict of experiment metadata.

        Returns:
            A report containing the results of the evaluation.
        """
        task_name = task_name or get_unwrapped_function_name(task)
        name = name or task_name
        total_cases = len(self.cases)
        progress_bar = Progress() if progress else None

        limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()

        extra_attributes: dict[str, Any] = {'gen_ai.operation.name': 'experiment'}
        if metadata is not None:
            extra_attributes['metadata'] = metadata
        with (
            logfire_span(
                'evaluate {name}',
                name=name,
                task_name=task_name,
                dataset_name=self.name,
                n_cases=len(self.cases),
                **extra_attributes,
            ) as eval_span,
            progress_bar or nullcontext(),
        ):
            task_id = progress_bar.add_task(f'Evaluating {task_name}', total=total_cases) if progress_bar else None

            async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
                async with limiter:
                    result = await _run_task_and_evaluators(
                        task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
                    )
                    if progress_bar and task_id is not None:  # pragma: no branch
                        progress_bar.update(task_id, advance=1)
                    return result

            if (context := eval_span.context) is None:  # pragma: no cover
                trace_id = None
                span_id = None
            else:
                trace_id = f'{context.trace_id:032x}'
                span_id = f'{context.span_id:016x}'
            cases_and_failures = await task_group_gather(
                [
                    lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
                    for i, case in enumerate(self.cases, 1)
                ]
            )
            cases: list[ReportCase] = []
            failures: list[ReportCaseFailure] = []
            for item in cases_and_failures:
                if isinstance(item, ReportCase):
                    cases.append(item)
                else:
                    failures.append(item)
            report = EvaluationReport(
                name=name,
                cases=cases,
                failures=failures,
                experiment_metadata=metadata,
                span_id=span_id,
                trace_id=trace_id,
            )
            full_experiment_metadata: dict[str, Any] = {'n_cases': len(self.cases)}
            if metadata is not None:
                full_experiment_metadata['metadata'] = metadata
            if (averages := report.averages()) is not None:
                full_experiment_metadata['averages'] = averages
                if averages.assertions is not None:
                    eval_span.set_attribute('assertion_pass_rate', averages.assertions)
            eval_span.set_attribute('logfire.experiment.metadata', full_experiment_metadata)
        return report

    def evaluate_sync(
        self,
        task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
        name: str | None = None,
        max_concurrency: int | None = None,
        progress: bool = True,
        retry_task: RetryConfig | None = None,
        retry_evaluators: RetryConfig | None = None,
        *,
        task_name: str | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> EvaluationReport[InputsT, OutputT, MetadataT]:
        """Evaluates the test cases in the dataset using the given task.

        This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience.

        Args:
            task: The task to evaluate. This should be a callable that takes the inputs of the case
                and returns the output.
            name: The name of the experiment being run, this is used to identify the experiment in the report.
                If omitted, the task_name will be used; if that is not specified, the name of the task function is used.
            max_concurrency: The maximum number of concurrent evaluations of the task to allow.
                If None, all cases will be evaluated concurrently.
            progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
            retry_task: Optional retry configuration for the task execution.
            retry_evaluators: Optional retry configuration for evaluator execution.
            task_name: Optional override to the name of the task being executed, otherwise the name of the task
                function will be used.
            metadata: Optional dict of experiment metadata.

        Returns:
            A report containing the results of the evaluation.
        """
        return get_event_loop().run_until_complete(
            self.evaluate(
                task,
                name=name,
                max_concurrency=max_concurrency,
                progress=progress,
                retry_task=retry_task,
                retry_evaluators=retry_evaluators,
                task_name=task_name,
                metadata=metadata,
            )
        )

    def add_case(
        self,
        *,
        name: str | None = None,
        inputs: InputsT,
        metadata: MetadataT | None = None,
        expected_output: OutputT | None = None,
        evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (),
    ) -> None:
        """Adds a case to the dataset.

        This is a convenience method for creating a [`Case`][pydantic_evals.Case] and adding it to the dataset.

        Args:
            name: Optional name for the case. If not provided, a generic name will be assigned.
            inputs: The inputs to the task being evaluated.
            metadata: Optional metadata for the case, which can be used by evaluators.
            expected_output: The expected output of the task, used for comparison in evaluators.
            evaluators: Tuple of evaluators specific to this case, in addition to dataset-level evaluators.
        """
        if name in {case.name for case in self.cases}:
            raise ValueError(f'Duplicate case name: {name!r}')

        case = Case[InputsT, OutputT, MetadataT](
            name=name,
            inputs=inputs,
            metadata=metadata,
            expected_output=expected_output,
            evaluators=evaluators,
        )
        self.cases.append(case)

    def add_evaluator(
        self,
        evaluator: Evaluator[InputsT, OutputT, MetadataT],
        specific_case: str | None = None,
    ) -> None:
        """Adds an evaluator to the dataset or a specific case.

        Args:
            evaluator: The evaluator to add.
            specific_case: If provided, the evaluator will only be added to the case with this name.
                If None, the evaluator will be added to all cases in the dataset.

        Raises:
            ValueError: If `specific_case` is provided but no case with that name exists in the dataset.
        """
        if specific_case is None:
            self.evaluators.append(evaluator)
        else:
            # If this is too slow, we could try to add a case lookup dict.
            # Note that if we do that, we'd need to make the cases list private to prevent modification.
            added = False
            for case in self.cases:
                if case.name == specific_case:
                    case.evaluators.append(evaluator)
                    added = True
            if not added:
                raise ValueError(f'Case {specific_case!r} not found in the dataset')

    @classmethod
    @functools.cache
    def _params(cls) -> tuple[type[InputsT], type[OutputT], type[MetadataT]]:
        """Get the type parameters for the Dataset class.

        Returns:
            A tuple of (InputsT, OutputT, MetadataT) types.
        """
        for c in cls.__mro__:
            metadata = getattr(c, '__pydantic_generic_metadata__', {})
            if len(args := (metadata.get('args', ()) or getattr(c, '__args__', ()))) == 3:  # pragma: no branch
                return args
        else:  # pragma: no cover
            warnings.warn(
                f'Could not determine the generic parameters for {cls}; using `Any` for each.'
                f' You should explicitly set the generic parameters via `Dataset[MyInputs, MyOutput, MyMetadata]`'
                f' when serializing or deserializing.',
                UserWarning,
            )
            return Any, Any, Any  # type: ignore

    @classmethod
    def from_file(
        cls,
        path: Path | str,
        fmt: Literal['yaml', 'json'] | None = None,
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
    ) -> Self:
        """Load a dataset from a file.

        Args:
            path: Path to the file to load.
            fmt: Format of the file. If None, the format will be inferred from the file extension.
                Must be either 'yaml' or 'json'.
            custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
                These are additional evaluators beyond the default ones.

        Returns:
            A new Dataset instance loaded from the file.

        Raises:
            ValidationError: If the file cannot be parsed as a valid dataset.
            ValueError: If the format cannot be inferred from the file extension.
        """
        path = Path(path)
        fmt = cls._infer_fmt(path, fmt)

        raw = Path(path).read_text()
        try:
            return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types, default_name=path.stem)
        except ValidationError as e:  # pragma: no cover
            raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e

    @classmethod
    def from_text(
        cls,
        contents: str,
        fmt: Literal['yaml', 'json'] = 'yaml',
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
        *,
        default_name: str | None = None,
    ) -> Self:
        """Load a dataset from a string.

        Args:
            contents: The string content to parse.
            fmt: Format of the content. Must be either 'yaml' or 'json'.
            custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
                These are additional evaluators beyond the default ones.
            default_name: Default name of the dataset, to be used if not specified in the serialized contents.

        Returns:
            A new Dataset instance parsed from the string.

        Raises:
            ValidationError: If the content cannot be parsed as a valid dataset.
        """
        if fmt == 'yaml':
            loaded = yaml.safe_load(contents)
            return cls.from_dict(loaded, custom_evaluator_types, default_name=default_name)
        else:
            dataset_model_type = cls._serialization_type()
            dataset_model = dataset_model_type.model_validate_json(contents)
            return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)

    @classmethod
    def from_dict(
        cls,
        data: dict[str, Any],
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
        *,
        default_name: str | None = None,
    ) -> Self:
        """Load a dataset from a dictionary.

        Args:
            data: Dictionary representation of the dataset.
            custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
                These are additional evaluators beyond the default ones.
            default_name: Default name of the dataset, to be used if not specified in the data.

        Returns:
            A new Dataset instance created from the dictionary.

        Raises:
            ValidationError: If the dictionary cannot be converted to a valid dataset.
        """
        dataset_model_type = cls._serialization_type()
        dataset_model = dataset_model_type.model_validate(data)
        return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)

    @classmethod
    def _from_dataset_model(
        cls,
        dataset_model: _DatasetModel[InputsT, OutputT, MetadataT],
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
        default_name: str | None = None,
    ) -> Self:
        """Create a Dataset from a _DatasetModel.

        Args:
            dataset_model: The _DatasetModel to convert.
            custom_evaluator_types: Custom evaluator classes to register for deserialization.
            default_name: Default name of the dataset, to be used if the value is `None` in the provided model.

        Returns:
            A new Dataset instance created from the _DatasetModel.
        """
        registry = _get_registry(custom_evaluator_types)

        cases: list[Case[InputsT, OutputT, MetadataT]] = []
        errors: list[ValueError] = []
        dataset_evaluators: list[Evaluator] = []
        for spec in dataset_model.evaluators:
            try:
                dataset_evaluator = _load_evaluator_from_registry(registry, None, spec)
            except ValueError as e:
                errors.append(e)
                continue
            dataset_evaluators.append(dataset_evaluator)

        for row in dataset_model.cases:
            evaluators: list[Evaluator] = []
            for spec in row.evaluators:
                try:
                    evaluator = _load_evaluator_from_registry(registry, row.name, spec)
                except ValueError as e:
                    errors.append(e)
                    continue
                evaluators.append(evaluator)
            row = Case[InputsT, OutputT, MetadataT](
                name=row.name,
                inputs=row.inputs,
                metadata=row.metadata,
                expected_output=row.expected_output,
            )
            row.evaluators = evaluators
            cases.append(row)
        if errors:
            raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3])
        result = cls(name=dataset_model.name, cases=cases)
        if result.name is None:
            result.name = default_name
        result.evaluators = dataset_evaluators
        return result

    def to_file(
        self,
        path: Path | str,
        fmt: Literal['yaml', 'json'] | None = None,
        schema_path: Path | str | None = DEFAULT_SCHEMA_PATH_TEMPLATE,
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
    ):
        """Save the dataset to a file.

        Args:
            path: Path to save the dataset to.
            fmt: Format to use. If None, the format will be inferred from the file extension.
                Must be either 'yaml' or 'json'.
            schema_path: Path to save the JSON schema to. If None, no schema will be saved.
                Can be a string template with {stem} which will be replaced with the dataset filename stem.
            custom_evaluator_types: Custom evaluator classes to include in the schema.
        """
        path = Path(path)
        fmt = self._infer_fmt(path, fmt)

        schema_ref: str | None = None
        if schema_path is not None:  # pragma: no branch
            if isinstance(schema_path, str):  # pragma: no branch
                schema_path = Path(schema_path.format(stem=path.stem))

            if not schema_path.is_absolute():
                schema_ref = str(schema_path)
                schema_path = path.parent / schema_path
            elif schema_path.is_relative_to(path):  # pragma: no cover
                schema_ref = str(_get_relative_path_reference(schema_path, path))
            else:  # pragma: no cover
                schema_ref = str(schema_path)
            self._save_schema(schema_path, custom_evaluator_types)

        context: dict[str, Any] = {'use_short_form': True}
        if fmt == 'yaml':
            dumped_data = self.model_dump(mode='json', by_alias=True, context=context)
            content = yaml.dump(dumped_data, sort_keys=False)
            if schema_ref:  # pragma: no branch
                yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}'
                content = f'{yaml_language_server_line}\n{content}'
            path.write_text(content)
        else:
            context['$schema'] = schema_ref
            json_data = self.model_dump_json(indent=2, by_alias=True, context=context)
            path.write_text(json_data + '\n')

    @classmethod
    def model_json_schema_with_evaluators(
        cls,
        custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
    ) -> dict[str, Any]:
        """Generate a JSON schema for this dataset type, including evaluator details.

        This is useful for generating a schema that can be used to validate YAML-format dataset files.

        Args:
            custom_evaluator_types: Custom evaluator classes to include in the schema.

        Returns:
            A dictionary representing the JSON schema.
        """
        # Note: this function could maybe be simplified now that Evaluators are always dataclasses
        registry = _get_registry(custom_evaluator_types)

        evaluator_schema_types: list[Any] = []
        for name, evaluator_class in registry.items():
            type_hints = _typing_extra.get_function_type_hints(evaluator_class)
            type_hints.pop('return', None)
            required_type_hints: dict[str, Any] = {}

            for p in inspect.signature(evaluator_class).parameters.values():
                type_hints.setdefault(p.name, Any)
                if p.default is not p.empty:
                    type_hints[p.name] = NotRequired[type_hints[p.name]]
                else:
                    required_type_hints[p.name] = type_hints[p.name]

            def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any:
                td = TypedDict(f'{cls_name_prefix}_{name}', fields)  # pyright: ignore[reportArgumentType]
                config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
                # TODO: Replace with pydantic.with_config once pydantic 2.11 is the min supported version
                td.__pydantic_config__ = config  # pyright: ignore[reportAttributeAccessIssue]
                return td

            # Shortest form: just the call name
            if len(type_hints) == 0 or not required_type_hints:
                evaluator_schema_types.append(Literal[name])

            # Short form: can be called with only one parameter
            if len(type_hints) == 1:
                [type_hint_type] = type_hints.values()
                evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type}))
            elif len(required_type_hints) == 1:  # pragma: no branch
                [type_hint_type] = required_type_hints.values()
                evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type}))

            # Long form: multiple parameters, possibly required
            if len(type_hints) > 1:
                params_td = _make_typed_dict('evaluator_params', type_hints)
                evaluator_schema_types.append(_make_typed_dict('evaluator', {name: params_td}))

        in_type, out_type, meta_type = cls._params()

        # Note: we shadow the `Case` and `Dataset` class names here to generate a clean JSON schema
        class Case(BaseModel, extra='forbid'):  # pyright: ignore[reportUnusedClass]  # this _is_ used below, but pyright doesn't seem to notice..
            name: str | None = None
            inputs: in_type  # pyright: ignore[reportInvalidTypeForm]
            metadata: meta_type | None = None  # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType]
            expected_output: out_type | None = None  # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType]
            if evaluator_schema_types:  # pragma: no branch
                evaluators: list[Union[tuple(evaluator_schema_types)]] = []  # pyright: ignore  # noqa UP007

        class Dataset(BaseModel, extra='forbid'):
            name: str | None = None
            cases: list[Case]
            if evaluator_schema_types:  # pragma: no branch
                evaluators: list[Union[tuple(evaluator_schema_types)]] = []  # pyright: ignore  # noqa UP007

        json_schema = Dataset.model_json_schema()
        # See `_add_json_schema` below, since `$schema` is added to the JSON, it has to be supported in the JSON
        json_schema['properties']['$schema'] = {'type': 'string'}
        return json_schema

    @classmethod
    def _save_schema(
        cls, path: Path | str, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = ()
    ):
        """Save the JSON schema for this dataset type to a file.

        Args:
            path: Path to save the schema to.
            custom_evaluator_types: Custom evaluator classes to include in the schema.
        """
        path = Path(path)
        json_schema = cls.model_json_schema_with_evaluators(custom_evaluator_types)
        schema_content = to_json(json_schema, indent=2).decode() + '\n'
        if not path.exists() or path.read_text() != schema_content:  # pragma: no branch
            path.write_text(schema_content)

    @classmethod
    @functools.cache
    def _serialization_type(cls) -> type[_DatasetModel[InputsT, OutputT, MetadataT]]:
        """Get the serialization type for this dataset class.

        Returns:
            A _DatasetModel type with the same generic parameters as this Dataset class.
        """
        input_type, output_type, metadata_type = cls._params()
        return _DatasetModel[input_type, output_type, metadata_type]

    @classmethod
    def _infer_fmt(cls, path: Path, fmt: Literal['yaml', 'json'] | None) -> Literal['yaml', 'json']:
        """Infer the format to use for a file based on its extension.

        Args:
            path: The path to infer the format for.
            fmt: The explicitly provided format, if any.

        Returns:
            The inferred format ('yaml' or 'json').

        Raises:
            ValueError: If the format cannot be inferred from the file extension.
        """
        if fmt is not None:
            return fmt
        suffix = path.suffix.lower()
        if suffix in {'.yaml', '.yml'}:
            return 'yaml'
        elif suffix == '.json':
            return 'json'
        raise ValueError(
            f'Could not infer format for filename {path.name!r}. Use the `fmt` argument to specify the format.'
        )

    @model_serializer(mode='wrap')
    def _add_json_schema(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict[str, Any]:
        """Add the JSON schema path to the serialized output.

        See <https://github.com/json-schema-org/json-schema-spec/issues/828> for context, that seems to be the nearest
        there is to a spec for this.
        """
        context = cast(dict[str, Any] | None, info.context)
        if isinstance(context, dict) and (schema := context.get('$schema')):
            return {'$schema': schema} | nxt(self)
        else:
            return nxt(self)


def _get_relative_path_reference(target: Path, source: Path, _prefix: str = '') -> Path:  # pragma: no cover
    """Get a relative path reference from source to target.

    Recursively resolve a relative path to target from source, adding '..' as needed.
    This is useful for creating a relative path reference from a source file to a target file.

    Args:
        target: The target path to reference.
        source: The source path to reference from.
        _prefix: Internal prefix used during recursion.

    Returns:
        A Path object representing the relative path from source to target.

    Example:
        If source is '/a/b/c.py' and target is '/a/d/e.py', the relative path reference
        would be '../../d/e.py'.
    """
    # Recursively resolve a relative path to target from source, adding '..' as needed.
    # This is useful for creating a relative path reference from a source file to a target file.
    # For example, if source is '/a/b/c.py' and target is '/a/d/e.py', the relative path reference
    # would be '../../d/e.py'.
    if not target.is_absolute():
        target = target.resolve()
    try:
        return Path(f'{_prefix}{Path(target).relative_to(source)}')
    except ValueError:
        return _get_relative_path_reference(target, source.parent, _prefix=f'{_prefix}../')


@dataclass
class _TaskRun:
    """Internal class to track metrics and attributes for a task run."""

    attributes: dict[str, Any] = field(init=False, default_factory=dict)
    metrics: dict[str, int | float] = field(init=False, default_factory=dict)

    def record_metric(self, name: str, value: int | float) -> None:
        """Record a metric value.

        Args:
            name: The name of the metric.
            value: The value of the metric.
        """
        self.metrics[name] = value

    def increment_metric(self, name: str, amount: int | float) -> None:
        """Increment a metric value.

        Args:
            name: The name of the metric.
            amount: The amount to increment by.

        Note:
            If the current value is 0 and the increment amount is 0, no metric will be recorded.
        """
        current_value = self.metrics.get(name, 0)
        incremented_value = current_value + amount
        if current_value == 0 and incremented_value == 0:
            return  # Avoid recording a metric that is always zero
        self.record_metric(name, incremented_value)

    def record_attribute(self, name: str, value: Any) -> None:
        """Record an attribute value.

        Args:
            name: The name of the attribute.
            value: The value of the attribute.
        """
        self.attributes[name] = value


async def _run_task(
    task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
    case: Case[InputsT, OutputT, MetadataT],
    retry: RetryConfig | None = None,
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
    """Run a task on a case and return the context for evaluators.

    Args:
        task: The task to run.
        case: The case to run the task on.
        retry: The retry config to use.

    Returns:
        An EvaluatorContext containing the inputs, actual output, expected output, and metadata.

    Raises:
        Exception: Any exception raised by the task.
    """

    async def _run_once():
        task_run_ = _TaskRun()
        if _CURRENT_TASK_RUN.get() is not None:  # pragma: no cover
            raise RuntimeError('A task run has already been entered. Task runs should not be nested')

        token = _CURRENT_TASK_RUN.set(task_run_)
        try:
            with (
                logfire_span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
                context_subtree() as span_tree_,
            ):
                t0 = time.perf_counter()
                if iscoroutinefunction(task):
                    task_output_ = cast(OutputT, await task(case.inputs))
                else:
                    task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
                fallback_duration = time.perf_counter() - t0
            duration_ = _get_span_duration(task_span, fallback_duration)
            return task_run_, task_output_, duration_, span_tree_
        finally:
            _CURRENT_TASK_RUN.reset(token)

    if retry:
        # import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
        from pydantic_ai.retries import retry as tenacity_retry

        _run_once = tenacity_retry(**retry)(_run_once)

    task_run, task_output, duration, span_tree = await _run_once()

    if isinstance(span_tree, SpanTree):  # pragma: no branch
        # Idea for making this more configurable: replace the following logic with a call to a user-provided function
        #   of type Callable[[_TaskRun, SpanTree], None] or similar, (maybe no _TaskRun and just use the public APIs).
        #   That way users can customize this logic. We'd default to a function that does the current thing but also
        #   allow `None` to disable it entirely.
        for node in span_tree:
            for k, v in node.attributes.items():
                if k == 'gen_ai.operation.name' and v == 'chat':
                    task_run.increment_metric('requests', 1)
                elif not isinstance(v, int | float):
                    continue
                elif k == 'operation.cost':
                    task_run.increment_metric('cost', v)
                elif k.startswith('gen_ai.usage.details.'):
                    task_run.increment_metric(k.removeprefix('gen_ai.usage.details.'), v)
                elif k.startswith('gen_ai.usage.'):
                    task_run.increment_metric(k.removeprefix('gen_ai.usage.'), v)

    return EvaluatorContext[InputsT, OutputT, MetadataT](
        name=case.name,
        inputs=case.inputs,
        metadata=case.metadata,
        expected_output=case.expected_output,
        output=task_output,
        duration=duration,
        _span_tree=span_tree,
        attributes=task_run.attributes,
        metrics=task_run.metrics,
    )


async def _run_task_and_evaluators(
    task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
    case: Case[InputsT, OutputT, MetadataT],
    report_case_name: str,
    dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
    retry_task: RetryConfig | None,
    retry_evaluators: RetryConfig | None,
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
    """Run a task on a case and evaluate the results.

    Args:
        task: The task to run.
        case: The case to run the task on.
        report_case_name: The name to use for this case in the report.
        dataset_evaluators: Evaluators from the dataset to apply to this case.
        retry_task: The retry config to use for running the task.
        retry_evaluators: The retry config to use for running the evaluators.

    Returns:
        A ReportCase containing the evaluation results.
    """
    trace_id: str | None = None
    span_id: str | None = None
    try:
        with logfire_span(
            'case: {case_name}',
            task_name=get_unwrapped_function_name(task),
            case_name=report_case_name,
            inputs=case.inputs,
            metadata=case.metadata,
            expected_output=case.expected_output,
        ) as case_span:
            context = case_span.context
            if context is not None:  # pragma: no branch
                trace_id = f'{context.trace_id:032x}'
                span_id = f'{context.span_id:016x}'

            t0 = time.time()
            scoring_context = await _run_task(task, case, retry_task)

            case_span.set_attribute('output', scoring_context.output)
            case_span.set_attribute('task_duration', scoring_context.duration)
            case_span.set_attribute('metrics', scoring_context.metrics)
            case_span.set_attribute('attributes', scoring_context.attributes)

            evaluators = case.evaluators + dataset_evaluators
            evaluator_outputs: list[EvaluationResult] = []
            evaluator_failures: list[EvaluatorFailure] = []
            if evaluators:
                evaluator_outputs_by_task = await task_group_gather(
                    [lambda ev=ev: run_evaluator(ev, scoring_context, retry_evaluators) for ev in evaluators]
                )
                for outputs in evaluator_outputs_by_task:
                    if isinstance(outputs, EvaluatorFailure):
                        evaluator_failures.append(outputs)
                    else:
                        evaluator_outputs.extend(outputs)

            assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
            case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
            case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
            case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
        fallback_duration = time.time() - t0

        return ReportCase[InputsT, OutputT, MetadataT](
            name=report_case_name,
            inputs=case.inputs,
            metadata=case.metadata,
            expected_output=case.expected_output,
            output=scoring_context.output,
            metrics=scoring_context.metrics,
            attributes=scoring_context.attributes,
            scores=scores,
            labels=labels,
            assertions=assertions,
            task_duration=scoring_context.duration,
            total_duration=_get_span_duration(case_span, fallback_duration),
            trace_id=trace_id,
            span_id=span_id,
            evaluator_failures=evaluator_failures,
        )
    except Exception as exc:
        return ReportCaseFailure[InputsT, OutputT, MetadataT](
            name=report_case_name,
            inputs=case.inputs,
            metadata=case.metadata,
            expected_output=case.expected_output,
            error_message=f'{type(exc).__name__}: {exc}',
            error_stacktrace=traceback.format_exc(),
            trace_id=trace_id,
            span_id=span_id,
        )


_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])


def _group_evaluator_outputs_by_type(
    evaluation_results: Sequence[EvaluationResult],
) -> tuple[
    dict[str, EvaluationResult[bool]],
    dict[str, EvaluationResult[int | float]],
    dict[str, EvaluationResult[str]],
]:
    """Group evaluator outputs by their result type.

    Args:
        evaluation_results: Sequence of evaluation results to group.

    Returns:
        A tuple of dictionaries mapping evaluator names to their results, grouped by result type:
        (success_evaluations, metric_evaluations, string_evaluations)
    """
    assertions: dict[str, EvaluationResult[bool]] = {}
    scores: dict[str, EvaluationResult[int | float]] = {}
    labels: dict[str, EvaluationResult[str]] = {}
    seen_names = set[str]()
    for er in evaluation_results:
        name = er.name
        # Dedupe repeated names by adding a numeric suffix
        if name in seen_names:
            suffix = 2
            while f'{name}_{suffix}' in seen_names:
                suffix += 1
            name = f'{name}_{suffix}'
        seen_names.add(name)
        if assertion := er.downcast(bool):
            assertions[name] = assertion
        elif score := er.downcast(int, float):
            scores[name] = score
        elif label := er.downcast(str):  # pragma: no branch
            labels[name] = label
    return assertions, scores, labels


_CURRENT_TASK_RUN = ContextVar['_TaskRun | None']('_CURRENT_TASK_RUN', default=None)


def set_eval_attribute(name: str, value: Any) -> None:
    """Set an attribute on the current task run.

    Args:
        name: The name of the attribute.
        value: The value of the attribute.
    """
    current_case = _CURRENT_TASK_RUN.get()
    if current_case is not None:  # pragma: no branch
        current_case.record_attribute(name, value)


def increment_eval_metric(name: str, amount: int | float) -> None:
    """Increment a metric on the current task run.

    Args:
        name: The name of the metric.
        amount: The amount to increment by.
    """
    current_case = _CURRENT_TASK_RUN.get()
    if current_case is not None:  # pragma: no branch
        current_case.increment_metric(name, amount)


def _get_span_duration(span: logfire_api.LogfireSpan, fallback: float) -> float:
    """Calculate the duration of a span in seconds.

    We prefer to obtain the duration from a span for the sake of consistency with observability and to make
    the values more reliable during testing. However, if the span is not available (e.g. when using logfire_api
    without logfire installed), we fall back to the provided duration.

    Args:
        span: The span to calculate the duration for.
        fallback: The fallback duration to use if unable to obtain the duration from the span.

    Returns:
        The duration of the span in seconds.
    """
    try:
        return (span.end_time - span.start_time) / 1_000_000_000  # type: ignore
    except (AttributeError, TypeError):  # pragma: lax no cover
        return fallback


def _get_registry(
    custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]],
) -> Mapping[str, type[Evaluator[InputsT, OutputT, MetadataT]]]:
    """Create a registry of evaluator types from default and custom evaluators.

    Args:
        custom_evaluator_types: Additional evaluator classes to include in the registry.

    Returns:
        A mapping from evaluator names to evaluator classes.
    """
    registry: dict[str, type[Evaluator[InputsT, OutputT, MetadataT]]] = {}

    for evaluator_class in custom_evaluator_types:
        if not issubclass(evaluator_class, Evaluator):
            raise ValueError(
                f'All custom evaluator classes must be subclasses of Evaluator, but {evaluator_class} is not'
            )
        if '__dataclass_fields__' not in evaluator_class.__dict__:
            raise ValueError(
                f'All custom evaluator classes must be decorated with `@dataclass`, but {evaluator_class} is not'
            )
        name = evaluator_class.get_serialization_name()
        if name in registry:
            raise ValueError(f'Duplicate evaluator class name: {name!r}')
        registry[name] = evaluator_class

    for evaluator_class in DEFAULT_EVALUATORS:
        # Allow overriding the default evaluators with custom evaluators raising an error
        registry.setdefault(evaluator_class.get_serialization_name(), evaluator_class)

    return registry


def _load_evaluator_from_registry(
    registry: Mapping[str, type[Evaluator[InputsT, OutputT, MetadataT]]],
    case_name: str | None,
    spec: EvaluatorSpec,
) -> Evaluator[InputsT, OutputT, MetadataT]:
    """Load an evaluator from the registry based on a specification.

    Args:
        registry: Mapping from evaluator names to evaluator classes.
        case_name: Name of the case this evaluator will be used for, or None for dataset-level evaluators.
        spec: Specification of the evaluator to load.

    Returns:
        An initialized evaluator instance.

    Raises:
        ValueError: If the evaluator name is not found in the registry.
    """
    evaluator_class = registry.get(spec.name)
    if evaluator_class is None:
        raise ValueError(
            f'Evaluator {spec.name!r} is not in the provided `custom_evaluator_types`. Valid choices: {list(registry.keys())}.'
            f' If you are trying to use a custom evaluator, you must include its type in the `custom_evaluator_types` argument.'
        )
    try:
        return evaluator_class(*spec.args, **spec.kwargs)
    except Exception as e:
        case_detail = f'case {case_name!r}' if case_name is not None else 'dataset'
        raise ValueError(f'Failed to instantiate evaluator {spec.name!r} for {case_detail}: {e}') from e
