airabbitX commited on
Commit
46a6768
·
verified ·
1 Parent(s): 5df473f

Upload 19 files

Browse files
app.py CHANGED
@@ -1,319 +1,60 @@
1
- import os
2
  import gradio as gr
3
- from openai import OpenAI
4
- import json
5
- from dotenv import load_dotenv
6
- from db_utils import DatabaseUtils
7
- from embedding_utils import parallel_generate_embeddings, get_embedding
8
 
9
- # Load environment variables from .env file
10
- load_dotenv()
11
-
12
- def check_credentials() -> tuple[bool, str]:
13
- """Check if required credentials are set and valid"""
14
- atlas_uri = os.getenv("ATLAS_URI")
15
- openai_key = os.getenv("OPENAI_API_KEY")
16
-
17
- if not atlas_uri:
18
- return False, """Please set up your MongoDB Atlas credentials:
19
- 1. Go to Settings tab
20
- 2. Add ATLAS_URI as a Repository Secret
21
- 3. Paste your MongoDB connection string (should start with 'mongodb+srv://')"""
22
-
23
- if not openai_key:
24
- return False, """Please set up your OpenAI API key:
25
- 1. Go to Settings tab
26
- 2. Add OPENAI_API_KEY as a Repository Secret
27
- 3. Paste your OpenAI API key"""
28
-
29
- return True, ""
30
-
31
- def init_clients():
32
- """Initialize OpenAI and MongoDB clients"""
33
- try:
34
- openai_client = OpenAI()
35
- db_utils = DatabaseUtils()
36
- return openai_client, db_utils
37
- except Exception as e:
38
- return None, None
39
-
40
- def get_field_names(db_name: str, collection_name: str) -> list[str]:
41
- """Get list of fields in the collection"""
42
- return db_utils.get_field_names(db_name, collection_name)
43
-
44
- def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]:
45
- """Generate embeddings for documents in parallel with progress tracking"""
46
- try:
47
- db = db_utils.client[db_name]
48
- collection = db[collection_name]
49
-
50
- # Count documents that need embeddings
51
- total_docs = collection.count_documents({field_name: {"$exists": True}})
52
- if total_docs == 0:
53
- return f"No documents found with field '{field_name}'", ""
54
-
55
- # Get total count of documents that need processing
56
- query = {
57
- field_name: {"$exists": True},
58
- embedding_field: {"$exists": False} # Only get docs without embeddings
59
- }
60
- total_to_process = collection.count_documents(query)
61
- if total_to_process == 0:
62
- return "No documents found that need embeddings", ""
63
-
64
- # Apply limit if specified
65
- if limit > 0:
66
- total_to_process = min(total_to_process, limit)
67
-
68
- print(f"\nFound {total_to_process} documents that need embeddings...")
69
-
70
- # Progress tracking
71
- progress_text = ""
72
- def update_progress(prog: float, processed: int, total: int):
73
- nonlocal progress_text
74
- progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
75
- print(progress_text) # Terminal logging
76
- progress(prog/100, f"Processed {processed}/{total} documents")
77
-
78
- # Show initial progress
79
- update_progress(0, 0, total_to_process)
80
-
81
- # Create cursor for batch processing
82
- cursor = collection.find(query)
83
- if limit > 0:
84
- cursor = cursor.limit(limit)
85
-
86
- # Generate embeddings in parallel with cursor-based batching
87
- processed = parallel_generate_embeddings(
88
- collection=collection,
89
- cursor=cursor,
90
- field_name=field_name,
91
- embedding_field=embedding_field,
92
- openai_client=openai_client,
93
- total_docs=total_to_process,
94
- callback=update_progress
95
- )
96
-
97
- # Return completion message and final progress
98
- instructions = f"""
99
- Successfully generated embeddings for {processed} documents using parallel processing!
100
-
101
- To create the vector search index in MongoDB Atlas:
102
- 1. Go to your Atlas cluster
103
- 2. Click on 'Search' tab
104
- 3. Create an index named 'vector_index' with this configuration:
105
- {{
106
- "fields": [
107
- {{
108
- "type": "vector",
109
- "path": "{embedding_field}",
110
- "numDimensions": 1536,
111
- "similarity": "dotProduct"
112
- }}
113
- ]
114
- }}
115
-
116
- You can now use the search tab with:
117
- - Field to search: {field_name}
118
- - Embedding field: {embedding_field}
119
- """
120
- return instructions, progress_text
121
-
122
- except Exception as e:
123
- return f"Error: {str(e)}", ""
124
-
125
- def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str:
126
- """Perform vector search using embeddings"""
127
- try:
128
- print(f"\nProcessing query: {query_text}")
129
-
130
- db = db_utils.client[db_name]
131
- collection = db[collection_name]
132
 
