airabbitX's picture
Upload 9 files
8fb6e2f verified
import os
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from pymongo.errors import (
ConnectionFailure,
OperationFailure,
ServerSelectionTimeoutError,
InvalidName
)
from dotenv import load_dotenv
class DatabaseError(Exception):
"""Base class for database operation errors"""
pass
class ConnectionError(DatabaseError):
"""Error when connecting to MongoDB Atlas"""
pass
class OperationError(DatabaseError):
"""Error during database operations"""
pass
class DatabaseUtils:
"""Utility class for MongoDB Atlas database operations
This class provides methods to interact with MongoDB Atlas databases and collections,
including listing databases, collections, and retrieving collection information.
Attributes:
atlas_uri (str): MongoDB Atlas connection string
client (MongoClient): MongoDB client instance
"""
def __init__(self):
"""Initialize DatabaseUtils with MongoDB Atlas connection
Raises:
ConnectionError: If unable to connect to MongoDB Atlas
ValueError: If ATLAS_URI environment variable is not set
"""
# Load environment variables
load_dotenv()
self.atlas_uri = os.getenv("ATLAS_URI")
if not self.atlas_uri:
raise ValueError("ATLAS_URI environment variable is not set")
try:
self.client = MongoClient(self.atlas_uri)
# Test connection
self.client.admin.command('ping')
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
def get_databases(self) -> List[str]:
"""Get list of all databases in Atlas cluster
Returns:
List[str]: List of database names
Raises:
OperationError: If unable to list databases
"""
try:
return self.client.list_database_names()
except OperationFailure as e:
raise OperationError(f"Failed to list databases: {str(e)}")
def get_collections(self, db_name: str) -> List[str]:
"""Get list of collections in a database
Args:
db_name (str): Name of the database
Returns:
List[str]: List of collection names
Raises:
OperationError: If unable to list collections
ValueError: If db_name is empty or invalid
"""
if not db_name or not isinstance(db_name, str):
raise ValueError("Database name must be a non-empty string")
try:
db = self.client[db_name]
return db.list_collection_names()
except (OperationFailure, InvalidName) as e:
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
"""Get information about a collection including document count and sample document
Args:
db_name (str): Name of the database
collection_name (str): Name of the collection
Returns:
Dict[str, Any]: Dictionary containing collection information:
- count: Number of documents in collection
- sample: Sample document from collection (if exists)
Raises:
OperationError: If unable to get collection information
ValueError: If db_name or collection_name is empty or invalid
"""
if not db_name or not isinstance(db_name, str):
raise ValueError("Database name must be a non-empty string")
if not collection_name or not isinstance(collection_name, str):
raise ValueError("Collection name must be a non-empty string")
try:
db = self.client[db_name]
collection = db[collection_name]
return {
'count': collection.count_documents({}),
'sample': collection.find_one()
}
except (OperationFailure, InvalidName) as e:
raise OperationError(
f"Failed to get info for collection '{collection_name}' "
f"in database '{db_name}': {str(e)}"
)
def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
"""Get list of fields in a collection based on sample document
Args:
db_name (str): Name of the database
collection_name (str): Name of the collection
Returns:
List[str]: List of field names (excluding _id and embedding fields)
Raises:
OperationError: If unable to get field names
ValueError: If db_name or collection_name is empty or invalid
"""
try:
info = self.get_collection_info(db_name, collection_name)
sample = info.get('sample', {})
if sample:
# Get all field names except _id and any existing embedding fields
return [field for field in sample.keys()
if field != '_id' and not field.endswith('_embedding')]
return []
except DatabaseError as e:
raise OperationError(
f"Failed to get field names for collection '{collection_name}' "
f"in database '{db_name}': {str(e)}"
)
def close(self):
"""Close MongoDB connection safely"""
if hasattr(self, 'client'):
self.client.close()