Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from openai import OpenAI | |
import json | |
from dotenv import load_dotenv | |
from db_utils import DatabaseUtils | |
from embedding_utils import parallel_generate_embeddings, get_embedding | |
# Load environment variables from .env file | |
load_dotenv() | |
def check_credentials() -> tuple[bool, str]: | |
"""Check if required credentials are set and valid""" | |
atlas_uri = os.getenv("ATLAS_URI") | |
openai_key = os.getenv("OPENAI_API_KEY") | |
if not atlas_uri: | |
return False, """Please set up your MongoDB Atlas credentials: | |
1. Go to Settings tab | |
2. Add ATLAS_URI as a Repository Secret | |
3. Paste your MongoDB connection string (should start with 'mongodb+srv://')""" | |
if not openai_key: | |
return False, """Please set up your OpenAI API key: | |
1. Go to Settings tab | |
2. Add OPENAI_API_KEY as a Repository Secret | |
3. Paste your OpenAI API key""" | |
return True, "" | |
def init_clients(): | |
"""Initialize OpenAI and MongoDB clients""" | |
try: | |
openai_client = OpenAI() | |
db_utils = DatabaseUtils() | |
return openai_client, db_utils | |
except Exception as e: | |
return None, None | |
def get_field_names(db_name: str, collection_name: str) -> list[str]: | |
"""Get list of fields in the collection""" | |
return db_utils.get_field_names(db_name, collection_name) | |
def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]: | |
"""Generate embeddings for documents in parallel with progress tracking""" | |
try: | |
db = db_utils.client[db_name] | |
collection = db[collection_name] | |
# Count documents that need embeddings | |
total_docs = collection.count_documents({field_name: {"$exists": True}}) | |
if total_docs == 0: | |
return f"No documents found with field '{field_name}'", "" | |
# Get total count of documents that need processing | |
query = { | |
field_name: {"$exists": True}, | |
embedding_field: {"$exists": False} # Only get docs without embeddings | |
} | |
total_to_process = collection.count_documents(query) | |
if total_to_process == 0: | |
return "No documents found that need embeddings", "" | |
# Apply limit if specified | |
if limit > 0: | |
total_to_process = min(total_to_process, limit) | |
print(f"\nFound {total_to_process} documents that need embeddings...") | |
# Progress tracking | |
progress_text = "" | |
def update_progress(prog: float, processed: int, total: int): | |
nonlocal progress_text | |
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n" | |
print(progress_text) # Terminal logging | |
progress(prog/100, f"Processed {processed}/{total} documents") | |
# Show initial progress | |
update_progress(0, 0, total_to_process) | |
# Create cursor for batch processing | |
cursor = collection.find(query) | |
if limit > 0: | |
cursor = cursor.limit(limit) | |
# Generate embeddings in parallel with cursor-based batching | |
processed = parallel_generate_embeddings( | |
collection=collection, | |
cursor=cursor, | |
field_name=field_name, | |
embedding_field=embedding_field, | |
openai_client=openai_client, | |
total_docs=total_to_process, | |
callback=update_progress | |
) | |
# Return completion message and final progress | |
instructions = f""" | |
Successfully generated embeddings for {processed} documents using parallel processing! | |
To create the vector search index in MongoDB Atlas: | |
1. Go to your Atlas cluster | |
2. Click on 'Search' tab | |
3. Create an index named 'vector_index' with this configuration: | |
{{ | |
"fields": [ | |
{{ | |
"type": "vector", | |
"path": "{embedding_field}", | |
"numDimensions": 1536, | |
"similarity": "dotProduct" | |
}} | |
] | |
}} | |
You can now use the search tab with: | |
- Field to search: {field_name} | |
- Embedding field: {embedding_field} | |
""" | |
return instructions, progress_text | |
except Exception as e: | |
return f"Error: {str(e)}", "" | |
def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str: | |
"""Perform vector search using embeddings""" | |
try: | |
print(f"\nProcessing query: {query_text}") | |
db = db_utils.client[db_name] | |
collection = db[collection_name] | |
# Get embeddings for query | |
embedding = get_embedding(query_text, openai_client) | |
print("Generated embeddings successfully") | |
results = collection.aggregate([ | |
{ | |
'$vectorSearch': { | |
"index": index_name, | |
"path": embedding_field, | |
"queryVector": embedding, | |
"numCandidates": 50, | |
"limit": 5 | |
} | |
}, | |
{ | |
"$project": { | |
"search_score": { "$meta": "vectorSearchScore" }, | |
"document": "$$ROOT" | |
} | |
} | |
]) | |
# Format results | |
results_list = list(results) | |
formatted_results = [] | |
for idx, result in enumerate(results_list, 1): | |
doc = result['document'] | |
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n" | |
# Add all fields except _id and embeddings | |
for key, value in doc.items(): | |
if key not in ['_id', embedding_field]: | |
formatted_result += f"{key}: {value}\n" | |
formatted_results.append(formatted_result) | |
return "\n".join(formatted_results) if formatted_results else "No results found" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio interface with tabs | |
with gr.Blocks(title="MongoDB Vector Search Tool") as iface: | |
gr.Markdown("# MongoDB Vector Search Tool") | |
# Check credentials first | |
has_creds, cred_message = check_credentials() | |
if not has_creds: | |
gr.Markdown(f""" | |
## ⚠️ Setup Required | |
{cred_message} | |
After setting up credentials, refresh this page. | |
""") | |
else: | |
# Initialize clients | |
openai_client, db_utils = init_clients() | |
if not openai_client or not db_utils: | |
gr.Markdown(""" | |
## ⚠️ Connection Error | |
Failed to connect to MongoDB Atlas or OpenAI. Please check your credentials and try again. | |
""") | |
else: | |
# Get available databases | |
try: | |
databases = db_utils.get_databases() | |
with gr.Tab("Generate Embeddings"): | |
with gr.Row(): | |
db_input = gr.Dropdown( | |
choices=databases, | |
label="Select Database", | |
info="Available databases in Atlas cluster" | |
) | |
collection_input = gr.Dropdown( | |
choices=[], | |
label="Select Collection", | |
info="Collections in selected database" | |
) | |
with gr.Row(): | |
field_input = gr.Dropdown( | |
choices=[], | |
label="Select Field for Embeddings", | |
info="Fields available in collection" | |
) | |
embedding_field_input = gr.Textbox( | |
label="Embedding Field Name", | |
value="embedding", | |
info="Field name where embeddings will be stored" | |
) | |
limit_input = gr.Number( | |
label="Document Limit", | |
value=10, | |
minimum=0, | |
info="Number of documents to process (0 for all documents)" | |
) | |
def update_collections(db_name): | |
collections = db_utils.get_collections(db_name) | |
# If there's only one collection, select it by default | |
value = collections[0] if len(collections) == 1 else None | |
return gr.Dropdown(choices=collections, value=value) | |
def update_fields(db_name, collection_name): | |
if db_name and collection_name: | |
fields = get_field_names(db_name, collection_name) | |
return gr.Dropdown(choices=fields) | |
return gr.Dropdown(choices=[]) | |
# Update collections when database changes | |
db_input.change( | |
fn=update_collections, | |
inputs=[db_input], | |
outputs=[collection_input] | |
) | |
# Update fields when collection changes | |
collection_input.change( | |
fn=update_fields, | |
inputs=[db_input, collection_input], | |
outputs=[field_input] | |
) | |
generate_btn = gr.Button("Generate Embeddings") | |
generate_output = gr.Textbox(label="Results", lines=10) | |
progress_output = gr.Textbox(label="Progress", lines=3) | |
generate_btn.click( | |
generate_embeddings_for_field, | |
inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input], | |
outputs=[generate_output, progress_output] | |
) | |
with gr.Tab("Search"): | |
with gr.Row(): | |
search_db_input = gr.Dropdown( | |
choices=databases, | |
label="Select Database", | |
info="Database containing the vectors" | |
) | |
search_collection_input = gr.Dropdown( | |
choices=[], | |
label="Select Collection", | |
info="Collection containing the vectors" | |
) | |
with gr.Row(): | |
search_embedding_field_input = gr.Textbox( | |
label="Embedding Field Name", | |
value="embedding", | |
info="Field containing the vectors" | |
) | |
search_index_input = gr.Textbox( | |
label="Vector Search Index Name", | |
value="vector_index", | |
info="Index created in Atlas UI" | |
) | |
# Update collections when database changes | |
search_db_input.change( | |
fn=update_collections, | |
inputs=[search_db_input], | |
outputs=[search_collection_input] | |
) | |
query_input = gr.Textbox( | |
label="Search Query", | |
lines=2, | |
placeholder="What would you like to search for?" | |
) | |
search_btn = gr.Button("Search") | |
search_output = gr.Textbox(label="Results", lines=10) | |
search_btn.click( | |
vector_search, | |
inputs=[ | |
query_input, | |
search_db_input, | |
search_collection_input, | |
search_embedding_field_input, | |
search_index_input | |
], | |
outputs=search_output | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |