mongo-vector-search-util / embedding_utils.py
airabbitX's picture
Upload 7 files
7301668 verified
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pymongo import UpdateOne
from pymongo.collection import Collection
import math
import time
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
"""Get embeddings for given text using OpenAI API with retry logic"""
text = text.replace("\n", " ")
for attempt in range(max_retries):
try:
resp = openai_client.embeddings.create(
input=[text],
model=model
)
return resp.data[0].embedding
except Exception as e:
if attempt == max_retries - 1:
raise
error_details = f"{type(e).__name__}: {str(e)}"
if hasattr(e, 'response'):
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
time.sleep(2 ** attempt) # Exponential backoff
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
"""Process a batch of documents to generate embeddings"""
logger.info(f"Processing batch of {len(docs)} documents")
results = []
for doc in docs:
# Skip if embedding already exists
if embedding_field in doc:
continue
text = doc[field_name]
if isinstance(text, str):
embedding = get_embedding(text, openai_client)
results.append((doc["_id"], embedding))
return results
def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
"""Process completed futures and update progress"""
completed = 0
for future in as_completed(futures, timeout=30): # 30 second timeout
try:
results = future.result()
if results:
bulk_ops = [
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
for doc_id, embedding in results
]
if bulk_ops:
collection.bulk_write(bulk_ops)
completed += len(bulk_ops)
# Update progress
if callback:
progress = ((processed + completed) / total_docs) * 100
callback(progress, processed + completed, total_docs)
except Exception as e:
error_details = f"{type(e).__name__}: {str(e)}"
if hasattr(e, 'response'):
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
logger.error(f"Error processing future:\n{error_details}")
return completed
def parallel_generate_embeddings(
collection: Collection,
cursor,
field_name: str,
embedding_field: str,
openai_client,
total_docs: int,
batch_size: int = 10, # Reduced initial batch size
callback=None
) -> int:
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
if total_docs == 0:
return 0
processed = 0
current_batch_size = batch_size
max_workers = 10 # Start with fewer workers
logger.info(f"Starting embedding generation for {total_docs} documents")
if callback:
callback(0, 0, total_docs)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
batch = []
futures = []
for doc in cursor:
batch.append(doc)
if len(batch) >= current_batch_size:
logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
futures.append(future)
batch = []
# Process completed futures more frequently
if len(futures) >= max_workers:
try:
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
processed += completed
futures = [] # Clear processed futures
# Gradually increase batch size and workers if processing is successful
if completed > 0:
current_batch_size = min(current_batch_size + 5, 30)
max_workers = min(max_workers + 2, 20)
logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
except Exception as e:
logger.error(f"Error processing futures: {str(e)}")
# Reduce batch size and workers on error
current_batch_size = max(5, current_batch_size - 5)
max_workers = max(5, max_workers - 2)
logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
# Process remaining batch
if batch:
logger.info(f"Processing final batch of {len(batch)} documents")
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
futures.append(future)
# Process remaining futures
if futures:
try:
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
processed += completed
except Exception as e:
logger.error(f"Error processing final futures: {str(e)}")
logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
return processed