Spaces:
Sleeping
Sleeping
File size: 5,806 Bytes
46a6768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from pymongo.errors import (
ConnectionFailure,
OperationFailure,
ServerSelectionTimeoutError,
InvalidName
)
from dotenv import load_dotenv
class DatabaseError(Exception):
"""Base class for database operation errors"""
pass
class ConnectionError(DatabaseError):
"""Error when connecting to MongoDB Atlas"""
pass
class OperationError(DatabaseError):
"""Error during database operations"""
pass
class DatabaseUtils:
"""Utility class for MongoDB Atlas database operations
This class provides methods to interact with MongoDB Atlas databases and collections,
including listing databases, collections, and retrieving collection information.
Attributes:
atlas_uri (str): MongoDB Atlas connection string
client (MongoClient): MongoDB client instance
"""
def __init__(self):
"""Initialize DatabaseUtils with MongoDB Atlas connection
Raises:
ConnectionError: If unable to connect to MongoDB Atlas
ValueError: If ATLAS_URI environment variable is not set
"""
# Load environment variables
load_dotenv()
self.atlas_uri = os.getenv("ATLAS_URI")
if not self.atlas_uri:
raise ValueError("ATLAS_URI environment variable is not set")
try:
self.client = MongoClient(self.atlas_uri)
# Test connection
self.client.admin.command('ping')
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
def get_databases(self) -> List[str]:
"""Get list of all databases in Atlas cluster
Returns:
List[str]: List of database names
Raises:
OperationError: If unable to list databases
"""
try:
return self.client.list_database_names()
except OperationFailure as e:
raise OperationError(f"Failed to list databases: {str(e)}")
def get_collections(self, db_name: str) -> List[str]:
"""Get list of collections in a database
Args:
db_name (str): Name of the database
Returns:
List[str]: List of collection names
Raises:
OperationError: If unable to list collections
ValueError: If db_name is empty or invalid
"""
if not db_name or not isinstance(db_name, str):
raise ValueError("Database name must be a non-empty string")
try:
db = self.client[db_name]
return db.list_collection_names()
except (OperationFailure, InvalidName) as e:
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
"""Get information about a collection including document count and sample document
Args:
db_name (str): Name of the database
collection_name (str): Name of the collection
Returns:
Dict[str, Any]: Dictionary containing collection information:
- count: Number of documents in collection
- sample: Sample document from collection (if exists)
Raises:
OperationError: If unable to get collection information
ValueError: If db_name or collection_name is empty or invalid
"""
if not db_name or not isinstance(db_name, str):
raise ValueError("Database name must be a non-empty string")
if not collection_name or not isinstance(collection_name, str):
raise ValueError("Collection name must be a non-empty string")
try:
db = self.client[db_name]
collection = db[collection_name]
return {
'count': collection.count_documents({}),
'sample': collection.find_one()
}
except (OperationFailure, InvalidName) as e:
raise OperationError(
f"Failed to get info for collection '{collection_name}' "
f"in database '{db_name}': {str(e)}"
)
def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
"""Get list of fields in a collection based on sample document
Args:
db_name (str): Name of the database
collection_name (str): Name of the collection
Returns:
List[str]: List of field names (excluding _id and embedding fields)
Raises:
OperationError: If unable to get field names
ValueError: If db_name or collection_name is empty or invalid
"""
try:
info = self.get_collection_info(db_name, collection_name)
sample = info.get('sample', {})
if sample:
# Get all field names except _id and any existing embedding fields
return [field for field in sample.keys()
if field != '_id' and not field.endswith('_embedding')]
return []
except DatabaseError as e:
raise OperationError(
f"Failed to get field names for collection '{collection_name}' "
f"in database '{db_name}': {str(e)}"
)
def close(self):
"""Close MongoDB connection safely"""
if hasattr(self, 'client'):
self.client.close()
|