"""
Utility functions for managing user puja sessions in the database.
Phone numbers are now stored in collected_info JSON field.
"""
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
from sqlmodel import Session, select, and_, or_
from app.db import engine
from app.schemas import UserPujaSession

logger = logging.getLogger(__name__)


# ============================================================================
# PUJA SESSION FUNCTIONS
# ============================================================================

def get_active_puja_session(user_id: str, session_id: str) -> Optional[UserPujaSession]:
    """
    Get the active (incomplete) puja session for a user in a specific session.
    Returns the first incomplete puja found for the user+session combination.
    """
    try:
        with Session(engine) as session:
            statement = (
                select(UserPujaSession)
                .where(
                    and_(
                        UserPujaSession.user_id == user_id,
                        UserPujaSession.session_id == session_id,
                        UserPujaSession.completed == False
                    )
                )
                .order_by(UserPujaSession.updated_at.desc())
            )
            result = session.exec(statement).first()
            return result
    except Exception as e:
        logger.error(f"Error getting active puja session: {type(e).__name__}")
        return None


def get_puja_session(user_id: str, session_id: str, puja_type: str) -> Optional[UserPujaSession]:
    """
    Get a specific puja session by user_id, session_id, and puja_type.
    """
    try:
        with Session(engine) as session:
            statement = (
                select(UserPujaSession)
                .where(
                    and_(
                        UserPujaSession.user_id == user_id,
                        UserPujaSession.session_id == session_id,
                        UserPujaSession.puja_type == puja_type
                    )
                )
                .order_by(UserPujaSession.updated_at.desc())
            )
            result = session.exec(statement).first()
            return result
    except Exception as e:
        logger.error(f"Error getting puja session: {type(e).__name__}")
        return None


def create_puja_session(
    user_id: str,
    session_id: str,
    puja_type: str,
    puja_name: str = None
) -> Optional[UserPujaSession]:
    """
    Create a new puja session in the database.
    """
    try:
        with Session(engine) as session:
            puja_session = UserPujaSession(
                user_id=user_id,
                session_id=session_id,
                puja_type=puja_type,
                puja_name=puja_name,
                created_at=datetime.now(),
                updated_at=datetime.now()
            )
            session.add(puja_session)
            session.commit()
            session.refresh(puja_session)
            
            logger.debug(f"Created puja session")
            return puja_session
    except Exception as e:
        logger.error(f"Error creating puja session: {type(e).__name__}")
        return None


def update_puja_session(
    user_id: str,
    session_id: str,
    puja_type: str,
    updates: Dict[str, Any]
) -> bool:
    """
    Update a puja session with new data.
    """
    try:
        with Session(engine) as session:
            statement = (
                select(UserPujaSession)
                .where(
                    and_(
                        UserPujaSession.user_id == user_id,
                        UserPujaSession.session_id == session_id,
                        UserPujaSession.puja_type == puja_type
                    )
                )
            )
            puja_session = session.exec(statement).first()
            
            if not puja_session:
                logger.debug(f"Puja session not found for update")
                return False
            
            # Update fields
            for key, value in updates.items():
                if hasattr(puja_session, key):
                    setattr(puja_session, key, value)
            
            puja_session.updated_at = datetime.now()
            
            session.add(puja_session)
            session.commit()
            
            logger.debug(f"Updated puja session")
            return True
    except Exception as e:
        logger.error(f"Error updating puja session: {type(e).__name__}")
        return False


def delete_puja_session(user_id: str, session_id: str, puja_type: str) -> bool:
    """
    Delete a puja session (e.g., when user starts a new puja in same session).
    """
    try:
        with Session(engine) as session:
            statement = (
                select(UserPujaSession)
                .where(
                    and_(
                        UserPujaSession.user_id == user_id,
                        UserPujaSession.session_id == session_id,
                        UserPujaSession.puja_type == puja_type
                    )
                )
            )
            puja_session = session.exec(statement).first()
            
            if puja_session:
                session.delete(puja_session)
                session.commit()
                logger.debug(f"Deleted puja session")
                return True
            return False
    except Exception as e:
        logger.error(f"Error deleting puja session: {type(e).__name__}")
        return False


def get_all_user_puja_sessions(user_id: str, session_id: Optional[str] = None) -> list:
    """
    Get all puja sessions for a user, optionally filtered by session_id.
    """
    try:
        with Session(engine) as session:
            if session_id:
                statement = (
                    select(UserPujaSession)
                    .where(
                        and_(
                            UserPujaSession.user_id == user_id,
                            UserPujaSession.session_id == session_id
                        )
                    )
                    .order_by(UserPujaSession.updated_at.desc())
                )
            else:
                statement = (
                    select(UserPujaSession)
                    .where(UserPujaSession.user_id == user_id)
                    .order_by(UserPujaSession.updated_at.desc())
                )
            
            results = session.exec(statement).all()
            return results
    except Exception as e:
        logger.error(f"Error getting all puja sessions: {type(e).__name__}")
        return []


