airabbitX commited on
Commit
124432f
·
verified ·
1 Parent(s): b1156c2
ui/app.py DELETED
@@ -1,60 +0,0 @@
1
- import gradio as gr
2
- from utils.credentials import check_credentials, init_clients
3
- from ui.embeddings_tab import create_embeddings_tab
4
- from ui.search_tab import create_search_tab
5
-
6
- def create_app():
7
- """Create and configure the Gradio application"""
8
- with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
9
- gr.Markdown("# MongoDB Vector Search Tool")
10
-
11
- # Check credentials first
12
- has_creds, cred_message = check_credentials()
13
- if not has_creds:
14
- gr.Markdown(f"""
15
- ## ⚠️ Setup Required
16
-
17
- {cred_message}
18
-
19
- After setting up credentials, refresh this page.
20
- """)
21
- else:
22
- # Initialize clients
23
- openai_client, db_utils = init_clients()
24
- if not openai_client or not db_utils:
25
- gr.Markdown("""
26
- ## ⚠️ Connection Error
27
-
28
- Failed to connect to MongoDB Atlas or OpenAI. Please check your credentials and try again.
29
- """)
30
- else:
31
- # Get available databases
32
- try:
33
- databases = db_utils.get_databases()
34
- except Exception as e:
35
- gr.Markdown(f"""
36
- ## ⚠️ Database Error
37
-
38
- Failed to list databases: {str(e)}
39
- Please check your MongoDB Atlas connection and try again.
40
- """)
41
- databases = []
42
-
43
- # Create tabs
44
- embeddings_tab, embeddings_interface = create_embeddings_tab(
45
- openai_client=openai_client,
46
- db_utils=db_utils,
47
- databases=databases
48
- )
49
-
50
- search_tab, search_interface = create_search_tab(
51
- openai_client=openai_client,
52
- db_utils=db_utils,
53
- databases=databases
54
- )
55
-
56
- return iface
57
-
58
- if __name__ == "__main__":
59
- app = create_app()
60
- app.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/db_utils.py DELETED
@@ -1,159 +0,0 @@
1
- import os
2
- from typing import List, Dict, Any, Optional
3
- from pymongo import MongoClient
4
- from pymongo.errors import (
5
- ConnectionFailure,
6
- OperationFailure,
7
- ServerSelectionTimeoutError,
8
- InvalidName
9
- )
10
- from dotenv import load_dotenv
11
-
12
- class DatabaseError(Exception):
13
- """Base class for database operation errors"""
14
- pass
15
-
16
- class ConnectionError(DatabaseError):
17
- """Error when connecting to MongoDB Atlas"""
18
- pass
19
-
20
- class OperationError(DatabaseError):
21
- """Error during database operations"""
22
- pass
23
-
24
- class DatabaseUtils:
25
- """Utility class for MongoDB Atlas database operations
26
-
27
- This class provides methods to interact with MongoDB Atlas databases and collections,
28
- including listing databases, collections, and retrieving collection information.
29
-
30
- Attributes:
31
- atlas_uri (str): MongoDB Atlas connection string
32
- client (MongoClient): MongoDB client instance
33
- """
34
-
35
- def __init__(self):
36
- """Initialize DatabaseUtils with MongoDB Atlas connection
37
-
38
- Raises:
39
- ConnectionError: If unable to connect to MongoDB Atlas
40
- ValueError: If ATLAS_URI environment variable is not set
41
- """
42
- # Load environment variables
43
- load_dotenv()
44
-
45
- self.atlas_uri = os.getenv("ATLAS_URI")
46
- if not self.atlas_uri:
47
- raise ValueError("ATLAS_URI environment variable is not set")
48
-
49
- try:
50
- self.client = MongoClient(self.atlas_uri)
51
- # Test connection
52
- self.client.admin.command('ping')
53
- except (ConnectionFailure, ServerSelectionTimeoutError) as e:
54
- raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
55
-
56
- def get_databases(self) -> List[str]:
57
- """Get list of all databases in Atlas cluster
58
-
59
- Returns:
60
- List[str]: List of database names
61
-
62
- Raises:
63
- OperationError: If unable to list databases
64
- """
65
- try:
66
- return self.client.list_database_names()
67
- except OperationFailure as e:
68
- raise OperationError(f"Failed to list databases: {str(e)}")
69
-
70
- def get_collections(self, db_name: str) -> List[str]:
71
- """Get list of collections in a database
72
-
73
- Args:
74
- db_name (str): Name of the database
75
-
76
- Returns:
77
- List[str]: List of collection names
78
-
79
- Raises:
80
- OperationError: If unable to list collections
81
- ValueError: If db_name is empty or invalid
82
- """
83
- if not db_name or not isinstance(db_name, str):
84
- raise ValueError("Database name must be a non-empty string")
85
-
86
- try:
87
- db = self.client[db_name]
88
- return db.list_collection_names()
89
- except (OperationFailure, InvalidName) as e:
90
- raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
91
-
92
- def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
93
- """Get information about a collection including document count and sample document
94
-
95
- Args:
96
- db_name (str): Name of the database
97
- collection_name (str): Name of the collection
98
-
99
- Returns:
100
- Dict[str, Any]: Dictionary containing collection information:
101
- - count: Number of documents in collection
102
- - sample: Sample document from collection (if exists)
103
-
104
- Raises:
105
- OperationError: If unable to get collection information
106
- ValueError: If db_name or collection_name is empty or invalid
107
- """
108
- if not db_name or not isinstance(db_name, str):
109
- raise ValueError("Database name must be a non-empty string")
110
- if not collection_name or not isinstance(collection_name, str):
111
- raise ValueError("Collection name must be a non-empty string")
112
-
113
- try:
114
- db = self.client[db_name]
115
- collection = db[collection_name]
116
-
117
- return {
118
- 'count': collection.count_documents({}),
119
- 'sample': collection.find_one()
120
- }
121
- except (OperationFailure, InvalidName) as e:
122
- raise OperationError(
123
- f"Failed to get info for collection '{collection_name}' "
124
- f"in database '{db_name}': {str(e)}"
125
- )
126
-
127
- def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
128
- """Get list of fields in a collection based on sample document
129
-
130
- Args:
131
- db_name (str): Name of the database
132
- collection_name (str): Name of the collection
133
-
134
- Returns:
135
- List[str]: List of field names (excluding _id and embedding fields)
136
-
137
- Raises:
138
- OperationError: If unable to get field names
139
- ValueError: If db_name or collection_name is empty or invalid
140
- """
141
- try:
142
- info = self.get_collection_info(db_name, collection_name)
143
- sample = info.get('sample', {})
144
-
145
- if sample:
146
- # Get all field names except _id and any existing embedding fields
147
- return [field for field in sample.keys()
148
- if field != '_id' and not field.endswith('_embedding')]
149
- return []
150
- except DatabaseError as e:
151
- raise OperationError(
152
- f"Failed to get field names for collection '{collection_name}' "
153
- f"in database '{db_name}': {str(e)}"
154
- )
155
-
156
- def close(self):
157
- """Close MongoDB connection safely"""
158
- if hasattr(self, 'client'):
159
- self.client.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/embedding_utils.py DELETED
@@ -1,143 +0,0 @@
1
- from typing import List, Tuple
2
- from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from pymongo import UpdateOne
4
- from pymongo.collection import Collection
5
- import math
6
- import time
7
- import logging
8
-
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
14
- """Get embeddings for given text using OpenAI API with retry logic"""
15
- text = text.replace("\n", " ")
16
-
17
- for attempt in range(max_retries):
18
- try:
19
- resp = openai_client.embeddings.create(
20
- input=[text],
21
- model=model
22
- )
23
- return resp.data[0].embedding
24
- except Exception as e:
25
- if attempt == max_retries - 1:
26
- raise
27
- error_details = f"{type(e).__name__}: {str(e)}"
28
- if hasattr(e, 'response'):
29
- error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
30
- logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
31
- time.sleep(2 ** attempt) # Exponential backoff
32
-
33
- def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
34
- """Process a batch of documents to generate embeddings"""
35
- logger.info(f"Processing batch of {len(docs)} documents")
36
- results = []
37
- for doc in docs:
38
- # Skip if embedding already exists
39
- if embedding_field in doc:
40
- continue
41
-
42
- text = doc[field_name]
43
- if isinstance(text, str):
44
- embedding = get_embedding(text, openai_client)
45
- results.append((doc["_id"], embedding))
46
- return results
47
-
48
- def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
49
- """Process completed futures and update progress"""
50
- completed = 0
51
- for future in as_completed(futures, timeout=30): # 30 second timeout
52
- try:
53
- results = future.result()
54
- if results:
55
- bulk_ops = [
56
- UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
57
- for doc_id, embedding in results
58
- ]
59
- if bulk_ops:
60
- collection.bulk_write(bulk_ops)
61
- completed += len(bulk_ops)
62
-
63
- # Update progress
64
- if callback:
65
- progress = ((processed + completed) / total_docs) * 100
66
- callback(progress, processed + completed, total_docs)
67
- except Exception as e:
68
- error_details = f"{type(e).__name__}: {str(e)}"
69
- if hasattr(e, 'response'):
70
- error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
71
- logger.error(f"Error processing future:\n{error_details}")
72
- return completed
73
-
74
- def parallel_generate_embeddings(
75
- collection: Collection,
76
- cursor,
77
- field_name: str,
78
- embedding_field: str,
79
- openai_client,
80
- total_docs: int,
81
- batch_size: int = 10, # Reduced initial batch size
82
- callback=None
83
- ) -> int:
84
- """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
85
- if total_docs == 0:
86
- return 0
87
-
88
- processed = 0
89
- current_batch_size = batch_size
90
- max_workers = 10 # Start with fewer workers
91
-
92
- logger.info(f"Starting embedding generation for {total_docs} documents")
93
- if callback:
94
- callback(0, 0, total_docs)
95
-
96
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
97
- batch = []
98
- futures = []
99
-
100
- for doc in cursor:
101
- batch.append(doc)
102
-
103
- if len(batch) >= current_batch_size:
104
- logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
105
- future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
106
- futures.append(future)
107
- batch = []
108
-
109
- # Process completed futures more frequently
110
- if len(futures) >= max_workers:
111
- try:
112
- completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
113
- processed += completed
114
- futures = [] # Clear processed futures
115
-
116
- # Gradually increase batch size and workers if processing is successful
117
- if completed > 0:
118
- current_batch_size = min(current_batch_size + 5, 30)
119
- max_workers = min(max_workers + 2, 20)
120
- logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
121
- except Exception as e:
122
- logger.error(f"Error processing futures: {str(e)}")
123
- # Reduce batch size and workers on error
124
- current_batch_size = max(5, current_batch_size - 5)
125
- max_workers = max(5, max_workers - 2)
126
- logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
127
-
128
- # Process remaining batch
129
- if batch:
130
- logger.info(f"Processing final batch of {len(batch)} documents")
131
- future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
132
- futures.append(future)
133
-
134
- # Process remaining futures
135
- if futures:
136
- try:
137
- completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
138
- processed += completed
139
- except Exception as e:
140
- logger.error(f"Error processing final futures: {str(e)}")
141
-
142
- logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
143
- return processed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/embeddings_tab.py DELETED
@@ -1,192 +0,0 @@
1
- import gradio as gr
2
- from typing import Tuple, Optional, List
3
- from openai import OpenAI
4
- from utils.db_utils import DatabaseUtils
5
- from utils.embedding_utils import parallel_generate_embeddings
6
-
7
- def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
- """Create the embeddings generation tab UI
9
-
10
- Args:
11
- openai_client: OpenAI client instance
12
- db_utils: DatabaseUtils instance
13
- databases: List of available databases
14
-
15
- Returns:
16
- Tuple[gr.Tab, dict]: The tab component and its interface elements
17
- """
18
- def update_collections(db_name: str) -> gr.Dropdown:
19
- """Update collections dropdown when database changes"""
20
- collections = db_utils.get_collections(db_name)
21
- # If there's only one collection, select it by default
22
- value = collections[0] if len(collections) == 1 else None
23
- return gr.Dropdown(choices=collections, value=value)
24
-
25
- def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
26
- """Update fields dropdown when collection changes"""
27
- if db_name and collection_name:
28
- fields = db_utils.get_field_names(db_name, collection_name)
29
- return gr.Dropdown(choices=fields)
30
- return gr.Dropdown(choices=[])
31
-
32
- def generate_embeddings(
33
- db_name: str,
34
- collection_name: str,
35
- field_name: str,
36
- embedding_field: str,
37
- limit: int = 10,
38
- progress=gr.Progress()
39
- ) -> Tuple[str, str]:
40
- """Generate embeddings for documents with progress tracking"""
41
- try:
42
- db = db_utils.client[db_name]
43
- collection = db[collection_name]
44
-
45
- # Count documents that need embeddings
46
- total_docs = collection.count_documents({field_name: {"$exists": True}})
47
- if total_docs == 0:
48
- return f"No documents found with field '{field_name}'", ""
49
-
50
- # Get total count of documents that need processing
51
- query = {
52
- field_name: {"$exists": True},
53
- embedding_field: {"$exists": False} # Only get docs without embeddings
54
- }
55
- total_to_process = collection.count_documents(query)
56
- if total_to_process == 0:
57
- return "No documents found that need embeddings", ""
58
-
59
- # Apply limit if specified
60
- if limit > 0:
61
- total_to_process = min(total_to_process, limit)
62
-
63
- print(f"\nFound {total_to_process} documents that need embeddings...")
64
-
65
- # Progress tracking
66
- progress_text = ""
67
- def update_progress(prog: float, processed: int, total: int):
68
- nonlocal progress_text
69
- progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
70
- print(progress_text) # Terminal logging
71
- progress(prog/100, f"Processed {processed}/{total} documents")
72
-
73
- # Show initial progress
74
- update_progress(0, 0, total_to_process)
75
-
76
- # Create cursor for batch processing
77
- cursor = collection.find(query)
78
- if limit > 0:
79
- cursor = cursor.limit(limit)
80
-
81
- # Generate embeddings in parallel with cursor-based batching
82
- processed = parallel_generate_embeddings(
83
- collection=collection,
84
- cursor=cursor,
85
- field_name=field_name,
86
- embedding_field=embedding_field,
87
- openai_client=openai_client,
88
- total_docs=total_to_process,
89
- callback=update_progress
90
- )
91
-
92
- # Return completion message and final progress
93
- instructions = f"""
94
- Successfully generated embeddings for {processed} documents using parallel processing!
95
-
96
- To create the vector search index in MongoDB Atlas:
97
- 1. Go to your Atlas cluster
98
- 2. Click on 'Search' tab
99
- 3. Create an index named 'vector_index' with this configuration:
100
- {{
101
- "fields": [
102
- {{
103
- "type": "vector",
104
- "path": "{embedding_field}",
105
- "numDimensions": 1536,
106
- "similarity": "dotProduct"
107
- }}
108
- ]
109
- }}
110
-
111
- You can now use the search tab with:
112
- - Field to search: {field_name}
113
- - Embedding field: {embedding_field}
114
- """
115
- return instructions, progress_text
116
-
117
- except Exception as e:
118
- return f"Error: {str(e)}", ""
119
-
120
- # Create the tab UI
121
- with gr.Tab("Generate Embeddings") as tab:
122
- with gr.Row():
123
- db_input = gr.Dropdown(
124
- choices=databases,
125
- label="Select Database",
126
- info="Available databases in Atlas cluster"
127
- )
128
- collection_input = gr.Dropdown(
129
- choices=[],
130
- label="Select Collection",
131
- info="Collections in selected database"
132
- )
133
- with gr.Row():
134
- field_input = gr.Dropdown(
135
- choices=[],
136
- label="Select Field for Embeddings",
137
- info="Fields available in collection"
138
- )
139
- embedding_field_input = gr.Textbox(
140
- label="Embedding Field Name",
141
- value="embedding",
142
- info="Field name where embeddings will be stored"
143
- )
144
- limit_input = gr.Number(
145
- label="Document Limit",
146
- value=10,
147
- minimum=0,
148
- info="Number of documents to process (0 for all documents)"
149
- )
150
-
151
- generate_btn = gr.Button("Generate Embeddings")
152
- generate_output = gr.Textbox(label="Results", lines=10)
153
- progress_output = gr.Textbox(label="Progress", lines=3)
154
-
155
- # Set up event handlers
156
- db_input.change(
157
- fn=update_collections,
158
- inputs=[db_input],
159
- outputs=[collection_input]
160
- )
161
-
162
- collection_input.change(
163
- fn=update_fields,
164
- inputs=[db_input, collection_input],
165
- outputs=[field_input]
166
- )
167
-
168
- generate_btn.click(
169
- fn=generate_embeddings,
170
- inputs=[
171
- db_input,
172
- collection_input,
173
- field_input,
174
- embedding_field_input,
175
- limit_input
176
- ],
177
- outputs=[generate_output, progress_output]
178
- )
179
-
180
- # Return the tab and its interface elements
181
- interface = {
182
- 'db_input': db_input,
183
- 'collection_input': collection_input,
184
- 'field_input': field_input,
185
- 'embedding_field_input': embedding_field_input,
186
- 'limit_input': limit_input,
187
- 'generate_btn': generate_btn,
188
- 'generate_output': generate_output,
189
- 'progress_output': progress_output
190
- }
191
-
192
- return tab, interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/list_collections.py DELETED
@@ -1,38 +0,0 @@
1
- import os
2
- from pymongo import MongoClient
3
- from dotenv import load_dotenv
4
-
5
- # Load environment variables
6
- load_dotenv()
7
-
8
- # Initialize MongoDB client
9
- atlas_uri = os.getenv("ATLAS_URI")
10
- client = MongoClient(atlas_uri)
11
-
12
- def list_all_collections():
13
- """List all databases and their collections in the Atlas cluster"""
14
- try:
15
- # Get all database names
16
- db_names = client.list_database_names()
17
-
18
- print("\nDatabases and Collections in your Atlas cluster:\n")
19
-
20
- # For each database, get and print collections
21
- for db_name in db_names:
22
- print(f"Database: {db_name}")
23
- db = client[db_name]
24
- collections = db.list_collection_names()
25
-
26
- for collection in collections:
27
- # Get count of documents in collection
28
- count = db[collection].count_documents({})
29
- print(f" └── Collection: {collection} ({count} documents)")
30
- print()
31
-
32
- except Exception as e:
33
- print(f"Error: {str(e)}")
34
- finally:
35
- client.close()
36
-
37
- if __name__ == "__main__":
38
- list_all_collections()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/list_db.py DELETED
@@ -1,27 +0,0 @@
1
- from db_utils import DatabaseUtils
2
-
3
- def main():
4
- db_utils = DatabaseUtils()
5
- try:
6
- print("\nDatabases and Collections in your Atlas cluster:\n")
7
-
8
- # Get all databases
9
- databases = db_utils.get_databases()
10
-
11
- # For each database, show collections and counts
12
- for db_name in databases:
13
- print(f"Database: {db_name}")
14
- collections = db_utils.get_collections(db_name)
15
-
16
- for coll_name in collections:
17
- info = db_utils.get_collection_info(db_name, coll_name)
18
- print(f" └── Collection: {coll_name} ({info['count']} documents)")
19
- print()
20
-
21
- except Exception as e:
22
- print(f"Error: {str(e)}")
23
- finally:
24
- db_utils.close()
25
-
26
- if __name__ == "__main__":
27
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/run.sh DELETED
@@ -1 +0,0 @@
1
- python app.py
 
 
ui/search_tab.py DELETED
@@ -1,142 +0,0 @@
1
- import gradio as gr
2
- from typing import Tuple, List
3
- from openai import OpenAI
4
- from utils.db_utils import DatabaseUtils
5
- from utils.embedding_utils import get_embedding
6
-
7
- def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
- """Create the vector search tab UI
9
-
10
- Args:
11
- openai_client: OpenAI client instance
12
- db_utils: DatabaseUtils instance
13
- databases: List of available databases
14
-
15
- Returns:
16
- Tuple[gr.Tab, dict]: The tab component and its interface elements
17
- """
18
- def update_collections(db_name: str) -> gr.Dropdown:
19
- """Update collections dropdown when database changes"""
20
- collections = db_utils.get_collections(db_name)
21
- # If there's only one collection, select it by default
22
- value = collections[0] if len(collections) == 1 else None
23
- return gr.Dropdown(choices=collections, value=value)
24
-
25
- def vector_search(
26
- query_text: str,
27
- db_name: str,
28
- collection_name: str,
29
- embedding_field: str,
30
- index_name: str
31
- ) -> str:
32
- """Perform vector search using embeddings"""
33
- try:
34
- print(f"\nProcessing query: {query_text}")
35
-
36
- db = db_utils.client[db_name]
37
- collection = db[collection_name]
38
-
39
- # Get embeddings for query
40
- embedding = get_embedding(query_text, openai_client)
41
- print("Generated embeddings successfully")
42
-
43
- results = collection.aggregate([
44
- {
45
- '$vectorSearch': {
46
- "index": index_name,
47
- "path": embedding_field,
48
- "queryVector": embedding,
49
- "numCandidates": 50,
50
- "limit": 5
51
- }
52
- },
53
- {
54
- "$project": {
55
- "search_score": { "$meta": "vectorSearchScore" },
56
- "document": "$$ROOT"
57
- }
58
- }
59
- ])
60
-
61
- # Format results
62
- results_list = list(results)
63
- formatted_results = []
64
-
65
- for idx, result in enumerate(results_list, 1):
66
- doc = result['document']
67
- formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
68
- # Add all fields except _id and embeddings
69
- for key, value in doc.items():
70
- if key not in ['_id', embedding_field]:
71
- formatted_result += f"{key}: {value}\n"
72
- formatted_results.append(formatted_result)
73
-
74
- return "\n".join(formatted_results) if formatted_results else "No results found"
75
-
76
- except Exception as e:
77
- return f"Error: {str(e)}"
78
-
79
- # Create the tab UI
80
- with gr.Tab("Search") as tab:
81
- with gr.Row():
82
- db_input = gr.Dropdown(
83
- choices=databases,
84
- label="Select Database",
85
- info="Database containing the vectors"
86
- )
87
- collection_input = gr.Dropdown(
88
- choices=[],
89
- label="Select Collection",
90
- info="Collection containing the vectors"
91
- )
92
- with gr.Row():
93
- embedding_field_input = gr.Textbox(
94
- label="Embedding Field Name",
95
- value="embedding",
96
- info="Field containing the vectors"
97
- )
98
- index_input = gr.Textbox(
99
- label="Vector Search Index Name",
100
- value="vector_index",
101
- info="Index created in Atlas UI"
102
- )
103
-
104
- query_input = gr.Textbox(
105
- label="Search Query",
106
- lines=2,
107
- placeholder="What would you like to search for?"
108
- )
109
- search_btn = gr.Button("Search")
110
- search_output = gr.Textbox(label="Results", lines=10)
111
-
112
- # Set up event handlers
113
- db_input.change(
114
- fn=update_collections,
115
- inputs=[db_input],
116
- outputs=[collection_input]
117
- )
118
-
119
- search_btn.click(
120
- fn=vector_search,
121
- inputs=[
122
- query_input,
123
- db_input,
124
- collection_input,
125
- embedding_field_input,
126
- index_input
127
- ],
128
- outputs=search_output
129
- )
130
-
131
- # Return the tab and its interface elements
132
- interface = {
133
- 'db_input': db_input,
134
- 'collection_input': collection_input,
135
- 'embedding_field_input': embedding_field_input,
136
- 'index_input': index_input,
137
- 'query_input': query_input,
138
- 'search_btn': search_btn,
139
- 'search_output': search_output
140
- }
141
-
142
- return tab, interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/setup_data.py DELETED
@@ -1,94 +0,0 @@
1
- import os
2
- from pymongo import MongoClient
3
- from openai import OpenAI
4
- from dotenv import load_dotenv
5
-
6
- # Load environment variables
7
- load_dotenv()
8
-
9
- # Initialize clients
10
- openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
11
- atlas_uri = os.getenv("ATLAS_URI")
12
- client = MongoClient(atlas_uri)
13
- db = client['sample_mflix']
14
- collection = db['embedded_movies']
15
-
16
- # Sample movie data
17
- sample_movies = [
18
- {
19
- "title": "The Matrix",
20
- "year": 1999,
21
- "plot": "A computer programmer discovers that reality as he knows it is a simulation created by machines, and joins a rebellion to overthrow them."
22
- },
23
- {
24
- "title": "Inception",
25
- "year": 2010,
26
- "plot": "A thief who enters the dreams of others to steal secrets from their subconscious is offered a chance to regain his old life in exchange for a task considered impossible: inception."
27
- },
28
- {
29
- "title": "The Shawshank Redemption",
30
- "year": 1994,
31
- "plot": "Two imprisoned men bond over a number of years, finding solace and eventual redemption through acts of common decency."
32
- },
33
- {
34
- "title": "Jurassic Park",
35
- "year": 1993,
36
- "plot": "A pragmatic paleontologist visiting an almost complete theme park is tasked with protecting a couple of kids after a power failure causes the park's cloned dinosaurs to run loose."
37
- },
38
- {
39
- "title": "The Lord of the Rings: The Fellowship of the Ring",
40
- "year": 2001,
41
- "plot": "A young hobbit, Frodo, who has found the One Ring that belongs to the Dark Lord Sauron, begins his journey with eight companions to Mount Doom, the only place where it can be destroyed."
42
- }
43
- ]
44
-
45
- def get_embedding(text: str, model="text-embedding-ada-002") -> list[float]:
46
- """Get embeddings for given text using OpenAI API"""
47
- text = text.replace("\n", " ")
48
- resp = openai_client.embeddings.create(
49
- input=[text],
50
- model=model
51
- )
52
- return resp.data[0].embedding
53
-
54
- def setup_data():
55
- try:
56
- # Drop existing collection if it exists
57
- collection.drop()
58
- print("Dropped existing collection")
59
-
60
- # Add embeddings to movies and insert into collection
61
- for movie in sample_movies:
62
- # Generate embedding for plot
63
- embedding = get_embedding(movie["plot"])
64
- movie["plot_embedding"] = embedding
65
-
66
- # Insert movie with embedding
67
- collection.insert_one(movie)
68
- print(f"Inserted movie: {movie['title']}")
69
-
70
- print("\nData setup completed successfully!")
71
- print("\nIMPORTANT: You need to create the vector search index manually in the Atlas UI:")
72
- print("1. Go to your Atlas cluster")
73
- print("2. Click on 'Search' tab")
74
- print("3. Create an index named 'idx_plot_embedding' with this definition:")
75
- print("""
76
- {
77
- "fields": [
78
- {
79
- "type": "vector",
80
- "path": "plot_embedding",
81
- "numDimensions": 1536,
82
- "similarity": "dotProduct"
83
- }
84
- ]
85
- }
86
- """)
87
-
88
- except Exception as e:
89
- print(f"Error during setup: {str(e)}")
90
- finally:
91
- client.close()
92
-
93
- if __name__ == "__main__":
94
- setup_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/ui/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # UI package for MongoDB Vector Search Tool
2
- from ui.embeddings_tab import create_embeddings_tab
3
- from ui.search_tab import create_search_tab
4
-
5
- __all__ = [
6
- 'create_embeddings_tab',
7
- 'create_search_tab'
8
- ]
 
 
 
 
 
 
 
 
 
ui/ui/__pycache__/embeddings_tab.cpython-312.pyc DELETED
Binary file (6.98 kB)
 
ui/ui/__pycache__/search_tab.cpython-312.pyc DELETED
Binary file (5.06 kB)
 
ui/ui/embeddings_tab.py DELETED
@@ -1,192 +0,0 @@
1
- import gradio as gr
2
- from typing import Tuple, Optional, List
3
- from openai import OpenAI
4
- from utils.db_utils import DatabaseUtils
5
- from utils.embedding_utils import parallel_generate_embeddings
6
-
7
- def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
- """Create the embeddings generation tab UI
9
-
10
- Args:
11
- openai_client: OpenAI client instance
12
- db_utils: DatabaseUtils instance
13
- databases: List of available databases
14
-
15
- Returns:
16
- Tuple[gr.Tab, dict]: The tab component and its interface elements
17
- """
18
- def update_collections(db_name: str) -> gr.Dropdown:
19
- """Update collections dropdown when database changes"""
20
- collections = db_utils.get_collections(db_name)
21
- # If there's only one collection, select it by default
22
- value = collections[0] if len(collections) == 1 else None
23
- return gr.Dropdown(choices=collections, value=value)
24
-
25
- def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
26
- """Update fields dropdown when collection changes"""
27
- if db_name and collection_name:
28
- fields = db_utils.get_field_names(db_name, collection_name)
29
- return gr.Dropdown(choices=fields)
30
- return gr.Dropdown(choices=[])
31
-
32
- def generate_embeddings(
33
- db_name: str,
34
- collection_name: str,
35
- field_name: str,
36
- embedding_field: str,
37
- limit: int = 10,
38
- progress=gr.Progress()
39
- ) -> Tuple[str, str]:
40
- """Generate embeddings for documents with progress tracking"""
41
- try:
42
- db = db_utils.client[db_name]
43
- collection = db[collection_name]
44
-
45
- # Count documents that need embeddings
46
- total_docs = collection.count_documents({field_name: {"$exists": True}})
47
- if total_docs == 0:
48
- return f"No documents found with field '{field_name}'", ""
49
-
50
- # Get total count of documents that need processing
51
- query = {
52
- field_name: {"$exists": True},
53
- embedding_field: {"$exists": False} # Only get docs without embeddings
54
- }
55
- total_to_process = collection.count_documents(query)
56
- if total_to_process == 0:
57
- return "No documents found that need embeddings", ""
58
-
59
- # Apply limit if specified
60
- if limit > 0:
61
- total_to_process = min(total_to_process, limit)
62
-
63
- print(f"\nFound {total_to_process} documents that need embeddings...")
64
-
65
- # Progress tracking
66
- progress_text = ""
67
- def update_progress(prog: float, processed: int, total: int):
68
- nonlocal progress_text
69
- progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
70
- print(progress_text) # Terminal logging
71
- progress(prog/100, f"Processed {processed}/{total} documents")
72
-
73
- # Show initial progress
74
- update_progress(0, 0, total_to_process)
75
-
76
- # Create cursor for batch processing
77
- cursor = collection.find(query)
78
- if limit > 0:
79
- cursor = cursor.limit(limit)
80
-
81
- # Generate embeddings in parallel with cursor-based batching
82
- processed = parallel_generate_embeddings(
83
- collection=collection,
84
- cursor=cursor,
85
- field_name=field_name,
86
- embedding_field=embedding_field,
87
- openai_client=openai_client,
88
- total_docs=total_to_process,
89
- callback=update_progress
90
- )
91
-
92
- # Return completion message and final progress
93
- instructions = f"""
94
- Successfully generated embeddings for {processed} documents using parallel processing!
95
-
96
- To create the vector search index in MongoDB Atlas:
97
- 1. Go to your Atlas cluster
98
- 2. Click on 'Search' tab
99
- 3. Create an index named 'vector_index' with this configuration:
100
- {{
101
- "fields": [
102
- {{
103
- "type": "vector",
104
- "path": "{embedding_field}",
105
- "numDimensions": 1536,
106
- "similarity": "dotProduct"
107
- }}
108
- ]
109
- }}
110
-
111
- You can now use the search tab with:
112
- - Field to search: {field_name}
113
- - Embedding field: {embedding_field}
114
- """
115
- return instructions, progress_text
116
-
117
- except Exception as e:
118
- return f"Error: {str(e)}", ""
119
-
120
- # Create the tab UI
121
- with gr.Tab("Generate Embeddings") as tab:
122
- with gr.Row():
123
- db_input = gr.Dropdown(
124
- choices=databases,
125
- label="Select Database",
126
- info="Available databases in Atlas cluster"
127
- )
128
- collection_input = gr.Dropdown(
129
- choices=[],
130
- label="Select Collection",
131
- info="Collections in selected database"
132
- )
133
- with gr.Row():
134
- field_input = gr.Dropdown(
135
- choices=[],
136
- label="Select Field for Embeddings",
137
- info="Fields available in collection"
138
- )
139
- embedding_field_input = gr.Textbox(
140
- label="Embedding Field Name",
141
- value="embedding",
142
- info="Field name where embeddings will be stored"
143
- )
144
- limit_input = gr.Number(
145
- label="Document Limit",
146
- value=10,
147
- minimum=0,
148
- info="Number of documents to process (0 for all documents)"
149
- )
150
-
151
- generate_btn = gr.Button("Generate Embeddings")
152
- generate_output = gr.Textbox(label="Results", lines=10)
153
- progress_output = gr.Textbox(label="Progress", lines=3)
154
-
155
- # Set up event handlers
156
- db_input.change(
157
- fn=update_collections,
158
- inputs=[db_input],
159
- outputs=[collection_input]
160
- )
161
-
162
- collection_input.change(
163
- fn=update_fields,
164
- inputs=[db_input, collection_input],
165
- outputs=[field_input]
166
- )
167
-
168
- generate_btn.click(
169
- fn=generate_embeddings,
170
- inputs=[
171
- db_input,
172
- collection_input,
173
- field_input,
174
- embedding_field_input,
175
- limit_input
176
- ],
177
- outputs=[generate_output, progress_output]
178
- )
179
-
180
- # Return the tab and its interface elements
181
- interface = {
182
- 'db_input': db_input,
183
- 'collection_input': collection_input,
184
- 'field_input': field_input,
185
- 'embedding_field_input': embedding_field_input,
186
- 'limit_input': limit_input,
187
- 'generate_btn': generate_btn,
188
- 'generate_output': generate_output,
189
- 'progress_output': progress_output
190
- }
191
-
192
- return tab, interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/ui/search_tab.py DELETED
@@ -1,142 +0,0 @@
1
- import gradio as gr
2
- from typing import Tuple, List
3
- from openai import OpenAI
4
- from utils.db_utils import DatabaseUtils
5
- from utils.embedding_utils import get_embedding
6
-
7
- def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
- """Create the vector search tab UI
9
-
10
- Args:
11
- openai_client: OpenAI client instance
12
- db_utils: DatabaseUtils instance
13
- databases: List of available databases
14
-
15
- Returns:
16
- Tuple[gr.Tab, dict]: The tab component and its interface elements
17
- """
18
- def update_collections(db_name: str) -> gr.Dropdown:
19
- """Update collections dropdown when database changes"""
20
- collections = db_utils.get_collections(db_name)
21
- # If there's only one collection, select it by default
22
- value = collections[0] if len(collections) == 1 else None
23
- return gr.Dropdown(choices=collections, value=value)
24
-
25
- def vector_search(
26
- query_text: str,
27
- db_name: str,
28
- collection_name: str,
29
- embedding_field: str,
30
- index_name: str
31
- ) -> str:
32
- """Perform vector search using embeddings"""
33
- try:
34
- print(f"\nProcessing query: {query_text}")
35
-
36
- db = db_utils.client[db_name]
37
- collection = db[collection_name]
38
-
39
- # Get embeddings for query
40
- embedding = get_embedding(query_text, openai_client)
41
- print("Generated embeddings successfully")
42
-
43
- results = collection.aggregate([
44
- {
45
- '$vectorSearch': {
46
- "index": index_name,
47
- "path": embedding_field,
48
- "queryVector": embedding,
49
- "numCandidates": 50,
50
- "limit": 5
51
- }
52
- },
53
- {
54
- "$project": {
55
- "search_score": { "$meta": "vectorSearchScore" },
56
- "document": "$$ROOT"
57
- }
58
- }
59
- ])
60
-
61
- # Format results
62
- results_list = list(results)
63
- formatted_results = []
64
-
65
- for idx, result in enumerate(results_list, 1):
66
- doc = result['document']
67
- formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
68
- # Add all fields except _id and embeddings
69
- for key, value in doc.items():
70
- if key not in ['_id', embedding_field]:
71
- formatted_result += f"{key}: {value}\n"
72
- formatted_results.append(formatted_result)
73
-
74
- return "\n".join(formatted_results) if formatted_results else "No results found"
75
-
76
- except Exception as e:
77
- return f"Error: {str(e)}"
78
-
79
- # Create the tab UI
80
- with gr.Tab("Search") as tab:
81
- with gr.Row():
82
- db_input = gr.Dropdown(
83
- choices=databases,
84
- label="Select Database",
85
- info="Database containing the vectors"
86
- )
87
- collection_input = gr.Dropdown(
88
- choices=[],
89
- label="Select Collection",
90
- info="Collection containing the vectors"
91
- )
92
- with gr.Row():
93
- embedding_field_input = gr.Textbox(
94
- label="Embedding Field Name",
95
- value="embedding",
96
- info="Field containing the vectors"
97
- )
98
- index_input = gr.Textbox(
99
- label="Vector Search Index Name",
100
- value="vector_index",
101
- info="Index created in Atlas UI"
102
- )
103
-
104
- query_input = gr.Textbox(
105
- label="Search Query",
106
- lines=2,
107
- placeholder="What would you like to search for?"
108
- )
109
- search_btn = gr.Button("Search")
110
- search_output = gr.Textbox(label="Results", lines=10)
111
-
112
- # Set up event handlers
113
- db_input.change(
114
- fn=update_collections,
115
- inputs=[db_input],
116
- outputs=[collection_input]
117
- )
118
-
119
- search_btn.click(
120
- fn=vector_search,
121
- inputs=[
122
- query_input,
123
- db_input,
124
- collection_input,
125
- embedding_field_input,
126
- index_input
127
- ],
128
- outputs=search_output
129
- )
130
-
131
- # Return the tab and its interface elements
132
- interface = {
133
- 'db_input': db_input,
134
- 'collection_input': collection_input,
135
- 'embedding_field_input': embedding_field_input,
136
- 'index_input': index_input,
137
- 'query_input': query_input,
138
- 'search_btn': search_btn,
139
- 'search_output': search_output
140
- }
141
-
142
- return tab, interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/utils/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Utils package for MongoDB Vector Search Tool
2
- from utils.credentials import check_credentials, init_clients
3
- from utils.db_utils import DatabaseUtils
4
- from utils.embedding_utils import get_embedding, parallel_generate_embeddings
5
-
6
- __all__ = [
7
- 'check_credentials',
8
- 'init_clients',
9
- 'DatabaseUtils',
10
- 'get_embedding',
11
- 'parallel_generate_embeddings'
12
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/utils/credentials.py DELETED
@@ -1,47 +0,0 @@
1
- import os
2
- from typing import Tuple
3
- from dotenv import load_dotenv
4
- from openai import OpenAI
5
- from utils.db_utils import DatabaseUtils
6
-
7
- def check_credentials() -> Tuple[bool, str]:
8
- """Check if required credentials are set and valid
9
-
10
- Returns:
11
- Tuple[bool, str]: (is_valid, message)
12
- - is_valid: True if all credentials are valid
13
- - message: Error message if credentials are invalid
14
- """
15
- # Load environment variables
16
- load_dotenv()
17
-
18
- atlas_uri = os.getenv("ATLAS_URI")
19
- openai_key = os.getenv("OPENAI_API_KEY")
20
-
21
- if not atlas_uri:
22
- return False, """Please set up your MongoDB Atlas credentials:
23
- 1. Go to Settings tab
24
- 2. Add ATLAS_URI as a Repository Secret
25
- 3. Paste your MongoDB connection string (should start with 'mongodb+srv://')"""
26
-
27
- if not openai_key:
28
- return False, """Please set up your OpenAI API key:
29
- 1. Go to Settings tab
30
- 2. Add OPENAI_API_KEY as a Repository Secret
31
- 3. Paste your OpenAI API key"""
32
-
33
- return True, ""
34
-
35
- def init_clients():
36
- """Initialize OpenAI and MongoDB clients
37
-
38
- Returns:
39
- Tuple[OpenAI, DatabaseUtils]: OpenAI client and DatabaseUtils instance
40
- or (None, None) if initialization fails
41
- """
42
- try:
43
- openai_client = OpenAI()
44
- db_utils = DatabaseUtils()
45
- return openai_client, db_utils
46
- except Exception as e:
47
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/utils/db_utils.py DELETED
@@ -1,159 +0,0 @@
1
- import os
2
- from typing import List, Dict, Any, Optional
3
- from pymongo import MongoClient
4
- from pymongo.errors import (
5
- ConnectionFailure,
6
- OperationFailure,
7
- ServerSelectionTimeoutError,
8
- InvalidName
9
- )
10
- from dotenv import load_dotenv
11
-
12
- class DatabaseError(Exception):
13
- """Base class for database operation errors"""
14
- pass
15
-
16
- class ConnectionError(DatabaseError):
17
- """Error when connecting to MongoDB Atlas"""
18
- pass
19
-
20
- class OperationError(DatabaseError):
21
- """Error during database operations"""
22
- pass
23
-
24
- class DatabaseUtils:
25
- """Utility class for MongoDB Atlas database operations
26
-
27
- This class provides methods to interact with MongoDB Atlas databases and collections,
28
- including listing databases, collections, and retrieving collection information.
29
-
30
- Attributes:
31
- atlas_uri (str): MongoDB Atlas connection string
32
- client (MongoClient): MongoDB client instance
33
- """
34
-
35
- def __init__(self):
36
- """Initialize DatabaseUtils with MongoDB Atlas connection
37
-
38
- Raises:
39
- ConnectionError: If unable to connect to MongoDB Atlas
40
- ValueError: If ATLAS_URI environment variable is not set
41
- """
42
- # Load environment variables
43
- load_dotenv()
44
-
45
- self.atlas_uri = os.getenv("ATLAS_URI")
46
- if not self.atlas_uri:
47
- raise ValueError("ATLAS_URI environment variable is not set")
48
-
49
- try:
50
- self.client = MongoClient(self.atlas_uri)
51
- # Test connection
52
- self.client.admin.command('ping')
53
- except (ConnectionFailure, ServerSelectionTimeoutError) as e:
54
- raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
55
-
56
- def get_databases(self) -> List[str]:
57
- """Get list of all databases in Atlas cluster
58
-
59
- Returns:
60
- List[str]: List of database names
61
-
62
- Raises:
63
- OperationError: If unable to list databases
64
- """
65
- try:
66
- return self.client.list_database_names()
67
- except OperationFailure as e:
68
- raise OperationError(f"Failed to list databases: {str(e)}")
69
-
70
- def get_collections(self, db_name: str) -> List[str]:
71
- """Get list of collections in a database
72
-
73
- Args:
74
- db_name (str): Name of the database
75
-
76
- Returns:
77
- List[str]: List of collection names
78
-
79
- Raises:
80
- OperationError: If unable to list collections
81
- ValueError: If db_name is empty or invalid
82
- """
83
- if not db_name or not isinstance(db_name, str):
84
- raise ValueError("Database name must be a non-empty string")
85
-
86
- try:
87
- db = self.client[db_name]
88
- return db.list_collection_names()
89
- except (OperationFailure, InvalidName) as e:
90
- raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
91
-
92
- def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
93
- """Get information about a collection including document count and sample document
94
-
95
- Args:
96
- db_name (str): Name of the database
97
- collection_name (str): Name of the collection
98
-
99
- Returns:
100
- Dict[str, Any]: Dictionary containing collection information:
101
- - count: Number of documents in collection
102
- - sample: Sample document from collection (if exists)
103
-
104
- Raises:
105
- OperationError: If unable to get collection information
106
- ValueError: If db_name or collection_name is empty or invalid
107
- """
108
- if not db_name or not isinstance(db_name, str):
109
- raise ValueError("Database name must be a non-empty string")
110
- if not collection_name or not isinstance(collection_name, str):
111
- raise ValueError("Collection name must be a non-empty string")
112
-
113
- try:
114
- db = self.client[db_name]
115
- collection = db[collection_name]
116
-
117
- return {
118
- 'count': collection.count_documents({}),
119
- 'sample': collection.find_one()
120
- }
121
- except (OperationFailure, InvalidName) as e:
122
- raise OperationError(
123
- f"Failed to get info for collection '{collection_name}' "
124
- f"in database '{db_name}': {str(e)}"
125
- )
126
-
127
- def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
128
- """Get list of fields in a collection based on sample document
129
-
130
- Args:
131
- db_name (str): Name of the database
132
- collection_name (str): Name of the collection
133
-
134
- Returns:
135
- List[str]: List of field names (excluding _id and embedding fields)
136
-
137
- Raises:
138
- OperationError: If unable to get field names
139
- ValueError: If db_name or collection_name is empty or invalid
140
- """
141
- try:
142
- info = self.get_collection_info(db_name, collection_name)
143
- sample = info.get('sample', {})
144
-
145
- if sample:
146
- # Get all field names except _id and any existing embedding fields
147
- return [field for field in sample.keys()
148
- if field != '_id' and not field.endswith('_embedding')]
149
- return []
150
- except DatabaseError as e:
151
- raise OperationError(
152
- f"Failed to get field names for collection '{collection_name}' "
153
- f"in database '{db_name}': {str(e)}"
154
- )
155
-
156
- def close(self):
157
- """Close MongoDB connection safely"""
158
- if hasattr(self, 'client'):
159
- self.client.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/utils/embedding_utils.py DELETED
@@ -1,143 +0,0 @@
1
- from typing import List, Tuple
2
- from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from pymongo import UpdateOne
4
- from pymongo.collection import Collection
5
- import math
6
- import time
7
- import logging
8
-
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
14
- """Get embeddings for given text using OpenAI API with retry logic"""
15
- text = text.replace("\n", " ")
16
-
17
- for attempt in range(max_retries):
18
- try:
19
- resp = openai_client.embeddings.create(
20
- input=[text],
21
- model=model
22
- )
23
- return resp.data[0].embedding
24
- except Exception as e:
25
- if attempt == max_retries - 1:
26
- raise
27
- error_details = f"{type(e).__name__}: {str(e)}"
28
- if hasattr(e, 'response'):
29
- error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
30
- logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
31
- time.sleep(2 ** attempt) # Exponential backoff
32
-
33
- def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
34
- """Process a batch of documents to generate embeddings"""
35
- logger.info(f"Processing batch of {len(docs)} documents")
36
- results = []
37
- for doc in docs:
38
- # Skip if embedding already exists
39
- if embedding_field in doc:
40
- continue
41
-
42
- text = doc[field_name]
43
- if isinstance(text, str):
44
- embedding = get_embedding(text, openai_client)
45
- results.append((doc["_id"], embedding))
46
- return results
47
-
48
- def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
49
- """Process completed futures and update progress"""
50
- completed = 0
51
- for future in as_completed(futures, timeout=30): # 30 second timeout
52
- try:
53
- results = future.result()
54
- if results:
55
- bulk_ops = [
56
- UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
57
- for doc_id, embedding in results
58
- ]
59
- if bulk_ops:
60
- collection.bulk_write(bulk_ops)
61
- completed += len(bulk_ops)
62
-
63
- # Update progress
64
- if callback:
65
- progress = ((processed + completed) / total_docs) * 100
66
- callback(progress, processed + completed, total_docs)
67
- except Exception as e:
68
- error_details = f"{type(e).__name__}: {str(e)}"
69
- if hasattr(e, 'response'):
70
- error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
71
- logger.error(f"Error processing future:\n{error_details}")
72
- return completed
73
-
74
- def parallel_generate_embeddings(
75
- collection: Collection,
76
- cursor,
77
- field_name: str,
78
- embedding_field: str,
79
- openai_client,
80
- total_docs: int,
81
- batch_size: int = 10, # Reduced initial batch size
82
- callback=None
83
- ) -> int:
84
- """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
85
- if total_docs == 0:
86
- return 0
87
-
88
- processed = 0
89
- current_batch_size = batch_size
90
- max_workers = 10 # Start with fewer workers
91
-
92
- logger.info(f"Starting embedding generation for {total_docs} documents")
93
- if callback:
94
- callback(0, 0, total_docs)
95
-
96
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
97
- batch = []
98
- futures = []
99
-
100
- for doc in cursor:
101
- batch.append(doc)
102
-
103
- if len(batch) >= current_batch_size:
104
- logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
105
- future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
106
- futures.append(future)
107
- batch = []
108
-
109
- # Process completed futures more frequently
110
- if len(futures) >= max_workers:
111
- try:
112
- completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
113
- processed += completed
114
- futures = [] # Clear processed futures
115
-
116
- # Gradually increase batch size and workers if processing is successful
117
- if completed > 0:
118
- current_batch_size = min(current_batch_size + 5, 30)
119
- max_workers = min(max_workers + 2, 20)
120
- logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
121
- except Exception as e:
122
- logger.error(f"Error processing futures: {str(e)}")
123
- # Reduce batch size and workers on error
124
- current_batch_size = max(5, current_batch_size - 5)
125
- max_workers = max(5, max_workers - 2)
126
- logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
127
-
128
- # Process remaining batch
129
- if batch:
130
- logger.info(f"Processing final batch of {len(batch)} documents")
131
- future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
132
- futures.append(future)
133
-
134
- # Process remaining futures
135
- if futures:
136
- try:
137
- completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
138
- processed += completed
139
- except Exception as e:
140
- logger.error(f"Error processing final futures: {str(e)}")
141
-
142
- logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
143
- return processed