from __future__ import annotations

import json
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, cast
from urllib.parse import parse_qs, urlparse

from opentelemetry import context
from opentelemetry.sdk.trace import Event, ReadableSpan, Span
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.trace import SpanKind, Status, StatusCode

import logfire

from ..constants import (
    ATTRIBUTES_JSON_SCHEMA_KEY,
    ATTRIBUTES_LOG_LEVEL_NUM_KEY,
    ATTRIBUTES_MESSAGE_KEY,
    ATTRIBUTES_MESSAGE_TEMPLATE_KEY,
    ATTRIBUTES_TAGS_KEY,
    LEVEL_NUMBERS,
    log_level_attributes,
)
from ..db_statement_summary import message_from_db_statement
from ..json_schema import JsonSchemaProperties, attributes_json_schema
from ..scrubbing import BaseScrubber
from ..utils import (
    ReadableSpanDict,
    handle_internal_errors,
    is_asgi_send_receive_span_name,
    is_instrumentation_suppressed,
    span_to_dict,
    truncate_string,
)
from .wrapper import WrapperSpanProcessor


class CheckSuppressInstrumentationProcessorWrapper(WrapperSpanProcessor):
    """Checks if instrumentation is suppressed, then suppresses instrumentation itself.

    Placed at the root of the tree of processors.
    """

    def on_start(self, span: Span, parent_context: context.Context | None = None) -> None:
        if is_instrumentation_suppressed():
            return
        with logfire.suppress_instrumentation():
            super().on_start(span, parent_context)

    def on_end(self, span: ReadableSpan) -> None:
        if is_instrumentation_suppressed():
            return
        with logfire.suppress_instrumentation():
            super().on_end(span)


@dataclass
class MainSpanProcessorWrapper(WrapperSpanProcessor):
    """Wrapper around other processors to intercept starting and ending spans with our own global logic.

    Suppresses starting/ending if the current context has a `suppress_instrumentation` value.
    Tweaks the send/receive span names generated by the ASGI middleware.
    """

    scrubber: BaseScrubber

    def on_start(
        self,
        span: Span,
        parent_context: context.Context | None = None,
    ) -> None:
        _set_log_level_on_asgi_send_receive_spans(span)
        super().on_start(span, parent_context)

    def on_end(self, span: ReadableSpan) -> None:
        with handle_internal_errors:
            span_dict = span_to_dict(span)
            _tweak_asgi_send_receive_spans(span_dict)
            _tweak_sqlalchemy_connect_spans(span_dict)
            _tweak_http_spans(span_dict)
            _tweak_fastapi_span(span_dict)
            _summarize_db_statement(span_dict)
            _set_error_level_and_status(span_dict)
            _transform_langchain_span(span_dict)
            _transform_google_genai_span(span_dict)
            _transform_litellm_span(span_dict)
            _default_gen_ai_response_model(span_dict)
            self.scrubber.scrub_span(span_dict)
            span = ReadableSpan(**span_dict)
        super().on_end(span)


def _set_error_level_and_status(span: ReadableSpanDict) -> None:
    """Default the log level to error if the status code is error, and vice versa.

    This makes querying for `level` and `otel_status_code` interchangeable ways to find errors.
    """
    status = span['status']
    attributes = span['attributes']
    if status.status_code == StatusCode.ERROR and ATTRIBUTES_LOG_LEVEL_NUM_KEY not in attributes:
        span['attributes'] = {**attributes, **log_level_attributes('error')}
    elif status.is_unset:
        level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY)
        if isinstance(level, int) and level >= LEVEL_NUMBERS['error']:
            span['status'] = Status(status_code=StatusCode.ERROR, description=status.description)


def _set_log_level_on_asgi_send_receive_spans(span: Span) -> None:
    """Set the log level of ASGI send/receive spans to debug.

    If a span doesn't have a level set, it defaults to 'info'. This is too high for ASGI send/receive spans,
    which are generated for every request and are not particularly interesting.
    """
    if _is_asgi_send_receive_span(span.name, span.instrumentation_scope):
        span.set_attributes(log_level_attributes('debug'))