def mark_puja_completed(user_id: str, session_id: str, puja_type: str) -> bool:
    """
    Mark a puja session as completed.
    """
    return update_puja_session(user_id, session_id, puja_type, {"completed": True})


def puja_session_to_dict(puja_session: UserPujaSession) -> Dict[str, Any]:
    """
    Convert a UserPujaSession model to a dictionary (for backward compatibility).
    """
    if not puja_session:
        return {}
    
    return {
        "user_id": puja_session.user_id,
        "session_id": puja_session.session_id,
        "puja_type": puja_session.puja_type,
        "puja_name": puja_session.puja_name,
        "current_question": puja_session.current_question,
        "collected_info": puja_session.collected_info,
        "completed": puja_session.completed,
        "kb_checked": puja_session.kb_checked,
        "invalid_attempts": puja_session.invalid_attempts,
        "awaiting_phone": puja_session.awaiting_phone,
        "phone_collected": puja_session.phone_collected,
        "created_at": puja_session.created_at.isoformat(),
        "updated_at": puja_session.updated_at.isoformat()
    }


# ============================================================================
# PHONE NUMBER FUNCTIONS
# ============================================================================

def save_user_phone_number(
    user_id: str,
    session_id: str,
    phone_number: str,
    puja_type: str,
    puja_id: int = None,
    phone_country_code: str = "+91"
) -> bool:
    """
    Save user's phone number to puja session's collected_info JSON field.
    NOTE: Phone numbers are now stored in collected_info, not a separate table.
    """
    try:
        # Update collected_info with phone number
        update_result = update_puja_session(
            user_id=user_id,
            session_id=session_id,
            puja_type=puja_type,
            updates={
                "collected_info": {"phone_number": phone_number},
                "phone_collected": True
            }
        )
        
        if update_result:
            return True
        else:
            logger.debug(f"Failed to update phone - session not found")
            return False
    except Exception as e:
        logger.error(f"Error saving phone: {type(e).__name__}")
        return False


def get_user_phone_number(user_id: str, session_id: str, puja_type: str) -> Optional[str]:
    """
    Get user's phone number from puja session's collected_info JSON field.
    
    Returns:
        Phone number string if found, None otherwise
    """
    try:
        puja_session = get_puja_session(user_id, session_id, puja_type)
        if puja_session and puja_session.collected_info:
            return puja_session.collected_info.get("phone_number")
        return None
    except Exception as e:
        logger.error(f"Error getting phone: {type(e).__name__}")
        return None


def get_all_user_phones(user_id: str) -> List[str]:
    """
    Get all phone numbers for a user across all puja sessions.
    
    Returns:
        List of unique phone numbers
    """
    try:
        with Session(engine) as session:
            statement = (
                select(UserPujaSession)
                .where(UserPujaSession.user_id == user_id)
                .order_by(UserPujaSession.created_at.desc())
            )
            sessions = session.exec(statement).all()
            
            phones = []
            for puja_session in sessions:
                if puja_session.collected_info and "phone_number" in puja_session.collected_info:
                    phone = puja_session.collected_info["phone_number"]
                    if phone and phone not in phones:
                        phones.append(phone)
            
            return phones
    except Exception as e:
        logger.error(f"Error getting phones: {type(e).__name__}")
        return []


# ============================================================================
# PHONE NUMBER VALIDATION
# ============================================================================

def clean_indian_phone_number(phone: str) -> Optional[str]:
    """
    Clean and normalize Indian phone number to 10-digit format.
    Handles various formats and validates against Indian mobile number rules.
    
    Accepts formats:
    - 10 digits: 9876543210
    - With country code: +91-9876543210, 0091-9876543210, 91-9876543210
    - With leading zero: 09876543210
    - With spaces/dashes/parentheses: (987) 654-3210
    
    Args:
        phone: Phone number string to clean
        
    Returns:
        10-digit cleaned phone number if valid, None otherwise
    """
    import re
    
    # Handle None or placeholder values
    if not phone or phone == "@phone":
        return None
    
    # Remove all spaces, dashes, and parentheses
    cleaned = re.sub(r'[\s\-\(\)]', '', phone.strip())
    
    # Remove country code if present
    if cleaned.startswith('+91'):
        cleaned = cleaned[3:]
    elif cleaned.startswith('0091'):
        cleaned = cleaned[4:]
    elif cleaned.startswith('91') and len(cleaned) == 12:
        cleaned = cleaned[2:]
    elif cleaned.startswith('0') and len(cleaned) == 11:
        cleaned = cleaned[1:]
    
    # Validate: exactly 10 digits starting with 6-9 (valid Indian mobile prefixes)
    if re.match(r'^[6-9]\d{9}$', cleaned):
        return cleaned
    
    return None


def validate_indian_phone_number(phone: str) -> bool:
    """
    Validate if a phone number is a valid Indian mobile number.
    Accepts formats: 10 digits, +91-10digits, 0091-10digits, with/without spaces/dashes
    
    Args:
        phone: Phone number string to validate
        
    Returns:
        True if valid Indian mobile number, False otherwise
    """
    return clean_indian_phone_number(phone) is not None
