from typing import List, Tuple from concurrent.futures import ThreadPoolExecutor from pymongo import UpdateOne from pymongo.collection import Collection import math def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]: """Get embeddings for given text using OpenAI API""" text = text.replace("\n", " ") resp = openai_client.embeddings.create( input=[text], model=model ) return resp.data[0].embedding def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: """Process a batch of documents to generate embeddings""" results = [] for doc in docs: # Skip if embedding already exists if embedding_field in doc: continue text = doc[field_name] if isinstance(text, str): embedding = get_embedding(text, openai_client) results.append((doc["_id"], embedding)) return results def parallel_generate_embeddings( collection: Collection, cursor, field_name: str, embedding_field: str, openai_client, total_docs: int, batch_size: int = 20, callback=None ) -> int: """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching Args: collection: MongoDB collection cursor: MongoDB cursor for document iteration field_name: Field containing text to embed embedding_field: Field to store embeddings openai_client: OpenAI client instance total_docs: Total number of documents to process batch_size: Size of batches for parallel processing callback: Optional callback function for progress updates Returns: Number of documents processed """ if total_docs == 0: return 0 processed = 0 # Initial progress update if callback: callback(0, 0, total_docs) # Process documents in batches using cursor with ThreadPoolExecutor(max_workers=20) as executor: batch = [] futures = [] # Iterate through cursor and build batches for doc in cursor: batch.append(doc) if len(batch) >= batch_size: # Submit batch for processing future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) futures.append(future) batch = [] # Clear batch for next round # Process completed futures to free up memory completed_futures = [f for f in futures if f.done()] for future in completed_futures: results = future.result() if results: # Batch update MongoDB bulk_ops = [ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) for doc_id, embedding in results ] if bulk_ops: collection.bulk_write(bulk_ops) processed += len(bulk_ops) # Update progress if callback: progress = (processed / total_docs) * 100 callback(progress, processed, total_docs) futures = [f for f in futures if not f.done()] # Process any remaining documents in the last batch if batch: future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) futures.append(future) # Wait for remaining futures to complete for future in futures: results = future.result() if results: bulk_ops = [ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) for doc_id, embedding in results ] if bulk_ops: collection.bulk_write(bulk_ops) processed += len(bulk_ops) # Final progress update if callback: progress = (processed / total_docs) * 100 callback(progress, processed, total_docs) return processed