133
- # Get embeddings for query
134
- embedding = get_embedding(query_text, openai_client)
135
- print("Generated embeddings successfully")
136
-
137
- results = collection.aggregate([
138
- {
139
- '$vectorSearch': {
140
- "index": index_name,
141
- "path": embedding_field,
142
- "queryVector": embedding,
143
- "numCandidates": 50,
144
- "limit": 5
145
- }
146
- },
147
- {
148
- "$project": {
149
- "search_score": { "$meta": "vectorSearchScore" },
150
- "document": "$$ROOT"
151
- }
152
- }
153
- ])
154
-
155
- # Format results
156
- results_list = list(results)
157
- formatted_results = []
158
-
159
- for idx, result in enumerate(results_list, 1):
160
- doc = result['document']
161
- formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
162
- # Add all fields except _id and embeddings
163
- for key, value in doc.items():
164
- if key not in ['_id', embedding_field]:
165
- formatted_result += f"{key}: {value}\n"
166
- formatted_results.append(formatted_result)
167
 
168
- return "\n".join(formatted_results) if formatted_results else "No results found"
169
-
170
- except Exception as e:
171
- return f"Error: {str(e)}"
172
-
173
- # Create Gradio interface with tabs
174
- with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
175
- gr.Markdown("# MongoDB Vector Search Tool")
176
-
177
- # Check credentials first
178
- has_creds, cred_message = check_credentials()
179
- if not has_creds:
180
- gr.Markdown(f"""
181
- ## ⚠️ Setup Required
182
-
183
- {cred_message}
184
-
185
- After setting up credentials, refresh this page.
186
- """)
187
- else:
188
- # Initialize clients
189
- openai_client, db_utils = init_clients()
190
- if not openai_client or not db_utils:
191
- gr.Markdown("""
192
- ## ⚠️ Connection Error
193
 
194
- Failed to connect to MongoDB Atlas or OpenAI. Please check your credentials and try again.
195
  """)
196
  else:
197
- # Get available databases
198
- try:
199
- databases = db_utils.get_databases()
200
-
201
- with gr.Tab("Generate Embeddings"):
202
- with gr.Row():
203
- db_input = gr.Dropdown(
204
- choices=databases,
205
- label="Select Database",
206
- info="Available databases in Atlas cluster"
207
- )
208
- collection_input = gr.Dropdown(
209
- choices=[],
210
- label="Select Collection",
211
- info="Collections in selected database"
212
- )
213
- with gr.Row():
214
- field_input = gr.Dropdown(
215
- choices=[],
216
- label="Select Field for Embeddings",
217
- info="Fields available in collection"
218
- )
219
- embedding_field_input = gr.Textbox(
220
- label="Embedding Field Name",
221
- value="embedding",
222
- info="Field name where embeddings will be stored"
223
- )
224
- limit_input = gr.Number(
225
- label="Document Limit",
226
- value=10,
227
- minimum=0,
228
- info="Number of documents to process (0 for all documents)"
229
- )
230
-
231
- def update_collections(db_name):
232
- collections = db_utils.get_collections(db_name)
233
- # If there's only one collection, select it by default
234
- value = collections[0] if len(collections) == 1 else None
235
- return gr.Dropdown(choices=collections, value=value)
236
-
237
- def update_fields(db_name, collection_name):
238
- if db_name and collection_name:
239
- fields = get_field_names(db_name, collection_name)
240
- return gr.Dropdown(choices=fields)
241
- return gr.Dropdown(choices=[])
242
-
243
- # Update collections when database changes
244
- db_input.change(
245
- fn=update_collections,
246
- inputs=[db_input],
247
- outputs=[collection_input]
248
- )
249
-
250
- # Update fields when collection changes
251
- collection_input.change(
252
- fn=update_fields,
253
- inputs=[db_input, collection_input],
254
- outputs=[field_input]
255
- )
256
-
257
- generate_btn = gr.Button("Generate Embeddings")
258
- generate_output = gr.Textbox(label="Results", lines=10)
259
- progress_output = gr.Textbox(label="Progress", lines=3)
260
-
261
- generate_btn.click(
262
- generate_embeddings_for_field,
263
- inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input],
264
- outputs=[generate_output, progress_output]
265
- )
266
 
