from datetime import datetime
from typing import Optional, Dict, Any
from sqlmodel import Field, SQLModel, Column
from sqlalchemy import Text, JSON
import uuid


def generate_session_id() -> str:
    """Generate a unique session ID"""
    return str(uuid.uuid4())


class ChatHistory(SQLModel, table=True):
    """Chat history model matching the MySQL schema."""
    __tablename__ = "ai_chat_history"
    
    id: Optional[int] = Field(default=None, primary_key=True)
    user_id: str = Field(max_length=50, index=True)
    session_id: str = Field(default_factory=generate_session_id, max_length=50)
    timestamp: datetime = Field(default_factory=datetime.now, index=True)
    type: str = Field(max_length=20)  # 'user' or 'assistant'
    message: str = Field(sa_column=Column(Text))
    context: str = Field(default="general", max_length=50)


class UserPujaSession(SQLModel, table=True):
    """
    User puja session model for tracking puja bookings.
    
    Key Features:
    - UNIQUE constraint on (user_id, session_id, puja_type) ensures ONE row per combination
    - Phone number stored in collected_info JSON as 'phone_number' key
    - All booking data (name, location, date, phone, etc.) stored in collected_info
    """
    __tablename__ = "ai_user_puja_sessions"
    
    id: Optional[int] = Field(default=None, primary_key=True)
    user_id: str = Field(max_length=50)
    session_id: str = Field(max_length=100)
    puja_type: str = Field(max_length=50)
    puja_name: Optional[str] = Field(default=None, max_length=100)
    
    # Puja state
    current_question: int = Field(default=0)
    collected_info: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
    # collected_info contains: name, location, date, phone_number, muhurtam, additional_context, etc.
    
    completed: bool = Field(default=False)
    kb_checked: bool = Field(default=False)
    invalid_attempts: int = Field(default=0)
    
    # Phone collection
    awaiting_phone: bool = Field(default=False)
    phone_collected: bool = Field(default=False)
    
    # Timestamps
    created_at: datetime = Field(default_factory=datetime.now)
    updated_at: datetime = Field(default_factory=datetime.now)
