import gradio as gr from typing import Tuple, Optional, List from openai import OpenAI from utils.db_utils import DatabaseUtils from utils.embedding_utils import parallel_generate_embeddings def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]: """Create the embeddings generation tab UI Args: openai_client: OpenAI client instance db_utils: DatabaseUtils instance databases: List of available databases Returns: Tuple[gr.Tab, dict]: The tab component and its interface elements """ def update_collections(db_name: str) -> gr.Dropdown: """Update collections dropdown when database changes""" collections = db_utils.get_collections(db_name) # If there's only one collection, select it by default value = collections[0] if len(collections) == 1 else None return gr.Dropdown(choices=collections, value=value) def update_fields(db_name: str, collection_name: str) -> gr.Dropdown: """Update fields dropdown when collection changes""" if db_name and collection_name: fields = db_utils.get_field_names(db_name, collection_name) return gr.Dropdown(choices=fields) return gr.Dropdown(choices=[]) def generate_embeddings( db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress() ) -> Tuple[str, str]: """Generate embeddings for documents with progress tracking""" try: db = db_utils.client[db_name] collection = db[collection_name] # Count documents that need embeddings total_docs = collection.count_documents({field_name: {"$exists": True}}) if total_docs == 0: return f"No documents found with field '{field_name}'", "" # Get total count of documents that need processing query = { field_name: {"$exists": True}, embedding_field: {"$exists": False} # Only get docs without embeddings } total_to_process = collection.count_documents(query) if total_to_process == 0: return "No documents found that need embeddings", "" # Apply limit if specified if limit > 0: total_to_process = min(total_to_process, limit) print(f"\nFound {total_to_process} documents that need embeddings...") # Progress tracking progress_text = "" def update_progress(prog: float, processed: int, total: int): nonlocal progress_text progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n" print(progress_text) # Terminal logging progress(prog/100, f"Processed {processed}/{total} documents") # Show initial progress update_progress(0, 0, total_to_process) # Create cursor for batch processing cursor = collection.find(query) if limit > 0: cursor = cursor.limit(limit) # Generate embeddings in parallel with cursor-based batching processed = parallel_generate_embeddings( collection=collection, cursor=cursor, field_name=field_name, embedding_field=embedding_field, openai_client=openai_client, total_docs=total_to_process, callback=update_progress ) # Return completion message and final progress instructions = f""" Successfully generated embeddings for {processed} documents using parallel processing! To create the vector search index in MongoDB Atlas: 1. Go to your Atlas cluster 2. Click on 'Search' tab 3. Create an index named 'vector_index' with this configuration: {{ "fields": [ {{ "type": "vector", "path": "{embedding_field}", "numDimensions": 1536, "similarity": "dotProduct" }} ] }} You can now use the search tab with: - Field to search: {field_name} - Embedding field: {embedding_field} """ return instructions, progress_text except Exception as e: return f"Error: {str(e)}", "" # Create the tab UI with gr.Tab("Generate Embeddings") as tab: with gr.Row(): db_input = gr.Dropdown( choices=databases, label="Select Database", info="Available databases in Atlas cluster" ) collection_input = gr.Dropdown( choices=[], label="Select Collection", info="Collections in selected database" ) with gr.Row(): field_input = gr.Dropdown( choices=[], label="Select Field for Embeddings", info="Fields available in collection" ) embedding_field_input = gr.Textbox( label="Embedding Field Name", value="embedding", info="Field name where embeddings will be stored" ) limit_input = gr.Number( label="Document Limit", value=10, minimum=0, info="Number of documents to process (0 for all documents)" ) generate_btn = gr.Button("Generate Embeddings") generate_output = gr.Textbox(label="Results", lines=10) progress_output = gr.Textbox(label="Progress", lines=3) # Set up event handlers db_input.change( fn=update_collections, inputs=[db_input], outputs=[collection_input] ) collection_input.change( fn=update_fields, inputs=[db_input, collection_input], outputs=[field_input] ) generate_btn.click( fn=generate_embeddings, inputs=[ db_input, collection_input, field_input, embedding_field_input, limit_input ], outputs=[generate_output, progress_output] ) # Return the tab and its interface elements interface = { 'db_input': db_input, 'collection_input': collection_input, 'field_input': field_input, 'embedding_field_input': embedding_field_input, 'limit_input': limit_input, 'generate_btn': generate_btn, 'generate_output': generate_output, 'progress_output': progress_output } return tab, interface