Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pymongo import UpdateOne | |
| from pymongo.collection import Collection | |
| import math | |
| import time | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]: | |
| """Get embeddings for given text using OpenAI API with retry logic""" | |
| text = text.replace("\n", " ") | |
| for attempt in range(max_retries): | |
| try: | |
| resp = openai_client.embeddings.create( | |
| input=[text], | |
| model=model | |
| ) | |
| return resp.data[0].embedding | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise | |
| error_details = f"{type(e).__name__}: {str(e)}" | |
| if hasattr(e, 'response'): | |
| error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" | |
| logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}") | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: | |
| """Process a batch of documents to generate embeddings""" | |
| logger.info(f"Processing batch of {len(docs)} documents") | |
| results = [] | |
| for doc in docs: | |
| # Skip if embedding already exists | |
| if embedding_field in doc: | |
| continue | |
| text = doc[field_name] | |
| if isinstance(text, str): | |
| embedding = get_embedding(text, openai_client) | |
| results.append((doc["_id"], embedding)) | |
| return results | |
| def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int: | |
| """Process completed futures and update progress""" | |
| completed = 0 | |
| for future in as_completed(futures, timeout=30): # 30 second timeout | |
| try: | |
| results = future.result() | |
| if results: | |
| bulk_ops = [ | |
| UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) | |
| for doc_id, embedding in results | |
| ] | |
| if bulk_ops: | |
| collection.bulk_write(bulk_ops) | |
| completed += len(bulk_ops) | |
| # Update progress | |
| if callback: | |
| progress = ((processed + completed) / total_docs) * 100 | |
| callback(progress, processed + completed, total_docs) | |
| except Exception as e: | |
| error_details = f"{type(e).__name__}: {str(e)}" | |
| if hasattr(e, 'response'): | |
| error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}" | |
| logger.error(f"Error processing future:\n{error_details}") | |
| return completed | |
| def parallel_generate_embeddings( | |
| collection: Collection, | |
| cursor, | |
| field_name: str, | |
| embedding_field: str, | |
| openai_client, | |
| total_docs: int, | |
| batch_size: int = 10, # Reduced initial batch size | |
| callback=None | |
| ) -> int: | |
| """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing""" | |
| if total_docs == 0: | |
| return 0 | |
| processed = 0 | |
| current_batch_size = batch_size | |
| max_workers = 10 # Start with fewer workers | |
| logger.info(f"Starting embedding generation for {total_docs} documents") | |
| if callback: | |
| callback(0, 0, total_docs) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| batch = [] | |
| futures = [] | |
| for doc in cursor: | |
| batch.append(doc) | |
| if len(batch) >= current_batch_size: | |
| logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})") | |
| future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) | |
| futures.append(future) | |
| batch = [] | |
| # Process completed futures more frequently | |
| if len(futures) >= max_workers: | |
| try: | |
| completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) | |
| processed += completed | |
| futures = [] # Clear processed futures | |
| # Gradually increase batch size and workers if processing is successful | |
| if completed > 0: | |
| current_batch_size = min(current_batch_size + 5, 30) | |
| max_workers = min(max_workers + 2, 20) | |
| logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}") | |
| except Exception as e: | |
| logger.error(f"Error processing futures: {str(e)}") | |
| # Reduce batch size and workers on error | |
| current_batch_size = max(5, current_batch_size - 5) | |
| max_workers = max(5, max_workers - 2) | |
| logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}") | |
| # Process remaining batch | |
| if batch: | |
| logger.info(f"Processing final batch of {len(batch)} documents") | |
| future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) | |
| futures.append(future) | |
| # Process remaining futures | |
| if futures: | |
| try: | |
| completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback) | |
| processed += completed | |
| except Exception as e: | |
| logger.error(f"Error processing final futures: {str(e)}") | |
| logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents") | |
| return processed | |