def _tweak_sqlalchemy_connect_spans(span: ReadableSpanDict) -> None:
    # Set the sqlalchemy 'connect' span to debug level so that it's hidden by default.
    # https://pydanticlogfire.slack.com/archives/C06EDRBSAH3/p1720205732316029
    if span['name'] != 'connect':
        return
    scope = span['instrumentation_scope']
    if scope is None or scope.name != 'opentelemetry.instrumentation.sqlalchemy':  # pragma: no cover
        return
    attributes = span['attributes']
    # We never expect db.statement to be in the attributes here.
    # This is just to be extra sure that we're not accidentally hiding an actual query span.
    if 'db.system' not in attributes or 'db.statement' in attributes:  # pragma: no cover
        return
    span['attributes'] = {**attributes, **log_level_attributes('debug')}


def _tweak_asgi_send_receive_spans(span: ReadableSpanDict) -> None:
    """Make the name/message of spans generated by OTEL's ASGI middleware more useful.

    For example, a single request will typically generate two 'send' spans with the same message,
    e.g. 'GET /foo http send'. This function may add part of the ASGI event type to the name to make it more useful,
    so instead it shows e.g. 'http send response.start' and 'http send response.body'.
    """
    name = span['name']
    if _is_asgi_send_receive_span(name, span['instrumentation_scope']):
        attributes = span['attributes']
        # The attribute name should be `asgi.event.type` after this is merged and released:
        # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2300
        typ = attributes.get('asgi.event.type') or attributes.get('type')
        if not (
            isinstance(typ, str)
            and typ.startswith(('http.', 'websocket.'))
            and attributes.get(ATTRIBUTES_MESSAGE_KEY) == name
        ):  # pragma: no cover
            return

        # Strip the 'http.' or 'websocket.' prefix from the event type and add it to the span name.
        if typ in ('websocket.send', 'websocket.receive'):
            # No point in adding anything in this case, otherwise it'd say e.g. 'websocket send send'.
            # No other event types in https://asgi.readthedocs.io/en/latest/specs/www.html are redundant like this.
            new_name = name
        else:
            span['name'] = new_name = f'{name} {typ.split(".", 1)[1]}'

        span['attributes'] = {**attributes, ATTRIBUTES_MESSAGE_KEY: new_name}


def _is_asgi_send_receive_span(name: str, instrumentation_scope: InstrumentationScope | None) -> bool:
    return (
        instrumentation_scope is not None
        and instrumentation_scope.name
        in (
            'opentelemetry.instrumentation.asgi',
            'opentelemetry.instrumentation.starlette',
            'opentelemetry.instrumentation.fastapi',
        )
    ) and is_asgi_send_receive_span_name(name)


