File size: 6,988 Bytes
46a6768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
from typing import Tuple, Optional, List
from openai import OpenAI
from utils.db_utils import DatabaseUtils
from utils.embedding_utils import parallel_generate_embeddings
def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
"""Create the embeddings generation tab UI
Args:
openai_client: OpenAI client instance
db_utils: DatabaseUtils instance
databases: List of available databases
Returns:
Tuple[gr.Tab, dict]: The tab component and its interface elements
"""
def update_collections(db_name: str) -> gr.Dropdown:
"""Update collections dropdown when database changes"""
collections = db_utils.get_collections(db_name)
# If there's only one collection, select it by default
value = collections[0] if len(collections) == 1 else None
return gr.Dropdown(choices=collections, value=value)
def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
"""Update fields dropdown when collection changes"""
if db_name and collection_name:
fields = db_utils.get_field_names(db_name, collection_name)
return gr.Dropdown(choices=fields)
return gr.Dropdown(choices=[])
def generate_embeddings(
db_name: str,
collection_name: str,
field_name: str,
embedding_field: str,
limit: int = 10,
progress=gr.Progress()
) -> Tuple[str, str]:
"""Generate embeddings for documents with progress tracking"""
try:
db = db_utils.client[db_name]
collection = db[collection_name]
# Count documents that need embeddings
total_docs = collection.count_documents({field_name: {"$exists": True}})
if total_docs == 0:
return f"No documents found with field '{field_name}'", ""
# Get total count of documents that need processing
query = {
field_name: {"$exists": True},
embedding_field: {"$exists": False} # Only get docs without embeddings
}
total_to_process = collection.count_documents(query)
if total_to_process == 0:
return "No documents found that need embeddings", ""
# Apply limit if specified
if limit > 0:
total_to_process = min(total_to_process, limit)
print(f"\nFound {total_to_process} documents that need embeddings...")
# Progress tracking
progress_text = ""
def update_progress(prog: float, processed: int, total: int):
nonlocal progress_text
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
print(progress_text) # Terminal logging
progress(prog/100, f"Processed {processed}/{total} documents")
# Show initial progress
update_progress(0, 0, total_to_process)
# Create cursor for batch processing
cursor = collection.find(query)
if limit > 0:
cursor = cursor.limit(limit)
# Generate embeddings in parallel with cursor-based batching
processed = parallel_generate_embeddings(
collection=collection,
cursor=cursor,
field_name=field_name,
embedding_field=embedding_field,
openai_client=openai_client,
total_docs=total_to_process,
callback=update_progress
)
# Return completion message and final progress
instructions = f"""
Successfully generated embeddings for {processed} documents using parallel processing!
To create the vector search index in MongoDB Atlas:
1. Go to your Atlas cluster
2. Click on 'Search' tab
3. Create an index named 'vector_index' with this configuration:
{{
"fields": [
{{
"type": "vector",
"path": "{embedding_field}",
"numDimensions": 1536,
"similarity": "dotProduct"
}}
]
}}
You can now use the search tab with:
- Field to search: {field_name}
- Embedding field: {embedding_field}
"""
return instructions, progress_text
except Exception as e:
return f"Error: {str(e)}", ""
# Create the tab UI
with gr.Tab("Generate Embeddings") as tab:
with gr.Row():
db_input = gr.Dropdown(
choices=databases,
label="Select Database",
info="Available databases in Atlas cluster"
)
collection_input = gr.Dropdown(
choices=[],
label="Select Collection",
info="Collections in selected database"
)
with gr.Row():
field_input = gr.Dropdown(
choices=[],
label="Select Field for Embeddings",
info="Fields available in collection"
)
embedding_field_input = gr.Textbox(
label="Embedding Field Name",
value="embedding",
info="Field name where embeddings will be stored"
)
limit_input = gr.Number(
label="Document Limit",
value=10,
minimum=0,
info="Number of documents to process (0 for all documents)"
)
generate_btn = gr.Button("Generate Embeddings")
generate_output = gr.Textbox(label="Results", lines=10)
progress_output = gr.Textbox(label="Progress", lines=3)
# Set up event handlers
db_input.change(
fn=update_collections,
inputs=[db_input],
outputs=[collection_input]
)
collection_input.change(
fn=update_fields,
inputs=[db_input, collection_input],
outputs=[field_input]
)
generate_btn.click(
fn=generate_embeddings,
inputs=[
db_input,
collection_input,
field_input,
embedding_field_input,
limit_input
],
outputs=[generate_output, progress_output]
)
# Return the tab and its interface elements
interface = {
'db_input': db_input,
'collection_input': collection_input,
'field_input': field_input,
'embedding_field_input': embedding_field_input,
'limit_input': limit_input,
'generate_btn': generate_btn,
'generate_output': generate_output,
'progress_output': progress_output
}
return tab, interface
|