Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- Changelog.md +2 -0
- README.md +96 -12
- app.py +276 -0
- db_utils.py +159 -0
- embedding_utils.py +122 -0
- list_collections.py +38 -0
- list_db.py +27 -0
- requirements.txt +57 -0
- setup_data.py +94 -0
Changelog.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
- 2025-01-26:
|
2 |
+
- 2025-01-26:
|
README.md
CHANGED
@@ -1,12 +1,96 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Vector Search Demo App
|
2 |
+
|
3 |
+
This is a Gradio web application that demonstrates vector search capabilities using MongoDB Atlas and OpenAI embeddings.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
1. MongoDB Atlas account with vector search enabled
|
8 |
+
2. OpenAI API key
|
9 |
+
3. Python 3.8+
|
10 |
+
4. Sample movie data loaded in MongoDB Atlas (sample_mflix database)
|
11 |
+
|
12 |
+
## Setup
|
13 |
+
|
14 |
+
1. Clone this repository
|
15 |
+
|
16 |
+
2. Install dependencies:
|
17 |
+
```bash
|
18 |
+
pip install -r requirements.txt
|
19 |
+
```
|
20 |
+
|
21 |
+
3. Set up environment variables:
|
22 |
+
```bash
|
23 |
+
export OPENAI_API_KEY="your-openai-api-key"
|
24 |
+
export ATLAS_URI="your-mongodb-atlas-connection-string"
|
25 |
+
```
|
26 |
+
|
27 |
+
4. Ensure your MongoDB Atlas setup:
|
28 |
+
- Database name: sample_mflix
|
29 |
+
- Collection: embedded_movies
|
30 |
+
- Vector search index: idx_plot_embedding
|
31 |
+
- Index configuration:
|
32 |
+
```json
|
33 |
+
{
|
34 |
+
"fields": [
|
35 |
+
{
|
36 |
+
"type": "vector",
|
37 |
+
"path": "plot_embedding",
|
38 |
+
"numDimensions": 1536,
|
39 |
+
"similarity": "dotProduct"
|
40 |
+
}
|
41 |
+
]
|
42 |
+
}
|
43 |
+
```
|
44 |
+
|
45 |
+
## Running the App
|
46 |
+
|
47 |
+
Start the application:
|
48 |
+
```bash
|
49 |
+
python app.py
|
50 |
+
```
|
51 |
+
|
52 |
+
The app will be available at http://localhost:7860
|
53 |
+
|
54 |
+
## Usage
|
55 |
+
|
56 |
+
### Generating Embeddings
|
57 |
+
1. Select your database and collection from the dropdowns
|
58 |
+
2. Choose the field to generate embeddings for
|
59 |
+
3. Specify the embedding field name (defaults to "embedding")
|
60 |
+
4. Set a document limit (0 for all documents)
|
61 |
+
5. Click "Generate Embeddings" to start processing
|
62 |
+
|
63 |
+
The app uses memory-efficient cursor-based batch processing that can handle large collections:
|
64 |
+
- Documents are processed in batches (default 20 documents per batch)
|
65 |
+
- Memory usage is optimized through cursor-based iteration
|
66 |
+
- Real-time progress tracking shows completed/total documents
|
67 |
+
- Supports processing of large collections (100,000+ documents)
|
68 |
+
- Automatically resumes from where it left off if embeddings already exist
|
69 |
+
|
70 |
+
### Searching
|
71 |
+
1. Enter a natural language query in the text box (e.g., "humans fighting aliens")
|
72 |
+
2. Click "Submit" to search
|
73 |
+
3. View the results showing matching documents with their similarity scores
|
74 |
+
|
75 |
+
## Example Queries
|
76 |
+
|
77 |
+
- "humans fighting aliens"
|
78 |
+
- "relationship drama between two good friends"
|
79 |
+
- "comedy about family vacation"
|
80 |
+
- "detective solving mysterious murder"
|
81 |
+
|
82 |
+
## Performance Notes
|
83 |
+
|
84 |
+
The application is optimized for handling large datasets:
|
85 |
+
- Uses cursor-based batch processing to avoid memory issues
|
86 |
+
- Processes documents in configurable batch sizes (default: 20)
|
87 |
+
- Implements parallel processing with ThreadPoolExecutor
|
88 |
+
- Provides real-time progress tracking
|
89 |
+
- Automatically handles memory cleanup during processing
|
90 |
+
- Supports resuming interrupted operations
|
91 |
+
|
92 |
+
## Notes
|
93 |
+
|
94 |
+
- The search uses OpenAI's text-embedding-ada-002 model to create embeddings
|
95 |
+
- Results are limited to top 5 matches
|
96 |
+
- Similarity scores range from 0 to 1, with higher scores indicating better matches
|
app.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from openai import OpenAI
|
4 |
+
import json
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from db_utils import DatabaseUtils
|
7 |
+
from embedding_utils import parallel_generate_embeddings, get_embedding
|
8 |
+
|
9 |
+
# Load environment variables from .env file
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Initialize OpenAI client
|
13 |
+
openai_client = OpenAI()
|
14 |
+
|
15 |
+
# Initialize database utils
|
16 |
+
db_utils = DatabaseUtils()
|
17 |
+
|
18 |
+
def get_field_names(db_name: str, collection_name: str) -> list[str]:
|
19 |
+
"""Get list of fields in the collection"""
|
20 |
+
return db_utils.get_field_names(db_name, collection_name)
|
21 |
+
|
22 |
+
def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]:
|
23 |
+
"""Generate embeddings for documents in parallel with progress tracking"""
|
24 |
+
try:
|
25 |
+
db = db_utils.client[db_name]
|
26 |
+
collection = db[collection_name]
|
27 |
+
|
28 |
+
# Count documents that need embeddings
|
29 |
+
total_docs = collection.count_documents({field_name: {"$exists": True}})
|
30 |
+
if total_docs == 0:
|
31 |
+
return f"No documents found with field '{field_name}'", ""
|
32 |
+
|
33 |
+
# Get total count of documents that need processing
|
34 |
+
query = {
|
35 |
+
field_name: {"$exists": True},
|
36 |
+
embedding_field: {"$exists": False} # Only get docs without embeddings
|
37 |
+
}
|
38 |
+
total_to_process = collection.count_documents(query)
|
39 |
+
if total_to_process == 0:
|
40 |
+
return "No documents found that need embeddings", ""
|
41 |
+
|
42 |
+
# Apply limit if specified
|
43 |
+
if limit > 0:
|
44 |
+
total_to_process = min(total_to_process, limit)
|
45 |
+
|
46 |
+
print(f"\nFound {total_to_process} documents that need embeddings...")
|
47 |
+
|
48 |
+
# Progress tracking
|
49 |
+
progress_text = ""
|
50 |
+
def update_progress(prog: float, processed: int, total: int):
|
51 |
+
nonlocal progress_text
|
52 |
+
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
|
53 |
+
print(progress_text) # Terminal logging
|
54 |
+
progress(prog/100, f"Processed {processed}/{total} documents")
|
55 |
+
|
56 |
+
# Show initial progress
|
57 |
+
update_progress(0, 0, total_to_process)
|
58 |
+
|
59 |
+
# Create cursor for batch processing
|
60 |
+
cursor = collection.find(query)
|
61 |
+
if limit > 0:
|
62 |
+
cursor = cursor.limit(limit)
|
63 |
+
|
64 |
+
# Generate embeddings in parallel with cursor-based batching
|
65 |
+
processed = parallel_generate_embeddings(
|
66 |
+
collection=collection,
|
67 |
+
cursor=cursor,
|
68 |
+
field_name=field_name,
|
69 |
+
embedding_field=embedding_field,
|
70 |
+
openai_client=openai_client,
|
71 |
+
total_docs=total_to_process,
|
72 |
+
callback=update_progress
|
73 |
+
)
|
74 |
+
|
75 |
+
# Return completion message and final progress
|
76 |
+
instructions = f"""
|
77 |
+
Successfully generated embeddings for {processed} documents using parallel processing!
|
78 |
+
|
79 |
+
To create the vector search index in MongoDB Atlas:
|
80 |
+
1. Go to your Atlas cluster
|
81 |
+
2. Click on 'Search' tab
|
82 |
+
3. Create an index named 'vector_index' with this configuration:
|
83 |
+
{{
|
84 |
+
"fields": [
|
85 |
+
{{
|
86 |
+
"type": "vector",
|
87 |
+
"path": "{embedding_field}",
|
88 |
+
"numDimensions": 1536,
|
89 |
+
"similarity": "dotProduct"
|
90 |
+
}}
|
91 |
+
]
|
92 |
+
}}
|
93 |
+
|
94 |
+
You can now use the search tab with:
|
95 |
+
- Field to search: {field_name}
|
96 |
+
- Embedding field: {embedding_field}
|
97 |
+
"""
|
98 |
+
return instructions, progress_text
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
return f"Error: {str(e)}", ""
|
102 |
+
|
103 |
+
def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str:
|
104 |
+
"""Perform vector search using embeddings"""
|
105 |
+
try:
|
106 |
+
print(f"\nProcessing query: {query_text}")
|
107 |
+
|
108 |
+
db = db_utils.client[db_name]
|
109 |
+
collection = db[collection_name]
|
110 |
+
|
111 |
+
# Get embeddings for query
|
112 |
+
embedding = get_embedding(query_text, openai_client)
|
113 |
+
print("Generated embeddings successfully")
|
114 |
+
|
115 |
+
results = collection.aggregate([
|
116 |
+
{
|
117 |
+
'$vectorSearch': {
|
118 |
+
"index": index_name,
|
119 |
+
"path": embedding_field,
|
120 |
+
"queryVector": embedding,
|
121 |
+
"numCandidates": 50,
|
122 |
+
"limit": 5
|
123 |
+
}
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"$project": {
|
127 |
+
"search_score": { "$meta": "vectorSearchScore" },
|
128 |
+
"document": "$$ROOT"
|
129 |
+
}
|
130 |
+
}
|
131 |
+
])
|
132 |
+
|
133 |
+
# Format results
|
134 |
+
results_list = list(results)
|
135 |
+
formatted_results = []
|
136 |
+
|
137 |
+
for idx, result in enumerate(results_list, 1):
|
138 |
+
doc = result['document']
|
139 |
+
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
|
140 |
+
# Add all fields except _id and embeddings
|
141 |
+
for key, value in doc.items():
|
142 |
+
if key not in ['_id', embedding_field]:
|
143 |
+
formatted_result += f"{key}: {value}\n"
|
144 |
+
formatted_results.append(formatted_result)
|
145 |
+
|
146 |
+
return "\n".join(formatted_results) if formatted_results else "No results found"
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
return f"Error: {str(e)}"
|
150 |
+
|
151 |
+
# Create Gradio interface with tabs
|
152 |
+
with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
|
153 |
+
gr.Markdown("# MongoDB Vector Search Tool")
|
154 |
+
|
155 |
+
# Get available databases
|
156 |
+
databases = db_utils.get_databases()
|
157 |
+
|
158 |
+
with gr.Tab("Generate Embeddings"):
|
159 |
+
with gr.Row():
|
160 |
+
db_input = gr.Dropdown(
|
161 |
+
choices=databases,
|
162 |
+
label="Select Database",
|
163 |
+
info="Available databases in Atlas cluster"
|
164 |
+
)
|
165 |
+
collection_input = gr.Dropdown(
|
166 |
+
choices=[],
|
167 |
+
label="Select Collection",
|
168 |
+
info="Collections in selected database"
|
169 |
+
)
|
170 |
+
with gr.Row():
|
171 |
+
field_input = gr.Dropdown(
|
172 |
+
choices=[],
|
173 |
+
label="Select Field for Embeddings",
|
174 |
+
info="Fields available in collection"
|
175 |
+
)
|
176 |
+
embedding_field_input = gr.Textbox(
|
177 |
+
label="Embedding Field Name",
|
178 |
+
value="embedding",
|
179 |
+
info="Field name where embeddings will be stored"
|
180 |
+
)
|
181 |
+
limit_input = gr.Number(
|
182 |
+
label="Document Limit",
|
183 |
+
value=10,
|
184 |
+
minimum=0,
|
185 |
+
info="Number of documents to process (0 for all documents)"
|
186 |
+
)
|
187 |
+
|
188 |
+
def update_collections(db_name):
|
189 |
+
collections = db_utils.get_collections(db_name)
|
190 |
+
# If there's only one collection, select it by default
|
191 |
+
value = collections[0] if len(collections) == 1 else None
|
192 |
+
return gr.Dropdown(choices=collections, value=value)
|
193 |
+
|
194 |
+
def update_fields(db_name, collection_name):
|
195 |
+
if db_name and collection_name:
|
196 |
+
fields = get_field_names(db_name, collection_name)
|
197 |
+
return gr.Dropdown(choices=fields)
|
198 |
+
return gr.Dropdown(choices=[])
|
199 |
+
|
200 |
+
# Update collections when database changes
|
201 |
+
db_input.change(
|
202 |
+
fn=update_collections,
|
203 |
+
inputs=[db_input],
|
204 |
+
outputs=[collection_input]
|
205 |
+
)
|
206 |
+
|
207 |
+
# Update fields when collection changes
|
208 |
+
collection_input.change(
|
209 |
+
fn=update_fields,
|
210 |
+
inputs=[db_input, collection_input],
|
211 |
+
outputs=[field_input]
|
212 |
+
)
|
213 |
+
|
214 |
+
generate_btn = gr.Button("Generate Embeddings")
|
215 |
+
generate_output = gr.Textbox(label="Results", lines=10)
|
216 |
+
progress_output = gr.Textbox(label="Progress", lines=3)
|
217 |
+
|
218 |
+
generate_btn.click(
|
219 |
+
generate_embeddings_for_field,
|
220 |
+
inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input],
|
221 |
+
outputs=[generate_output, progress_output]
|
222 |
+
)
|
223 |
+
|
224 |
+
with gr.Tab("Search"):
|
225 |
+
with gr.Row():
|
226 |
+
search_db_input = gr.Dropdown(
|
227 |
+
choices=databases,
|
228 |
+
label="Select Database",
|
229 |
+
info="Database containing the vectors"
|
230 |
+
)
|
231 |
+
search_collection_input = gr.Dropdown(
|
232 |
+
choices=[],
|
233 |
+
label="Select Collection",
|
234 |
+
info="Collection containing the vectors"
|
235 |
+
)
|
236 |
+
with gr.Row():
|
237 |
+
search_embedding_field_input = gr.Textbox(
|
238 |
+
label="Embedding Field Name",
|
239 |
+
value="embedding",
|
240 |
+
info="Field containing the vectors"
|
241 |
+
)
|
242 |
+
search_index_input = gr.Textbox(
|
243 |
+
label="Vector Search Index Name",
|
244 |
+
value="vector_index",
|
245 |
+
info="Index created in Atlas UI"
|
246 |
+
)
|
247 |
+
|
248 |
+
# Update collections when database changes
|
249 |
+
search_db_input.change(
|
250 |
+
fn=update_collections,
|
251 |
+
inputs=[search_db_input],
|
252 |
+
outputs=[search_collection_input]
|
253 |
+
)
|
254 |
+
|
255 |
+
query_input = gr.Textbox(
|
256 |
+
label="Search Query",
|
257 |
+
lines=2,
|
258 |
+
placeholder="What would you like to search for?"
|
259 |
+
)
|
260 |
+
search_btn = gr.Button("Search")
|
261 |
+
search_output = gr.Textbox(label="Results", lines=10)
|
262 |
+
|
263 |
+
search_btn.click(
|
264 |
+
vector_search,
|
265 |
+
inputs=[
|
266 |
+
query_input,
|
267 |
+
search_db_input,
|
268 |
+
search_collection_input,
|
269 |
+
search_embedding_field_input,
|
270 |
+
search_index_input
|
271 |
+
],
|
272 |
+
outputs=search_output
|
273 |
+
)
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
db_utils.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
from pymongo import MongoClient
|
4 |
+
from pymongo.errors import (
|
5 |
+
ConnectionFailure,
|
6 |
+
OperationFailure,
|
7 |
+
ServerSelectionTimeoutError,
|
8 |
+
InvalidName
|
9 |
+
)
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
class DatabaseError(Exception):
|
13 |
+
"""Base class for database operation errors"""
|
14 |
+
pass
|
15 |
+
|
16 |
+
class ConnectionError(DatabaseError):
|
17 |
+
"""Error when connecting to MongoDB Atlas"""
|
18 |
+
pass
|
19 |
+
|
20 |
+
class OperationError(DatabaseError):
|
21 |
+
"""Error during database operations"""
|
22 |
+
pass
|
23 |
+
|
24 |
+
class DatabaseUtils:
|
25 |
+
"""Utility class for MongoDB Atlas database operations
|
26 |
+
|
27 |
+
This class provides methods to interact with MongoDB Atlas databases and collections,
|
28 |
+
including listing databases, collections, and retrieving collection information.
|
29 |
+
|
30 |
+
Attributes:
|
31 |
+
atlas_uri (str): MongoDB Atlas connection string
|
32 |
+
client (MongoClient): MongoDB client instance
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
"""Initialize DatabaseUtils with MongoDB Atlas connection
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
ConnectionError: If unable to connect to MongoDB Atlas
|
40 |
+
ValueError: If ATLAS_URI environment variable is not set
|
41 |
+
"""
|
42 |
+
# Load environment variables
|
43 |
+
load_dotenv()
|
44 |
+
|
45 |
+
self.atlas_uri = os.getenv("ATLAS_URI")
|
46 |
+
if not self.atlas_uri:
|
47 |
+
raise ValueError("ATLAS_URI environment variable is not set")
|
48 |
+
|
49 |
+
try:
|
50 |
+
self.client = MongoClient(self.atlas_uri)
|
51 |
+
# Test connection
|
52 |
+
self.client.admin.command('ping')
|
53 |
+
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
54 |
+
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
|
55 |
+
|
56 |
+
def get_databases(self) -> List[str]:
|
57 |
+
"""Get list of all databases in Atlas cluster
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
List[str]: List of database names
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
OperationError: If unable to list databases
|
64 |
+
"""
|
65 |
+
try:
|
66 |
+
return self.client.list_database_names()
|
67 |
+
except OperationFailure as e:
|
68 |
+
raise OperationError(f"Failed to list databases: {str(e)}")
|
69 |
+
|
70 |
+
def get_collections(self, db_name: str) -> List[str]:
|
71 |
+
"""Get list of collections in a database
|
72 |
+
|
73 |
+
Args:
|
74 |
+
db_name (str): Name of the database
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
List[str]: List of collection names
|
78 |
+
|
79 |
+
Raises:
|
80 |
+
OperationError: If unable to list collections
|
81 |
+
ValueError: If db_name is empty or invalid
|
82 |
+
"""
|
83 |
+
if not db_name or not isinstance(db_name, str):
|
84 |
+
raise ValueError("Database name must be a non-empty string")
|
85 |
+
|
86 |
+
try:
|
87 |
+
db = self.client[db_name]
|
88 |
+
return db.list_collection_names()
|
89 |
+
except (OperationFailure, InvalidName) as e:
|
90 |
+
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
|
91 |
+
|
92 |
+
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
|
93 |
+
"""Get information about a collection including document count and sample document
|
94 |
+
|
95 |
+
Args:
|
96 |
+
db_name (str): Name of the database
|
97 |
+
collection_name (str): Name of the collection
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
Dict[str, Any]: Dictionary containing collection information:
|
101 |
+
- count: Number of documents in collection
|
102 |
+
- sample: Sample document from collection (if exists)
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
OperationError: If unable to get collection information
|
106 |
+
ValueError: If db_name or collection_name is empty or invalid
|
107 |
+
"""
|
108 |
+
if not db_name or not isinstance(db_name, str):
|
109 |
+
raise ValueError("Database name must be a non-empty string")
|
110 |
+
if not collection_name or not isinstance(collection_name, str):
|
111 |
+
raise ValueError("Collection name must be a non-empty string")
|
112 |
+
|
113 |
+
try:
|
114 |
+
db = self.client[db_name]
|
115 |
+
collection = db[collection_name]
|
116 |
+
|
117 |
+
return {
|
118 |
+
'count': collection.count_documents({}),
|
119 |
+
'sample': collection.find_one()
|
120 |
+
}
|
121 |
+
except (OperationFailure, InvalidName) as e:
|
122 |
+
raise OperationError(
|
123 |
+
f"Failed to get info for collection '{collection_name}' "
|
124 |
+
f"in database '{db_name}': {str(e)}"
|
125 |
+
)
|
126 |
+
|
127 |
+
def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
|
128 |
+
"""Get list of fields in a collection based on sample document
|
129 |
+
|
130 |
+
Args:
|
131 |
+
db_name (str): Name of the database
|
132 |
+
collection_name (str): Name of the collection
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
List[str]: List of field names (excluding _id and embedding fields)
|
136 |
+
|
137 |
+
Raises:
|
138 |
+
OperationError: If unable to get field names
|
139 |
+
ValueError: If db_name or collection_name is empty or invalid
|
140 |
+
"""
|
141 |
+
try:
|
142 |
+
info = self.get_collection_info(db_name, collection_name)
|
143 |
+
sample = info.get('sample', {})
|
144 |
+
|
145 |
+
if sample:
|
146 |
+
# Get all field names except _id and any existing embedding fields
|
147 |
+
return [field for field in sample.keys()
|
148 |
+
if field != '_id' and not field.endswith('_embedding')]
|
149 |
+
return []
|
150 |
+
except DatabaseError as e:
|
151 |
+
raise OperationError(
|
152 |
+
f"Failed to get field names for collection '{collection_name}' "
|
153 |
+
f"in database '{db_name}': {str(e)}"
|
154 |
+
)
|
155 |
+
|
156 |
+
def close(self):
|
157 |
+
"""Close MongoDB connection safely"""
|
158 |
+
if hasattr(self, 'client'):
|
159 |
+
self.client.close()
|
embedding_utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from pymongo import UpdateOne
|
4 |
+
from pymongo.collection import Collection
|
5 |
+
import math
|
6 |
+
|
7 |
+
def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]:
|
8 |
+
"""Get embeddings for given text using OpenAI API"""
|
9 |
+
text = text.replace("\n", " ")
|
10 |
+
resp = openai_client.embeddings.create(
|
11 |
+
input=[text],
|
12 |
+
model=model
|
13 |
+
)
|
14 |
+
return resp.data[0].embedding
|
15 |
+
|
16 |
+
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
|
17 |
+
"""Process a batch of documents to generate embeddings"""
|
18 |
+
results = []
|
19 |
+
for doc in docs:
|
20 |
+
# Skip if embedding already exists
|
21 |
+
if embedding_field in doc:
|
22 |
+
continue
|
23 |
+
|
24 |
+
text = doc[field_name]
|
25 |
+
if isinstance(text, str):
|
26 |
+
embedding = get_embedding(text, openai_client)
|
27 |
+
results.append((doc["_id"], embedding))
|
28 |
+
return results
|
29 |
+
|
30 |
+
def parallel_generate_embeddings(
|
31 |
+
collection: Collection,
|
32 |
+
cursor,
|
33 |
+
field_name: str,
|
34 |
+
embedding_field: str,
|
35 |
+
openai_client,
|
36 |
+
total_docs: int,
|
37 |
+
batch_size: int = 20,
|
38 |
+
callback=None
|
39 |
+
) -> int:
|
40 |
+
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching
|
41 |
+
|
42 |
+
Args:
|
43 |
+
collection: MongoDB collection
|
44 |
+
cursor: MongoDB cursor for document iteration
|
45 |
+
field_name: Field containing text to embed
|
46 |
+
embedding_field: Field to store embeddings
|
47 |
+
openai_client: OpenAI client instance
|
48 |
+
total_docs: Total number of documents to process
|
49 |
+
batch_size: Size of batches for parallel processing
|
50 |
+
callback: Optional callback function for progress updates
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Number of documents processed
|
54 |
+
"""
|
55 |
+
if total_docs == 0:
|
56 |
+
return 0
|
57 |
+
|
58 |
+
processed = 0
|
59 |
+
|
60 |
+
# Initial progress update
|
61 |
+
if callback:
|
62 |
+
callback(0, 0, total_docs)
|
63 |
+
|
64 |
+
# Process documents in batches using cursor
|
65 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
66 |
+
batch = []
|
67 |
+
futures = []
|
68 |
+
|
69 |
+
# Iterate through cursor and build batches
|
70 |
+
for doc in cursor:
|
71 |
+
batch.append(doc)
|
72 |
+
|
73 |
+
if len(batch) >= batch_size:
|
74 |
+
# Submit batch for processing
|
75 |
+
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
|
76 |
+
futures.append(future)
|
77 |
+
batch = [] # Clear batch for next round
|
78 |
+
|
79 |
+
# Process completed futures to free up memory
|
80 |
+
completed_futures = [f for f in futures if f.done()]
|
81 |
+
for future in completed_futures:
|
82 |
+
results = future.result()
|
83 |
+
if results:
|
84 |
+
# Batch update MongoDB
|
85 |
+
bulk_ops = [
|
86 |
+
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
|
87 |
+
for doc_id, embedding in results
|
88 |
+
]
|
89 |
+
if bulk_ops:
|
90 |
+
collection.bulk_write(bulk_ops)
|
91 |
+
processed += len(bulk_ops)
|
92 |
+
|
93 |
+
# Update progress
|
94 |
+
if callback:
|
95 |
+
progress = (processed / total_docs) * 100
|
96 |
+
callback(progress, processed, total_docs)
|
97 |
+
|
98 |
+
futures = [f for f in futures if not f.done()]
|
99 |
+
|
100 |
+
# Process any remaining documents in the last batch
|
101 |
+
if batch:
|
102 |
+
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
|
103 |
+
futures.append(future)
|
104 |
+
|
105 |
+
# Wait for remaining futures to complete
|
106 |
+
for future in futures:
|
107 |
+
results = future.result()
|
108 |
+
if results:
|
109 |
+
bulk_ops = [
|
110 |
+
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
|
111 |
+
for doc_id, embedding in results
|
112 |
+
]
|
113 |
+
if bulk_ops:
|
114 |
+
collection.bulk_write(bulk_ops)
|
115 |
+
processed += len(bulk_ops)
|
116 |
+
|
117 |
+
# Final progress update
|
118 |
+
if callback:
|
119 |
+
progress = (processed / total_docs) * 100
|
120 |
+
callback(progress, processed, total_docs)
|
121 |
+
|
122 |
+
return processed
|
list_collections.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pymongo import MongoClient
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
# Load environment variables
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
# Initialize MongoDB client
|
9 |
+
atlas_uri = os.getenv("ATLAS_URI")
|
10 |
+
client = MongoClient(atlas_uri)
|
11 |
+
|
12 |
+
def list_all_collections():
|
13 |
+
"""List all databases and their collections in the Atlas cluster"""
|
14 |
+
try:
|
15 |
+
# Get all database names
|
16 |
+
db_names = client.list_database_names()
|
17 |
+
|
18 |
+
print("\nDatabases and Collections in your Atlas cluster:\n")
|
19 |
+
|
20 |
+
# For each database, get and print collections
|
21 |
+
for db_name in db_names:
|
22 |
+
print(f"Database: {db_name}")
|
23 |
+
db = client[db_name]
|
24 |
+
collections = db.list_collection_names()
|
25 |
+
|
26 |
+
for collection in collections:
|
27 |
+
# Get count of documents in collection
|
28 |
+
count = db[collection].count_documents({})
|
29 |
+
print(f" └── Collection: {collection} ({count} documents)")
|
30 |
+
print()
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error: {str(e)}")
|
34 |
+
finally:
|
35 |
+
client.close()
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
list_all_collections()
|
list_db.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from db_utils import DatabaseUtils
|
2 |
+
|
3 |
+
def main():
|
4 |
+
db_utils = DatabaseUtils()
|
5 |
+
try:
|
6 |
+
print("\nDatabases and Collections in your Atlas cluster:\n")
|
7 |
+
|
8 |
+
# Get all databases
|
9 |
+
databases = db_utils.get_databases()
|
10 |
+
|
11 |
+
# For each database, show collections and counts
|
12 |
+
for db_name in databases:
|
13 |
+
print(f"Database: {db_name}")
|
14 |
+
collections = db_utils.get_collections(db_name)
|
15 |
+
|
16 |
+
for coll_name in collections:
|
17 |
+
info = db_utils.get_collection_info(db_name, coll_name)
|
18 |
+
print(f" └── Collection: {coll_name} ({info['count']} documents)")
|
19 |
+
print()
|
20 |
+
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error: {str(e)}")
|
23 |
+
finally:
|
24 |
+
db_utils.close()
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.8.0
|
4 |
+
certifi==2024.12.14
|
5 |
+
charset-normalizer==3.4.1
|
6 |
+
click==8.1.8
|
7 |
+
distro==1.9.0
|
8 |
+
dnspython==2.7.0
|
9 |
+
fastapi==0.115.7
|
10 |
+
ffmpy==0.5.0
|
11 |
+
filelock==3.17.0
|
12 |
+
fsspec==2024.12.0
|
13 |
+
gradio==5.13.1
|
14 |
+
gradio_client==1.6.0
|
15 |
+
h11==0.14.0
|
16 |
+
httpcore==1.0.7
|
17 |
+
httpx==0.28.1
|
18 |
+
huggingface-hub==0.27.1
|
19 |
+
idna==3.10
|
20 |
+
Jinja2==3.1.5
|
21 |
+
jiter==0.8.2
|
22 |
+
markdown-it-py==3.0.0
|
23 |
+
MarkupSafe==2.1.5
|
24 |
+
mdurl==0.1.2
|
25 |
+
numpy==2.2.2
|
26 |
+
openai==1.60.1
|
27 |
+
orjson==3.10.15
|
28 |
+
packaging==24.2
|
29 |
+
pandas==2.2.3
|
30 |
+
pillow==11.1.0
|
31 |
+
pydantic==2.10.6
|
32 |
+
pydantic_core==2.27.2
|
33 |
+
pydub==0.25.1
|
34 |
+
Pygments==2.19.1
|
35 |
+
pymongo==4.10.1
|
36 |
+
python-dateutil==2.9.0.post0
|
37 |
+
python-dotenv==1.0.1
|
38 |
+
python-multipart==0.0.20
|
39 |
+
pytz==2024.2
|
40 |
+
PyYAML==6.0.2
|
41 |
+
requests==2.32.3
|
42 |
+
rich==13.9.4
|
43 |
+
ruff==0.9.3
|
44 |
+
safehttpx==0.1.6
|
45 |
+
semantic-version==2.10.0
|
46 |
+
shellingham==1.5.4
|
47 |
+
six==1.17.0
|
48 |
+
sniffio==1.3.1
|
49 |
+
starlette==0.45.3
|
50 |
+
tomlkit==0.13.2
|
51 |
+
tqdm==4.67.1
|
52 |
+
typer==0.15.1
|
53 |
+
typing_extensions==4.12.2
|
54 |
+
tzdata==2025.1
|
55 |
+
urllib3==2.3.0
|
56 |
+
uvicorn==0.34.0
|
57 |
+
websockets==14.2
|
setup_data.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pymongo import MongoClient
|
3 |
+
from openai import OpenAI
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
# Load environment variables
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
# Initialize clients
|
10 |
+
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
11 |
+
atlas_uri = os.getenv("ATLAS_URI")
|
12 |
+
client = MongoClient(atlas_uri)
|
13 |
+
db = client['sample_mflix']
|
14 |
+
collection = db['embedded_movies']
|
15 |
+
|
16 |
+
# Sample movie data
|
17 |
+
sample_movies = [
|
18 |
+
{
|
19 |
+
"title": "The Matrix",
|
20 |
+
"year": 1999,
|
21 |
+
"plot": "A computer programmer discovers that reality as he knows it is a simulation created by machines, and joins a rebellion to overthrow them."
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"title": "Inception",
|
25 |
+
"year": 2010,
|
26 |
+
"plot": "A thief who enters the dreams of others to steal secrets from their subconscious is offered a chance to regain his old life in exchange for a task considered impossible: inception."
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"title": "The Shawshank Redemption",
|
30 |
+
"year": 1994,
|
31 |
+
"plot": "Two imprisoned men bond over a number of years, finding solace and eventual redemption through acts of common decency."
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"title": "Jurassic Park",
|
35 |
+
"year": 1993,
|
36 |
+
"plot": "A pragmatic paleontologist visiting an almost complete theme park is tasked with protecting a couple of kids after a power failure causes the park's cloned dinosaurs to run loose."
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"title": "The Lord of the Rings: The Fellowship of the Ring",
|
40 |
+
"year": 2001,
|
41 |
+
"plot": "A young hobbit, Frodo, who has found the One Ring that belongs to the Dark Lord Sauron, begins his journey with eight companions to Mount Doom, the only place where it can be destroyed."
|
42 |
+
}
|
43 |
+
]
|
44 |
+
|
45 |
+
def get_embedding(text: str, model="text-embedding-ada-002") -> list[float]:
|
46 |
+
"""Get embeddings for given text using OpenAI API"""
|
47 |
+
text = text.replace("\n", " ")
|
48 |
+
resp = openai_client.embeddings.create(
|
49 |
+
input=[text],
|
50 |
+
model=model
|
51 |
+
)
|
52 |
+
return resp.data[0].embedding
|
53 |
+
|
54 |
+
def setup_data():
|
55 |
+
try:
|
56 |
+
# Drop existing collection if it exists
|
57 |
+
collection.drop()
|
58 |
+
print("Dropped existing collection")
|
59 |
+
|
60 |
+
# Add embeddings to movies and insert into collection
|
61 |
+
for movie in sample_movies:
|
62 |
+
# Generate embedding for plot
|
63 |
+
embedding = get_embedding(movie["plot"])
|
64 |
+
movie["plot_embedding"] = embedding
|
65 |
+
|
66 |
+
# Insert movie with embedding
|
67 |
+
collection.insert_one(movie)
|
68 |
+
print(f"Inserted movie: {movie['title']}")
|
69 |
+
|
70 |
+
print("\nData setup completed successfully!")
|
71 |
+
print("\nIMPORTANT: You need to create the vector search index manually in the Atlas UI:")
|
72 |
+
print("1. Go to your Atlas cluster")
|
73 |
+
print("2. Click on 'Search' tab")
|
74 |
+
print("3. Create an index named 'idx_plot_embedding' with this definition:")
|
75 |
+
print("""
|
76 |
+
{
|
77 |
+
"fields": [
|
78 |
+
{
|
79 |
+
"type": "vector",
|
80 |
+
"path": "plot_embedding",
|
81 |
+
"numDimensions": 1536,
|
82 |
+
"similarity": "dotProduct"
|
83 |
+
}
|
84 |
+
]
|
85 |
+
}
|
86 |
+
""")
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error during setup: {str(e)}")
|
90 |
+
finally:
|
91 |
+
client.close()
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
setup_data()
|