airabbitX commited on
Commit
50e3a95
·
verified ·
1 Parent(s): 124432f

Upload 3 files

Browse files
Files changed (3) hide show
  1. ui/__init__.py +8 -0
  2. ui/embeddings_tab.py +192 -0
  3. ui/search_tab.py +142 -0
ui/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # UI package for MongoDB Vector Search Tool
2
+ from ui.embeddings_tab import create_embeddings_tab
3
+ from ui.search_tab import create_search_tab
4
+
5
+ __all__ = [
6
+ 'create_embeddings_tab',
7
+ 'create_search_tab'
8
+ ]
ui/embeddings_tab.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Tuple, Optional, List
3
+ from openai import OpenAI
4
+ from utils.db_utils import DatabaseUtils
5
+ from utils.embedding_utils import parallel_generate_embeddings
6
+
7
+ def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
+ """Create the embeddings generation tab UI
9
+
10
+ Args:
11
+ openai_client: OpenAI client instance
12
+ db_utils: DatabaseUtils instance
13
+ databases: List of available databases
14
+
15
+ Returns:
16
+ Tuple[gr.Tab, dict]: The tab component and its interface elements
17
+ """
18
+ def update_collections(db_name: str) -> gr.Dropdown:
19
+ """Update collections dropdown when database changes"""
20
+ collections = db_utils.get_collections(db_name)
21
+ # If there's only one collection, select it by default
22
+ value = collections[0] if len(collections) == 1 else None
23
+ return gr.Dropdown(choices=collections, value=value)
24
+
25
+ def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
26
+ """Update fields dropdown when collection changes"""
27
+ if db_name and collection_name:
28
+ fields = db_utils.get_field_names(db_name, collection_name)
29
+ return gr.Dropdown(choices=fields)
30
+ return gr.Dropdown(choices=[])
31
+
32
+ def generate_embeddings(
33
+ db_name: str,
34
+ collection_name: str,
35
+ field_name: str,
36
+ embedding_field: str,
37
+ limit: int = 10,
38
+ progress=gr.Progress()
39
+ ) -> Tuple[str, str]:
40
+ """Generate embeddings for documents with progress tracking"""
41
+ try:
42
+ db = db_utils.client[db_name]
43
+ collection = db[collection_name]
44
+
45
+ # Count documents that need embeddings
46
+ total_docs = collection.count_documents({field_name: {"$exists": True}})
47
+ if total_docs == 0:
48
+ return f"No documents found with field '{field_name}'", ""
49
+
50
+ # Get total count of documents that need processing
51
+ query = {
52
+ field_name: {"$exists": True},
53
+ embedding_field: {"$exists": False} # Only get docs without embeddings
54
+ }
55
+ total_to_process = collection.count_documents(query)
56
+ if total_to_process == 0:
57
+ return "No documents found that need embeddings", ""
58
+
59
+ # Apply limit if specified
60
+ if limit > 0:
61
+ total_to_process = min(total_to_process, limit)
62
+
63
+ print(f"\nFound {total_to_process} documents that need embeddings...")
64
+
65
+ # Progress tracking
66
+ progress_text = ""
67
+ def update_progress(prog: float, processed: int, total: int):
68
+ nonlocal progress_text
69
+ progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
70
+ print(progress_text) # Terminal logging
71
+ progress(prog/100, f"Processed {processed}/{total} documents")
72
+
73
+ # Show initial progress
74
+ update_progress(0, 0, total_to_process)
75
+
76
+ # Create cursor for batch processing
77
+ cursor = collection.find(query)
78
+ if limit > 0:
79
+ cursor = cursor.limit(limit)
80
+
81
+ # Generate embeddings in parallel with cursor-based batching
82
+ processed = parallel_generate_embeddings(
83
+ collection=collection,
84
+ cursor=cursor,
85
+ field_name=field_name,
86
+ embedding_field=embedding_field,
87
+ openai_client=openai_client,
88
+ total_docs=total_to_process,
89
+ callback=update_progress
90
+ )
91
+
92
+ # Return completion message and final progress
93
+ instructions = f"""
94
+ Successfully generated embeddings for {processed} documents using parallel processing!
95
+
96
+ To create the vector search index in MongoDB Atlas:
97
+ 1. Go to your Atlas cluster
98
+ 2. Click on 'Search' tab
99
+ 3. Create an index named 'vector_index' with this configuration:
100
+ {{
101
+ "fields": [
102
+ {{
103
+ "type": "vector",
104
+ "path": "{embedding_field}",
105
+ "numDimensions": 1536,
106
+ "similarity": "dotProduct"
107
+ }}
108
+ ]
109
+ }}
110
+
111
+ You can now use the search tab with:
112
+ - Field to search: {field_name}
113
+ - Embedding field: {embedding_field}
114
+ """
115
+ return instructions, progress_text
116
+
117
+ except Exception as e:
118
+ return f"Error: {str(e)}", ""
119
+
120
+ # Create the tab UI
121
+ with gr.Tab("Generate Embeddings") as tab:
122
+ with gr.Row():
123
+ db_input = gr.Dropdown(
124
+ choices=databases,
125
+ label="Select Database",
126
+ info="Available databases in Atlas cluster"
127
+ )
128
+ collection_input = gr.Dropdown(
129
+ choices=[],
130
+ label="Select Collection",
131
+ info="Collections in selected database"
132
+ )
133
+ with gr.Row():
134
+ field_input = gr.Dropdown(
135
+ choices=[],
136
+ label="Select Field for Embeddings",
137
+ info="Fields available in collection"
138
+ )
139
+ embedding_field_input = gr.Textbox(
140
+ label="Embedding Field Name",
141
+ value="embedding",
142
+ info="Field name where embeddings will be stored"
143
+ )
144
+ limit_input = gr.Number(
145
+ label="Document Limit",
146
+ value=10,
147
+ minimum=0,
148
+ info="Number of documents to process (0 for all documents)"
149
+ )
150
+
151
+ generate_btn = gr.Button("Generate Embeddings")
152
+ generate_output = gr.Textbox(label="Results", lines=10)
153
+ progress_output = gr.Textbox(label="Progress", lines=3)
154
+
155
+ # Set up event handlers
156
+ db_input.change(
157
+ fn=update_collections,
158
+ inputs=[db_input],
159
+ outputs=[collection_input]
160
+ )
161
+
162
+ collection_input.change(
163
+ fn=update_fields,
164
+ inputs=[db_input, collection_input],
165
+ outputs=[field_input]
166
+ )
167
+
168
+ generate_btn.click(
169
+ fn=generate_embeddings,
170
+ inputs=[
171
+ db_input,
172
+ collection_input,
173
+ field_input,
174
+ embedding_field_input,
175
+ limit_input
176
+ ],
177
+ outputs=[generate_output, progress_output]
178
+ )
179
+
180
+ # Return the tab and its interface elements
181
+ interface = {
182
+ 'db_input': db_input,
183
+ 'collection_input': collection_input,
184
+ 'field_input': field_input,
185
+ 'embedding_field_input': embedding_field_input,
186
+ 'limit_input': limit_input,
187
+ 'generate_btn': generate_btn,
188
+ 'generate_output': generate_output,
189
+ 'progress_output': progress_output
190
+ }
191
+
192
+ return tab, interface
ui/search_tab.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Tuple, List
3
+ from openai import OpenAI
4
+ from utils.db_utils import DatabaseUtils
5
+ from utils.embedding_utils import get_embedding
6
+
7
+ def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
8
+ """Create the vector search tab UI
9
+
10
+ Args:
11
+ openai_client: OpenAI client instance
12
+ db_utils: DatabaseUtils instance
13
+ databases: List of available databases
14
+
15
+ Returns:
16
+ Tuple[gr.Tab, dict]: The tab component and its interface elements
17
+ """
18
+ def update_collections(db_name: str) -> gr.Dropdown:
19
+ """Update collections dropdown when database changes"""
20
+ collections = db_utils.get_collections(db_name)
21
+ # If there's only one collection, select it by default
22
+ value = collections[0] if len(collections) == 1 else None
23
+ return gr.Dropdown(choices=collections, value=value)
24
+
25
+ def vector_search(
26
+ query_text: str,
27
+ db_name: str,
28
+ collection_name: str,
29
+ embedding_field: str,
30
+ index_name: str
31
+ ) -> str:
32
+ """Perform vector search using embeddings"""
33
+ try:
34
+ print(f"\nProcessing query: {query_text}")
35
+
36
+ db = db_utils.client[db_name]
37
+ collection = db[collection_name]
38
+
39
+ # Get embeddings for query
40
+ embedding = get_embedding(query_text, openai_client)
41
+ print("Generated embeddings successfully")
42
+
43
+ results = collection.aggregate([
44
+ {
45
+ '$vectorSearch': {
46
+ "index": index_name,
47
+ "path": embedding_field,
48
+ "queryVector": embedding,
49
+ "numCandidates": 50,
50
+ "limit": 5
51
+ }
52
+ },
53
+ {
54
+ "$project": {
55
+ "search_score": { "$meta": "vectorSearchScore" },
56
+ "document": "$$ROOT"
57
+ }
58
+ }
59
+ ])
60
+
61
+ # Format results
62
+ results_list = list(results)
63
+ formatted_results = []
64
+
65
+ for idx, result in enumerate(results_list, 1):
66
+ doc = result['document']
67
+ formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
68
+ # Add all fields except _id and embeddings
69
+ for key, value in doc.items():
70
+ if key not in ['_id', embedding_field]:
71
+ formatted_result += f"{key}: {value}\n"
72
+ formatted_results.append(formatted_result)
73
+
74
+ return "\n".join(formatted_results) if formatted_results else "No results found"
75
+
76
+ except Exception as e:
77
+ return f"Error: {str(e)}"
78
+
79
+ # Create the tab UI
80
+ with gr.Tab("Search") as tab:
81
+ with gr.Row():
82
+ db_input = gr.Dropdown(
83
+ choices=databases,
84
+ label="Select Database",
85
+ info="Database containing the vectors"
86
+ )
87
+ collection_input = gr.Dropdown(
88
+ choices=[],
89
+ label="Select Collection",
90
+ info="Collection containing the vectors"
91
+ )
92
+ with gr.Row():
93
+ embedding_field_input = gr.Textbox(
94
+ label="Embedding Field Name",
95
+ value="embedding",
96
+ info="Field containing the vectors"
97
+ )
98
+ index_input = gr.Textbox(
99
+ label="Vector Search Index Name",
100
+ value="vector_index",
101
+ info="Index created in Atlas UI"
102
+ )
103
+
104
+ query_input = gr.Textbox(
105
+ label="Search Query",
106
+ lines=2,
107
+ placeholder="What would you like to search for?"
108
+ )
109
+ search_btn = gr.Button("Search")
110
+ search_output = gr.Textbox(label="Results", lines=10)
111
+
112
+ # Set up event handlers
113
+ db_input.change(
114
+ fn=update_collections,
115
+ inputs=[db_input],
116
+ outputs=[collection_input]
117
+ )
118
+
119
+ search_btn.click(
120
+ fn=vector_search,
121
+ inputs=[
122
+ query_input,
123
+ db_input,
124
+ collection_input,
125
+ embedding_field_input,
126
+ index_input
127
+ ],
128
+ outputs=search_output
129
+ )
130
+
131
+ # Return the tab and its interface elements
132
+ interface = {
133
+ 'db_input': db_input,
134
+ 'collection_input': collection_input,
135
+ 'embedding_field_input': embedding_field_input,
136
+ 'index_input': index_input,
137
+ 'query_input': query_input,
138
+ 'search_btn': search_btn,
139
+ 'search_output': search_output
140
+ }
141
+
142
+ return tab, interface