Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Dict, Any, Optional | |
from pymongo import MongoClient | |
from pymongo.errors import ( | |
ConnectionFailure, | |
OperationFailure, | |
ServerSelectionTimeoutError, | |
InvalidName | |
) | |
from dotenv import load_dotenv | |
class DatabaseError(Exception): | |
"""Base class for database operation errors""" | |
pass | |
class ConnectionError(DatabaseError): | |
"""Error when connecting to MongoDB Atlas""" | |
pass | |
class OperationError(DatabaseError): | |
"""Error during database operations""" | |
pass | |
class DatabaseUtils: | |
"""Utility class for MongoDB Atlas database operations | |
This class provides methods to interact with MongoDB Atlas databases and collections, | |
including listing databases, collections, and retrieving collection information. | |
Attributes: | |
atlas_uri (str): MongoDB Atlas connection string | |
client (MongoClient): MongoDB client instance | |
""" | |
def __init__(self): | |
"""Initialize DatabaseUtils with MongoDB Atlas connection | |
Raises: | |
ConnectionError: If unable to connect to MongoDB Atlas | |
ValueError: If ATLAS_URI environment variable is not set | |
""" | |
# Load environment variables | |
load_dotenv() | |
self.atlas_uri = os.getenv("ATLAS_URI") | |
if not self.atlas_uri: | |
raise ValueError("ATLAS_URI environment variable is not set") | |
try: | |
self.client = MongoClient(self.atlas_uri) | |
# Test connection | |
self.client.admin.command('ping') | |
except (ConnectionFailure, ServerSelectionTimeoutError) as e: | |
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}") | |
def get_databases(self) -> List[str]: | |
"""Get list of all databases in Atlas cluster | |
Returns: | |
List[str]: List of database names | |
Raises: | |
OperationError: If unable to list databases | |
""" | |
try: | |
return self.client.list_database_names() | |
except OperationFailure as e: | |
raise OperationError(f"Failed to list databases: {str(e)}") | |
def get_collections(self, db_name: str) -> List[str]: | |
"""Get list of collections in a database | |
Args: | |
db_name (str): Name of the database | |
Returns: | |
List[str]: List of collection names | |
Raises: | |
OperationError: If unable to list collections | |
ValueError: If db_name is empty or invalid | |
""" | |
if not db_name or not isinstance(db_name, str): | |
raise ValueError("Database name must be a non-empty string") | |
try: | |
db = self.client[db_name] | |
return db.list_collection_names() | |
except (OperationFailure, InvalidName) as e: | |
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}") | |
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]: | |
"""Get information about a collection including document count and sample document | |
Args: | |
db_name (str): Name of the database | |
collection_name (str): Name of the collection | |
Returns: | |
Dict[str, Any]: Dictionary containing collection information: | |
- count: Number of documents in collection | |
- sample: Sample document from collection (if exists) | |
Raises: | |
OperationError: If unable to get collection information | |
ValueError: If db_name or collection_name is empty or invalid | |
""" | |
if not db_name or not isinstance(db_name, str): | |
raise ValueError("Database name must be a non-empty string") | |
if not collection_name or not isinstance(collection_name, str): | |
raise ValueError("Collection name must be a non-empty string") | |
try: | |
db = self.client[db_name] | |
collection = db[collection_name] | |
return { | |
'count': collection.count_documents({}), | |
'sample': collection.find_one() | |
} | |
except (OperationFailure, InvalidName) as e: | |
raise OperationError( | |
f"Failed to get info for collection '{collection_name}' " | |
f"in database '{db_name}': {str(e)}" | |
) | |
def get_field_names(self, db_name: str, collection_name: str) -> List[str]: | |
"""Get list of fields in a collection based on sample document | |
Args: | |
db_name (str): Name of the database | |
collection_name (str): Name of the collection | |
Returns: | |
List[str]: List of field names (excluding _id and embedding fields) | |
Raises: | |
OperationError: If unable to get field names | |
ValueError: If db_name or collection_name is empty or invalid | |
""" | |
try: | |
info = self.get_collection_info(db_name, collection_name) | |
sample = info.get('sample', {}) | |
if sample: | |
# Get all field names except _id and any existing embedding fields | |
return [field for field in sample.keys() | |
if field != '_id' and not field.endswith('_embedding')] | |
return [] | |
except DatabaseError as e: | |
raise OperationError( | |
f"Failed to get field names for collection '{collection_name}' " | |
f"in database '{db_name}': {str(e)}" | |
) | |
def close(self): | |
"""Close MongoDB connection safely""" | |
if hasattr(self, 'client'): | |
self.client.close() | |