|
import os |
|
import gradio as gr |
|
from openai import OpenAI |
|
import json |
|
from dotenv import load_dotenv |
|
from db_utils import DatabaseUtils |
|
from embedding_utils import parallel_generate_embeddings, get_embedding |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
openai_client = OpenAI() |
|
|
|
|
|
db_utils = DatabaseUtils() |
|
|
|
def get_field_names(db_name: str, collection_name: str) -> list[str]: |
|
"""Get list of fields in the collection""" |
|
return db_utils.get_field_names(db_name, collection_name) |
|
|
|
def generate_embeddings_for_field(db_name: str, collection_name: str, field_name: str, embedding_field: str, limit: int = 10, progress=gr.Progress()) -> tuple[str, str]: |
|
"""Generate embeddings for documents in parallel with progress tracking""" |
|
try: |
|
db = db_utils.client[db_name] |
|
collection = db[collection_name] |
|
|
|
|
|
total_docs = collection.count_documents({field_name: {"$exists": True}}) |
|
if total_docs == 0: |
|
return f"No documents found with field '{field_name}'", "" |
|
|
|
|
|
query = { |
|
field_name: {"$exists": True}, |
|
embedding_field: {"$exists": False} |
|
} |
|
total_to_process = collection.count_documents(query) |
|
if total_to_process == 0: |
|
return "No documents found that need embeddings", "" |
|
|
|
|
|
if limit > 0: |
|
total_to_process = min(total_to_process, limit) |
|
|
|
print(f"\nFound {total_to_process} documents that need embeddings...") |
|
|
|
|
|
progress_text = "" |
|
def update_progress(prog: float, processed: int, total: int): |
|
nonlocal progress_text |
|
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n" |
|
print(progress_text) |
|
progress(prog/100, f"Processed {processed}/{total} documents") |
|
|
|
|
|
update_progress(0, 0, total_to_process) |
|
|
|
|
|
cursor = collection.find(query) |
|
if limit > 0: |
|
cursor = cursor.limit(limit) |
|
|
|
|
|
processed = parallel_generate_embeddings( |
|
collection=collection, |
|
cursor=cursor, |
|
field_name=field_name, |
|
embedding_field=embedding_field, |
|
openai_client=openai_client, |
|
total_docs=total_to_process, |
|
callback=update_progress |
|
) |
|
|
|
|
|
instructions = f""" |
|
Successfully generated embeddings for {processed} documents using parallel processing! |
|
|
|
To create the vector search index in MongoDB Atlas: |
|
1. Go to your Atlas cluster |
|
2. Click on 'Search' tab |
|
3. Create an index named 'vector_index' with this configuration: |
|
{{ |
|
"fields": [ |
|
{{ |
|
"type": "vector", |
|
"path": "{embedding_field}", |
|
"numDimensions": 1536, |
|
"similarity": "dotProduct" |
|
}} |
|
] |
|
}} |
|
|
|
You can now use the search tab with: |
|
- Field to search: {field_name} |
|
- Embedding field: {embedding_field} |
|
""" |
|
return instructions, progress_text |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", "" |
|
|
|
def vector_search(query_text: str, db_name: str, collection_name: str, embedding_field: str, index_name: str) -> str: |
|
"""Perform vector search using embeddings""" |
|
try: |
|
print(f"\nProcessing query: {query_text}") |
|
|
|
db = db_utils.client[db_name] |
|
collection = db[collection_name] |
|
|
|
|
|
embedding = get_embedding(query_text, openai_client) |
|
print("Generated embeddings successfully") |
|
|
|
results = collection.aggregate([ |
|
{ |
|
'$vectorSearch': { |
|
"index": index_name, |
|
"path": embedding_field, |
|
"queryVector": embedding, |
|
"numCandidates": 50, |
|
"limit": 5 |
|
} |
|
}, |
|
{ |
|
"$project": { |
|
"search_score": { "$meta": "vectorSearchScore" }, |
|
"document": "$$ROOT" |
|
} |
|
} |
|
]) |
|
|
|
|
|
results_list = list(results) |
|
formatted_results = [] |
|
|
|
for idx, result in enumerate(results_list, 1): |
|
doc = result['document'] |
|
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n" |
|
|
|
for key, value in doc.items(): |
|
if key not in ['_id', embedding_field]: |
|
formatted_result += f"{key}: {value}\n" |
|
formatted_results.append(formatted_result) |
|
|
|
return "\n".join(formatted_results) if formatted_results else "No results found" |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="MongoDB Vector Search Tool") as iface: |
|
gr.Markdown("# MongoDB Vector Search Tool") |
|
|
|
|
|
databases = db_utils.get_databases() |
|
|
|
with gr.Tab("Generate Embeddings"): |
|
with gr.Row(): |
|
db_input = gr.Dropdown( |
|
choices=databases, |
|
label="Select Database", |
|
info="Available databases in Atlas cluster" |
|
) |
|
collection_input = gr.Dropdown( |
|
choices=[], |
|
label="Select Collection", |
|
info="Collections in selected database" |
|
) |
|
with gr.Row(): |
|
field_input = gr.Dropdown( |
|
choices=[], |
|
label="Select Field for Embeddings", |
|
info="Fields available in collection" |
|
) |
|
embedding_field_input = gr.Textbox( |
|
label="Embedding Field Name", |
|
value="embedding", |
|
info="Field name where embeddings will be stored" |
|
) |
|
limit_input = gr.Number( |
|
label="Document Limit", |
|
value=10, |
|
minimum=0, |
|
info="Number of documents to process (0 for all documents)" |
|
) |
|
|
|
def update_collections(db_name): |
|
collections = db_utils.get_collections(db_name) |
|
|
|
value = collections[0] if len(collections) == 1 else None |
|
return gr.Dropdown(choices=collections, value=value) |
|
|
|
def update_fields(db_name, collection_name): |
|
if db_name and collection_name: |
|
fields = get_field_names(db_name, collection_name) |
|
return gr.Dropdown(choices=fields) |
|
return gr.Dropdown(choices=[]) |
|
|
|
|
|
db_input.change( |
|
fn=update_collections, |
|
inputs=[db_input], |
|
outputs=[collection_input] |
|
) |
|
|
|
|
|
collection_input.change( |
|
fn=update_fields, |
|
inputs=[db_input, collection_input], |
|
outputs=[field_input] |
|
) |
|
|
|
generate_btn = gr.Button("Generate Embeddings") |
|
generate_output = gr.Textbox(label="Results", lines=10) |
|
progress_output = gr.Textbox(label="Progress", lines=3) |
|
|
|
generate_btn.click( |
|
generate_embeddings_for_field, |
|
inputs=[db_input, collection_input, field_input, embedding_field_input, limit_input], |
|
outputs=[generate_output, progress_output] |
|
) |
|
|
|
with gr.Tab("Search"): |
|
with gr.Row(): |
|
search_db_input = gr.Dropdown( |
|
choices=databases, |
|
label="Select Database", |
|
info="Database containing the vectors" |
|
) |
|
search_collection_input = gr.Dropdown( |
|
choices=[], |
|
label="Select Collection", |
|
info="Collection containing the vectors" |
|
) |
|
with gr.Row(): |
|
search_embedding_field_input = gr.Textbox( |
|
label="Embedding Field Name", |
|
value="embedding", |
|
info="Field containing the vectors" |
|
) |
|
search_index_input = gr.Textbox( |
|
label="Vector Search Index Name", |
|
value="vector_index", |
|
info="Index created in Atlas UI" |
|
) |
|
|
|
|
|
search_db_input.change( |
|
fn=update_collections, |
|
inputs=[search_db_input], |
|
outputs=[search_collection_input] |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="Search Query", |
|
lines=2, |
|
placeholder="What would you like to search for?" |
|
) |
|
search_btn = gr.Button("Search") |
|
search_output = gr.Textbox(label="Results", lines=10) |
|
|
|
search_btn.click( |
|
vector_search, |
|
inputs=[ |
|
query_input, |
|
search_db_input, |
|
search_collection_input, |
|
search_embedding_field_input, |
|
search_index_input |
|
], |
|
outputs=search_output |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |
|
|