267
- with gr.Tab("Search"):
268
- with gr.Row():
269
- search_db_input = gr.Dropdown(
270
- choices=databases,
271
- label="Select Database",
272
- info="Database containing the vectors"
273
- )
274
- search_collection_input = gr.Dropdown(
275
- choices=[],
276
- label="Select Collection",
277
- info="Collection containing the vectors"
278
- )
279
- with gr.Row():
280
- search_embedding_field_input = gr.Textbox(
281
- label="Embedding Field Name",
282
- value="embedding",
283
- info="Field containing the vectors"
284
- )
285
- search_index_input = gr.Textbox(
286
- label="Vector Search Index Name",
287
- value="vector_index",
288
- info="Index created in Atlas UI"
289
- )
290
-
291
- # Update collections when database changes
292
- search_db_input.change(
293
- fn=update_collections,
294
- inputs=[search_db_input],
295
- outputs=[search_collection_input]
296
- )
297
-
298
- query_input = gr.Textbox(
299
- label="Search Query",
300
- lines=2,
301
- placeholder="What would you like to search for?"
302
- )
303
- search_btn = gr.Button("Search")
304
- search_output = gr.Textbox(label="Results", lines=10)
305
-
306
- search_btn.click(
307
- vector_search,
308
- inputs=[
309
- query_input,
310
- search_db_input,
311
- search_collection_input,
312
- search_embedding_field_input,
313
- search_index_input
314
- ],
315
- outputs=search_output
316
- )
317
 
318
  if __name__ == "__main__":
319
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
1
  import gradio as gr
2
+ from utils.credentials import check_credentials, init_clients
3
+ from ui.embeddings_tab import create_embeddings_tab
4
+ from ui.search_tab import create_search_tab
 
 
5
 
6
+ def create_app():
7
+ """Create and configure the Gradio application"""
8
+ with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
9
+ gr.Markdown("# MongoDB Vector Search Tool")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Check credentials first
12
+ has_creds, cred_message = check_credentials()
13
+ if not has_creds:
14
+ gr.Markdown(f"""
15
+ ## ⚠️ Setup Required
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ {cred_message}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ After setting up credentials, refresh this page.
20
  """)
21
  else:
