from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Literal

from typing_extensions import assert_never

from pydantic_graph.beta.decision import Decision
from pydantic_graph.beta.id_types import NodeID
from pydantic_graph.beta.join import Join
from pydantic_graph.beta.node import EndNode, Fork, StartNode
from pydantic_graph.beta.node_types import AnyNode
from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path
from pydantic_graph.beta.step import Step

DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
"""The default CSS to use for highlighting nodes."""


StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT']
"""Used to specify the direction of the state diagram generated by mermaid.

- `'TB'`: Top to bottom, this is the default for mermaid charts.
- `'LR'`: Left to right
- `'RL'`: Right to left
- `'BT'`: Bottom to top
"""

NodeKind = Literal['broadcast', 'map', 'join', 'start', 'end', 'step', 'decision']


@dataclass
class MermaidNode:
    """A mermaid node."""

    id: str
    kind: NodeKind
    label: str | None
    note: str | None


@dataclass
class MermaidEdge:
    """A mermaid edge."""

    start_id: str
    end_id: str
    label: str | None


def build_mermaid_graph(  # noqa C901
    graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
) -> MermaidGraph:
    """Build a mermaid graph."""
    nodes: list[MermaidNode] = []
    edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list)

    def _collect_edges(path: Path, last_source_id: NodeID) -> None:
        working_label: str | None = None
        for item in path.items:
            assert not isinstance(item, MapMarker | BroadcastMarker), 'These should be removed during Graph building'
            if isinstance(item, LabelMarker):
                working_label = item.label
            elif isinstance(item, DestinationMarker):
                edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label))

    for node_id, node in graph_nodes.items():
        kind: NodeKind
        label: str | None = None
        note: str | None = None
        if isinstance(node, StartNode):
            kind = 'start'
        elif isinstance(node, EndNode):
            kind = 'end'
        elif isinstance(node, Step):
            kind = 'step'
            label = node.label
        elif isinstance(node, Join):
            kind = 'join'
        elif isinstance(node, Fork):
            kind = 'map' if node.is_map else 'broadcast'
        elif isinstance(node, Decision):
            kind = 'decision'
            note = node.note
        else:
            assert_never(node)

        source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note)
        nodes.append(source_node)

    for k, v in graph_edges_by_source.items():
        for path in v:
            _collect_edges(path, k)

    for node in graph_nodes.values():
        if isinstance(node, Decision):
            for branch in node.branches:
                _collect_edges(branch.path, node.id)

    # Add edges in the same order that we added nodes
    edges: list[MermaidEdge] = sum([edges_by_source.get(node.id, []) for node in nodes], list[MermaidEdge]())
    return MermaidGraph(nodes, edges)


@dataclass
class MermaidGraph:
    """A mermaid graph."""

    nodes: list[MermaidNode]
    edges: list[MermaidEdge]

    title: str | None = None
    direction: StateDiagramDirection | None = None

    def render(
        self,
        direction: StateDiagramDirection | None = None,
        title: str | None = None,
        edge_labels: bool = True,
    ):
        lines: list[str] = []
        if title:
            lines = ['---', f'title: {title}', '---']
        lines.append('stateDiagram-v2')
        if direction is not None:
            lines.append(f'  direction {direction}')

        nodes, edges = _topological_sort(self.nodes, self.edges)
        for node in nodes:
            # List all nodes in order they were created
            node_lines: list[str] = []
            if node.kind == 'start' or node.kind == 'end':
                pass  # Start and end nodes use special [*] syntax in edges
            elif node.kind == 'step':
                line = f'  {node.id}'
                if node.label:
                    line += f': {node.label}'
                node_lines.append(line)
            elif node.kind == 'join':
                node_lines = [f'  state {node.id} <<join>>']
            elif node.kind == 'broadcast' or node.kind == 'map':
                node_lines = [f'  state {node.id} <<fork>>']
            elif node.kind == 'decision':
                node_lines = [f'  state {node.id} <<choice>>']
                if node.note:
                    node_lines.append(f'  note right of {node.id}\n    {node.note}\n  end note')
            else:  # pragma: no cover
                assert_never(node.kind)
            lines.extend(node_lines)

        lines.append('')

        for edge in edges:
            # Use special [*] syntax for start/end nodes
            render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id
            render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id
            edge_line = f'  {render_start_id} --> {render_end_id}'
            if edge.label and edge_labels:
                edge_line += f': {edge.label}'
            lines.append(edge_line)

        return '\n'.join(lines)


def _topological_sort(
    nodes: list[MermaidNode], edges: list[MermaidEdge]
) -> tuple[list[MermaidNode], list[MermaidEdge]]:
    """Sort nodes and edges in a logical topological order.

    Uses BFS from the start node to assign depths, then sorts:
    - Nodes by their distance from start
    - Edges by the distance of their source and target nodes
    """
    # Build adjacency list for BFS
    adjacency: dict[str, list[str]] = defaultdict(list)
    for edge in edges:
        adjacency[edge.start_id].append(edge.end_id)

    # BFS to assign depth to each node (distance from start)
    depths: dict[str, int] = {}
    queue: list[tuple[str, int]] = [(StartNode.id, 0)]
    depths[StartNode.id] = 0

    while queue:
        node_id, depth = queue.pop(0)
        for next_id in adjacency[node_id]:
            if next_id not in depths:  # pragma: no branch
                depths[next_id] = depth + 1
                queue.append((next_id, depth + 1))

    # Sort nodes by depth (distance from start), then by id for stability
    # Nodes not reachable from start get infinity depth (sorted to end)
    sorted_nodes = sorted(nodes, key=lambda n: (depths.get(n.id, float('inf')), n.id))

    # Sort edges by source depth, then target depth
    # This ensures edges closer to start come first, edges closer to end come last
    sorted_edges = sorted(
        edges,
        key=lambda e: (
            depths.get(e.start_id, float('inf')),
            depths.get(e.end_id, float('inf')),
            e.start_id,
            e.end_id,
        ),
    )

    return sorted_nodes, sorted_edges
