Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from typing import Tuple, Optional, List | |
| from openai import OpenAI | |
| from utils.db_utils import DatabaseUtils | |
| from utils.embedding_utils import parallel_generate_embeddings | |
| def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]: | |
| """Create the embeddings generation tab UI | |
| Args: | |
| openai_client: OpenAI client instance | |
| db_utils: DatabaseUtils instance | |
| databases: List of available databases | |
| Returns: | |
| Tuple[gr.Tab, dict]: The tab component and its interface elements | |
| """ | |
| def update_collections(db_name: str) -> gr.Dropdown: | |
| """Update collections dropdown when database changes""" | |
| collections = db_utils.get_collections(db_name) | |
| # If there's only one collection, select it by default | |
| value = collections[0] if len(collections) == 1 else None | |
| return gr.Dropdown(choices=collections, value=value) | |
| def update_fields(db_name: str, collection_name: str) -> gr.Dropdown: | |
| """Update fields dropdown when collection changes""" | |
| if db_name and collection_name: | |
| fields = db_utils.get_field_names(db_name, collection_name) | |
| return gr.Dropdown(choices=fields) | |
| return gr.Dropdown(choices=[]) | |
| def generate_embeddings( | |
| db_name: str, | |
| collection_name: str, | |
| field_name: str, | |
| embedding_field: str, | |
| limit: int = 10, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str]: | |
| """Generate embeddings for documents with progress tracking""" | |
| try: | |
| db = db_utils.client[db_name] | |
| collection = db[collection_name] | |
| # Count documents that need embeddings | |
| total_docs = collection.count_documents({field_name: {"$exists": True}}) | |
| if total_docs == 0: | |
| return f"No documents found with field '{field_name}'", "" | |
| # Get total count of documents that need processing | |
| query = { | |
| field_name: {"$exists": True}, | |
| embedding_field: {"$exists": False} # Only get docs without embeddings | |
| } | |
| total_to_process = collection.count_documents(query) | |
| if total_to_process == 0: | |
| return "No documents found that need embeddings", "" | |
| # Apply limit if specified | |
| if limit > 0: | |
| total_to_process = min(total_to_process, limit) | |
| print(f"\nFound {total_to_process} documents that need embeddings...") | |
| # Progress tracking | |
| progress_text = "" | |
| def update_progress(prog: float, processed: int, total: int): | |
| nonlocal progress_text | |
| progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n" | |
| print(progress_text) # Terminal logging | |
| progress(prog/100, f"Processed {processed}/{total} documents") | |
| # Show initial progress | |
| update_progress(0, 0, total_to_process) | |
| # Create cursor for batch processing | |
| cursor = collection.find(query) | |
| if limit > 0: | |
| cursor = cursor.limit(limit) | |
| # Generate embeddings in parallel with cursor-based batching | |
| processed = parallel_generate_embeddings( | |
| collection=collection, | |
| cursor=cursor, | |
| field_name=field_name, | |
| embedding_field=embedding_field, | |
| openai_client=openai_client, | |
| total_docs=total_to_process, | |
| callback=update_progress | |
| ) | |
| # Return completion message and final progress | |
| instructions = f""" | |
| Successfully generated embeddings for {processed} documents using parallel processing! | |
| To create the vector search index in MongoDB Atlas: | |
| 1. Go to your Atlas cluster | |
| 2. Click on 'Search' tab | |
| 3. Create an index named 'vector_index' with this configuration: | |
| {{ | |
| "fields": [ | |
| {{ | |
| "type": "vector", | |
| "path": "{embedding_field}", | |
| "numDimensions": 1536, | |
| "similarity": "dotProduct" | |
| }} | |
| ] | |
| }} | |
| You can now use the search tab with: | |
| - Field to search: {field_name} | |
| - Embedding field: {embedding_field} | |
| """ | |
| return instructions, progress_text | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| # Create the tab UI | |
| with gr.Tab("Generate Embeddings") as tab: | |
| with gr.Row(): | |
| db_input = gr.Dropdown( | |
| choices=databases, | |
| label="Select Database", | |
| info="Available databases in Atlas cluster" | |
| ) | |
| collection_input = gr.Dropdown( | |
| choices=[], | |
| label="Select Collection", | |
| info="Collections in selected database" | |
| ) | |
| with gr.Row(): | |
| field_input = gr.Dropdown( | |
| choices=[], | |
| label="Select Field for Embeddings", | |
| info="Fields available in collection" | |
| ) | |
| embedding_field_input = gr.Textbox( | |
| label="Embedding Field Name", | |
| value="embedding", | |
| info="Field name where embeddings will be stored" | |
| ) | |
| limit_input = gr.Number( | |
| label="Document Limit", | |
| value=10, | |
| minimum=0, | |
| info="Number of documents to process (0 for all documents)" | |
| ) | |
| generate_btn = gr.Button("Generate Embeddings") | |
| generate_output = gr.Textbox(label="Results", lines=10) | |
| progress_output = gr.Textbox(label="Progress", lines=3) | |
| # Set up event handlers | |
| db_input.change( | |
| fn=update_collections, | |
| inputs=[db_input], | |
| outputs=[collection_input] | |
| ) | |
| collection_input.change( | |
| fn=update_fields, | |
| inputs=[db_input, collection_input], | |
| outputs=[field_input] | |
| ) | |
| generate_btn.click( | |
| fn=generate_embeddings, | |
| inputs=[ | |
| db_input, | |
| collection_input, | |
| field_input, | |
| embedding_field_input, | |
| limit_input | |
| ], | |
| outputs=[generate_output, progress_output] | |
| ) | |
| # Return the tab and its interface elements | |
| interface = { | |
| 'db_input': db_input, | |
| 'collection_input': collection_input, | |
| 'field_input': field_input, | |
| 'embedding_field_input': embedding_field_input, | |
| 'limit_input': limit_input, | |
| 'generate_btn': generate_btn, | |
| 'generate_output': generate_output, | |
| 'progress_output': progress_output | |
| } | |
| return tab, interface | |