|
import os |
|
from typing import List, Dict, Any, Optional |
|
from pymongo import MongoClient |
|
from pymongo.errors import ( |
|
ConnectionFailure, |
|
OperationFailure, |
|
ServerSelectionTimeoutError, |
|
InvalidName |
|
) |
|
from dotenv import load_dotenv |
|
|
|
class DatabaseError(Exception): |
|
"""Base class for database operation errors""" |
|
pass |
|
|
|
class ConnectionError(DatabaseError): |
|
"""Error when connecting to MongoDB Atlas""" |
|
pass |
|
|
|
class OperationError(DatabaseError): |
|
"""Error during database operations""" |
|
pass |
|
|
|
class DatabaseUtils: |
|
"""Utility class for MongoDB Atlas database operations |
|
|
|
This class provides methods to interact with MongoDB Atlas databases and collections, |
|
including listing databases, collections, and retrieving collection information. |
|
|
|
Attributes: |
|
atlas_uri (str): MongoDB Atlas connection string |
|
client (MongoClient): MongoDB client instance |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize DatabaseUtils with MongoDB Atlas connection |
|
|
|
Raises: |
|
ConnectionError: If unable to connect to MongoDB Atlas |
|
ValueError: If ATLAS_URI environment variable is not set |
|
""" |
|
|
|
load_dotenv() |
|
|
|
self.atlas_uri = os.getenv("ATLAS_URI") |
|
if not self.atlas_uri: |
|
raise ValueError("ATLAS_URI environment variable is not set") |
|
|
|
try: |
|
self.client = MongoClient(self.atlas_uri) |
|
|
|
self.client.admin.command('ping') |
|
except (ConnectionFailure, ServerSelectionTimeoutError) as e: |
|
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}") |
|
|
|
def get_databases(self) -> List[str]: |
|
"""Get list of all databases in Atlas cluster |
|
|
|
Returns: |
|
List[str]: List of database names |
|
|
|
Raises: |
|
OperationError: If unable to list databases |
|
""" |
|
try: |
|
return self.client.list_database_names() |
|
except OperationFailure as e: |
|
raise OperationError(f"Failed to list databases: {str(e)}") |
|
|
|
def get_collections(self, db_name: str) -> List[str]: |
|
"""Get list of collections in a database |
|
|
|
Args: |
|
db_name (str): Name of the database |
|
|
|
Returns: |
|
List[str]: List of collection names |
|
|
|
Raises: |
|
OperationError: If unable to list collections |
|
ValueError: If db_name is empty or invalid |
|
""" |
|
if not db_name or not isinstance(db_name, str): |
|
raise ValueError("Database name must be a non-empty string") |
|
|
|
try: |
|
db = self.client[db_name] |
|
return db.list_collection_names() |
|
except (OperationFailure, InvalidName) as e: |
|
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}") |
|
|
|
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]: |
|
"""Get information about a collection including document count and sample document |
|
|
|
Args: |
|
db_name (str): Name of the database |
|
collection_name (str): Name of the collection |
|
|
|
Returns: |
|
Dict[str, Any]: Dictionary containing collection information: |
|
- count: Number of documents in collection |
|
- sample: Sample document from collection (if exists) |
|
|
|
Raises: |
|
OperationError: If unable to get collection information |
|
ValueError: If db_name or collection_name is empty or invalid |
|
""" |
|
if not db_name or not isinstance(db_name, str): |
|
raise ValueError("Database name must be a non-empty string") |
|
if not collection_name or not isinstance(collection_name, str): |
|
raise ValueError("Collection name must be a non-empty string") |
|
|
|
try: |
|
db = self.client[db_name] |
|
collection = db[collection_name] |
|
|
|
return { |
|
'count': collection.count_documents({}), |
|
'sample': collection.find_one() |
|
} |
|
except (OperationFailure, InvalidName) as e: |
|
raise OperationError( |
|
f"Failed to get info for collection '{collection_name}' " |
|
f"in database '{db_name}': {str(e)}" |
|
) |
|
|
|
def get_field_names(self, db_name: str, collection_name: str) -> List[str]: |
|
"""Get list of fields in a collection based on sample document |
|
|
|
Args: |
|
db_name (str): Name of the database |
|
collection_name (str): Name of the collection |
|
|
|
Returns: |
|
List[str]: List of field names (excluding _id and embedding fields) |
|
|
|
Raises: |
|
OperationError: If unable to get field names |
|
ValueError: If db_name or collection_name is empty or invalid |
|
""" |
|
try: |
|
info = self.get_collection_info(db_name, collection_name) |
|
sample = info.get('sample', {}) |
|
|
|
if sample: |
|
|
|
return [field for field in sample.keys() |
|
if field != '_id' and not field.endswith('_embedding')] |
|
return [] |
|
except DatabaseError as e: |
|
raise OperationError( |
|
f"Failed to get field names for collection '{collection_name}' " |
|
f"in database '{db_name}': {str(e)}" |
|
) |
|
|
|
def close(self): |
|
"""Close MongoDB connection safely""" |
|
if hasattr(self, 'client'): |
|
self.client.close() |
|
|