airabbitX commited on
Commit
8fb6e2f
·
verified ·
1 Parent(s): c7e373c

Upload 9 files

Browse files
Files changed (9) hide show
  1. Changelog.md +2 -0
  2. README.md +96 -12
  3. app.py +276 -0
  4. db_utils.py +159 -0
  5. embedding_utils.py +122 -0
  6. list_collections.py +38 -0
  7. list_db.py +27 -0
  8. requirements.txt +57 -0
  9. setup_data.py +94 -0
Changelog.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ - 2025-01-26:
2
+ - 2025-01-26:
README.md CHANGED
@@ -1,12 +1,96 @@
1
- ---
2
- title: Mongo Vector Search Util
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.13.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vector Search Demo App
2
+
3
+ This is a Gradio web application that demonstrates vector search capabilities using MongoDB Atlas and OpenAI embeddings.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. MongoDB Atlas account with vector search enabled
8
+ 2. OpenAI API key
9
+ 3. Python 3.8+
10
+ 4. Sample movie data loaded in MongoDB Atlas (sample_mflix database)
11
+
12
+ ## Setup
13
+
14
+ 1. Clone this repository
15
+
16
+ 2. Install dependencies:
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ 3. Set up environment variables:
22
+ ```bash
23
+ export OPENAI_API_KEY="your-openai-api-key"
24
+ export ATLAS_URI="your-mongodb-atlas-connection-string"
25
+ ```
26
+
27
+ 4. Ensure your MongoDB Atlas setup:
28
+ - Database name: sample_mflix
29
+ - Collection: embedded_movies
30
+ - Vector search index: idx_plot_embedding
31
+ - Index configuration:
32
+ ```json
33
+ {
34
+ "fields": [
35
+ {
36
+ "type": "vector",
37
+ "path": "plot_embedding",
38
+ "numDimensions": 1536,
39
+ "similarity": "dotProduct"
40
+ }
41
+ ]
42
+ }
43
+ ```
44
+
45
+ ## Running the App
46
+
47
+ Start the application:
48
+ ```bash
49
+ python app.py
50
+ ```
51
+
52
+ The app will be available at http://localhost:7860
53
+
54
+ ## Usage
55
+
56
+ ### Generating Embeddings
57
+ 1. Select your database and collection from the dropdowns
58
+ 2. Choose the field to generate embeddings for
59
+ 3. Specify the embedding field name (defaults to "embedding")
60
+ 4. Set a document limit (0 for all documents)
61
+ 5. Click "Generate Embeddings" to start processing
62
+
63
+ The app uses memory-efficient cursor-based batch processing that can handle large collections:
64
+ - Documents are processed in batches (default 20 documents per batch)
65
+ - Memory usage is optimized through cursor-based iteration
66
+ - Real-time progress tracking shows completed/total documents
67
+ - Supports processing of large collections (100,000+ documents)
68
+ - Automatically resumes from where it left off if embeddings already exist
69
+
70
+ ### Searching
71
+ 1. Enter a natural language query in the text box (e.g., "humans fighting aliens")
72
+ 2. Click "Submit" to search
73
+ 3. View the results showing matching documents with their similarity scores
74
+
75
+ ## Example Queries
76
+
77
+ - "humans fighting aliens"
78
+ - "relationship drama between two good friends"
79
+ - "comedy about family vacation"
80
+ - "detective solving mysterious murder"
81
+
82
+ ## Performance Notes
83
+
84
+ The application is optimized for handling large datasets:
85
+ - Uses cursor-based batch processing to avoid memory issues
86
+ - Processes documents in configurable batch sizes (default: 20)
87
+ - Implements parallel processing with ThreadPoolExecutor
88
+ - Provides real-time progress tracking
89
+ - Automatically handles memory cleanup during processing
90
+ - Supports resuming interrupted operations
91
+
92
+ ## Notes
93
+
94
+ - The search uses OpenAI's text-embedding-ada-002 model to create embeddings
95
+ - Results are limited to top 5 matches
96
+ - Similarity scores range from 0 to 1, with higher scores indicating better matches
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from openai import OpenAI
4
+ import json
5
+ from dotenv import load_dotenv
6
+ from db_utils import DatabaseUtils
7
+ from embedding_utils import parallel_generate_embeddings, get_embedding
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ # Initialize OpenAI client
13
+ openai_client = OpenAI()
14
+
15
+ # Initialize database utils
16
+ db_utils = DatabaseUtils()
17
+
18
+ def get_field_names(db_name: str, collection_name: str) -> list[str]:
19
+ """Get list of fields in the collection"""
20
+ return db_utils.get_field_names(db_name, collection_name)
21
+
22
+ def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]:
23
+ """Generate embeddings for documents in parallel with progress tracking"""
24
+ try:
25
+ db = db_utils.client[db_name]
26
+ collection = db[collection_name]
27
+
28
+ # Count documents that need embeddings
29
+ total_docs = collection.count_documents({field_name: {"$exists": True}})
30
+ if total_docs == 0:
31
+ return f"No documents found with field '{field_name}'", ""
32
+
33
+ # Get total count of documents that need processing
34
+ query = {
35
+ field_name: {"$exists": True},
36
+ embedding_field: {"$exists": False} # Only get docs without embeddings
37
+ }
38
+ total_to_process = collection.count_documents(query)
39
+ if total_to_process == 0:
40
+ return "No documents found that need embeddings", ""
41
+
42
+ # Apply limit if specified
43
+ if limit > 0:
44
+ total_to_process = min(total_to_process, limit)
45
+
46
+ print(f"\nFound {total_to_process} documents that need embeddings...")
47
+
48
+ # Progress tracking
49
+ progress_text = ""
50
+ def update_progress(prog: float, processed: int, total: int):
51
+ nonlocal progress_text
52
+ progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
53
+ print(progress_text) # Terminal logging
54
+ progress(prog/100, f"Processed {processed}/{total} documents")
55
+
56
+ # Show initial progress
57
+ update_progress(0, 0, total_to_process)
58
+
59
+ # Create cursor for batch processing
60
+ cursor = collection.find(query)
61
+ if limit > 0:
62
+ cursor = cursor.limit(limit)
63
+
64
+ # Generate embeddings in parallel with cursor-based batching
65
+ processed = parallel_generate_embeddings(
66
+ collection=collection,
67
+ cursor=cursor,
68
+ field_name=field_name,
69
+ embedding_field=embedding_field,
70
+ openai_client=openai_client,
71
+ total_docs=total_to_process,
72
+ callback=update_progress
73
+ )
74
+
75
+ # Return completion message and final progress
76
+ instructions = f"""
77
+ Successfully generated embeddings for {processed} documents using parallel processing!
78
+
79
+ To create the vector search index in MongoDB Atlas:
80
+ 1. Go to your Atlas cluster
81
+ 2. Click on 'Search' tab
82
+ 3. Create an index named 'vector_index' with this configuration:
83
+ {{
84
+ "fields": [
85
+ {{
86
+ "type": "vector",
87
+ "path": "{embedding_field}",
88
+ "numDimensions": 1536,
89
+ "similarity": "dotProduct"
90
+ }}
91
+ ]
92
+ }}
93
+
94
+ You can now use the search tab with:
95
+ - Field to search: {field_name}
96
+ - Embedding field: {embedding_field}
97
+ """
98
+ return instructions, progress_text
99
+
100
+ except Exception as e:
101
+ return f"Error: {str(e)}", ""
102
+
103
+ def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str:
104
+ """Perform vector search using embeddings"""
105
+ try:
106
+ print(f"\nProcessing query: {query_text}")
107
+
108
+ db = db_utils.client[db_name]
109
+ collection = db[collection_name]
110
+
111
+ # Get embeddings for query
112
+ embedding = get_embedding(query_text, openai_client)
113
+ print("Generated embeddings successfully")
114
+
115
+ results = collection.aggregate([
116
+ {
117
+ '$vectorSearch': {
118
+ "index": index_name,
119
+ "path": embedding_field,
120
+ "queryVector": embedding,
121
+ "numCandidates": 50,
122
+ "limit": 5
123
+ }
124
+ },
125
+ {
126
+ "$project": {
127
+ "search_score": { "$meta": "vectorSearchScore" },
128
+ "document": "$$ROOT"
129
+ }
130
+ }
131
+ ])
132
+
133
+ # Format results
134
+ results_list = list(results)
135
+ formatted_results = []
136
+
137
+ for idx, result in enumerate(results_list, 1):
138
+ doc = result['document']
139
+ formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
140
+ # Add all fields except _id and embeddings
141
+ for key, value in doc.items():
142
+ if key not in ['_id', embedding_field]:
143
+ formatted_result += f"{key}: {value}\n"
144
+ formatted_results.append(formatted_result)
145
+
146
+ return "\n".join(formatted_results) if formatted_results else "No results found"
147
+
148
+ except Exception as e:
149
+ return f"Error: {str(e)}"
150
+
151
+ # Create Gradio interface with tabs
152
+ with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
153
+ gr.Markdown("# MongoDB Vector Search Tool")
154
+
155
+ # Get available databases
156
+ databases = db_utils.get_databases()
157
+
158
+ with gr.Tab("Generate Embeddings"):
159
+ with gr.Row():
160
+ db_input = gr.Dropdown(
161
+ choices=databases,
162
+ label="Select Database",
163
+ info="Available databases in Atlas cluster"
164
+ )
165
+ collection_input = gr.Dropdown(
166
+ choices=[],
167
+ label="Select Collection",
168
+ info="Collections in selected database"
169
+ )
170
+ with gr.Row():
171
+ field_input = gr.Dropdown(
172
+ choices=[],
173
+ label="Select Field for Embeddings",
174
+ info="Fields available in collection"
175
+ )
176
+ embedding_field_input = gr.Textbox(
177
+ label="Embedding Field Name",
178
+ value="embedding",
179
+ info="Field name where embeddings will be stored"
180
+ )
181
+ limit_input = gr.Number(
182
+ label="Document Limit",
183
+ value=10,
184
+ minimum=0,
185
+ info="Number of documents to process (0 for all documents)"
186
+ )
187
+
188
+ def update_collections(db_name):
189
+ collections = db_utils.get_collections(db_name)
190
+ # If there's only one collection, select it by default
191
+ value = collections[0] if len(collections) == 1 else None
192
+ return gr.Dropdown(choices=collections, value=value)
193
+
194
+ def update_fields(db_name, collection_name):
195
+ if db_name and collection_name:
196
+ fields = get_field_names(db_name, collection_name)
197
+ return gr.Dropdown(choices=fields)
198
+ return gr.Dropdown(choices=[])
199
+
200
+ # Update collections when database changes
201
+ db_input.change(
202
+ fn=update_collections,
203
+ inputs=[db_input],
204
+ outputs=[collection_input]
205
+ )
206
+
207
+ # Update fields when collection changes
208
+ collection_input.change(
209
+ fn=update_fields,
210
+ inputs=[db_input, collection_input],
211
+ outputs=[field_input]
212
+ )
213
+
214
+ generate_btn = gr.Button("Generate Embeddings")
215
+ generate_output = gr.Textbox(label="Results", lines=10)
216
+ progress_output = gr.Textbox(label="Progress", lines=3)
217
+
218
+ generate_btn.click(
219
+ generate_embeddings_for_field,
220
+ inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input],
221
+ outputs=[generate_output, progress_output]
222
+ )
223
+
224
+ with gr.Tab("Search"):
225
+ with gr.Row():
226
+ search_db_input = gr.Dropdown(
227
+ choices=databases,
228
+ label="Select Database",
229
+ info="Database containing the vectors"
230
+ )
231
+ search_collection_input = gr.Dropdown(
232
+ choices=[],
233
+ label="Select Collection",
234
+ info="Collection containing the vectors"
235
+ )
236
+ with gr.Row():
237
+ search_embedding_field_input = gr.Textbox(
238
+ label="Embedding Field Name",
239
+ value="embedding",
240
+ info="Field containing the vectors"
241
+ )
242
+ search_index_input = gr.Textbox(
243
+ label="Vector Search Index Name",
244
+ value="vector_index",
245
+ info="Index created in Atlas UI"
246
+ )
247
+
248
+ # Update collections when database changes
249
+ search_db_input.change(
250
+ fn=update_collections,
251
+ inputs=[search_db_input],
252
+ outputs=[search_collection_input]
253
+ )
254
+
255
+ query_input = gr.Textbox(
256
+ label="Search Query",
257
+ lines=2,
258
+ placeholder="What would you like to search for?"
259
+ )
260
+ search_btn = gr.Button("Search")
261
+ search_output = gr.Textbox(label="Results", lines=10)
262
+
263
+ search_btn.click(
264
+ vector_search,
265
+ inputs=[
266
+ query_input,
267
+ search_db_input,
268
+ search_collection_input,
269
+ search_embedding_field_input,
270
+ search_index_input
271
+ ],
272
+ outputs=search_output
273
+ )
274
+
275
+ if __name__ == "__main__":
276
+ iface.launch(server_name="0.0.0.0", server_port=7860)
db_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any, Optional
3
+ from pymongo import MongoClient
4
+ from pymongo.errors import (
5
+ ConnectionFailure,
6
+ OperationFailure,
7
+ ServerSelectionTimeoutError,
8
+ InvalidName
9
+ )
10
+ from dotenv import load_dotenv
11
+
12
+ class DatabaseError(Exception):
13
+ """Base class for database operation errors"""
14
+ pass
15
+
16
+ class ConnectionError(DatabaseError):
17
+ """Error when connecting to MongoDB Atlas"""
18
+ pass
19
+
20
+ class OperationError(DatabaseError):
21
+ """Error during database operations"""
22
+ pass
23
+
24
+ class DatabaseUtils:
25
+ """Utility class for MongoDB Atlas database operations
26
+
27
+ This class provides methods to interact with MongoDB Atlas databases and collections,
28
+ including listing databases, collections, and retrieving collection information.
29
+
30
+ Attributes:
31
+ atlas_uri (str): MongoDB Atlas connection string
32
+ client (MongoClient): MongoDB client instance
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initialize DatabaseUtils with MongoDB Atlas connection
37
+
38
+ Raises:
39
+ ConnectionError: If unable to connect to MongoDB Atlas
40
+ ValueError: If ATLAS_URI environment variable is not set
41
+ """
42
+ # Load environment variables
43
+ load_dotenv()
44
+
45
+ self.atlas_uri = os.getenv("ATLAS_URI")
46
+ if not self.atlas_uri:
47
+ raise ValueError("ATLAS_URI environment variable is not set")
48
+
49
+ try:
50
+ self.client = MongoClient(self.atlas_uri)
51
+ # Test connection
52
+ self.client.admin.command('ping')
53
+ except (ConnectionFailure, ServerSelectionTimeoutError) as e:
54
+ raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
55
+
56
+ def get_databases(self) -> List[str]:
57
+ """Get list of all databases in Atlas cluster
58
+
59
+ Returns:
60
+ List[str]: List of database names
61
+
62
+ Raises:
63
+ OperationError: If unable to list databases
64
+ """
65
+ try:
66
+ return self.client.list_database_names()
67
+ except OperationFailure as e:
68
+ raise OperationError(f"Failed to list databases: {str(e)}")
69
+
70
+ def get_collections(self, db_name: str) -> List[str]:
71
+ """Get list of collections in a database
72
+
73
+ Args:
74
+ db_name (str): Name of the database
75
+
76
+ Returns:
77
+ List[str]: List of collection names
78
+
79
+ Raises:
80
+ OperationError: If unable to list collections
81
+ ValueError: If db_name is empty or invalid
82
+ """
83
+ if not db_name or not isinstance(db_name, str):
84
+ raise ValueError("Database name must be a non-empty string")
85
+
86
+ try:
87
+ db = self.client[db_name]
88
+ return db.list_collection_names()
89
+ except (OperationFailure, InvalidName) as e:
90
+ raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
91
+
92
+ def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
93
+ """Get information about a collection including document count and sample document
94
+
95
+ Args:
96
+ db_name (str): Name of the database
97
+ collection_name (str): Name of the collection
98
+
99
+ Returns:
100
+ Dict[str, Any]: Dictionary containing collection information:
101
+ - count: Number of documents in collection
102
+ - sample: Sample document from collection (if exists)
103
+
104
+ Raises:
105
+ OperationError: If unable to get collection information
106
+ ValueError: If db_name or collection_name is empty or invalid
107
+ """
108
+ if not db_name or not isinstance(db_name, str):
109
+ raise ValueError("Database name must be a non-empty string")
110
+ if not collection_name or not isinstance(collection_name, str):
111
+ raise ValueError("Collection name must be a non-empty string")
112
+
113
+ try:
114
+ db = self.client[db_name]
115
+ collection = db[collection_name]
116
+
117
+ return {
118
+ 'count': collection.count_documents({}),
119
+ 'sample': collection.find_one()
120
+ }
121
+ except (OperationFailure, InvalidName) as e:
122
+ raise OperationError(
123
+ f"Failed to get info for collection '{collection_name}' "
124
+ f"in database '{db_name}': {str(e)}"
125
+ )
126
+
127
+ def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
128
+ """Get list of fields in a collection based on sample document
129
+
130
+ Args:
131
+ db_name (str): Name of the database
132
+ collection_name (str): Name of the collection
133
+
134
+ Returns:
135
+ List[str]: List of field names (excluding _id and embedding fields)
136
+
137
+ Raises:
138
+ OperationError: If unable to get field names
139
+ ValueError: If db_name or collection_name is empty or invalid
140
+ """
141
+ try:
142
+ info = self.get_collection_info(db_name, collection_name)
143
+ sample = info.get('sample', {})
144
+
145
+ if sample:
146
+ # Get all field names except _id and any existing embedding fields
147
+ return [field for field in sample.keys()
148
+ if field != '_id' and not field.endswith('_embedding')]
149
+ return []
150
+ except DatabaseError as e:
151
+ raise OperationError(
152
+ f"Failed to get field names for collection '{collection_name}' "
153
+ f"in database '{db_name}': {str(e)}"
154
+ )
155
+
156
+ def close(self):
157
+ """Close MongoDB connection safely"""
158
+ if hasattr(self, 'client'):
159
+ self.client.close()
embedding_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from pymongo import UpdateOne
4
+ from pymongo.collection import Collection
5
+ import math
6
+
7
+ def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]:
8
+ """Get embeddings for given text using OpenAI API"""
9
+ text = text.replace("\n", " ")
10
+ resp = openai_client.embeddings.create(
11
+ input=[text],
12
+ model=model
13
+ )
14
+ return resp.data[0].embedding
15
+
16
+ def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
17
+ """Process a batch of documents to generate embeddings"""
18
+ results = []
19
+ for doc in docs:
20
+ # Skip if embedding already exists
21
+ if embedding_field in doc:
22
+ continue
23
+
24
+ text = doc[field_name]
25
+ if isinstance(text, str):
26
+ embedding = get_embedding(text, openai_client)
27
+ results.append((doc["_id"], embedding))
28
+ return results
29
+
30
+ def parallel_generate_embeddings(
31
+ collection: Collection,
32
+ cursor,
33
+ field_name: str,
34
+ embedding_field: str,
35
+ openai_client,
36
+ total_docs: int,
37
+ batch_size: int = 20,
38
+ callback=None
39
+ ) -> int:
40
+ """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching
41
+
42
+ Args:
43
+ collection: MongoDB collection
44
+ cursor: MongoDB cursor for document iteration
45
+ field_name: Field containing text to embed
46
+ embedding_field: Field to store embeddings
47
+ openai_client: OpenAI client instance
48
+ total_docs: Total number of documents to process
49
+ batch_size: Size of batches for parallel processing
50
+ callback: Optional callback function for progress updates
51
+
52
+ Returns:
53
+ Number of documents processed
54
+ """
55
+ if total_docs == 0:
56
+ return 0
57
+
58
+ processed = 0
59
+
60
+ # Initial progress update
61
+ if callback:
62
+ callback(0, 0, total_docs)
63
+
64
+ # Process documents in batches using cursor
65
+ with ThreadPoolExecutor(max_workers=20) as executor:
66
+ batch = []
67
+ futures = []
68
+
69
+ # Iterate through cursor and build batches
70
+ for doc in cursor:
71
+ batch.append(doc)
72
+
73
+ if len(batch) >= batch_size:
74
+ # Submit batch for processing
75
+ future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
76
+ futures.append(future)
77
+ batch = [] # Clear batch for next round
78
+
79
+ # Process completed futures to free up memory
80
+ completed_futures = [f for f in futures if f.done()]
81
+ for future in completed_futures:
82
+ results = future.result()
83
+ if results:
84
+ # Batch update MongoDB
85
+ bulk_ops = [
86
+ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
87
+ for doc_id, embedding in results
88
+ ]
89
+ if bulk_ops:
90
+ collection.bulk_write(bulk_ops)
91
+ processed += len(bulk_ops)
92
+
93
+ # Update progress
94
+ if callback:
95
+ progress = (processed / total_docs) * 100
96
+ callback(progress, processed, total_docs)
97
+
98
+ futures = [f for f in futures if not f.done()]
99
+
100
+ # Process any remaining documents in the last batch
101
+ if batch:
102
+ future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
103
+ futures.append(future)
104
+
105
+ # Wait for remaining futures to complete
106
+ for future in futures:
107
+ results = future.result()
108
+ if results:
109
+ bulk_ops = [
110
+ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
111
+ for doc_id, embedding in results
112
+ ]
113
+ if bulk_ops:
114
+ collection.bulk_write(bulk_ops)
115
+ processed += len(bulk_ops)
116
+
117
+ # Final progress update
118
+ if callback:
119
+ progress = (processed / total_docs) * 100
120
+ callback(progress, processed, total_docs)
121
+
122
+ return processed
list_collections.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+ from dotenv import load_dotenv
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+
8
+ # Initialize MongoDB client
9
+ atlas_uri = os.getenv("ATLAS_URI")
10
+ client = MongoClient(atlas_uri)
11
+
12
+ def list_all_collections():
13
+ """List all databases and their collections in the Atlas cluster"""
14
+ try:
15
+ # Get all database names
16
+ db_names = client.list_database_names()
17
+
18
+ print("\nDatabases and Collections in your Atlas cluster:\n")
19
+
20
+ # For each database, get and print collections
21
+ for db_name in db_names:
22
+ print(f"Database: {db_name}")
23
+ db = client[db_name]
24
+ collections = db.list_collection_names()
25
+
26
+ for collection in collections:
27
+ # Get count of documents in collection
28
+ count = db[collection].count_documents({})
29
+ print(f" └── Collection: {collection} ({count} documents)")
30
+ print()
31
+
32
+ except Exception as e:
33
+ print(f"Error: {str(e)}")
34
+ finally:
35
+ client.close()
36
+
37
+ if __name__ == "__main__":
38
+ list_all_collections()
list_db.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from db_utils import DatabaseUtils
2
+
3
+ def main():
4
+ db_utils = DatabaseUtils()
5
+ try:
6
+ print("\nDatabases and Collections in your Atlas cluster:\n")
7
+
8
+ # Get all databases
9
+ databases = db_utils.get_databases()
10
+
11
+ # For each database, show collections and counts
12
+ for db_name in databases:
13
+ print(f"Database: {db_name}")
14
+ collections = db_utils.get_collections(db_name)
15
+
16
+ for coll_name in collections:
17
+ info = db_utils.get_collection_info(db_name, coll_name)
18
+ print(f" └── Collection: {coll_name} ({info['count']} documents)")
19
+ print()
20
+
21
+ except Exception as e:
22
+ print(f"Error: {str(e)}")
23
+ finally:
24
+ db_utils.close()
25
+
26
+ if __name__ == "__main__":
27
+ main()
requirements.txt ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ certifi==2024.12.14
5
+ charset-normalizer==3.4.1
6
+ click==8.1.8
7
+ distro==1.9.0
8
+ dnspython==2.7.0
9
+ fastapi==0.115.7
10
+ ffmpy==0.5.0
11
+ filelock==3.17.0
12
+ fsspec==2024.12.0
13
+ gradio==5.13.1
14
+ gradio_client==1.6.0
15
+ h11==0.14.0
16
+ httpcore==1.0.7
17
+ httpx==0.28.1
18
+ huggingface-hub==0.27.1
19
+ idna==3.10
20
+ Jinja2==3.1.5
21
+ jiter==0.8.2
22
+ markdown-it-py==3.0.0
23
+ MarkupSafe==2.1.5
24
+ mdurl==0.1.2
25
+ numpy==2.2.2
26
+ openai==1.60.1
27
+ orjson==3.10.15
28
+ packaging==24.2
29
+ pandas==2.2.3
30
+ pillow==11.1.0
31
+ pydantic==2.10.6
32
+ pydantic_core==2.27.2
33
+ pydub==0.25.1
34
+ Pygments==2.19.1
35
+ pymongo==4.10.1
36
+ python-dateutil==2.9.0.post0
37
+ python-dotenv==1.0.1
38
+ python-multipart==0.0.20
39
+ pytz==2024.2
40
+ PyYAML==6.0.2
41
+ requests==2.32.3
42
+ rich==13.9.4
43
+ ruff==0.9.3
44
+ safehttpx==0.1.6
45
+ semantic-version==2.10.0
46
+ shellingham==1.5.4
47
+ six==1.17.0
48
+ sniffio==1.3.1
49
+ starlette==0.45.3
50
+ tomlkit==0.13.2
51
+ tqdm==4.67.1
52
+ typer==0.15.1
53
+ typing_extensions==4.12.2
54
+ tzdata==2025.1
55
+ urllib3==2.3.0
56
+ uvicorn==0.34.0
57
+ websockets==14.2
setup_data.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+ from openai import OpenAI
4
+ from dotenv import load_dotenv
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ # Initialize clients
10
+ openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
11
+ atlas_uri = os.getenv("ATLAS_URI")
12
+ client = MongoClient(atlas_uri)
13
+ db = client['sample_mflix']
14
+ collection = db['embedded_movies']
15
+
16
+ # Sample movie data
17
+ sample_movies = [
18
+ {
19
+ "title": "The Matrix",
20
+ "year": 1999,
21
+ "plot": "A computer programmer discovers that reality as he knows it is a simulation created by machines, and joins a rebellion to overthrow them."
22
+ },
23
+ {
24
+ "title": "Inception",
25
+ "year": 2010,
26
+ "plot": "A thief who enters the dreams of others to steal secrets from their subconscious is offered a chance to regain his old life in exchange for a task considered impossible: inception."
27
+ },
28
+ {
29
+ "title": "The Shawshank Redemption",
30
+ "year": 1994,
31
+ "plot": "Two imprisoned men bond over a number of years, finding solace and eventual redemption through acts of common decency."
32
+ },
33
+ {
34
+ "title": "Jurassic Park",
35
+ "year": 1993,
36
+ "plot": "A pragmatic paleontologist visiting an almost complete theme park is tasked with protecting a couple of kids after a power failure causes the park's cloned dinosaurs to run loose."
37
+ },
38
+ {
39
+ "title": "The Lord of the Rings: The Fellowship of the Ring",
40
+ "year": 2001,
41
+ "plot": "A young hobbit, Frodo, who has found the One Ring that belongs to the Dark Lord Sauron, begins his journey with eight companions to Mount Doom, the only place where it can be destroyed."
42
+ }
43
+ ]
44
+
45
+ def get_embedding(text: str, model="text-embedding-ada-002") -> list[float]:
46
+ """Get embeddings for given text using OpenAI API"""
47
+ text = text.replace("\n", " ")
48
+ resp = openai_client.embeddings.create(
49
+ input=[text],
50
+ model=model
51
+ )
52
+ return resp.data[0].embedding
53
+
54
+ def setup_data():
55
+ try:
56
+ # Drop existing collection if it exists
57
+ collection.drop()
58
+ print("Dropped existing collection")
59
+
60
+ # Add embeddings to movies and insert into collection
61
+ for movie in sample_movies:
62
+ # Generate embedding for plot
63
+ embedding = get_embedding(movie["plot"])
64
+ movie["plot_embedding"] = embedding
65
+
66
+ # Insert movie with embedding
67
+ collection.insert_one(movie)
68
+ print(f"Inserted movie: {movie['title']}")
69
+
70
+ print("\nData setup completed successfully!")
71
+ print("\nIMPORTANT: You need to create the vector search index manually in the Atlas UI:")
72
+ print("1. Go to your Atlas cluster")
73
+ print("2. Click on 'Search' tab")
74
+ print("3. Create an index named 'idx_plot_embedding' with this definition:")
75
+ print("""
76
+ {
77
+ "fields": [
78
+ {
79
+ "type": "vector",
80
+ "path": "plot_embedding",
81
+ "numDimensions": 1536,
82
+ "similarity": "dotProduct"
83
+ }
84
+ ]
85
+ }
86
+ """)
87
+
88
+ except Exception as e:
89
+ print(f"Error during setup: {str(e)}")
90
+ finally:
91
+ client.close()
92
+
93
+ if __name__ == "__main__":
94
+ setup_data()