import os
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from pymongo.errors import (
    ConnectionFailure, 
    OperationFailure, 
    ServerSelectionTimeoutError,
    InvalidName
)
from dotenv import load_dotenv

class DatabaseError(Exception):
    """Base class for database operation errors"""
    pass

class ConnectionError(DatabaseError):
    """Error when connecting to MongoDB Atlas"""
    pass

class OperationError(DatabaseError):
    """Error during database operations"""
    pass

class DatabaseUtils:
    """Utility class for MongoDB Atlas database operations
    
    This class provides methods to interact with MongoDB Atlas databases and collections,
    including listing databases, collections, and retrieving collection information.
    
    Attributes:
        atlas_uri (str): MongoDB Atlas connection string
        client (MongoClient): MongoDB client instance
    """
    
    def __init__(self):
        """Initialize DatabaseUtils with MongoDB Atlas connection
        
        Raises:
            ConnectionError: If unable to connect to MongoDB Atlas
            ValueError: If ATLAS_URI environment variable is not set
        """
        # Load environment variables
        load_dotenv()
        
        self.atlas_uri = os.getenv("ATLAS_URI")
        if not self.atlas_uri:
            raise ValueError("ATLAS_URI environment variable is not set")
            
        try:
            self.client = MongoClient(self.atlas_uri)
            # Test connection
            self.client.admin.command('ping')
        except (ConnectionFailure, ServerSelectionTimeoutError) as e:
            raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
    
    def get_databases(self) -> List[str]:
        """Get list of all databases in Atlas cluster
        
        Returns:
            List[str]: List of database names
            
        Raises:
            OperationError: If unable to list databases
        """
        try:
            return self.client.list_database_names()
        except OperationFailure as e:
            raise OperationError(f"Failed to list databases: {str(e)}")
    
    def get_collections(self, db_name: str) -> List[str]:
        """Get list of collections in a database
        
        Args:
            db_name (str): Name of the database
            
        Returns:
            List[str]: List of collection names
            
        Raises:
            OperationError: If unable to list collections
            ValueError: If db_name is empty or invalid
        """
        if not db_name or not isinstance(db_name, str):
            raise ValueError("Database name must be a non-empty string")
            
        try:
            db = self.client[db_name]
            return db.list_collection_names()
        except (OperationFailure, InvalidName) as e:
            raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
    
    def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
        """Get information about a collection including document count and sample document
        
        Args:
            db_name (str): Name of the database
            collection_name (str): Name of the collection
            
        Returns:
            Dict[str, Any]: Dictionary containing collection information:
                - count: Number of documents in collection
                - sample: Sample document from collection (if exists)
                
        Raises:
            OperationError: If unable to get collection information
            ValueError: If db_name or collection_name is empty or invalid
        """
        if not db_name or not isinstance(db_name, str):
            raise ValueError("Database name must be a non-empty string")
        if not collection_name or not isinstance(collection_name, str):
            raise ValueError("Collection name must be a non-empty string")
            
        try:
            db = self.client[db_name]
            collection = db[collection_name]
            
            return {
                'count': collection.count_documents({}),
                'sample': collection.find_one()
            }
        except (OperationFailure, InvalidName) as e:
            raise OperationError(
                f"Failed to get info for collection '{collection_name}' "
                f"in database '{db_name}': {str(e)}"
            )
    
    def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
        """Get list of fields in a collection based on sample document
        
        Args:
            db_name (str): Name of the database
            collection_name (str): Name of the collection
            
        Returns:
            List[str]: List of field names (excluding _id and embedding fields)
            
        Raises:
            OperationError: If unable to get field names
            ValueError: If db_name or collection_name is empty or invalid
        """
        try:
            info = self.get_collection_info(db_name, collection_name)
            sample = info.get('sample', {})
            
            if sample:
                # Get all field names except _id and any existing embedding fields
                return [field for field in sample.keys() 
                        if field != '_id' and not field.endswith('_embedding')]
            return []
        except DatabaseError as e:
            raise OperationError(
                f"Failed to get field names for collection '{collection_name}' "
                f"in database '{db_name}': {str(e)}"
            )
    
    def close(self):
        """Close MongoDB connection safely"""
        if hasattr(self, 'client'):
            self.client.close()