def _tweak_http_spans(span: ReadableSpanDict):
    """Tweak spans from HTTP instrumentations, particularly the span name and message.

    Also derives `http.target` from `http.url` if needed.

    The span names from OTEL instrumentations are an inconsistent and generally lacking mess.
    This is partly due to not having a concept of 'message' separate from span names.

    This function checks if the current name is some combination of method and route/target, and if so sets:
    - The span name to method + route (low cardinality)
    - The message to method + target (more information)
    In both cases, if only one of method and route/target is available, it just uses that.

    For some spans (e.g. ASGI) this actually removes information (the target) from the span name,
    but leaves it in the message.
    """
    attributes = span['attributes']

    # Check that this generally looks like a span not generated by logfire methods.
    # This is intended for OTEL instrumentations of frameworks like FastAPI, but written to be general.
    if ATTRIBUTES_MESSAGE_TEMPLATE_KEY in attributes:
        return

    name = span['name']
    if name != attributes.get(ATTRIBUTES_MESSAGE_KEY):  # pragma: no cover
        return

    method = attributes.get('http.method')
    route = attributes.get('http.route')
    target = attributes.get('http.target')
    url: Any = attributes.get('http.url')
    if not (method or route or target or url):
        return

    if not target and url and isinstance(url, str):
        try:
            target = urlparse(url).path
            span['attributes'] = attributes = {**attributes, 'http.target': target}
        except Exception:  # pragma: no cover
            pass

    if not method and name in ('HTTP', f'HTTP {target}', f'HTTP {route}'):
        method = 'HTTP'

    # Build up a list of possible span names and messages in order from worst to best
    names: list[str] = []
    messages: list[str] = []
    if method and isinstance(method, str):
        names.append(method)
        messages.append(method)
    if target and isinstance(target, str):  # pragma: no branch
        message_target = target
        if span['kind'] == SpanKind.CLIENT:
            # For outgoing requests, we also want the domain, not just the path.
            server_name: Any = (
                attributes.get('server.address') or attributes.get('http.server_name') or attributes.get('http.host')
            )
            if not server_name:
                try:
                    server_name = urlparse(url).hostname
                except Exception:  # pragma: no cover
                    pass
            server_name = server_name or url
            if server_name and isinstance(server_name, str):  # pragma: no branch
                message_target = server_name + message_target
        messages.append(message_target)
    if route and isinstance(route, str):
        names.append(route)

    # If both method and target/route are present, also use the combination
    for lst in (names, messages):
        if len(lst) == 2:
            lst.append(' '.join(lst))

    # If the name doesn't already consist of method and/or target/route, leave it alone
    if name not in names + messages:
        return

    # For each of name and message, update to the best option, which is the last in the list.
    # Minor optimization: only do this if there's a change.
    if names and (new_name := names[-1]) != name:
        span['name'] = new_name

    if not messages:  # pragma: no cover
        return

    message = messages[-1]

    # Add query params to the message if:
    # 1. The message currently ends with the target
    # 2. We have a URL to parse query params from
    # 3. Some query params exist
    # 4. The target doesn't already end with the query string
    #       (it's supposed to according to the spec, but the OTEL libraries don't include it)
    if (
        url and target and isinstance(url, str) and isinstance(target, str) and message.endswith(target)
    ):  # pragma: no branch
        query_string = urlparse(url).query
        query_params = parse_qs(query_string)
        if query_params and not target.endswith(query_string):
            pairs = [(k, v) for k, vs in query_params.items() for v in vs]
            # Put shorter query params first so that they'll be visible in the UI even if the whole message isn't.
            pairs.sort(key=lambda pair: (len(pair[0]) + len(pair[1]), pair))
            # Limit keys and values to 20 chars each.
            truncated_pairs = [[truncate_string(s, max_length=20, middle='…') for s in pair] for pair in pairs]
            # Show
            #   /path?foo=1&bar=2%203
            # as:
            #   /path ? foo='1' & bar='2 3'
            # to make things nice and readable.
            # Note that we show decoded values, e.g. %20 -> ' '.
            message += ' ? ' + ' & '.join(f'{k}={v!r}' for k, v in truncated_pairs)

    if message != name:
        span['attributes'] = {**attributes, ATTRIBUTES_MESSAGE_KEY: message}


def _summarize_db_statement(span: ReadableSpanDict):
    attributes = span['attributes']
    message: str | None = attributes.get(ATTRIBUTES_MESSAGE_KEY)  # type: ignore
    summary = message_from_db_statement(attributes, message, span['name'])
    if summary is not None:
        span['attributes'] = {**attributes, ATTRIBUTES_MESSAGE_KEY: summary}


def _tweak_fastapi_span(span: ReadableSpanDict):
    scope = span['instrumentation_scope']

    if not (scope and scope.name == 'opentelemetry.instrumentation.fastapi'):
        return

    # Our fastapi instrumentation records some exceptions directly on the request span.
    # These might be handled and not seen again, or they may bubble through and be recorded by the OTel middleware,
    # thus appearing twice on the same span.
    # We dedupe them here, keeping the latter event which has a fuller traceback.
    events = span['events']
    new_events: list[Event] = []
    # (type, message) keys of exceptions we've seen.
    seen_exceptions: set[tuple[Any, Any]] = set()
    # Go in reverse order to give the latter events precedence.
    for event in events[::-1]:
        attrs = event.attributes
        if not (event.name == 'exception' and attrs and 'exception.type' in attrs and 'exception.message' in attrs):
            new_events.append(event)
            continue
        key = (attrs['exception.type'], attrs['exception.message'])
        if key in seen_exceptions and attrs.get('recorded_by_logfire_fastapi'):
            continue
        seen_exceptions.add(key)
        new_events.append(event)
    span['events'] = new_events[::-1]


