import os import gradio as gr from openai import OpenAI import json from dotenv import load_dotenv from db_utils import DatabaseUtils from embedding_utils import parallel_generate_embeddings, get_embedding # Load environment variables from .env file load_dotenv() # Initialize OpenAI client openai_client = OpenAI() # Initialize database utils db_utils = DatabaseUtils() def get_field_names(db_name: str, collection_name: str) -> list[str]: """Get list of fields in the collection""" return db_utils.get_field_names(db_name, collection_name) def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]: """Generate embeddings for documents in parallel with progress tracking""" try: db = db_utils.client[db_name] collection = db[collection_name] # Count documents that need embeddings total_docs = collection.count_documents({field_name: {"$exists": True}}) if total_docs == 0: return f"No documents found with field '{field_name}'", "" # Get total count of documents that need processing query = { field_name: {"$exists": True}, embedding_field: {"$exists": False} # Only get docs without embeddings } total_to_process = collection.count_documents(query) if total_to_process == 0: return "No documents found that need embeddings", "" # Apply limit if specified if limit > 0: total_to_process = min(total_to_process, limit) print(f"\nFound {total_to_process} documents that need embeddings...") # Progress tracking progress_text = "" def update_progress(prog: float, processed: int, total: int): nonlocal progress_text progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n" print(progress_text) # Terminal logging progress(prog/100, f"Processed {processed}/{total} documents") # Show initial progress update_progress(0, 0, total_to_process) # Create cursor for batch processing cursor = collection.find(query) if limit > 0: cursor = cursor.limit(limit) # Generate embeddings in parallel with cursor-based batching processed = parallel_generate_embeddings( collection=collection, cursor=cursor, field_name=field_name, embedding_field=embedding_field, openai_client=openai_client, total_docs=total_to_process, callback=update_progress ) # Return completion message and final progress instructions = f""" Successfully generated embeddings for {processed} documents using parallel processing! To create the vector search index in MongoDB Atlas: 1. Go to your Atlas cluster 2. Click on 'Search' tab 3. Create an index named 'vector_index' with this configuration: {{ "fields": [ {{ "type": "vector", "path": "{embedding_field}", "numDimensions": 1536, "similarity": "dotProduct" }} ] }} You can now use the search tab with: - Field to search: {field_name} - Embedding field: {embedding_field} """ return instructions, progress_text except Exception as e: return f"Error: {str(e)}", "" def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str: """Perform vector search using embeddings""" try: print(f"\nProcessing query: {query_text}") db = db_utils.client[db_name] collection = db[collection_name] # Get embeddings for query embedding = get_embedding(query_text, openai_client) print("Generated embeddings successfully") results = collection.aggregate([ { '$vectorSearch': { "index": index_name, "path": embedding_field, "queryVector": embedding, "numCandidates": 50, "limit": 5 } }, { "$project": { "search_score": { "$meta": "vectorSearchScore" }, "document": "$$ROOT" } } ]) # Format results results_list = list(results) formatted_results = [] for idx, result in enumerate(results_list, 1): doc = result['document'] formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n" # Add all fields except _id and embeddings for key, value in doc.items(): if key not in ['_id', embedding_field]: formatted_result += f"{key}: {value}\n" formatted_results.append(formatted_result) return "\n".join(formatted_results) if formatted_results else "No results found" except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with tabs with gr.Blocks(title="MongoDB Vector Search Tool") as iface: gr.Markdown("# MongoDB Vector Search Tool") # Get available databases databases = db_utils.get_databases() with gr.Tab("Generate Embeddings"): with gr.Row(): db_input = gr.Dropdown( choices=databases, label="Select Database", info="Available databases in Atlas cluster" ) collection_input = gr.Dropdown( choices=[], label="Select Collection", info="Collections in selected database" ) with gr.Row(): field_input = gr.Dropdown( choices=[], label="Select Field for Embeddings", info="Fields available in collection" ) embedding_field_input = gr.Textbox( label="Embedding Field Name", value="embedding", info="Field name where embeddings will be stored" ) limit_input = gr.Number( label="Document Limit", value=10, minimum=0, info="Number of documents to process (0 for all documents)" ) def update_collections(db_name): collections = db_utils.get_collections(db_name) # If there's only one collection, select it by default value = collections[0] if len(collections) == 1 else None return gr.Dropdown(choices=collections, value=value) def update_fields(db_name, collection_name): if db_name and collection_name: fields = get_field_names(db_name, collection_name) return gr.Dropdown(choices=fields) return gr.Dropdown(choices=[]) # Update collections when database changes db_input.change( fn=update_collections, inputs=[db_input], outputs=[collection_input] ) # Update fields when collection changes collection_input.change( fn=update_fields, inputs=[db_input, collection_input], outputs=[field_input] ) generate_btn = gr.Button("Generate Embeddings") generate_output = gr.Textbox(label="Results", lines=10) progress_output = gr.Textbox(label="Progress", lines=3) generate_btn.click( generate_embeddings_for_field, inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input], outputs=[generate_output, progress_output] ) with gr.Tab("Search"): with gr.Row(): search_db_input = gr.Dropdown( choices=databases, label="Select Database", info="Database containing the vectors" ) search_collection_input = gr.Dropdown( choices=[], label="Select Collection", info="Collection containing the vectors" ) with gr.Row(): search_embedding_field_input = gr.Textbox( label="Embedding Field Name", value="embedding", info="Field containing the vectors" ) search_index_input = gr.Textbox( label="Vector Search Index Name", value="vector_index", info="Index created in Atlas UI" ) # Update collections when database changes search_db_input.change( fn=update_collections, inputs=[search_db_input], outputs=[search_collection_input] ) query_input = gr.Textbox( label="Search Query", lines=2, placeholder="What would you like to search for?" ) search_btn = gr.Button("Search") search_output = gr.Textbox(label="Results", lines=10) search_btn.click( vector_search, inputs=[ query_input, search_db_input, search_collection_input, search_embedding_field_input, search_index_input ], outputs=search_output ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)