|
import gradio as gr |
|
from typing import Tuple, List |
|
from openai import OpenAI |
|
from utils.db_utils import DatabaseUtils |
|
from utils.embedding_utils import get_embedding |
|
|
|
def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]: |
|
"""Create the vector search tab UI |
|
|
|
Args: |
|
openai_client: OpenAI client instance |
|
db_utils: DatabaseUtils instance |
|
databases: List of available databases |
|
|
|
Returns: |
|
Tuple[gr.Tab, dict]: The tab component and its interface elements |
|
""" |
|
def update_collections(db_name: str) -> gr.Dropdown: |
|
"""Update collections dropdown when database changes""" |
|
collections = db_utils.get_collections(db_name) |
|
|
|
value = collections[0] if len(collections) == 1 else None |
|
return gr.Dropdown(choices=collections, value=value) |
|
|
|
def vector_search( |
|
query_text: str, |
|
db_name: str, |
|
collection_name: str, |
|
embedding_field: str, |
|
index_name: str |
|
) -> str: |
|
"""Perform vector search using embeddings""" |
|
try: |
|
print(f"\nProcessing query: {query_text}") |
|
|
|
db = db_utils.client[db_name] |
|
collection = db[collection_name] |
|
|
|
|
|
embedding = get_embedding(query_text, openai_client) |
|
print("Generated embeddings successfully") |
|
|
|
results = collection.aggregate([ |
|
{ |
|
'$vectorSearch': { |
|
"index": index_name, |
|
"path": embedding_field, |
|
"queryVector": embedding, |
|
"numCandidates": 50, |
|
"limit": 5 |
|
} |
|
}, |
|
{ |
|
"$project": { |
|
"search_score": { "$meta": "vectorSearchScore" }, |
|
"document": "$$ROOT" |
|
} |
|
} |
|
]) |
|
|
|
|
|
results_list = list(results) |
|
formatted_results = [] |
|
|
|
for idx, result in enumerate(results_list, 1): |
|
doc = result['document'] |
|
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n" |
|
|
|
for key, value in doc.items(): |
|
if key not in ['_id', embedding_field]: |
|
formatted_result += f"{key}: {value}\n" |
|
formatted_results.append(formatted_result) |
|
|
|
return "\n".join(formatted_results) if formatted_results else "No results found" |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Tab("Search") as tab: |
|
with gr.Row(): |
|
db_input = gr.Dropdown( |
|
choices=databases, |
|
label="Select Database", |
|
info="Database containing the vectors" |
|
) |
|
collection_input = gr.Dropdown( |
|
choices=[], |
|
label="Select Collection", |
|
info="Collection containing the vectors" |
|
) |
|
with gr.Row(): |
|
embedding_field_input = gr.Textbox( |
|
label="Embedding Field Name", |
|
value="embedding", |
|
info="Field containing the vectors" |
|
) |
|
index_input = gr.Textbox( |
|
label="Vector Search Index Name", |
|
value="vector_index", |
|
info="Index created in Atlas UI" |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="Search Query", |
|
lines=2, |
|
placeholder="What would you like to search for?" |
|
) |
|
search_btn = gr.Button("Search") |
|
search_output = gr.Textbox(label="Results", lines=10) |
|
|
|
|
|
db_input.change( |
|
fn=update_collections, |
|
inputs=[db_input], |
|
outputs=[collection_input] |
|
) |
|
|
|
search_btn.click( |
|
fn=vector_search, |
|
inputs=[ |
|
query_input, |
|
db_input, |
|
collection_input, |
|
embedding_field_input, |
|
index_input |
|
], |
|
outputs=search_output |
|
) |
|
|
|
|
|
interface = { |
|
'db_input': db_input, |
|
'collection_input': collection_input, |
|
'embedding_field_input': embedding_field_input, |
|
'index_input': index_input, |
|
'query_input': query_input, |
|
'search_btn': search_btn, |
|
'search_output': search_output |
|
} |
|
|
|
return tab, interface |
|
|