def _transform_langchain_span(span: ReadableSpanDict):
    """Transform spans generated by LangSmith to work better in the Logfire UI.

    - Add attribute names to the JSON schema so that they get parsed as JSON.
    - Add OTel semconv attributes.
    - Add `all_messages_events` to display the conversation in the Generation panel.
    """
    scope = span['instrumentation_scope']

    # This was originally written for and tested with openinference.instrumentation.langchain,
    # which produces essentially the same spans as langsmith but with different attribute names.
    if not (scope and scope.name in ('openinference.instrumentation.langchain', 'langsmith')):
        return

    attributes = span['attributes']
    existing_json_schema = attributes.get(ATTRIBUTES_JSON_SCHEMA_KEY)
    if existing_json_schema:  # pragma: no cover
        return

    properties = JsonSchemaProperties({})
    parsed_attributes: dict[str, Any] = {}
    for key, value in attributes.items():
        if not isinstance(value, str) or not value.startswith(('{"', '[')):
            continue
        try:
            parsed_attributes[key] = json.loads(value)
        except json.JSONDecodeError:  # pragma: no cover
            continue
        # Tell the Logfire backend to parse this attribute as JSON.
        properties[key] = {'type': 'object' if value.startswith('{') else 'array'}

    new_attributes: dict[str, Any] = {}

    # OTel semconv attributes, needed for displaying costs.
    with suppress(Exception):
        new_attributes['gen_ai.request.model'] = parsed_attributes['llm.invocation_parameters']['model']
    with suppress(Exception):
        new_attributes['gen_ai.response.model'] = model = parsed_attributes['gen_ai.completion']['llm_output'][
            'model_name'
        ]
        new_attributes.setdefault('gen_ai.request.model', model)

    request_model: str = attributes.get('gen_ai.request.model') or new_attributes.get('gen_ai.request.model', '')  # type: ignore

    if not request_model and 'gen_ai.usage.input_tokens' in attributes:  # pragma: no cover
        # Only keep usage attributes on spans with actual token usage, i.e. model requests,
        # to prevent double counting costs in the UI.
        # This applies to older langsmith versions
        attributes = {k: v for k, v in attributes.items() if not k.startswith('gen_ai.usage.')}

    guessed_system = guess_system(request_model)
    actual_system = attributes.get('gen_ai.system')
    if guessed_system:
        if actual_system in (None, 'langchain'):  # pragma: no cover
            new_attributes['gen_ai.system'] = guessed_system
    elif actual_system == 'langchain':
        # Remove gen_ai.system=langchain as this also interferes with costs in the UI.
        attributes = {k: v for k, v in attributes.items() if k != 'gen_ai.system'}

    # Add `all_messages_events`
    with suppress(Exception):
        input_messages = parsed_attributes.get('input.value', parsed_attributes.get('gen_ai.prompt', {}))['messages']
        if len(input_messages) == 1 and isinstance(input_messages[0], list):
            [input_messages] = input_messages

        message_events = [_transform_langchain_message(old_message) for old_message in input_messages]

        # If we fail to parse output messages, fine, but only try if we've succeeded to parse input messages.
        with suppress(Exception):
            output_value = parsed_attributes.get('output.value', parsed_attributes.get('gen_ai.completion', {}))
            try:
                # Multiple generations mean multiple choices, we can only display one.
                message_events += [_transform_langchain_message(output_value['generations'][0][0]['message'])]
            except Exception:
                try:
                    output_message_events = [_transform_langchain_message(m) for m in output_value['messages']]
                    if (
                        message_events
                        and len(message_events) <= len(output_message_events)
                        and all(
                            all(om.get(k) == im.get(k) for k in im)
                            for im, om in zip(message_events, output_message_events)
                        )
                    ):
                        # If the input messages are a prefix of the output messages, we can just use the output messages.
                        message_events = output_message_events
                    else:
                        message_events += output_message_events
                except Exception:
                    message_events += [_transform_langchain_message(output_value['output'])]

        new_attributes['all_messages_events'] = json.dumps(message_events)
        properties['all_messages_events'] = {'type': 'array'}

    span['attributes'] = {
        **attributes,
        ATTRIBUTES_JSON_SCHEMA_KEY: attributes_json_schema(properties),
        **new_attributes,
    }


