File size: 5,806 Bytes
46a6768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from pymongo.errors import (
    ConnectionFailure, 
    OperationFailure, 
    ServerSelectionTimeoutError,
    InvalidName
)
from dotenv import load_dotenv

class DatabaseError(Exception):
    """Base class for database operation errors"""
    pass

class ConnectionError(DatabaseError):
    """Error when connecting to MongoDB Atlas"""
    pass

class OperationError(DatabaseError):
    """Error during database operations"""
    pass

class DatabaseUtils:
    """Utility class for MongoDB Atlas database operations
    
    This class provides methods to interact with MongoDB Atlas databases and collections,
    including listing databases, collections, and retrieving collection information.
    
    Attributes:
        atlas_uri (str): MongoDB Atlas connection string
        client (MongoClient): MongoDB client instance
    """
    
    def __init__(self):
        """Initialize DatabaseUtils with MongoDB Atlas connection
        
        Raises:
            ConnectionError: If unable to connect to MongoDB Atlas
            ValueError: If ATLAS_URI environment variable is not set
        """
        # Load environment variables
        load_dotenv()
        
        self.atlas_uri = os.getenv("ATLAS_URI")
        if not self.atlas_uri:
            raise ValueError("ATLAS_URI environment variable is not set")
            
        try:
            self.client = MongoClient(self.atlas_uri)
            # Test connection
            self.client.admin.command('ping')
        except (ConnectionFailure, ServerSelectionTimeoutError) as e:
            raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
    
    def get_databases(self) -> List[str]:
        """Get list of all databases in Atlas cluster
        
        Returns:
            List[str]: List of database names
            
        Raises:
            OperationError: If unable to list databases
        """
        try:
            return self.client.list_database_names()
        except OperationFailure as e:
            raise OperationError(f"Failed to list databases: {str(e)}")
    
    def get_collections(self, db_name: str) -> List[str]:
        """Get list of collections in a database
        
        Args:
            db_name (str): Name of the database
            
        Returns:
            List[str]: List of collection names
            
        Raises:
            OperationError: If unable to list collections
            ValueError: If db_name is empty or invalid
        """
        if not db_name or not isinstance(db_name, str):
            raise ValueError("Database name must be a non-empty string")
            
        try:
            db = self.client[db_name]
            return db.list_collection_names()
        except (OperationFailure, InvalidName) as e:
            raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
    
    def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
        """Get information about a collection including document count and sample document
        
        Args:
            db_name (str): Name of the database
            collection_name (str): Name of the collection
            
        Returns:
            Dict[str, Any]: Dictionary containing collection information:
                - count: Number of documents in collection
                - sample: Sample document from collection (if exists)
                
        Raises:
            OperationError: If unable to get collection information
            ValueError: If db_name or collection_name is empty or invalid
        """
        if not db_name or not isinstance(db_name, str):
            raise ValueError("Database name must be a non-empty string")
        if not collection_name or not isinstance(collection_name, str):
            raise ValueError("Collection name must be a non-empty string")
            
        try:
            db = self.client[db_name]
            collection = db[collection_name]
            
            return {
                'count': collection.count_documents({}),
                'sample': collection.find_one()
            }
        except (OperationFailure, InvalidName) as e:
            raise OperationError(
                f"Failed to get info for collection '{collection_name}' "
                f"in database '{db_name}': {str(e)}"
            )
    
    def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
        """Get list of fields in a collection based on sample document
        
        Args:
            db_name (str): Name of the database
            collection_name (str): Name of the collection
            
        Returns:
            List[str]: List of field names (excluding _id and embedding fields)
            
        Raises:
            OperationError: If unable to get field names
            ValueError: If db_name or collection_name is empty or invalid
        """
        try:
            info = self.get_collection_info(db_name, collection_name)
            sample = info.get('sample', {})
            
            if sample:
                # Get all field names except _id and any existing embedding fields
                return [field for field in sample.keys() 
                        if field != '_id' and not field.endswith('_embedding')]
            return []
        except DatabaseError as e:
            raise OperationError(
                f"Failed to get field names for collection '{collection_name}' "
                f"in database '{db_name}': {str(e)}"
            )
    
    def close(self):
        """Close MongoDB connection safely"""
        if hasattr(self, 'client'):
            self.client.close()