from typing import List, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed from pymongo import UpdateOne from pymongo.collection import Collection import math import time import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]: """Get embeddings for given text using OpenAI API with retry logic""" text = text.replace("\n", " ") for attempt in range(max_retries): try: resp = openai_client.embeddings.create( input=[text], model=model ) return resp.data[0].embedding except Exception as e: if attempt == max_retries - 1: raise error_details = f"{type(e).__name__}: {str(e)}" if hasattr(e, 'response'): error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}") time.sleep(2 ** attempt) # Exponential backoff def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: """Process a batch of documents to generate embeddings""" logger.info(f"Processing batch of {len(docs)} documents") results = [] for doc in docs: # Skip if embedding already exists if embedding_field in doc: continue text = doc[field_name] if isinstance(text, str): embedding = get_embedding(text, openai_client) results.append((doc["_id"], embedding)) return results def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int: """Process completed futures and update progress""" completed = 0 for future in as_completed(futures, timeout=30): # 30 second timeout try: results = future.result() if results: bulk_ops = [ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) for doc_id, embedding in results ] if bulk_ops: collection.bulk_write(bulk_ops) completed += len(bulk_ops) # Update progress if callback: progress = ((processed + completed) / total_docs) * 100 callback(progress, processed + completed, total_docs) except Exception as e: error_details = f"{type(e).__name__}: {str(e)}" if hasattr(e, 'response'): error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" logger.error(f"Error processing future:\n{error_details}") return completed def parallel_generate_embeddings( collection: Collection, cursor, field_name: str, embedding_field: str, openai_client, total_docs: int, batch_size: int = 10, # Reduced initial batch size callback=None ) -> int: """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing""" if total_docs == 0: return 0 processed = 0 current_batch_size = batch_size max_workers = 10 # Start with fewer workers logger.info(f"Starting embedding generation for {total_docs} documents") if callback: callback(0, 0, total_docs) with ThreadPoolExecutor(max_workers=max_workers) as executor: batch = [] futures = [] for doc in cursor: batch.append(doc) if len(batch) >= current_batch_size: logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})") future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) futures.append(future) batch = [] # Process completed futures more frequently if len(futures) >= max_workers: try: completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) processed += completed futures = [] # Clear processed futures # Gradually increase batch size and workers if processing is successful if completed > 0: current_batch_size = min(current_batch_size + 5, 30) max_workers = min(max_workers + 2, 20) logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}") except Exception as e: logger.error(f"Error processing futures: {str(e)}") # Reduce batch size and workers on error current_batch_size = max(5, current_batch_size - 5) max_workers = max(5, max_workers - 2) logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}") # Process remaining batch if batch: logger.info(f"Processing final batch of {len(batch)} documents") future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) futures.append(future) # Process remaining futures if futures: try: completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) processed += completed except Exception as e: logger.error(f"Error processing final futures: {str(e)}") logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents") return processed