def _transform_langchain_message(old_message: dict[str, Any]) -> dict[str, Any]:
    if old_message.get('type') == 'constructor' and 'kwargs' in old_message:
        kwargs = old_message['kwargs']
    else:
        kwargs = old_message

    role = kwargs.get('role') or {'human': 'user', 'ai': 'assistant'}.get(kwargs['type'], kwargs['type'])
    result: dict[str, Any] = {
        **{
            k: v
            for k, v in kwargs.items()
            if k not in ('type', 'additional_kwargs', 'response_metadata', 'id', 'usage_metadata')
        },
        **kwargs.get('additional_kwargs', {}),
        'role': role,
    }

    if tool_calls := result.get('tool_calls'):
        for tool_call in tool_calls:
            if (
                'function' not in tool_call
                and 'name' in tool_call
                and 'args' in tool_call
                and tool_call.get('type') == 'tool_call'
            ):  # pragma: no branch
                tool_call.update(
                    function=dict(
                        name=tool_call.pop('name'),
                        arguments=tool_call.pop('args'),
                    ),
                    type='function',
                )
    else:
        result.pop('tool_calls', None)

    if 'tool_call_id' in result:
        result['id'] = result.pop('tool_call_id')
    return result


def _default_gen_ai_response_model(span: ReadableSpanDict):
    attrs = span['attributes']
    if 'gen_ai.request.model' in attrs and 'gen_ai.response.model' not in attrs:
        span['attributes'] = {
            **attrs,
            'gen_ai.response.model': attrs['gen_ai.request.model'],
        }


def _transform_google_genai_span(span: ReadableSpanDict):
    scope = span['instrumentation_scope']
    if not (scope and scope.name == 'opentelemetry.instrumentation.google_genai' and span['events']):
        return

    new_events: list[Event] = []
    events_attr: list[dict[str, Any]] = []
    for event in span['events']:
        if not (
            event.name.startswith('gen_ai.')
            and event.attributes
            and isinstance(event_attrs_string := event.attributes.get('event_body'), str)
        ):  # pragma: no cover
            new_events.append(event)
            continue
        event_attrs: dict[str, Any] = json.loads(event_attrs_string)
        events_attr.append(event_attrs)
    span['attributes'] = {
        **span['attributes'],
        'events': json.dumps(events_attr),
        'gen_ai.operation.name': 'chat',
        ATTRIBUTES_JSON_SCHEMA_KEY: attributes_json_schema(JsonSchemaProperties({'events': {'type': 'array'}})),
    }
    span['events'] = new_events


def _transform_litellm_span(span: ReadableSpanDict):
    scope = span['instrumentation_scope']
    if not (scope and scope.name == 'openinference.instrumentation.litellm'):
        return

    attributes = span['attributes']
    try:
        output_value = attributes['output.value']
        new_attrs = {
            'request_data': attributes['input.value'],
        }
        if output_value == attributes.get('llm.output_messages.0.message.content'):
            message = {
                'content': output_value,
                'role': attributes.get('llm.output_messages.0.message.role', 'assistant'),
            }
        else:
            parsed_output_value = json.loads(cast(str, output_value))
            message = parsed_output_value['choices'][0]['message']
            if 'model' in parsed_output_value:  # pragma: no branch
                new_attrs['gen_ai.response.model'] = parsed_output_value['model']

        new_attrs['response_data'] = json.dumps({'message': message})
    except Exception:  # pragma: no cover
        return

    try:
        request_model = cast(str, attributes['llm.model_name'])
        new_attrs.update(
            {
                'gen_ai.request.model': request_model,
                'gen_ai.usage.input_tokens': attributes['llm.token_count.prompt'],
                'gen_ai.usage.output_tokens': attributes['llm.token_count.completion'],
                'gen_ai.system': guess_system(request_model, 'litellm'),
            }
        )
    except Exception:  # pragma: no cover
        pass

    span['attributes'] = {
        **attributes,
        **new_attrs,
        ATTRIBUTES_TAGS_KEY: ['LLM'],
        ATTRIBUTES_JSON_SCHEMA_KEY: attributes_json_schema(
            JsonSchemaProperties(
                {
                    'request_data': {'type': 'object'},
                    'response_data': {'type': 'object'},
                }
            )
        ),
    }


def guess_system(model: str, default: str = ''):
    model_lower = model.lower()
    if 'openai' in model_lower or 'gpt-4' in model_lower or 'gpt-3.5' in model_lower:
        return 'openai'
    elif 'google' in model_lower or 'gemini' in model_lower:
        return 'google'
    elif 'anthropic' in model_lower or 'claude' in model_lower:
        return 'anthropic'
    else:
        return default
