from typing import List, Tuple |
from concurrent.futures import ThreadPoolExecutor, as_completed |
from pymongo import UpdateOne |
from pymongo.collection import Collection |
import math |
import time |
import logging |
logging.basicConfig(level=logging.INFO) |
logger = logging.getLogger(__name__) |
def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]: |
"""Get embeddings for given text using OpenAI API with retry logic""" |
text = text.replace("\n", " ") |
for attempt in range(max_retries): |
try: |
resp = openai_client.embeddings.create( |
input=[text], |
model=model |
) |
return resp.data[0].embedding |
except Exception as e: |
if attempt == max_retries - 1: |
raise |
error_details = f"{type(e).__name__}: {str(e)}" |
if hasattr(e, 'response'): |
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" |
logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}") |
time.sleep(2 ** attempt) |
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: |
"""Process a batch of documents to generate embeddings""" |
logger.info(f"Processing batch of {len(docs)} documents") |
results = [] |
for doc in docs: |
if embedding_field in doc: |
continue |
text = doc[field_name] |
if isinstance(text, str): |
embedding = get_embedding(text, openai_client) |
results.append((doc["_id"], embedding)) |
return results |
def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int: |
"""Process completed futures and update progress""" |
completed = 0 |
for future in as_completed(futures, timeout=30): |
try: |
results = future.result() |
if results: |
bulk_ops = [ |
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) |
for doc_id, embedding in results |
] |
if bulk_ops: |
collection.bulk_write(bulk_ops) |
completed += len(bulk_ops) |
if callback: |
progress = ((processed + completed) / total_docs) * 100 |
callback(progress, processed + completed, total_docs) |
except Exception as e: |
error_details = f"{type(e).__name__}: {str(e)}" |
if hasattr(e, 'response'): |
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" |
logger.error(f"Error processing future:\n{error_details}") |
return completed |
def parallel_generate_embeddings( |
collection: Collection, |
cursor, |
field_name: str, |
embedding_field: str, |
openai_client, |
total_docs: int, |
batch_size: int = 10, |
callback=None |
) -> int: |
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing""" |
if total_docs == 0: |
return 0 |
processed = 0 |
current_batch_size = batch_size |
max_workers = 10 |
logger.info(f"Starting embedding generation for {total_docs} documents") |
if callback: |
callback(0, 0, total_docs) |
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
batch = [] |
futures = [] |
for doc in cursor: |
batch.append(doc) |
if len(batch) >= current_batch_size: |
logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})") |
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) |
futures.append(future) |
batch = [] |
if len(futures) >= max_workers: |
try: |
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) |
processed += completed |
futures = [] |
if completed > 0: |
current_batch_size = min(current_batch_size + 5, 30) |
max_workers = min(max_workers + 2, 20) |
logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}") |
except Exception as e: |
logger.error(f"Error processing futures: {str(e)}") |
current_batch_size = max(5, current_batch_size - 5) |
max_workers = max(5, max_workers - 2) |
logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}") |
if batch: |
logger.info(f"Processing final batch of {len(batch)} documents") |
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) |
futures.append(future) |
if futures: |
try: |
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) |
processed += completed |
except Exception as e: |
logger.error(f"Error processing final futures: {str(e)}") |
logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents") |
return processed |