Spaces:
Sleeping
Sleeping
Delete ui
Browse files- ui/app.py +0 -60
- ui/db_utils.py +0 -159
- ui/embedding_utils.py +0 -143
- ui/embeddings_tab.py +0 -192
- ui/list_collections.py +0 -38
- ui/list_db.py +0 -27
- ui/run.sh +0 -1
- ui/search_tab.py +0 -142
- ui/setup_data.py +0 -94
- ui/ui/__init__.py +0 -8
- ui/ui/__pycache__/embeddings_tab.cpython-312.pyc +0 -0
- ui/ui/__pycache__/search_tab.cpython-312.pyc +0 -0
- ui/ui/embeddings_tab.py +0 -192
- ui/ui/search_tab.py +0 -142
- ui/utils/__init__.py +0 -12
- ui/utils/credentials.py +0 -47
- ui/utils/db_utils.py +0 -159
- ui/utils/embedding_utils.py +0 -143
ui/app.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from utils.credentials import check_credentials, init_clients
|
3 |
-
from ui.embeddings_tab import create_embeddings_tab
|
4 |
-
from ui.search_tab import create_search_tab
|
5 |
-
|
6 |
-
def create_app():
|
7 |
-
"""Create and configure the Gradio application"""
|
8 |
-
with gr.Blocks(title="MongoDB Vector Search Tool") as iface:
|
9 |
-
gr.Markdown("# MongoDB Vector Search Tool")
|
10 |
-
|
11 |
-
# Check credentials first
|
12 |
-
has_creds, cred_message = check_credentials()
|
13 |
-
if not has_creds:
|
14 |
-
gr.Markdown(f"""
|
15 |
-
## ⚠️ Setup Required
|
16 |
-
|
17 |
-
{cred_message}
|
18 |
-
|
19 |
-
After setting up credentials, refresh this page.
|
20 |
-
""")
|
21 |
-
else:
|
22 |
-
# Initialize clients
|
23 |
-
openai_client, db_utils = init_clients()
|
24 |
-
if not openai_client or not db_utils:
|
25 |
-
gr.Markdown("""
|
26 |
-
## ⚠️ Connection Error
|
27 |
-
|
28 |
-
Failed to connect to MongoDB Atlas or OpenAI. Please check your credentials and try again.
|
29 |
-
""")
|
30 |
-
else:
|
31 |
-
# Get available databases
|
32 |
-
try:
|
33 |
-
databases = db_utils.get_databases()
|
34 |
-
except Exception as e:
|
35 |
-
gr.Markdown(f"""
|
36 |
-
## ⚠️ Database Error
|
37 |
-
|
38 |
-
Failed to list databases: {str(e)}
|
39 |
-
Please check your MongoDB Atlas connection and try again.
|
40 |
-
""")
|
41 |
-
databases = []
|
42 |
-
|
43 |
-
# Create tabs
|
44 |
-
embeddings_tab, embeddings_interface = create_embeddings_tab(
|
45 |
-
openai_client=openai_client,
|
46 |
-
db_utils=db_utils,
|
47 |
-
databases=databases
|
48 |
-
)
|
49 |
-
|
50 |
-
search_tab, search_interface = create_search_tab(
|
51 |
-
openai_client=openai_client,
|
52 |
-
db_utils=db_utils,
|
53 |
-
databases=databases
|
54 |
-
)
|
55 |
-
|
56 |
-
return iface
|
57 |
-
|
58 |
-
if __name__ == "__main__":
|
59 |
-
app = create_app()
|
60 |
-
app.launch(server_name="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/db_utils.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import List, Dict, Any, Optional
|
3 |
-
from pymongo import MongoClient
|
4 |
-
from pymongo.errors import (
|
5 |
-
ConnectionFailure,
|
6 |
-
OperationFailure,
|
7 |
-
ServerSelectionTimeoutError,
|
8 |
-
InvalidName
|
9 |
-
)
|
10 |
-
from dotenv import load_dotenv
|
11 |
-
|
12 |
-
class DatabaseError(Exception):
|
13 |
-
"""Base class for database operation errors"""
|
14 |
-
pass
|
15 |
-
|
16 |
-
class ConnectionError(DatabaseError):
|
17 |
-
"""Error when connecting to MongoDB Atlas"""
|
18 |
-
pass
|
19 |
-
|
20 |
-
class OperationError(DatabaseError):
|
21 |
-
"""Error during database operations"""
|
22 |
-
pass
|
23 |
-
|
24 |
-
class DatabaseUtils:
|
25 |
-
"""Utility class for MongoDB Atlas database operations
|
26 |
-
|
27 |
-
This class provides methods to interact with MongoDB Atlas databases and collections,
|
28 |
-
including listing databases, collections, and retrieving collection information.
|
29 |
-
|
30 |
-
Attributes:
|
31 |
-
atlas_uri (str): MongoDB Atlas connection string
|
32 |
-
client (MongoClient): MongoDB client instance
|
33 |
-
"""
|
34 |
-
|
35 |
-
def __init__(self):
|
36 |
-
"""Initialize DatabaseUtils with MongoDB Atlas connection
|
37 |
-
|
38 |
-
Raises:
|
39 |
-
ConnectionError: If unable to connect to MongoDB Atlas
|
40 |
-
ValueError: If ATLAS_URI environment variable is not set
|
41 |
-
"""
|
42 |
-
# Load environment variables
|
43 |
-
load_dotenv()
|
44 |
-
|
45 |
-
self.atlas_uri = os.getenv("ATLAS_URI")
|
46 |
-
if not self.atlas_uri:
|
47 |
-
raise ValueError("ATLAS_URI environment variable is not set")
|
48 |
-
|
49 |
-
try:
|
50 |
-
self.client = MongoClient(self.atlas_uri)
|
51 |
-
# Test connection
|
52 |
-
self.client.admin.command('ping')
|
53 |
-
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
54 |
-
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
|
55 |
-
|
56 |
-
def get_databases(self) -> List[str]:
|
57 |
-
"""Get list of all databases in Atlas cluster
|
58 |
-
|
59 |
-
Returns:
|
60 |
-
List[str]: List of database names
|
61 |
-
|
62 |
-
Raises:
|
63 |
-
OperationError: If unable to list databases
|
64 |
-
"""
|
65 |
-
try:
|
66 |
-
return self.client.list_database_names()
|
67 |
-
except OperationFailure as e:
|
68 |
-
raise OperationError(f"Failed to list databases: {str(e)}")
|
69 |
-
|
70 |
-
def get_collections(self, db_name: str) -> List[str]:
|
71 |
-
"""Get list of collections in a database
|
72 |
-
|
73 |
-
Args:
|
74 |
-
db_name (str): Name of the database
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
List[str]: List of collection names
|
78 |
-
|
79 |
-
Raises:
|
80 |
-
OperationError: If unable to list collections
|
81 |
-
ValueError: If db_name is empty or invalid
|
82 |
-
"""
|
83 |
-
if not db_name or not isinstance(db_name, str):
|
84 |
-
raise ValueError("Database name must be a non-empty string")
|
85 |
-
|
86 |
-
try:
|
87 |
-
db = self.client[db_name]
|
88 |
-
return db.list_collection_names()
|
89 |
-
except (OperationFailure, InvalidName) as e:
|
90 |
-
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
|
91 |
-
|
92 |
-
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
|
93 |
-
"""Get information about a collection including document count and sample document
|
94 |
-
|
95 |
-
Args:
|
96 |
-
db_name (str): Name of the database
|
97 |
-
collection_name (str): Name of the collection
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
Dict[str, Any]: Dictionary containing collection information:
|
101 |
-
- count: Number of documents in collection
|
102 |
-
- sample: Sample document from collection (if exists)
|
103 |
-
|
104 |
-
Raises:
|
105 |
-
OperationError: If unable to get collection information
|
106 |
-
ValueError: If db_name or collection_name is empty or invalid
|
107 |
-
"""
|
108 |
-
if not db_name or not isinstance(db_name, str):
|
109 |
-
raise ValueError("Database name must be a non-empty string")
|
110 |
-
if not collection_name or not isinstance(collection_name, str):
|
111 |
-
raise ValueError("Collection name must be a non-empty string")
|
112 |
-
|
113 |
-
try:
|
114 |
-
db = self.client[db_name]
|
115 |
-
collection = db[collection_name]
|
116 |
-
|
117 |
-
return {
|
118 |
-
'count': collection.count_documents({}),
|
119 |
-
'sample': collection.find_one()
|
120 |
-
}
|
121 |
-
except (OperationFailure, InvalidName) as e:
|
122 |
-
raise OperationError(
|
123 |
-
f"Failed to get info for collection '{collection_name}' "
|
124 |
-
f"in database '{db_name}': {str(e)}"
|
125 |
-
)
|
126 |
-
|
127 |
-
def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
|
128 |
-
"""Get list of fields in a collection based on sample document
|
129 |
-
|
130 |
-
Args:
|
131 |
-
db_name (str): Name of the database
|
132 |
-
collection_name (str): Name of the collection
|
133 |
-
|
134 |
-
Returns:
|
135 |
-
List[str]: List of field names (excluding _id and embedding fields)
|
136 |
-
|
137 |
-
Raises:
|
138 |
-
OperationError: If unable to get field names
|
139 |
-
ValueError: If db_name or collection_name is empty or invalid
|
140 |
-
"""
|
141 |
-
try:
|
142 |
-
info = self.get_collection_info(db_name, collection_name)
|
143 |
-
sample = info.get('sample', {})
|
144 |
-
|
145 |
-
if sample:
|
146 |
-
# Get all field names except _id and any existing embedding fields
|
147 |
-
return [field for field in sample.keys()
|
148 |
-
if field != '_id' and not field.endswith('_embedding')]
|
149 |
-
return []
|
150 |
-
except DatabaseError as e:
|
151 |
-
raise OperationError(
|
152 |
-
f"Failed to get field names for collection '{collection_name}' "
|
153 |
-
f"in database '{db_name}': {str(e)}"
|
154 |
-
)
|
155 |
-
|
156 |
-
def close(self):
|
157 |
-
"""Close MongoDB connection safely"""
|
158 |
-
if hasattr(self, 'client'):
|
159 |
-
self.client.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/embedding_utils.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
from typing import List, Tuple
|
2 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
-
from pymongo import UpdateOne
|
4 |
-
from pymongo.collection import Collection
|
5 |
-
import math
|
6 |
-
import time
|
7 |
-
import logging
|
8 |
-
|
9 |
-
# Configure logging
|
10 |
-
logging.basicConfig(level=logging.INFO)
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
|
14 |
-
"""Get embeddings for given text using OpenAI API with retry logic"""
|
15 |
-
text = text.replace("\n", " ")
|
16 |
-
|
17 |
-
for attempt in range(max_retries):
|
18 |
-
try:
|
19 |
-
resp = openai_client.embeddings.create(
|
20 |
-
input=[text],
|
21 |
-
model=model
|
22 |
-
)
|
23 |
-
return resp.data[0].embedding
|
24 |
-
except Exception as e:
|
25 |
-
if attempt == max_retries - 1:
|
26 |
-
raise
|
27 |
-
error_details = f"{type(e).__name__}: {str(e)}"
|
28 |
-
if hasattr(e, 'response'):
|
29 |
-
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
|
30 |
-
logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
|
31 |
-
time.sleep(2 ** attempt) # Exponential backoff
|
32 |
-
|
33 |
-
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
|
34 |
-
"""Process a batch of documents to generate embeddings"""
|
35 |
-
logger.info(f"Processing batch of {len(docs)} documents")
|
36 |
-
results = []
|
37 |
-
for doc in docs:
|
38 |
-
# Skip if embedding already exists
|
39 |
-
if embedding_field in doc:
|
40 |
-
continue
|
41 |
-
|
42 |
-
text = doc[field_name]
|
43 |
-
if isinstance(text, str):
|
44 |
-
embedding = get_embedding(text, openai_client)
|
45 |
-
results.append((doc["_id"], embedding))
|
46 |
-
return results
|
47 |
-
|
48 |
-
def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
|
49 |
-
"""Process completed futures and update progress"""
|
50 |
-
completed = 0
|
51 |
-
for future in as_completed(futures, timeout=30): # 30 second timeout
|
52 |
-
try:
|
53 |
-
results = future.result()
|
54 |
-
if results:
|
55 |
-
bulk_ops = [
|
56 |
-
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
|
57 |
-
for doc_id, embedding in results
|
58 |
-
]
|
59 |
-
if bulk_ops:
|
60 |
-
collection.bulk_write(bulk_ops)
|
61 |
-
completed += len(bulk_ops)
|
62 |
-
|
63 |
-
# Update progress
|
64 |
-
if callback:
|
65 |
-
progress = ((processed + completed) / total_docs) * 100
|
66 |
-
callback(progress, processed + completed, total_docs)
|
67 |
-
except Exception as e:
|
68 |
-
error_details = f"{type(e).__name__}: {str(e)}"
|
69 |
-
if hasattr(e, 'response'):
|
70 |
-
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
|
71 |
-
logger.error(f"Error processing future:\n{error_details}")
|
72 |
-
return completed
|
73 |
-
|
74 |
-
def parallel_generate_embeddings(
|
75 |
-
collection: Collection,
|
76 |
-
cursor,
|
77 |
-
field_name: str,
|
78 |
-
embedding_field: str,
|
79 |
-
openai_client,
|
80 |
-
total_docs: int,
|
81 |
-
batch_size: int = 10, # Reduced initial batch size
|
82 |
-
callback=None
|
83 |
-
) -> int:
|
84 |
-
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
|
85 |
-
if total_docs == 0:
|
86 |
-
return 0
|
87 |
-
|
88 |
-
processed = 0
|
89 |
-
current_batch_size = batch_size
|
90 |
-
max_workers = 10 # Start with fewer workers
|
91 |
-
|
92 |
-
logger.info(f"Starting embedding generation for {total_docs} documents")
|
93 |
-
if callback:
|
94 |
-
callback(0, 0, total_docs)
|
95 |
-
|
96 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
97 |
-
batch = []
|
98 |
-
futures = []
|
99 |
-
|
100 |
-
for doc in cursor:
|
101 |
-
batch.append(doc)
|
102 |
-
|
103 |
-
if len(batch) >= current_batch_size:
|
104 |
-
logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
|
105 |
-
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
|
106 |
-
futures.append(future)
|
107 |
-
batch = []
|
108 |
-
|
109 |
-
# Process completed futures more frequently
|
110 |
-
if len(futures) >= max_workers:
|
111 |
-
try:
|
112 |
-
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
|
113 |
-
processed += completed
|
114 |
-
futures = [] # Clear processed futures
|
115 |
-
|
116 |
-
# Gradually increase batch size and workers if processing is successful
|
117 |
-
if completed > 0:
|
118 |
-
current_batch_size = min(current_batch_size + 5, 30)
|
119 |
-
max_workers = min(max_workers + 2, 20)
|
120 |
-
logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
|
121 |
-
except Exception as e:
|
122 |
-
logger.error(f"Error processing futures: {str(e)}")
|
123 |
-
# Reduce batch size and workers on error
|
124 |
-
current_batch_size = max(5, current_batch_size - 5)
|
125 |
-
max_workers = max(5, max_workers - 2)
|
126 |
-
logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
|
127 |
-
|
128 |
-
# Process remaining batch
|
129 |
-
if batch:
|
130 |
-
logger.info(f"Processing final batch of {len(batch)} documents")
|
131 |
-
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
|
132 |
-
futures.append(future)
|
133 |
-
|
134 |
-
# Process remaining futures
|
135 |
-
if futures:
|
136 |
-
try:
|
137 |
-
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
|
138 |
-
processed += completed
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Error processing final futures: {str(e)}")
|
141 |
-
|
142 |
-
logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
|
143 |
-
return processed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/embeddings_tab.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from typing import Tuple, Optional, List
|
3 |
-
from openai import OpenAI
|
4 |
-
from utils.db_utils import DatabaseUtils
|
5 |
-
from utils.embedding_utils import parallel_generate_embeddings
|
6 |
-
|
7 |
-
def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
|
8 |
-
"""Create the embeddings generation tab UI
|
9 |
-
|
10 |
-
Args:
|
11 |
-
openai_client: OpenAI client instance
|
12 |
-
db_utils: DatabaseUtils instance
|
13 |
-
databases: List of available databases
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
Tuple[gr.Tab, dict]: The tab component and its interface elements
|
17 |
-
"""
|
18 |
-
def update_collections(db_name: str) -> gr.Dropdown:
|
19 |
-
"""Update collections dropdown when database changes"""
|
20 |
-
collections = db_utils.get_collections(db_name)
|
21 |
-
# If there's only one collection, select it by default
|
22 |
-
value = collections[0] if len(collections) == 1 else None
|
23 |
-
return gr.Dropdown(choices=collections, value=value)
|
24 |
-
|
25 |
-
def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
|
26 |
-
"""Update fields dropdown when collection changes"""
|
27 |
-
if db_name and collection_name:
|
28 |
-
fields = db_utils.get_field_names(db_name, collection_name)
|
29 |
-
return gr.Dropdown(choices=fields)
|
30 |
-
return gr.Dropdown(choices=[])
|
31 |
-
|
32 |
-
def generate_embeddings(
|
33 |
-
db_name: str,
|
34 |
-
collection_name: str,
|
35 |
-
field_name: str,
|
36 |
-
embedding_field: str,
|
37 |
-
limit: int = 10,
|
38 |
-
progress=gr.Progress()
|
39 |
-
) -> Tuple[str, str]:
|
40 |
-
"""Generate embeddings for documents with progress tracking"""
|
41 |
-
try:
|
42 |
-
db = db_utils.client[db_name]
|
43 |
-
collection = db[collection_name]
|
44 |
-
|
45 |
-
# Count documents that need embeddings
|
46 |
-
total_docs = collection.count_documents({field_name: {"$exists": True}})
|
47 |
-
if total_docs == 0:
|
48 |
-
return f"No documents found with field '{field_name}'", ""
|
49 |
-
|
50 |
-
# Get total count of documents that need processing
|
51 |
-
query = {
|
52 |
-
field_name: {"$exists": True},
|
53 |
-
embedding_field: {"$exists": False} # Only get docs without embeddings
|
54 |
-
}
|
55 |
-
total_to_process = collection.count_documents(query)
|
56 |
-
if total_to_process == 0:
|
57 |
-
return "No documents found that need embeddings", ""
|
58 |
-
|
59 |
-
# Apply limit if specified
|
60 |
-
if limit > 0:
|
61 |
-
total_to_process = min(total_to_process, limit)
|
62 |
-
|
63 |
-
print(f"\nFound {total_to_process} documents that need embeddings...")
|
64 |
-
|
65 |
-
# Progress tracking
|
66 |
-
progress_text = ""
|
67 |
-
def update_progress(prog: float, processed: int, total: int):
|
68 |
-
nonlocal progress_text
|
69 |
-
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
|
70 |
-
print(progress_text) # Terminal logging
|
71 |
-
progress(prog/100, f"Processed {processed}/{total} documents")
|
72 |
-
|
73 |
-
# Show initial progress
|
74 |
-
update_progress(0, 0, total_to_process)
|
75 |
-
|
76 |
-
# Create cursor for batch processing
|
77 |
-
cursor = collection.find(query)
|
78 |
-
if limit > 0:
|
79 |
-
cursor = cursor.limit(limit)
|
80 |
-
|
81 |
-
# Generate embeddings in parallel with cursor-based batching
|
82 |
-
processed = parallel_generate_embeddings(
|
83 |
-
collection=collection,
|
84 |
-
cursor=cursor,
|
85 |
-
field_name=field_name,
|
86 |
-
embedding_field=embedding_field,
|
87 |
-
openai_client=openai_client,
|
88 |
-
total_docs=total_to_process,
|
89 |
-
callback=update_progress
|
90 |
-
)
|
91 |
-
|
92 |
-
# Return completion message and final progress
|
93 |
-
instructions = f"""
|
94 |
-
Successfully generated embeddings for {processed} documents using parallel processing!
|
95 |
-
|
96 |
-
To create the vector search index in MongoDB Atlas:
|
97 |
-
1. Go to your Atlas cluster
|
98 |
-
2. Click on 'Search' tab
|
99 |
-
3. Create an index named 'vector_index' with this configuration:
|
100 |
-
{{
|
101 |
-
"fields": [
|
102 |
-
{{
|
103 |
-
"type": "vector",
|
104 |
-
"path": "{embedding_field}",
|
105 |
-
"numDimensions": 1536,
|
106 |
-
"similarity": "dotProduct"
|
107 |
-
}}
|
108 |
-
]
|
109 |
-
}}
|
110 |
-
|
111 |
-
You can now use the search tab with:
|
112 |
-
- Field to search: {field_name}
|
113 |
-
- Embedding field: {embedding_field}
|
114 |
-
"""
|
115 |
-
return instructions, progress_text
|
116 |
-
|
117 |
-
except Exception as e:
|
118 |
-
return f"Error: {str(e)}", ""
|
119 |
-
|
120 |
-
# Create the tab UI
|
121 |
-
with gr.Tab("Generate Embeddings") as tab:
|
122 |
-
with gr.Row():
|
123 |
-
db_input = gr.Dropdown(
|
124 |
-
choices=databases,
|
125 |
-
label="Select Database",
|
126 |
-
info="Available databases in Atlas cluster"
|
127 |
-
)
|
128 |
-
collection_input = gr.Dropdown(
|
129 |
-
choices=[],
|
130 |
-
label="Select Collection",
|
131 |
-
info="Collections in selected database"
|
132 |
-
)
|
133 |
-
with gr.Row():
|
134 |
-
field_input = gr.Dropdown(
|
135 |
-
choices=[],
|
136 |
-
label="Select Field for Embeddings",
|
137 |
-
info="Fields available in collection"
|
138 |
-
)
|
139 |
-
embedding_field_input = gr.Textbox(
|
140 |
-
label="Embedding Field Name",
|
141 |
-
value="embedding",
|
142 |
-
info="Field name where embeddings will be stored"
|
143 |
-
)
|
144 |
-
limit_input = gr.Number(
|
145 |
-
label="Document Limit",
|
146 |
-
value=10,
|
147 |
-
minimum=0,
|
148 |
-
info="Number of documents to process (0 for all documents)"
|
149 |
-
)
|
150 |
-
|
151 |
-
generate_btn = gr.Button("Generate Embeddings")
|
152 |
-
generate_output = gr.Textbox(label="Results", lines=10)
|
153 |
-
progress_output = gr.Textbox(label="Progress", lines=3)
|
154 |
-
|
155 |
-
# Set up event handlers
|
156 |
-
db_input.change(
|
157 |
-
fn=update_collections,
|
158 |
-
inputs=[db_input],
|
159 |
-
outputs=[collection_input]
|
160 |
-
)
|
161 |
-
|
162 |
-
collection_input.change(
|
163 |
-
fn=update_fields,
|
164 |
-
inputs=[db_input, collection_input],
|
165 |
-
outputs=[field_input]
|
166 |
-
)
|
167 |
-
|
168 |
-
generate_btn.click(
|
169 |
-
fn=generate_embeddings,
|
170 |
-
inputs=[
|
171 |
-
db_input,
|
172 |
-
collection_input,
|
173 |
-
field_input,
|
174 |
-
embedding_field_input,
|
175 |
-
limit_input
|
176 |
-
],
|
177 |
-
outputs=[generate_output, progress_output]
|
178 |
-
)
|
179 |
-
|
180 |
-
# Return the tab and its interface elements
|
181 |
-
interface = {
|
182 |
-
'db_input': db_input,
|
183 |
-
'collection_input': collection_input,
|
184 |
-
'field_input': field_input,
|
185 |
-
'embedding_field_input': embedding_field_input,
|
186 |
-
'limit_input': limit_input,
|
187 |
-
'generate_btn': generate_btn,
|
188 |
-
'generate_output': generate_output,
|
189 |
-
'progress_output': progress_output
|
190 |
-
}
|
191 |
-
|
192 |
-
return tab, interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/list_collections.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pymongo import MongoClient
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
|
5 |
-
# Load environment variables
|
6 |
-
load_dotenv()
|
7 |
-
|
8 |
-
# Initialize MongoDB client
|
9 |
-
atlas_uri = os.getenv("ATLAS_URI")
|
10 |
-
client = MongoClient(atlas_uri)
|
11 |
-
|
12 |
-
def list_all_collections():
|
13 |
-
"""List all databases and their collections in the Atlas cluster"""
|
14 |
-
try:
|
15 |
-
# Get all database names
|
16 |
-
db_names = client.list_database_names()
|
17 |
-
|
18 |
-
print("\nDatabases and Collections in your Atlas cluster:\n")
|
19 |
-
|
20 |
-
# For each database, get and print collections
|
21 |
-
for db_name in db_names:
|
22 |
-
print(f"Database: {db_name}")
|
23 |
-
db = client[db_name]
|
24 |
-
collections = db.list_collection_names()
|
25 |
-
|
26 |
-
for collection in collections:
|
27 |
-
# Get count of documents in collection
|
28 |
-
count = db[collection].count_documents({})
|
29 |
-
print(f" └── Collection: {collection} ({count} documents)")
|
30 |
-
print()
|
31 |
-
|
32 |
-
except Exception as e:
|
33 |
-
print(f"Error: {str(e)}")
|
34 |
-
finally:
|
35 |
-
client.close()
|
36 |
-
|
37 |
-
if __name__ == "__main__":
|
38 |
-
list_all_collections()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/list_db.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from db_utils import DatabaseUtils
|
2 |
-
|
3 |
-
def main():
|
4 |
-
db_utils = DatabaseUtils()
|
5 |
-
try:
|
6 |
-
print("\nDatabases and Collections in your Atlas cluster:\n")
|
7 |
-
|
8 |
-
# Get all databases
|
9 |
-
databases = db_utils.get_databases()
|
10 |
-
|
11 |
-
# For each database, show collections and counts
|
12 |
-
for db_name in databases:
|
13 |
-
print(f"Database: {db_name}")
|
14 |
-
collections = db_utils.get_collections(db_name)
|
15 |
-
|
16 |
-
for coll_name in collections:
|
17 |
-
info = db_utils.get_collection_info(db_name, coll_name)
|
18 |
-
print(f" └── Collection: {coll_name} ({info['count']} documents)")
|
19 |
-
print()
|
20 |
-
|
21 |
-
except Exception as e:
|
22 |
-
print(f"Error: {str(e)}")
|
23 |
-
finally:
|
24 |
-
db_utils.close()
|
25 |
-
|
26 |
-
if __name__ == "__main__":
|
27 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/run.sh
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
python app.py
|
|
|
|
ui/search_tab.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from typing import Tuple, List
|
3 |
-
from openai import OpenAI
|
4 |
-
from utils.db_utils import DatabaseUtils
|
5 |
-
from utils.embedding_utils import get_embedding
|
6 |
-
|
7 |
-
def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
|
8 |
-
"""Create the vector search tab UI
|
9 |
-
|
10 |
-
Args:
|
11 |
-
openai_client: OpenAI client instance
|
12 |
-
db_utils: DatabaseUtils instance
|
13 |
-
databases: List of available databases
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
Tuple[gr.Tab, dict]: The tab component and its interface elements
|
17 |
-
"""
|
18 |
-
def update_collections(db_name: str) -> gr.Dropdown:
|
19 |
-
"""Update collections dropdown when database changes"""
|
20 |
-
collections = db_utils.get_collections(db_name)
|
21 |
-
# If there's only one collection, select it by default
|
22 |
-
value = collections[0] if len(collections) == 1 else None
|
23 |
-
return gr.Dropdown(choices=collections, value=value)
|
24 |
-
|
25 |
-
def vector_search(
|
26 |
-
query_text: str,
|
27 |
-
db_name: str,
|
28 |
-
collection_name: str,
|
29 |
-
embedding_field: str,
|
30 |
-
index_name: str
|
31 |
-
) -> str:
|
32 |
-
"""Perform vector search using embeddings"""
|
33 |
-
try:
|
34 |
-
print(f"\nProcessing query: {query_text}")
|
35 |
-
|
36 |
-
db = db_utils.client[db_name]
|
37 |
-
collection = db[collection_name]
|
38 |
-
|
39 |
-
# Get embeddings for query
|
40 |
-
embedding = get_embedding(query_text, openai_client)
|
41 |
-
print("Generated embeddings successfully")
|
42 |
-
|
43 |
-
results = collection.aggregate([
|
44 |
-
{
|
45 |
-
'$vectorSearch': {
|
46 |
-
"index": index_name,
|
47 |
-
"path": embedding_field,
|
48 |
-
"queryVector": embedding,
|
49 |
-
"numCandidates": 50,
|
50 |
-
"limit": 5
|
51 |
-
}
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"$project": {
|
55 |
-
"search_score": { "$meta": "vectorSearchScore" },
|
56 |
-
"document": "$$ROOT"
|
57 |
-
}
|
58 |
-
}
|
59 |
-
])
|
60 |
-
|
61 |
-
# Format results
|
62 |
-
results_list = list(results)
|
63 |
-
formatted_results = []
|
64 |
-
|
65 |
-
for idx, result in enumerate(results_list, 1):
|
66 |
-
doc = result['document']
|
67 |
-
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
|
68 |
-
# Add all fields except _id and embeddings
|
69 |
-
for key, value in doc.items():
|
70 |
-
if key not in ['_id', embedding_field]:
|
71 |
-
formatted_result += f"{key}: {value}\n"
|
72 |
-
formatted_results.append(formatted_result)
|
73 |
-
|
74 |
-
return "\n".join(formatted_results) if formatted_results else "No results found"
|
75 |
-
|
76 |
-
except Exception as e:
|
77 |
-
return f"Error: {str(e)}"
|
78 |
-
|
79 |
-
# Create the tab UI
|
80 |
-
with gr.Tab("Search") as tab:
|
81 |
-
with gr.Row():
|
82 |
-
db_input = gr.Dropdown(
|
83 |
-
choices=databases,
|
84 |
-
label="Select Database",
|
85 |
-
info="Database containing the vectors"
|
86 |
-
)
|
87 |
-
collection_input = gr.Dropdown(
|
88 |
-
choices=[],
|
89 |
-
label="Select Collection",
|
90 |
-
info="Collection containing the vectors"
|
91 |
-
)
|
92 |
-
with gr.Row():
|
93 |
-
embedding_field_input = gr.Textbox(
|
94 |
-
label="Embedding Field Name",
|
95 |
-
value="embedding",
|
96 |
-
info="Field containing the vectors"
|
97 |
-
)
|
98 |
-
index_input = gr.Textbox(
|
99 |
-
label="Vector Search Index Name",
|
100 |
-
value="vector_index",
|
101 |
-
info="Index created in Atlas UI"
|
102 |
-
)
|
103 |
-
|
104 |
-
query_input = gr.Textbox(
|
105 |
-
label="Search Query",
|
106 |
-
lines=2,
|
107 |
-
placeholder="What would you like to search for?"
|
108 |
-
)
|
109 |
-
search_btn = gr.Button("Search")
|
110 |
-
search_output = gr.Textbox(label="Results", lines=10)
|
111 |
-
|
112 |
-
# Set up event handlers
|
113 |
-
db_input.change(
|
114 |
-
fn=update_collections,
|
115 |
-
inputs=[db_input],
|
116 |
-
outputs=[collection_input]
|
117 |
-
)
|
118 |
-
|
119 |
-
search_btn.click(
|
120 |
-
fn=vector_search,
|
121 |
-
inputs=[
|
122 |
-
query_input,
|
123 |
-
db_input,
|
124 |
-
collection_input,
|
125 |
-
embedding_field_input,
|
126 |
-
index_input
|
127 |
-
],
|
128 |
-
outputs=search_output
|
129 |
-
)
|
130 |
-
|
131 |
-
# Return the tab and its interface elements
|
132 |
-
interface = {
|
133 |
-
'db_input': db_input,
|
134 |
-
'collection_input': collection_input,
|
135 |
-
'embedding_field_input': embedding_field_input,
|
136 |
-
'index_input': index_input,
|
137 |
-
'query_input': query_input,
|
138 |
-
'search_btn': search_btn,
|
139 |
-
'search_output': search_output
|
140 |
-
}
|
141 |
-
|
142 |
-
return tab, interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/setup_data.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pymongo import MongoClient
|
3 |
-
from openai import OpenAI
|
4 |
-
from dotenv import load_dotenv
|
5 |
-
|
6 |
-
# Load environment variables
|
7 |
-
load_dotenv()
|
8 |
-
|
9 |
-
# Initialize clients
|
10 |
-
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
11 |
-
atlas_uri = os.getenv("ATLAS_URI")
|
12 |
-
client = MongoClient(atlas_uri)
|
13 |
-
db = client['sample_mflix']
|
14 |
-
collection = db['embedded_movies']
|
15 |
-
|
16 |
-
# Sample movie data
|
17 |
-
sample_movies = [
|
18 |
-
{
|
19 |
-
"title": "The Matrix",
|
20 |
-
"year": 1999,
|
21 |
-
"plot": "A computer programmer discovers that reality as he knows it is a simulation created by machines, and joins a rebellion to overthrow them."
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"title": "Inception",
|
25 |
-
"year": 2010,
|
26 |
-
"plot": "A thief who enters the dreams of others to steal secrets from their subconscious is offered a chance to regain his old life in exchange for a task considered impossible: inception."
|
27 |
-
},
|
28 |
-
{
|
29 |
-
"title": "The Shawshank Redemption",
|
30 |
-
"year": 1994,
|
31 |
-
"plot": "Two imprisoned men bond over a number of years, finding solace and eventual redemption through acts of common decency."
|
32 |
-
},
|
33 |
-
{
|
34 |
-
"title": "Jurassic Park",
|
35 |
-
"year": 1993,
|
36 |
-
"plot": "A pragmatic paleontologist visiting an almost complete theme park is tasked with protecting a couple of kids after a power failure causes the park's cloned dinosaurs to run loose."
|
37 |
-
},
|
38 |
-
{
|
39 |
-
"title": "The Lord of the Rings: The Fellowship of the Ring",
|
40 |
-
"year": 2001,
|
41 |
-
"plot": "A young hobbit, Frodo, who has found the One Ring that belongs to the Dark Lord Sauron, begins his journey with eight companions to Mount Doom, the only place where it can be destroyed."
|
42 |
-
}
|
43 |
-
]
|
44 |
-
|
45 |
-
def get_embedding(text: str, model="text-embedding-ada-002") -> list[float]:
|
46 |
-
"""Get embeddings for given text using OpenAI API"""
|
47 |
-
text = text.replace("\n", " ")
|
48 |
-
resp = openai_client.embeddings.create(
|
49 |
-
input=[text],
|
50 |
-
model=model
|
51 |
-
)
|
52 |
-
return resp.data[0].embedding
|
53 |
-
|
54 |
-
def setup_data():
|
55 |
-
try:
|
56 |
-
# Drop existing collection if it exists
|
57 |
-
collection.drop()
|
58 |
-
print("Dropped existing collection")
|
59 |
-
|
60 |
-
# Add embeddings to movies and insert into collection
|
61 |
-
for movie in sample_movies:
|
62 |
-
# Generate embedding for plot
|
63 |
-
embedding = get_embedding(movie["plot"])
|
64 |
-
movie["plot_embedding"] = embedding
|
65 |
-
|
66 |
-
# Insert movie with embedding
|
67 |
-
collection.insert_one(movie)
|
68 |
-
print(f"Inserted movie: {movie['title']}")
|
69 |
-
|
70 |
-
print("\nData setup completed successfully!")
|
71 |
-
print("\nIMPORTANT: You need to create the vector search index manually in the Atlas UI:")
|
72 |
-
print("1. Go to your Atlas cluster")
|
73 |
-
print("2. Click on 'Search' tab")
|
74 |
-
print("3. Create an index named 'idx_plot_embedding' with this definition:")
|
75 |
-
print("""
|
76 |
-
{
|
77 |
-
"fields": [
|
78 |
-
{
|
79 |
-
"type": "vector",
|
80 |
-
"path": "plot_embedding",
|
81 |
-
"numDimensions": 1536,
|
82 |
-
"similarity": "dotProduct"
|
83 |
-
}
|
84 |
-
]
|
85 |
-
}
|
86 |
-
""")
|
87 |
-
|
88 |
-
except Exception as e:
|
89 |
-
print(f"Error during setup: {str(e)}")
|
90 |
-
finally:
|
91 |
-
client.close()
|
92 |
-
|
93 |
-
if __name__ == "__main__":
|
94 |
-
setup_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/ui/__init__.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# UI package for MongoDB Vector Search Tool
|
2 |
-
from ui.embeddings_tab import create_embeddings_tab
|
3 |
-
from ui.search_tab import create_search_tab
|
4 |
-
|
5 |
-
__all__ = [
|
6 |
-
'create_embeddings_tab',
|
7 |
-
'create_search_tab'
|
8 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/ui/__pycache__/embeddings_tab.cpython-312.pyc
DELETED
Binary file (6.98 kB)
|
|
ui/ui/__pycache__/search_tab.cpython-312.pyc
DELETED
Binary file (5.06 kB)
|
|
ui/ui/embeddings_tab.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from typing import Tuple, Optional, List
|
3 |
-
from openai import OpenAI
|
4 |
-
from utils.db_utils import DatabaseUtils
|
5 |
-
from utils.embedding_utils import parallel_generate_embeddings
|
6 |
-
|
7 |
-
def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
|
8 |
-
"""Create the embeddings generation tab UI
|
9 |
-
|
10 |
-
Args:
|
11 |
-
openai_client: OpenAI client instance
|
12 |
-
db_utils: DatabaseUtils instance
|
13 |
-
databases: List of available databases
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
Tuple[gr.Tab, dict]: The tab component and its interface elements
|
17 |
-
"""
|
18 |
-
def update_collections(db_name: str) -> gr.Dropdown:
|
19 |
-
"""Update collections dropdown when database changes"""
|
20 |
-
collections = db_utils.get_collections(db_name)
|
21 |
-
# If there's only one collection, select it by default
|
22 |
-
value = collections[0] if len(collections) == 1 else None
|
23 |
-
return gr.Dropdown(choices=collections, value=value)
|
24 |
-
|
25 |
-
def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
|
26 |
-
"""Update fields dropdown when collection changes"""
|
27 |
-
if db_name and collection_name:
|
28 |
-
fields = db_utils.get_field_names(db_name, collection_name)
|
29 |
-
return gr.Dropdown(choices=fields)
|
30 |
-
return gr.Dropdown(choices=[])
|
31 |
-
|
32 |
-
def generate_embeddings(
|
33 |
-
db_name: str,
|
34 |
-
collection_name: str,
|
35 |
-
field_name: str,
|
36 |
-
embedding_field: str,
|
37 |
-
limit: int = 10,
|
38 |
-
progress=gr.Progress()
|
39 |
-
) -> Tuple[str, str]:
|
40 |
-
"""Generate embeddings for documents with progress tracking"""
|
41 |
-
try:
|
42 |
-
db = db_utils.client[db_name]
|
43 |
-
collection = db[collection_name]
|
44 |
-
|
45 |
-
# Count documents that need embeddings
|
46 |
-
total_docs = collection.count_documents({field_name: {"$exists": True}})
|
47 |
-
if total_docs == 0:
|
48 |
-
return f"No documents found with field '{field_name}'", ""
|
49 |
-
|
50 |
-
# Get total count of documents that need processing
|
51 |
-
query = {
|
52 |
-
field_name: {"$exists": True},
|
53 |
-
embedding_field: {"$exists": False} # Only get docs without embeddings
|
54 |
-
}
|
55 |
-
total_to_process = collection.count_documents(query)
|
56 |
-
if total_to_process == 0:
|
57 |
-
return "No documents found that need embeddings", ""
|
58 |
-
|
59 |
-
# Apply limit if specified
|
60 |
-
if limit > 0:
|
61 |
-
total_to_process = min(total_to_process, limit)
|
62 |
-
|
63 |
-
print(f"\nFound {total_to_process} documents that need embeddings...")
|
64 |
-
|
65 |
-
# Progress tracking
|
66 |
-
progress_text = ""
|
67 |
-
def update_progress(prog: float, processed: int, total: int):
|
68 |
-
nonlocal progress_text
|
69 |
-
progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
|
70 |
-
print(progress_text) # Terminal logging
|
71 |
-
progress(prog/100, f"Processed {processed}/{total} documents")
|
72 |
-
|
73 |
-
# Show initial progress
|
74 |
-
update_progress(0, 0, total_to_process)
|
75 |
-
|
76 |
-
# Create cursor for batch processing
|
77 |
-
cursor = collection.find(query)
|
78 |
-
if limit > 0:
|
79 |
-
cursor = cursor.limit(limit)
|
80 |
-
|
81 |
-
# Generate embeddings in parallel with cursor-based batching
|
82 |
-
processed = parallel_generate_embeddings(
|
83 |
-
collection=collection,
|
84 |
-
cursor=cursor,
|
85 |
-
field_name=field_name,
|
86 |
-
embedding_field=embedding_field,
|
87 |
-
openai_client=openai_client,
|
88 |
-
total_docs=total_to_process,
|
89 |
-
callback=update_progress
|
90 |
-
)
|
91 |
-
|
92 |
-
# Return completion message and final progress
|
93 |
-
instructions = f"""
|
94 |
-
Successfully generated embeddings for {processed} documents using parallel processing!
|
95 |
-
|
96 |
-
To create the vector search index in MongoDB Atlas:
|
97 |
-
1. Go to your Atlas cluster
|
98 |
-
2. Click on 'Search' tab
|
99 |
-
3. Create an index named 'vector_index' with this configuration:
|
100 |
-
{{
|
101 |
-
"fields": [
|
102 |
-
{{
|
103 |
-
"type": "vector",
|
104 |
-
"path": "{embedding_field}",
|
105 |
-
"numDimensions": 1536,
|
106 |
-
"similarity": "dotProduct"
|
107 |
-
}}
|
108 |
-
]
|
109 |
-
}}
|
110 |
-
|
111 |
-
You can now use the search tab with:
|
112 |
-
- Field to search: {field_name}
|
113 |
-
- Embedding field: {embedding_field}
|
114 |
-
"""
|
115 |
-
return instructions, progress_text
|
116 |
-
|
117 |
-
except Exception as e:
|
118 |
-
return f"Error: {str(e)}", ""
|
119 |
-
|
120 |
-
# Create the tab UI
|
121 |
-
with gr.Tab("Generate Embeddings") as tab:
|
122 |
-
with gr.Row():
|
123 |
-
db_input = gr.Dropdown(
|
124 |
-
choices=databases,
|
125 |
-
label="Select Database",
|
126 |
-
info="Available databases in Atlas cluster"
|
127 |
-
)
|
128 |
-
collection_input = gr.Dropdown(
|
129 |
-
choices=[],
|
130 |
-
label="Select Collection",
|
131 |
-
info="Collections in selected database"
|
132 |
-
)
|
133 |
-
with gr.Row():
|
134 |
-
field_input = gr.Dropdown(
|
135 |
-
choices=[],
|
136 |
-
label="Select Field for Embeddings",
|
137 |
-
info="Fields available in collection"
|
138 |
-
)
|
139 |
-
embedding_field_input = gr.Textbox(
|
140 |
-
label="Embedding Field Name",
|
141 |
-
value="embedding",
|
142 |
-
info="Field name where embeddings will be stored"
|
143 |
-
)
|
144 |
-
limit_input = gr.Number(
|
145 |
-
label="Document Limit",
|
146 |
-
value=10,
|
147 |
-
minimum=0,
|
148 |
-
info="Number of documents to process (0 for all documents)"
|
149 |
-
)
|
150 |
-
|
151 |
-
generate_btn = gr.Button("Generate Embeddings")
|
152 |
-
generate_output = gr.Textbox(label="Results", lines=10)
|
153 |
-
progress_output = gr.Textbox(label="Progress", lines=3)
|
154 |
-
|
155 |
-
# Set up event handlers
|
156 |
-
db_input.change(
|
157 |
-
fn=update_collections,
|
158 |
-
inputs=[db_input],
|
159 |
-
outputs=[collection_input]
|
160 |
-
)
|
161 |
-
|
162 |
-
collection_input.change(
|
163 |
-
fn=update_fields,
|
164 |
-
inputs=[db_input, collection_input],
|
165 |
-
outputs=[field_input]
|
166 |
-
)
|
167 |
-
|
168 |
-
generate_btn.click(
|
169 |
-
fn=generate_embeddings,
|
170 |
-
inputs=[
|
171 |
-
db_input,
|
172 |
-
collection_input,
|
173 |
-
field_input,
|
174 |
-
embedding_field_input,
|
175 |
-
limit_input
|
176 |
-
],
|
177 |
-
outputs=[generate_output, progress_output]
|
178 |
-
)
|
179 |
-
|
180 |
-
# Return the tab and its interface elements
|
181 |
-
interface = {
|
182 |
-
'db_input': db_input,
|
183 |
-
'collection_input': collection_input,
|
184 |
-
'field_input': field_input,
|
185 |
-
'embedding_field_input': embedding_field_input,
|
186 |
-
'limit_input': limit_input,
|
187 |
-
'generate_btn': generate_btn,
|
188 |
-
'generate_output': generate_output,
|
189 |
-
'progress_output': progress_output
|
190 |
-
}
|
191 |
-
|
192 |
-
return tab, interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/ui/search_tab.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from typing import Tuple, List
|
3 |
-
from openai import OpenAI
|
4 |
-
from utils.db_utils import DatabaseUtils
|
5 |
-
from utils.embedding_utils import get_embedding
|
6 |
-
|
7 |
-
def create_search_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
|
8 |
-
"""Create the vector search tab UI
|
9 |
-
|
10 |
-
Args:
|
11 |
-
openai_client: OpenAI client instance
|
12 |
-
db_utils: DatabaseUtils instance
|
13 |
-
databases: List of available databases
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
Tuple[gr.Tab, dict]: The tab component and its interface elements
|
17 |
-
"""
|
18 |
-
def update_collections(db_name: str) -> gr.Dropdown:
|
19 |
-
"""Update collections dropdown when database changes"""
|
20 |
-
collections = db_utils.get_collections(db_name)
|
21 |
-
# If there's only one collection, select it by default
|
22 |
-
value = collections[0] if len(collections) == 1 else None
|
23 |
-
return gr.Dropdown(choices=collections, value=value)
|
24 |
-
|
25 |
-
def vector_search(
|
26 |
-
query_text: str,
|
27 |
-
db_name: str,
|
28 |
-
collection_name: str,
|
29 |
-
embedding_field: str,
|
30 |
-
index_name: str
|
31 |
-
) -> str:
|
32 |
-
"""Perform vector search using embeddings"""
|
33 |
-
try:
|
34 |
-
print(f"\nProcessing query: {query_text}")
|
35 |
-
|
36 |
-
db = db_utils.client[db_name]
|
37 |
-
collection = db[collection_name]
|
38 |
-
|
39 |
-
# Get embeddings for query
|
40 |
-
embedding = get_embedding(query_text, openai_client)
|
41 |
-
print("Generated embeddings successfully")
|
42 |
-
|
43 |
-
results = collection.aggregate([
|
44 |
-
{
|
45 |
-
'$vectorSearch': {
|
46 |
-
"index": index_name,
|
47 |
-
"path": embedding_field,
|
48 |
-
"queryVector": embedding,
|
49 |
-
"numCandidates": 50,
|
50 |
-
"limit": 5
|
51 |
-
}
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"$project": {
|
55 |
-
"search_score": { "$meta": "vectorSearchScore" },
|
56 |
-
"document": "$$ROOT"
|
57 |
-
}
|
58 |
-
}
|
59 |
-
])
|
60 |
-
|
61 |
-
# Format results
|
62 |
-
results_list = list(results)
|
63 |
-
formatted_results = []
|
64 |
-
|
65 |
-
for idx, result in enumerate(results_list, 1):
|
66 |
-
doc = result['document']
|
67 |
-
formatted_result = f"{idx}. Score: {result['search_score']:.4f}\n"
|
68 |
-
# Add all fields except _id and embeddings
|
69 |
-
for key, value in doc.items():
|
70 |
-
if key not in ['_id', embedding_field]:
|
71 |
-
formatted_result += f"{key}: {value}\n"
|
72 |
-
formatted_results.append(formatted_result)
|
73 |
-
|
74 |
-
return "\n".join(formatted_results) if formatted_results else "No results found"
|
75 |
-
|
76 |
-
except Exception as e:
|
77 |
-
return f"Error: {str(e)}"
|
78 |
-
|
79 |
-
# Create the tab UI
|
80 |
-
with gr.Tab("Search") as tab:
|
81 |
-
with gr.Row():
|
82 |
-
db_input = gr.Dropdown(
|
83 |
-
choices=databases,
|
84 |
-
label="Select Database",
|
85 |
-
info="Database containing the vectors"
|
86 |
-
)
|
87 |
-
collection_input = gr.Dropdown(
|
88 |
-
choices=[],
|
89 |
-
label="Select Collection",
|
90 |
-
info="Collection containing the vectors"
|
91 |
-
)
|
92 |
-
with gr.Row():
|
93 |
-
embedding_field_input = gr.Textbox(
|
94 |
-
label="Embedding Field Name",
|
95 |
-
value="embedding",
|
96 |
-
info="Field containing the vectors"
|
97 |
-
)
|
98 |
-
index_input = gr.Textbox(
|
99 |
-
label="Vector Search Index Name",
|
100 |
-
value="vector_index",
|
101 |
-
info="Index created in Atlas UI"
|
102 |
-
)
|
103 |
-
|
104 |
-
query_input = gr.Textbox(
|
105 |
-
label="Search Query",
|
106 |
-
lines=2,
|
107 |
-
placeholder="What would you like to search for?"
|
108 |
-
)
|
109 |
-
search_btn = gr.Button("Search")
|
110 |
-
search_output = gr.Textbox(label="Results", lines=10)
|
111 |
-
|
112 |
-
# Set up event handlers
|
113 |
-
db_input.change(
|
114 |
-
fn=update_collections,
|
115 |
-
inputs=[db_input],
|
116 |
-
outputs=[collection_input]
|
117 |
-
)
|
118 |
-
|
119 |
-
search_btn.click(
|
120 |
-
fn=vector_search,
|
121 |
-
inputs=[
|
122 |
-
query_input,
|
123 |
-
db_input,
|
124 |
-
collection_input,
|
125 |
-
embedding_field_input,
|
126 |
-
index_input
|
127 |
-
],
|
128 |
-
outputs=search_output
|
129 |
-
)
|
130 |
-
|
131 |
-
# Return the tab and its interface elements
|
132 |
-
interface = {
|
133 |
-
'db_input': db_input,
|
134 |
-
'collection_input': collection_input,
|
135 |
-
'embedding_field_input': embedding_field_input,
|
136 |
-
'index_input': index_input,
|
137 |
-
'query_input': query_input,
|
138 |
-
'search_btn': search_btn,
|
139 |
-
'search_output': search_output
|
140 |
-
}
|
141 |
-
|
142 |
-
return tab, interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/utils/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
# Utils package for MongoDB Vector Search Tool
|
2 |
-
from utils.credentials import check_credentials, init_clients
|
3 |
-
from utils.db_utils import DatabaseUtils
|
4 |
-
from utils.embedding_utils import get_embedding, parallel_generate_embeddings
|
5 |
-
|
6 |
-
__all__ = [
|
7 |
-
'check_credentials',
|
8 |
-
'init_clients',
|
9 |
-
'DatabaseUtils',
|
10 |
-
'get_embedding',
|
11 |
-
'parallel_generate_embeddings'
|
12 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/utils/credentials.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import Tuple
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
from openai import OpenAI
|
5 |
-
from utils.db_utils import DatabaseUtils
|
6 |
-
|
7 |
-
def check_credentials() -> Tuple[bool, str]:
|
8 |
-
"""Check if required credentials are set and valid
|
9 |
-
|
10 |
-
Returns:
|
11 |
-
Tuple[bool, str]: (is_valid, message)
|
12 |
-
- is_valid: True if all credentials are valid
|
13 |
-
- message: Error message if credentials are invalid
|
14 |
-
"""
|
15 |
-
# Load environment variables
|
16 |
-
load_dotenv()
|
17 |
-
|
18 |
-
atlas_uri = os.getenv("ATLAS_URI")
|
19 |
-
openai_key = os.getenv("OPENAI_API_KEY")
|
20 |
-
|
21 |
-
if not atlas_uri:
|
22 |
-
return False, """Please set up your MongoDB Atlas credentials:
|
23 |
-
1. Go to Settings tab
|
24 |
-
2. Add ATLAS_URI as a Repository Secret
|
25 |
-
3. Paste your MongoDB connection string (should start with 'mongodb+srv://')"""
|
26 |
-
|
27 |
-
if not openai_key:
|
28 |
-
return False, """Please set up your OpenAI API key:
|
29 |
-
1. Go to Settings tab
|
30 |
-
2. Add OPENAI_API_KEY as a Repository Secret
|
31 |
-
3. Paste your OpenAI API key"""
|
32 |
-
|
33 |
-
return True, ""
|
34 |
-
|
35 |
-
def init_clients():
|
36 |
-
"""Initialize OpenAI and MongoDB clients
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
Tuple[OpenAI, DatabaseUtils]: OpenAI client and DatabaseUtils instance
|
40 |
-
or (None, None) if initialization fails
|
41 |
-
"""
|
42 |
-
try:
|
43 |
-
openai_client = OpenAI()
|
44 |
-
db_utils = DatabaseUtils()
|
45 |
-
return openai_client, db_utils
|
46 |
-
except Exception as e:
|
47 |
-
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/utils/db_utils.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import List, Dict, Any, Optional
|
3 |
-
from pymongo import MongoClient
|
4 |
-
from pymongo.errors import (
|
5 |
-
ConnectionFailure,
|
6 |
-
OperationFailure,
|
7 |
-
ServerSelectionTimeoutError,
|
8 |
-
InvalidName
|
9 |
-
)
|
10 |
-
from dotenv import load_dotenv
|
11 |
-
|
12 |
-
class DatabaseError(Exception):
|
13 |
-
"""Base class for database operation errors"""
|
14 |
-
pass
|
15 |
-
|
16 |
-
class ConnectionError(DatabaseError):
|
17 |
-
"""Error when connecting to MongoDB Atlas"""
|
18 |
-
pass
|
19 |
-
|
20 |
-
class OperationError(DatabaseError):
|
21 |
-
"""Error during database operations"""
|
22 |
-
pass
|
23 |
-
|
24 |
-
class DatabaseUtils:
|
25 |
-
"""Utility class for MongoDB Atlas database operations
|
26 |
-
|
27 |
-
This class provides methods to interact with MongoDB Atlas databases and collections,
|
28 |
-
including listing databases, collections, and retrieving collection information.
|
29 |
-
|
30 |
-
Attributes:
|
31 |
-
atlas_uri (str): MongoDB Atlas connection string
|
32 |
-
client (MongoClient): MongoDB client instance
|
33 |
-
"""
|
34 |
-
|
35 |
-
def __init__(self):
|
36 |
-
"""Initialize DatabaseUtils with MongoDB Atlas connection
|
37 |
-
|
38 |
-
Raises:
|
39 |
-
ConnectionError: If unable to connect to MongoDB Atlas
|
40 |
-
ValueError: If ATLAS_URI environment variable is not set
|
41 |
-
"""
|
42 |
-
# Load environment variables
|
43 |
-
load_dotenv()
|
44 |
-
|
45 |
-
self.atlas_uri = os.getenv("ATLAS_URI")
|
46 |
-
if not self.atlas_uri:
|
47 |
-
raise ValueError("ATLAS_URI environment variable is not set")
|
48 |
-
|
49 |
-
try:
|
50 |
-
self.client = MongoClient(self.atlas_uri)
|
51 |
-
# Test connection
|
52 |
-
self.client.admin.command('ping')
|
53 |
-
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
54 |
-
raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}")
|
55 |
-
|
56 |
-
def get_databases(self) -> List[str]:
|
57 |
-
"""Get list of all databases in Atlas cluster
|
58 |
-
|
59 |
-
Returns:
|
60 |
-
List[str]: List of database names
|
61 |
-
|
62 |
-
Raises:
|
63 |
-
OperationError: If unable to list databases
|
64 |
-
"""
|
65 |
-
try:
|
66 |
-
return self.client.list_database_names()
|
67 |
-
except OperationFailure as e:
|
68 |
-
raise OperationError(f"Failed to list databases: {str(e)}")
|
69 |
-
|
70 |
-
def get_collections(self, db_name: str) -> List[str]:
|
71 |
-
"""Get list of collections in a database
|
72 |
-
|
73 |
-
Args:
|
74 |
-
db_name (str): Name of the database
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
List[str]: List of collection names
|
78 |
-
|
79 |
-
Raises:
|
80 |
-
OperationError: If unable to list collections
|
81 |
-
ValueError: If db_name is empty or invalid
|
82 |
-
"""
|
83 |
-
if not db_name or not isinstance(db_name, str):
|
84 |
-
raise ValueError("Database name must be a non-empty string")
|
85 |
-
|
86 |
-
try:
|
87 |
-
db = self.client[db_name]
|
88 |
-
return db.list_collection_names()
|
89 |
-
except (OperationFailure, InvalidName) as e:
|
90 |
-
raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}")
|
91 |
-
|
92 |
-
def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]:
|
93 |
-
"""Get information about a collection including document count and sample document
|
94 |
-
|
95 |
-
Args:
|
96 |
-
db_name (str): Name of the database
|
97 |
-
collection_name (str): Name of the collection
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
Dict[str, Any]: Dictionary containing collection information:
|
101 |
-
- count: Number of documents in collection
|
102 |
-
- sample: Sample document from collection (if exists)
|
103 |
-
|
104 |
-
Raises:
|
105 |
-
OperationError: If unable to get collection information
|
106 |
-
ValueError: If db_name or collection_name is empty or invalid
|
107 |
-
"""
|
108 |
-
if not db_name or not isinstance(db_name, str):
|
109 |
-
raise ValueError("Database name must be a non-empty string")
|
110 |
-
if not collection_name or not isinstance(collection_name, str):
|
111 |
-
raise ValueError("Collection name must be a non-empty string")
|
112 |
-
|
113 |
-
try:
|
114 |
-
db = self.client[db_name]
|
115 |
-
collection = db[collection_name]
|
116 |
-
|
117 |
-
return {
|
118 |
-
'count': collection.count_documents({}),
|
119 |
-
'sample': collection.find_one()
|
120 |
-
}
|
121 |
-
except (OperationFailure, InvalidName) as e:
|
122 |
-
raise OperationError(
|
123 |
-
f"Failed to get info for collection '{collection_name}' "
|
124 |
-
f"in database '{db_name}': {str(e)}"
|
125 |
-
)
|
126 |
-
|
127 |
-
def get_field_names(self, db_name: str, collection_name: str) -> List[str]:
|
128 |
-
"""Get list of fields in a collection based on sample document
|
129 |
-
|
130 |
-
Args:
|
131 |
-
db_name (str): Name of the database
|
132 |
-
collection_name (str): Name of the collection
|
133 |
-
|
134 |
-
Returns:
|
135 |
-
List[str]: List of field names (excluding _id and embedding fields)
|
136 |
-
|
137 |
-
Raises:
|
138 |
-
OperationError: If unable to get field names
|
139 |
-
ValueError: If db_name or collection_name is empty or invalid
|
140 |
-
"""
|
141 |
-
try:
|
142 |
-
info = self.get_collection_info(db_name, collection_name)
|
143 |
-
sample = info.get('sample', {})
|
144 |
-
|
145 |
-
if sample:
|
146 |
-
# Get all field names except _id and any existing embedding fields
|
147 |
-
return [field for field in sample.keys()
|
148 |
-
if field != '_id' and not field.endswith('_embedding')]
|
149 |
-
return []
|
150 |
-
except DatabaseError as e:
|
151 |
-
raise OperationError(
|
152 |
-
f"Failed to get field names for collection '{collection_name}' "
|
153 |
-
f"in database '{db_name}': {str(e)}"
|
154 |
-
)
|
155 |
-
|
156 |
-
def close(self):
|
157 |
-
"""Close MongoDB connection safely"""
|
158 |
-
if hasattr(self, 'client'):
|
159 |
-
self.client.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ui/utils/embedding_utils.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
from typing import List, Tuple
|
2 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
-
from pymongo import UpdateOne
|
4 |
-
from pymongo.collection import Collection
|
5 |
-
import math
|
6 |
-
import time
|
7 |
-
import logging
|
8 |
-
|
9 |
-
# Configure logging
|
10 |
-
logging.basicConfig(level=logging.INFO)
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
|
14 |
-
"""Get embeddings for given text using OpenAI API with retry logic"""
|
15 |
-
text = text.replace("\n", " ")
|
16 |
-
|
17 |
-
for attempt in range(max_retries):
|
18 |
-
try:
|
19 |
-
resp = openai_client.embeddings.create(
|
20 |
-
input=[text],
|
21 |
-
model=model
|
22 |
-
)
|
23 |
-
return resp.data[0].embedding
|
24 |
-
except Exception as e:
|
25 |
-
if attempt == max_retries - 1:
|
26 |
-
raise
|
27 |
-
error_details = f"{type(e).__name__}: {str(e)}"
|
28 |
-
if hasattr(e, 'response'):
|
29 |
-
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
|
30 |
-
logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
|
31 |
-
time.sleep(2 ** attempt) # Exponential backoff
|
32 |
-
|
33 |
-
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
|
34 |
-
"""Process a batch of documents to generate embeddings"""
|
35 |
-
logger.info(f"Processing batch of {len(docs)} documents")
|
36 |
-
results = []
|
37 |
-
for doc in docs:
|
38 |
-
# Skip if embedding already exists
|
39 |
-
if embedding_field in doc:
|
40 |
-
continue
|
41 |
-
|
42 |
-
text = doc[field_name]
|
43 |
-
if isinstance(text, str):
|
44 |
-
embedding = get_embedding(text, openai_client)
|
45 |
-
results.append((doc["_id"], embedding))
|
46 |
-
return results
|
47 |
-
|
48 |
-
def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
|
49 |
-
"""Process completed futures and update progress"""
|
50 |
-
completed = 0
|
51 |
-
for future in as_completed(futures, timeout=30): # 30 second timeout
|
52 |
-
try:
|
53 |
-
results = future.result()
|
54 |
-
if results:
|
55 |
-
bulk_ops = [
|
56 |
-
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
|
57 |
-
for doc_id, embedding in results
|
58 |
-
]
|
59 |
-
if bulk_ops:
|
60 |
-
collection.bulk_write(bulk_ops)
|
61 |
-
completed += len(bulk_ops)
|
62 |
-
|
63 |
-
# Update progress
|
64 |
-
if callback:
|
65 |
-
progress = ((processed + completed) / total_docs) * 100
|
66 |
-
callback(progress, processed + completed, total_docs)
|
67 |
-
except Exception as e:
|
68 |
-
error_details = f"{type(e).__name__}: {str(e)}"
|
69 |
-
if hasattr(e, 'response'):
|
70 |
-
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
|
71 |
-
logger.error(f"Error processing future:\n{error_details}")
|
72 |
-
return completed
|
73 |
-
|
74 |
-
def parallel_generate_embeddings(
|
75 |
-
collection: Collection,
|
76 |
-
cursor,
|
77 |
-
field_name: str,
|
78 |
-
embedding_field: str,
|
79 |
-
openai_client,
|
80 |
-
total_docs: int,
|
81 |
-
batch_size: int = 10, # Reduced initial batch size
|
82 |
-
callback=None
|
83 |
-
) -> int:
|
84 |
-
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
|
85 |
-
if total_docs == 0:
|
86 |
-
return 0
|
87 |
-
|
88 |
-
processed = 0
|
89 |
-
current_batch_size = batch_size
|
90 |
-
max_workers = 10 # Start with fewer workers
|
91 |
-
|
92 |
-
logger.info(f"Starting embedding generation for {total_docs} documents")
|
93 |
-
if callback:
|
94 |
-
callback(0, 0, total_docs)
|
95 |
-
|
96 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
97 |
-
batch = []
|
98 |
-
futures = []
|
99 |
-
|
100 |
-
for doc in cursor:
|
101 |
-
batch.append(doc)
|
102 |
-
|
103 |
-
if len(batch) >= current_batch_size:
|
104 |
-
logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
|
105 |
-
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
|
106 |
-
futures.append(future)
|
107 |
-
batch = []
|
108 |
-
|
109 |
-
# Process completed futures more frequently
|
110 |
-
if len(futures) >= max_workers:
|
111 |
-
try:
|
112 |
-
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
|
113 |
-
processed += completed
|
114 |
-
futures = [] # Clear processed futures
|
115 |
-
|
116 |
-
# Gradually increase batch size and workers if processing is successful
|
117 |
-
if completed > 0:
|
118 |
-
current_batch_size = min(current_batch_size + 5, 30)
|
119 |
-
max_workers = min(max_workers + 2, 20)
|
120 |
-
logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
|
121 |
-
except Exception as e:
|
122 |
-
logger.error(f"Error processing futures: {str(e)}")
|
123 |
-
# Reduce batch size and workers on error
|
124 |
-
current_batch_size = max(5, current_batch_size - 5)
|
125 |
-
max_workers = max(5, max_workers - 2)
|
126 |
-
logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
|
127 |
-
|
128 |
-
# Process remaining batch
|
129 |
-
if batch:
|
130 |
-
logger.info(f"Processing final batch of {len(batch)} documents")
|
131 |
-
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
|
132 |
-
futures.append(future)
|
133 |
-
|
134 |
-
# Process remaining futures
|
135 |
-
if futures:
|
136 |
-
try:
|
137 |
-
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
|
138 |
-
processed += completed
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Error processing final futures: {str(e)}")
|
141 |
-
|
142 |
-
logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
|
143 |
-
return processed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|