import os from typing import List, Dict, Any, Optional from pymongo import MongoClient from pymongo.errors import ( ConnectionFailure, OperationFailure, ServerSelectionTimeoutError, InvalidName ) from dotenv import load_dotenv class DatabaseError(Exception): """Base class for database operation errors""" pass class ConnectionError(DatabaseError): """Error when connecting to MongoDB Atlas""" pass class OperationError(DatabaseError): """Error during database operations""" pass class DatabaseUtils: """Utility class for MongoDB Atlas database operations This class provides methods to interact with MongoDB Atlas databases and collections, including listing databases, collections, and retrieving collection information. Attributes: atlas_uri (str): MongoDB Atlas connection string client (MongoClient): MongoDB client instance """ def __init__(self): """Initialize DatabaseUtils with MongoDB Atlas connection Raises: ConnectionError: If unable to connect to MongoDB Atlas ValueError: If ATLAS_URI environment variable is not set """ # Load environment variables load_dotenv() self.atlas_uri = os.getenv("ATLAS_URI") if not self.atlas_uri: raise ValueError("ATLAS_URI environment variable is not set") try: self.client = MongoClient(self.atlas_uri) # Test connection self.client.admin.command('ping') except (ConnectionFailure, ServerSelectionTimeoutError) as e: raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}") def get_databases(self) -> List[str]: """Get list of all databases in Atlas cluster Returns: List[str]: List of database names Raises: OperationError: If unable to list databases """ try: return self.client.list_database_names() except OperationFailure as e: raise OperationError(f"Failed to list databases: {str(e)}") def get_collections(self, db_name: str) -> List[str]: """Get list of collections in a database Args: db_name (str): Name of the database Returns: List[str]: List of collection names Raises: OperationError: If unable to list collections ValueError: If db_name is empty or invalid """ if not db_name or not isinstance(db_name, str): raise ValueError("Database name must be a non-empty string") try: db = self.client[db_name] return db.list_collection_names() except (OperationFailure, InvalidName) as e: raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}") def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]: """Get information about a collection including document count and sample document Args: db_name (str): Name of the database collection_name (str): Name of the collection Returns: Dict[str, Any]: Dictionary containing collection information: - count: Number of documents in collection - sample: Sample document from collection (if exists) Raises: OperationError: If unable to get collection information ValueError: If db_name or collection_name is empty or invalid """ if not db_name or not isinstance(db_name, str): raise ValueError("Database name must be a non-empty string") if not collection_name or not isinstance(collection_name, str): raise ValueError("Collection name must be a non-empty string") try: db = self.client[db_name] collection = db[collection_name] return { 'count': collection.count_documents({}), 'sample': collection.find_one() } except (OperationFailure, InvalidName) as e: raise OperationError( f"Failed to get info for collection '{collection_name}' " f"in database '{db_name}': {str(e)}" ) def get_field_names(self, db_name: str, collection_name: str) -> List[str]: """Get list of fields in a collection based on sample document Args: db_name (str): Name of the database collection_name (str): Name of the collection Returns: List[str]: List of field names (excluding _id and embedding fields) Raises: OperationError: If unable to get field names ValueError: If db_name or collection_name is empty or invalid """ try: info = self.get_collection_info(db_name, collection_name) sample = info.get('sample', {}) if sample: # Get all field names except _id and any existing embedding fields return [field for field in sample.keys() if field != '_id' and not field.endswith('_embedding')] return [] except DatabaseError as e: raise OperationError( f"Failed to get field names for collection '{collection_name}' " f"in database '{db_name}': {str(e)}" ) def close(self): """Close MongoDB connection safely""" if hasattr(self, 'client'): self.client.close()