airabbitX's picture
Upload 9 files
8fb6e2f verified
raw
history blame
9.75 kB
import os
import gradio as gr
from openai import OpenAI
import json
from dotenv import load_dotenv
from db_utils import DatabaseUtils
from embedding_utils import parallel_generate_embeddings, get_embedding
# Load environment variables from .env file
load_dotenv()
# Initialize OpenAI client
openai_client = OpenAI()
# Initialize database utils
db_utils = DatabaseUtils()
def get_field_names(db_name: str, collection_name: str) -> list[str]:
"""Get list of fields in the collection"""
return db_utils.get_field_names(db_name, collection_name)
def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]:
"""Generate embeddings for documents in parallel with progress tracking"""
try:
db = db_utils.client[db_name]
collection = db[collection_name]
# Count documents that need embeddings
total_docs = collection.count_documents({field_name: {"$exists": True}})
if total_docs == 0:
return f"No documents found with field '{field_name}'", ""
# Get total count of documents that need processing
query = {
field_name: {"$exists": True},
embedding_field: {"$exists": False} # Only get docs without embeddings
}
total_to_process = collection.count_documents(query)
if total_to_process == 0:
return "No documents found that need embeddings", ""
# Apply limit if specified
if limit > 0:
total_to_process = min(total_to_process, limit)
print(f"\nFound {total_to_process} documents that need embeddings...")
# Progress tracking
progress_text = ""
def update_progress(prog: float, processed: int, total: int):
nonlocal progress_text
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
print(progress_text) # Terminal logging
progress(prog/100, f"Processed {processed}/{total} documents")
# Show initial progress
update_progress(0, 0, total_to_process)
# Create cursor for batch processing
cursor = collection.find(query)
if limit > 0:
cursor = cursor.limit(limit)
# Generate embeddings in parallel with cursor-based batching
processed = parallel_generate_embeddings(
collection=collection,
cursor=cursor,
field_name=field_name,
embedding_field=embedding_field,
openai_client=openai_client,
total_docs=total_to_process,
callback=update_progress
)
# Return completion message and final progress
instructions = f"""
Successfully generated embeddings for {processed} documents using parallel processing!
To create the vector search index in MongoDB Atlas:
1. Go to your Atlas cluster
2. Click on 'Search' tab
3. Create an index named 'vector_index' with this configuration:
{{
"fields": [
{{
"type": "vector",
"path": "{embedding_field}",
"numDimensions": 1536,
"similarity": "dotProduct"
}}
]
}}
You can now use the search tab with:
- Field to search: {field_name}
- Embedding field: {embedding_field}
"""
return instructions, progress_text
except Exception as e:
return f"Error: {str(e)}", ""
def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str:
"""Perform vector search using embeddings"""
try:
print(f"\nProcessing query: {query_text}")
db = db_utils.client[db_name]
collection = db[collection_name]
# Get embeddings for query
embedding = get_embedding(query_text, openai_client)
print("Generated embeddings successfully")
results = collection.aggregate([
{
'$vectorSearch': {
"index": index_name,
"path": embedding_field,
"queryVector": embedding,
"numCandidates": 50,
"limit": 5
}
},
{
"$project": {
"search_score": { "$meta": "vectorSearchScore" },
"document": "$$ROOT"
}
}
])
# Format results
results_list = list(results)
formatted_results = []
for idx, result in enumerate(results_list, 1):
doc = result['document']
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
# Add all fields except _id and embeddings
for key, value in doc.items():
if key not in ['_id', embedding_field]:
formatted_result += f"{key}: {value}\n"
formatted_results.append(formatted_result)
return "\n".join(formatted_results) if formatted_results else "No results found"
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface with tabs
with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
gr.Markdown("# MongoDB Vector Search Tool")
# Get available databases
databases = db_utils.get_databases()
with gr.Tab("Generate Embeddings"):
with gr.Row():
db_input = gr.Dropdown(
choices=databases,
label="Select Database",
info="Available databases in Atlas cluster"
)
collection_input = gr.Dropdown(
choices=[],
label="Select Collection",
info="Collections in selected database"
)
with gr.Row():
field_input = gr.Dropdown(
choices=[],
label="Select Field for Embeddings",
info="Fields available in collection"
)
embedding_field_input = gr.Textbox(
label="Embedding Field Name",
value="embedding",
info="Field name where embeddings will be stored"
)
limit_input = gr.Number(
label="Document Limit",
value=10,
minimum=0,
info="Number of documents to process (0 for all documents)"
)
def update_collections(db_name):
collections = db_utils.get_collections(db_name)
# If there's only one collection, select it by default
value = collections[0] if len(collections) == 1 else None
return gr.Dropdown(choices=collections, value=value)
def update_fields(db_name, collection_name):
if db_name and collection_name:
fields = get_field_names(db_name, collection_name)
return gr.Dropdown(choices=fields)
return gr.Dropdown(choices=[])
# Update collections when database changes
db_input.change(
fn=update_collections,
inputs=[db_input],
outputs=[collection_input]
)
# Update fields when collection changes
collection_input.change(
fn=update_fields,
inputs=[db_input, collection_input],
outputs=[field_input]
)
generate_btn = gr.Button("Generate Embeddings")
generate_output = gr.Textbox(label="Results", lines=10)
progress_output = gr.Textbox(label="Progress", lines=3)
generate_btn.click(
generate_embeddings_for_field,
inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input],
outputs=[generate_output, progress_output]
)
with gr.Tab("Search"):
with gr.Row():
search_db_input = gr.Dropdown(
choices=databases,
label="Select Database",
info="Database containing the vectors"
)
search_collection_input = gr.Dropdown(
choices=[],
label="Select Collection",
info="Collection containing the vectors"
)
with gr.Row():
search_embedding_field_input = gr.Textbox(
label="Embedding Field Name",
value="embedding",
info="Field containing the vectors"
)
search_index_input = gr.Textbox(
label="Vector Search Index Name",
value="vector_index",
info="Index created in Atlas UI"
)
# Update collections when database changes
search_db_input.change(
fn=update_collections,
inputs=[search_db_input],
outputs=[search_collection_input]
)
query_input = gr.Textbox(
label="Search Query",
lines=2,
placeholder="What would you like to search for?"
)
search_btn = gr.Button("Search")
search_output = gr.Textbox(label="Results", lines=10)
search_btn.click(
vector_search,
inputs=[
query_input,
search_db_input,
search_collection_input,
search_embedding_field_input,
search_index_input
],
outputs=search_output
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)