File size: 4,916 Bytes
50e3a95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
from typing import Tuple, List
from openai import OpenAI
from utils.db_utils import DatabaseUtils
from utils.embedding_utils import get_embedding
def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
"""Create the vector search tab UI
Args:
openai_client: OpenAI client instance
db_utils: DatabaseUtils instance
databases: List of available databases
Returns:
Tuple[gr.Tab, dict]: The tab component and its interface elements
"""
def update_collections(db_name: str) -> gr.Dropdown:
"""Update collections dropdown when database changes"""
collections = db_utils.get_collections(db_name)
# If there's only one collection, select it by default
value = collections[0] if len(collections) == 1 else None
return gr.Dropdown(choices=collections, value=value)
def vector_search(
query_text: str,
db_name: str,
collection_name: str,
embedding_field: str,
index_name: str
) -> str:
"""Perform vector search using embeddings"""
try:
print(f"\nProcessing query: {query_text}")
db = db_utils.client[db_name]
collection = db[collection_name]
# Get embeddings for query
embedding = get_embedding(query_text, openai_client)
print("Generated embeddings successfully")
results = collection.aggregate([
{
'$vectorSearch': {
"index": index_name,
"path": embedding_field,
"queryVector": embedding,
"numCandidates": 50,
"limit": 5
}
},
{
"$project": {
"search_score": { "$meta": "vectorSearchScore" },
"document": "$$ROOT"
}
}
])
# Format results
results_list = list(results)
formatted_results = []
for idx, result in enumerate(results_list, 1):
doc = result['document']
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
# Add all fields except _id and embeddings
for key, value in doc.items():
if key not in ['_id', embedding_field]:
formatted_result += f"{key}: {value}\n"
formatted_results.append(formatted_result)
return "\n".join(formatted_results) if formatted_results else "No results found"
except Exception as e:
return f"Error: {str(e)}"
# Create the tab UI
with gr.Tab("Search") as tab:
with gr.Row():
db_input = gr.Dropdown(
choices=databases,
label="Select Database",
info="Database containing the vectors"
)
collection_input = gr.Dropdown(
choices=[],
label="Select Collection",
info="Collection containing the vectors"
)
with gr.Row():
embedding_field_input = gr.Textbox(
label="Embedding Field Name",
value="embedding",
info="Field containing the vectors"
)
index_input = gr.Textbox(
label="Vector Search Index Name",
value="vector_index",
info="Index created in Atlas UI"
)
query_input = gr.Textbox(
label="Search Query",
lines=2,
placeholder="What would you like to search for?"
)
search_btn = gr.Button("Search")
search_output = gr.Textbox(label="Results", lines=10)
# Set up event handlers
db_input.change(
fn=update_collections,
inputs=[db_input],
outputs=[collection_input]
)
search_btn.click(
fn=vector_search,
inputs=[
query_input,
db_input,
collection_input,
embedding_field_input,
index_input
],
outputs=search_output
)
# Return the tab and its interface elements
interface = {
'db_input': db_input,
'collection_input': collection_input,
'embedding_field_input': embedding_field_input,
'index_input': index_input,
'query_input': query_input,
'search_btn': search_btn,
'search_output': search_output
}
return tab, interface
|