22
+ # Initialize clients
23
+ openai_client, db_utils = init_clients()
24
+ if not openai_client or not db_utils:
25
+ gr.Markdown("""
26
+ ## ⚠️ Connection Error
27
+
28
+ Failed to connect to MongoDB Atlas or OpenAI. Please check your credentials and try again.
29
+ """)
30
+ else:
31
+ # Get available databases
32
+ try:
33
+ databases = db_utils.get_databases()
34
+ except Exception as e:
35
+ gr.Markdown(f"""
36
+ ## ⚠️ Database Error
37
+
38
+ Failed to list databases: {str(e)}
39
+ Please check your MongoDB Atlas connection and try again.
40
+ """)
41
+ databases = []
42
+
43
+ # Create tabs
44
+ embeddings_tab, embeddings_interface = create_embeddings_tab(
45
+ openai_client=openai_client,
46
+ db_utils=db_utils,
47
+ databases=databases
48
+ )
49
+
50
+ search_tab, search_interface = create_search_tab(
51
+ openai_client=openai_client,
52
+ db_utils=db_utils,
53
+ databases=databases
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ return iface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  if __name__ == "__main__":
59
+ app = create_app()
60
+ app.launch(server_name="0.0.0.0")
ui/__pycache__/embeddings_tab.cpython-312.pyc ADDED
Binary file (6.98 kB). View file
 
ui/__pycache__/search_tab.cpython-312.pyc ADDED
Binary file (5.06 kB). View file
 
ui/embeddings_tab.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Tuple, Optional, List
3
+ from openai import OpenAI
4
+ from utils.db_utils import DatabaseUtils
5
+ from utils.embedding_utils import parallel_generate_embeddings
6
+
7
+ def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
+ """Create the embeddings generation tab UI
9
+
10
+ Args:
11
+ openai_client: OpenAI client instance
12
+ db_utils: DatabaseUtils instance
13
+ databases: List of available databases
14
+
15
+ Returns:
16
+ Tuple[gr.Tab, dict]: The tab component and its interface elements
17
+ """
18
+ def update_collections(db_name: str) -> gr.Dropdown:
19
+ """Update collections dropdown when database changes"""
20
+ collections = db_utils.get_collections(db_name)
21
+ # If there's only one collection, select it by default
22
+ value = collections[0] if len(collections) == 1 else None
23
+ return gr.Dropdown(choices=collections, value=value)
24
+
25
+ def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
26
+ """Update fields dropdown when collection changes"""
27
+ if db_name and collection_name:
28
+ fields = db_utils.get_field_names(db_name, collection_name)
29
+ return gr.Dropdown(choices=fields)
30
+ return gr.Dropdown(choices=[])
31
+
32
+ def generate_embeddings(
33
+ db_name: str,
34
+ collection_name: str,
35
+ field_name: str,
36
+ embedding_field: str,
37
+ limit: int = 10,
38
+ progress=gr.Progress()
39
+ ) -> Tuple[str, str]:
40
+ """Generate embeddings for documents with progress tracking"""
41
+ try:
42
+ db = db_utils.client[db_name]
43
+ collection = db[collection_name]
44
+
45
+ # Count documents that need embeddings
46
+ total_docs = collection.count_documents({field_name: {"$exists": True}})
47
+ if total_docs == 0:
48
+ return f"No documents found with field '{field_name}'", ""
49
+
50
+ # Get total count of documents that need processing
51
+ query = {
52
+ field_name: {"$exists": True},
53
+ embedding_field: {"$exists": False} # Only get docs without embeddings
54
+ }
55
+ total_to_process = collection.count_documents(query)
56
+ if total_to_process == 0:
57
+ return "No documents found that need embeddings", ""
58
+
59
+ # Apply limit if specified
60
+ if limit > 0:
61
+ total_to_process = min(total_to_process, limit)
62
+
63
+ print(f"\nFound {total_to_process} documents that need embeddings...")
64
+
65
+ # Progress tracking
66
+ progress_text = ""
67
+ def update_progress(prog: float, processed: int, total: int):
68
+ nonlocal progress_text
69
+ progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
70
+ print(progress_text) # Terminal logging
71
+ progress(prog/100, f"Processed {processed}/{total} documents")
72
+
73
+ # Show initial progress
74
+ update_progress(0, 0, total_to_process)
75
+
76
+ # Create cursor for batch processing
77
+ cursor = collection.find(query)
78
+ if limit > 0:
79
+ cursor = cursor.limit(limit)
80
+
81
+ # Generate embeddings in parallel with cursor-based batching
82
+ processed = parallel_generate_embeddings(
83
+ collection=collection,
84
+ cursor=cursor,
85
+ field_name=field_name,
86
+ embedding_field=embedding_field,
87
+ openai_client=openai_client,
88
+ total_docs=total_to_process,
89
+ callback=update_progress
90
+ )
91
+
92
+ # Return completion message and final progress
93
+ instructions = f"""
94
+ Successfully generated embeddings for {processed} documents using parallel processing!
95
+
96
+ To create the vector search index in MongoDB Atlas:
97
+ 1. Go to your Atlas cluster
98
+ 2. Click on 'Search' tab
99
+ 3. Create an index named 'vector_index' with this configuration:
100
+ {{
101
+ "fields": [
102
+ {{
103
+ "type": "vector",
104
+ "path": "{embedding_field}",
105
+ "numDimensions": 1536,
106
+ "similarity": "dotProduct"
107
+ }}
108
+ ]
109
+ }}
110
+
111
+ You can now use the search tab with:
112
+ - Field to search: {field_name}
113
+ - Embedding field: {embedding_field}
114
+ """
115
+ return instructions, progress_text
116
+
117
+ except Exception as e:
118
+ return f"Error: {str(e)}", ""
119
+
120
+ # Create the tab UI
121
+ with gr.Tab("Generate Embeddings") as tab:
122
+ with gr.Row():
123
+ db_input = gr.Dropdown(
124
+ choices=databases,
125
+ label="Select Database",
126
+ info="Available databases in Atlas cluster"
127
+ )
128
+ collection_input = gr.Dropdown(
129
+ choices=[],
130
+ label="Select Collection",
131
+ info="Collections in selected database"
132
+ )
133
+ with gr.Row():
134
+ field_input = gr.Dropdown(
135
+ choices=[],
136
+ label="Select Field for Embeddings",
137
+ info="Fields available in collection"
138
+ )
139
+ embedding_field_input = gr.Textbox(
140
+ label="Embedding Field Name",
141
+ value="embedding",
142
+ info="Field name where embeddings will be stored"
143
+ )
144
+ limit_input = gr.Number(
145
+ label="Document Limit",
146
+ value=10,
147
+ minimum=0,
148
+ info="Number of documents to process (0 for all documents)"
149
+ )
150
+
151
+ generate_btn = gr.Button("Generate Embeddings")
152
+ generate_output = gr.Textbox(label="Results", lines=10)
153
+ progress_output = gr.Textbox(label="Progress", lines=3)
154
+
155
+ # Set up event handlers
156
+ db_input.change(
157
+ fn=update_collections,
158
+ inputs=[db_input],
159
+ outputs=[collection_input]
160
+ )
161
+
162
+ collection_input.change(
163
+ fn=update_fields,
164
+ inputs=[db_input, collection_input],
165
+ outputs=[field_input]
166
+ )
167
+
168
+ generate_btn.click(
169
+ fn=generate_embeddings,
170
+ inputs=[
171
+ db_input,
172
+ collection_input,
173
+ field_input,
174
+ embedding_field_input,
175
+ limit_input
176
+ ],
177
+ outputs=[generate_output, progress_output]
178
+ )
179
+
180
+ # Return the tab and its interface elements
181
+ interface = {
182
+ 'db_input': db_input,
183
+ 'collection_input': collection_input,
184
+ 'field_input': field_input,
185
+ 'embedding_field_input': embedding_field_input,
186
+ 'limit_input': limit_input,
187
+ 'generate_btn': generate_btn,
188
+ 'generate_output': generate_output,
189
+ 'progress_output': progress_output
190
+ }
191
+
192
+ return tab, interface
ui/search_tab.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Tuple, List
3
+ from openai import OpenAI
4
+ from utils.db_utils import DatabaseUtils
5
+ from utils.embedding_utils import get_embedding
6
+
7
+ def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
+ """Create the vector search tab UI
9
+
10
+ Args:
11
+ openai_client: OpenAI client instance
12
+ db_utils: DatabaseUtils instance
13
+ databases: List of available databases
14
+
15
+ Returns:
16
+ Tuple[gr.Tab, dict]: The tab component and its interface elements
17
+ """
18
+ def update_collections(db_name: str) -> gr.Dropdown:
19
+ """Update collections dropdown when database changes"""
20
+ collections = db_utils.get_collections(db_name)
21
+ # If there's only one collection, select it by default
22
+ value = collections[0] if len(collections) == 1 else None
23
+ return gr.Dropdown(choices=collections, value=value)
24
+
25
+ def vector_search(
26
+ query_text: str,
27
+ db_name: str,
28
+ collection_name: str,
29
+ embedding_field: str,
30
+ index_name: str
31
+ ) -> str:
32
+ """Perform vector search using embeddings"""
33
+ try:
34
+ print(f"\nProcessing query: {query_text}")
35
+
36
+ db = db_utils.client[db_name]
37
+ collection = db[collection_name]
38
+
39
+ # Get embeddings for query
40
+ embedding = get_embedding(query_text, openai_client)
41
+ print("Generated embeddings successfully")
42
+
43
+ results = collection.aggregate([
44
+ {
45
+ '$vectorSearch': {
46
+ "index": index_name,
47
+ "path": embedding_field,
48
+ "queryVector": embedding,
49
+ "numCandidates": 50,
50
+ "limit": 5
51
+ }
52
+ },
53
+ {
54
+ "$project": {
55
+ "search_score": { "$meta": "vectorSearchScore" },
56
+ "document": "$$ROOT"
57
+ }
58
+ }
59
+ ])
60
+
61
+ # Format results
62
+ results_list = list(results)
63
+ formatted_results = []
64
+
65
+ for idx, result in enumerate(results_list, 1):
66
+ doc = result['document']
67
+ formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
68
+ # Add all fields except _id and embeddings
69
+ for key, value in doc.items():
70
+ if key not in ['_id', embedding_field]:
71
+ formatted_result += f"{key}: {value}\n"
72
+ formatted_results.append(formatted_result)
73
+
74
+ return "\n".join(formatted_results) if formatted_results else "No results found"
75
+
76
+ except Exception as e:
77
+ return f"Error: {str(e)}"
78
+
79
+ # Create the tab UI
80
+ with gr.Tab("Search") as tab:
81
+ with gr.Row():
82
+ db_input = gr.Dropdown(
83
+ choices=databases,
84
+ label="Select Database",
85
+ info="Database containing the vectors"
86
+ )
87
+ collection_input = gr.Dropdown(
88
+ choices=[],
89
+ label="Select Collection",
90
+ info="Collection containing the vectors"
91
+ )
92
+ with gr.Row():
93
+ embedding_field_input = gr.Textbox(
94
+ label="Embedding Field Name",
95
+ value="embedding",
96
+ info="Field containing the vectors"
97
+ )
98
+ index_input = gr.Textbox(
99
+ label="Vector Search Index Name",
100
+ value="vector_index",
101
+ info="Index created in Atlas UI"
102
+ )
103
+
104
+ query_input = gr.Textbox(
105
+ label="Search Query",
106
+ lines=2,
107
+ placeholder="What would you like to search for?"
108
+ )
109
+ search_btn = gr.Button("Search")
110
+ search_output = gr.Textbox(label="Results", lines=10)
111
+
112
+ # Set up event handlers
113
+ db_input.change(
114
+ fn=update_collections,
115
+ inputs=[db_input],
116
+ outputs=[collection_input]
117
+ )
118
+
119
+ search_btn.click(
120
+ fn=vector_search,
121
+ inputs=[
122
+ query_input,
123
+ db_input,
124
+ collection_input,
125
+ embedding_field_input,
126
+ index_input
127
+ ],
128
+ outputs=search_output
129
+ )
130
+
131
+ # Return the tab and its interface elements
132
+ interface = {
133
+ 'db_input': db_input,
134
+ 'collection_input': collection_input,
135
+ 'embedding_field_input': embedding_field_input,
136
+ 'index_input': index_input,
137
+ 'query_input': query_input,
138
+ 'search_btn': search_btn,
139
+ 'search_output': search_output
140
+ }
141
+
142
+ return tab, interface
utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils package for MongoDB Vector Search Tool
2
+ from utils.credentials import check_credentials, init_clients
3
+ from utils.db_utils import DatabaseUtils
4
+ from utils.embedding_utils import get_embedding, parallel_generate_embeddings
5
+
6
+ __all__ = [
7
+ 'check_credentials',
8
+ 'init_clients',
9
+ 'DatabaseUtils',
10
+ 'get_embedding',
11
+ 'parallel_generate_embeddings'
12
+ ]
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (427 Bytes). View file
 
utils/__pycache__/credentials.cpython-312.pyc ADDED
Binary file (1.79 kB). View file
 
utils/__pycache__/db_utils.cpython-312.pyc ADDED
Binary file (7.62 kB). View file
 
utils/__pycache__/embedding_utils.cpython-312.pyc ADDED
Binary file (7.2 kB). View file
 
utils/credentials.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+ from dotenv import load_dotenv
4
+ from openai import OpenAI
5
+ from utils.db_utils import DatabaseUtils
6
+
7
+ def check_credentials() -> Tuple[bool, str]:
8
+ """Check if required credentials are set and valid
9
+
10
+ Returns:
11
+ Tuple[bool, str]: (is_valid, message)
12
+ - is_valid: True if all credentials are valid
13
+ - message: Error message if credentials are invalid
14
+ """
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ atlas_uri = os.getenv("ATLAS_URI")
19
+ openai_key = os.getenv("OPENAI_API_KEY")
20
+
21
+ if not atlas_uri:
22
+ return False, """Please set up your MongoDB Atlas credentials:
23
+ 1. Go to Settings tab
24
+ 2. Add ATLAS_URI as a Repository Secret
25
+ 3. Paste your MongoDB connection string (should start with 'mongodb+srv://')"""
26
+
27
+ if not openai_key:
28
+ return False, """Please set up your OpenAI API key:
29
+ 1. Go to Settings tab
30
+ 2. Add OPENAI_API_KEY as a Repository Secret
31
+ 3. Paste your OpenAI API key"""
32
+
33
+ return True, ""
34
+
35
+ def init_clients():
36
+ """Initialize OpenAI and MongoDB clients
37
+
38
+ Returns:
39
+ Tuple[OpenAI, DatabaseUtils]: OpenAI client and DatabaseUtils instance
40
+ or (None, None) if initialization fails
41
+ """
42
+ try:
43
+ openai_client = OpenAI()
44
+ db_utils = DatabaseUtils()
45
+ return openai_client, db_utils
46
+ except Exception as e:
47
+ return None, None
utils/db_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any, Optional
3
+ from pymongo import MongoClient
4
+ from pymongo.errors import (
5
+ ConnectionFailure,
6
+ OperationFailure,
7
+ ServerSelectionTimeoutError,
8
+ InvalidName
9
+ )
10
+ from dotenv import load_dotenv
11
+
12
+ class DatabaseError(Exception):
13
+ """Base class for database operation errors"""
14
+ pass
15
+
16
+ class ConnectionError(DatabaseError):
17
+ """Error when connecting to MongoDB Atlas"""
18
+ pass
19
+
20
+ class OperationError(DatabaseError):
21
+ """Error during database operations"""
22
+ pass
23
+
24
+ class DatabaseUtils:
25
+ """Utility class for MongoDB Atlas database operations
26
+
27
+ This class provides methods to interact with MongoDB Atlas databases and collections,
28
+ including listing databases, collections, and retrieving collection information.
29
+
30
+ Attributes:
31
+ atlas_uri (str): MongoDB Atlas connection string
32
+ client (MongoClient): MongoDB client instance
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initialize DatabaseUtils with MongoDB Atlas connection
37
+
38
+ Raises:
39
+ ConnectionError: If unable to connect to MongoDB Atlas
40
+ ValueError: If ATLAS_URI environment variable is not set
41
+ """
42
+ # Load environment variables
43
+ load_dotenv()
44
+
45
+ self.atlas_uri = os.getenv("ATLAS_URI")
46
+ if not self.atlas_uri:
47
+ raise ValueError("ATLAS_URI environment variable is not set")
48
+
49
+ try:
50
+ self.client = MongoClient(self.atlas_uri)
51
+ # Test connection
52
+ self.client.admin.command('ping')
53
+ except (ConnectionFailure, ServerSelectionTimeoutError) as e:
54
+ raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
55
+
56
+ def get_databases(self) -> List[str]:
57
+ """Get list of all databases in Atlas cluster
58
+
59
+ Returns:
60
+ List[str]: List of database names
61
+
62
+ Raises:
63
+ OperationError: If unable to list databases
64
+ """
65
+ try:
66
+ return self.client.list_database_names()
67
+ except OperationFailure as e:
68
+ raise OperationError(f"Failed to list databases: {str(e)}")
69
+
70
+ def get_collections(self, db_name: str) -> List[str]:
71
+ """Get list of collections in a database
72
+
73
+ Args:
74
+ db_name (str): Name of the database
75
+
76
+ Returns:
77
+ List[str]: List of collection names
78
+
79
+ Raises:
80
+ OperationError: If unable to list collections
81
+ ValueError: If db_name is empty or invalid
82
+ """
83
+ if not db_name or not isinstance(db_name, str):
84
+ raise ValueError("Database name must be a non-empty string")
85
+
86
+ try:
87
+ db = self.client[db_name]
88
+ return db.list_collection_names()
89
+ except (OperationFailure, InvalidName) as e:
90
+ raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
91
+
92
+ def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
93
+ """Get information about a collection including document count and sample document
94
+
95
+ Args:
96
+ db_name (str): Name of the database
97
+ collection_name (str): Name of the collection
98
+
99
+ Returns:
100
+ Dict[str, Any]: Dictionary containing collection information:
101
+ - count: Number of documents in collection
102
+ - sample: Sample document from collection (if exists)
103
+
104
+ Raises:
105
+ OperationError: If unable to get collection information
106
+ ValueError: If db_name or collection_name is empty or invalid
107
+ """
108
+ if not db_name or not isinstance(db_name, str):
109
+ raise ValueError("Database name must be a non-empty string")
110
+ if not collection_name or not isinstance(collection_name, str):
111
+ raise ValueError("Collection name must be a non-empty string")
112
+
113
+ try:
114
+ db = self.client[db_name]
115
+ collection = db[collection_name]
116
+
117
+ return {
118
+ 'count': collection.count_documents({}),
119
+ 'sample': collection.find_one()
120
+ }
121
+ except (OperationFailure, InvalidName) as e:
122
+ raise OperationError(
123
+ f"Failed to get info for collection '{collection_name}' "
124
+ f"in database '{db_name}': {str(e)}"
125
+ )
126
+
127
+ def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
128
+ """Get list of fields in a collection based on sample document
129
+
130
+ Args:
131
+ db_name (str): Name of the database
132
+ collection_name (str): Name of the collection
133
+
134
+ Returns:
135
+ List[str]: List of field names (excluding _id and embedding fields)
136
+
137
+ Raises:
138
+ OperationError: If unable to get field names
139
+ ValueError: If db_name or collection_name is empty or invalid
140
+ """
141
+ try:
142
+ info = self.get_collection_info(db_name, collection_name)
143
+ sample = info.get('sample', {})
144
+
145
+ if sample:
146
+ # Get all field names except _id and any existing embedding fields
147
+ return [field for field in sample.keys()
148
+ if field != '_id' and not field.endswith('_embedding')]
149
+ return []
150
+ except DatabaseError as e:
151
+ raise OperationError(
152
+ f"Failed to get field names for collection '{collection_name}' "
153
+ f"in database '{db_name}': {str(e)}"
154
+ )
155
+
156
+ def close(self):
157
+ """Close MongoDB connection safely"""
158
+ if hasattr(self, 'client'):
159
+ self.client.close()
utils/embedding_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from pymongo import UpdateOne
4
+ from pymongo.collection import Collection
5
+ import math
6
+ import time
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
14
+ """Get embeddings for given text using OpenAI API with retry logic"""
15
+ text = text.replace("\n", " ")
16
+
17
+ for attempt in range(max_retries):
18
+ try:
19
+ resp = openai_client.embeddings.create(
20
+ input=[text],
21
+ model=model
22
+ )
23
+ return resp.data[0].embedding
24
+ except Exception as e:
25
+ if attempt == max_retries - 1:
26
+ raise
27
+ error_details = f"{type(e).__name__}: {str(e)}"
28
+ if hasattr(e, 'response'):
29
+ error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
30
+ logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
31
+ time.sleep(2 ** attempt) # Exponential backoff
32
+
33
+ def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
34
+ """Process a batch of documents to generate embeddings"""
35
+ logger.info(f"Processing batch of {len(docs)} documents")
36
+ results = []
37
+ for doc in docs:
38
+ # Skip if embedding already exists
39
+ if embedding_field in doc:
40
+ continue
41
+
42
+ text = doc[field_name]
43
+ if isinstance(text, str):
44
+ embedding = get_embedding(text, openai_client)
45
+ results.append((doc["_id"], embedding))
46
+ return results
47
+
48
+ def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
49
+ """Process completed futures and update progress"""
50
+ completed = 0
51
+ for future in as_completed(futures, timeout=30): # 30 second timeout
52
+ try:
53
+ results = future.result()
54
+ if results:
55
+ bulk_ops = [
56
+ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
57
+ for doc_id, embedding in results
58
+ ]
59
+ if bulk_ops:
60
+ collection.bulk_write(bulk_ops)
61
+ completed += len(bulk_ops)
62
+
63
+ # Update progress
64
+ if callback:
65
+ progress = ((processed + completed) / total_docs) * 100
66
+ callback(progress, processed + completed, total_docs)
67
+ except Exception as e:
68
+ error_details = f"{type(e).__name__}: {str(e)}"
69
+ if hasattr(e, 'response'):
70
+ error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
71
+ logger.error(f"Error processing future:\n{error_details}")
72
+ return completed
73
+
74
+ def parallel_generate_embeddings(
75
+ collection: Collection,
76
+ cursor,
77
+ field_name: str,
78
+ embedding_field: str,
79
+ openai_client,
80
+ total_docs: int,
81
+ batch_size: int = 10, # Reduced initial batch size
82
+ callback=None
83
+ ) -> int:
84
+ """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
85
+ if total_docs == 0:
86
+ return 0
87
+
88
+ processed = 0
89
+ current_batch_size = batch_size
90
+ max_workers = 10 # Start with fewer workers
91
+
92
+ logger.info(f"Starting embedding generation for {total_docs} documents")
93
+ if callback:
94
+ callback(0, 0, total_docs)
95
+
96
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
97
+ batch = []
98
+ futures = []
99
+
100
+ for doc in cursor:
101
+ batch.append(doc)
102
+
103
+ if len(batch) >= current_batch_size:
104
+ logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
105
+ future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
106
+ futures.append(future)
107
+ batch = []
108
+
109
+ # Process completed futures more frequently
110
+ if len(futures) >= max_workers:
111
+ try:
112
+ completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
113
+ processed += completed
114
+ futures = [] # Clear processed futures
115
+
116
+ # Gradually increase batch size and workers if processing is successful
117
+ if completed > 0:
118
+ current_batch_size = min(current_batch_size + 5, 30)
119
+ max_workers = min(max_workers + 2, 20)
120
+ logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
121
+ except Exception as e:
122
+ logger.error(f"Error processing futures: {str(e)}")
123
+ # Reduce batch size and workers on error
124
+ current_batch_size = max(5, current_batch_size - 5)
125
+ max_workers = max(5, max_workers - 2)
126
+ logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
127
+
128
+ # Process remaining batch
129
+ if batch:
130
+ logger.info(f"Processing final batch of {len(batch)} documents")
131
+ future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
132
+ futures.append(future)
133
+
134
+ # Process remaining futures
135
+ if futures:
136
+ try:
137
+ completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
138
+ processed += completed
139
+ except Exception as e:
140
+ logger.error(f"Error processing final futures: {str(e)}")
141
+
142
+ logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
143
+ return processed