mongo-vector-search-util / ui /embeddings_tab.py
airabbitX's picture
Upload 3 files
50e3a95 verified
import gradio as gr
from typing import Tuple, Optional, List
from openai import OpenAI
from utils.db_utils import DatabaseUtils
from utils.embedding_utils import parallel_generate_embeddings
def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
"""Create the embeddings generation tab UI
Args:
openai_client: OpenAI client instance
db_utils: DatabaseUtils instance
databases: List of available databases
Returns:
Tuple[gr.Tab, dict]: The tab component and its interface elements
"""
def update_collections(db_name: str) -> gr.Dropdown:
"""Update collections dropdown when database changes"""
collections = db_utils.get_collections(db_name)
# If there's only one collection, select it by default
value = collections[0] if len(collections) == 1 else None
return gr.Dropdown(choices=collections, value=value)
def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
"""Update fields dropdown when collection changes"""
if db_name and collection_name:
fields = db_utils.get_field_names(db_name, collection_name)
return gr.Dropdown(choices=fields)
return gr.Dropdown(choices=[])
def generate_embeddings(
db_name: str,
collection_name: str,
field_name: str,
embedding_field: str,
limit: int = 10,
progress=gr.Progress()
) -> Tuple[str, str]:
"""Generate embeddings for documents with progress tracking"""
try:
db = db_utils.client[db_name]
collection = db[collection_name]
# Count documents that need embeddings
total_docs = collection.count_documents({field_name: {"$exists": True}})
if total_docs == 0:
return f"No documents found with field '{field_name}'", ""
# Get total count of documents that need processing
query = {
field_name: {"$exists": True},
embedding_field: {"$exists": False} # Only get docs without embeddings
}
total_to_process = collection.count_documents(query)
if total_to_process == 0:
return "No documents found that need embeddings", ""
# Apply limit if specified
if limit > 0:
total_to_process = min(total_to_process, limit)
print(f"\nFound {total_to_process} documents that need embeddings...")
# Progress tracking
progress_text = ""
def update_progress(prog: float, processed: int, total: int):
nonlocal progress_text
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
print(progress_text) # Terminal logging
progress(prog/100, f"Processed {processed}/{total} documents")
# Show initial progress
update_progress(0, 0, total_to_process)
# Create cursor for batch processing
cursor = collection.find(query)
if limit > 0:
cursor = cursor.limit(limit)
# Generate embeddings in parallel with cursor-based batching
processed = parallel_generate_embeddings(
collection=collection,
cursor=cursor,
field_name=field_name,
embedding_field=embedding_field,
openai_client=openai_client,
total_docs=total_to_process,
callback=update_progress
)
# Return completion message and final progress
instructions = f"""
Successfully generated embeddings for {processed} documents using parallel processing!
To create the vector search index in MongoDB Atlas:
1. Go to your Atlas cluster
2. Click on 'Search' tab
3. Create an index named 'vector_index' with this configuration:
{{
"fields": [
{{
"type": "vector",
"path": "{embedding_field}",
"numDimensions": 1536,
"similarity": "dotProduct"
}}
]
}}
You can now use the search tab with:
- Field to search: {field_name}
- Embedding field: {embedding_field}
"""
return instructions, progress_text
except Exception as e:
return f"Error: {str(e)}", ""
# Create the tab UI
with gr.Tab("Generate Embeddings") as tab:
with gr.Row():
db_input = gr.Dropdown(
choices=databases,
label="Select Database",
info="Available databases in Atlas cluster"
)
collection_input = gr.Dropdown(
choices=[],
label="Select Collection",
info="Collections in selected database"
)
with gr.Row():
field_input = gr.Dropdown(
choices=[],
label="Select Field for Embeddings",
info="Fields available in collection"
)
embedding_field_input = gr.Textbox(
label="Embedding Field Name",
value="embedding",
info="Field name where embeddings will be stored"
)
limit_input = gr.Number(
label="Document Limit",
value=10,
minimum=0,
info="Number of documents to process (0 for all documents)"
)
generate_btn = gr.Button("Generate Embeddings")
generate_output = gr.Textbox(label="Results", lines=10)
progress_output = gr.Textbox(label="Progress", lines=3)
# Set up event handlers
db_input.change(
fn=update_collections,
inputs=[db_input],
outputs=[collection_input]
)
collection_input.change(
fn=update_fields,
inputs=[db_input, collection_input],
outputs=[field_input]
)
generate_btn.click(
fn=generate_embeddings,
inputs=[
db_input,
collection_input,
field_input,
embedding_field_input,
limit_input
],
outputs=[generate_output, progress_output]
)
# Return the tab and its interface elements
interface = {
'db_input': db_input,
'collection_input': collection_input,
'field_input': field_input,
'embedding_field_input': embedding_field_input,
'limit_input': limit_input,
'generate_btn': generate_btn,
'generate_output': generate_output,
'progress_output': progress_output
}
return tab, interface