File size: 6,263 Bytes
8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 7301668 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 7301668 bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f bbccab6 8fb6e2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pymongo import UpdateOne
from pymongo.collection import Collection
import math
import time
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
"""Get embeddings for given text using OpenAI API with retry logic"""
text = text.replace("\n", " ")
for attempt in range(max_retries):
try:
resp = openai_client.embeddings.create(
input=[text],
model=model
)
return resp.data[0].embedding
except Exception as e:
if attempt == max_retries - 1:
raise
error_details = f"{type(e).__name__}: {str(e)}"
if hasattr(e, 'response'):
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
time.sleep(2 ** attempt) # Exponential backoff
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
"""Process a batch of documents to generate embeddings"""
logger.info(f"Processing batch of {len(docs)} documents")
results = []
for doc in docs:
# Skip if embedding already exists
if embedding_field in doc:
continue
text = doc[field_name]
if isinstance(text, str):
embedding = get_embedding(text, openai_client)
results.append((doc["_id"], embedding))
return results
def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
"""Process completed futures and update progress"""
completed = 0
for future in as_completed(futures, timeout=30): # 30 second timeout
try:
results = future.result()
if results:
bulk_ops = [
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
for doc_id, embedding in results
]
if bulk_ops:
collection.bulk_write(bulk_ops)
completed += len(bulk_ops)
# Update progress
if callback:
progress = ((processed + completed) / total_docs) * 100
callback(progress, processed + completed, total_docs)
except Exception as e:
error_details = f"{type(e).__name__}: {str(e)}"
if hasattr(e, 'response'):
error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
logger.error(f"Error processing future:\n{error_details}")
return completed
def parallel_generate_embeddings(
collection: Collection,
cursor,
field_name: str,
embedding_field: str,
openai_client,
total_docs: int,
batch_size: int = 10, # Reduced initial batch size
callback=None
) -> int:
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
if total_docs == 0:
return 0
processed = 0
current_batch_size = batch_size
max_workers = 10 # Start with fewer workers
logger.info(f"Starting embedding generation for {total_docs} documents")
if callback:
callback(0, 0, total_docs)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
batch = []
futures = []
for doc in cursor:
batch.append(doc)
if len(batch) >= current_batch_size:
logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
futures.append(future)
batch = []
# Process completed futures more frequently
if len(futures) >= max_workers:
try:
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
processed += completed
futures = [] # Clear processed futures
# Gradually increase batch size and workers if processing is successful
if completed > 0:
current_batch_size = min(current_batch_size + 5, 30)
max_workers = min(max_workers + 2, 20)
logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
except Exception as e:
logger.error(f"Error processing futures: {str(e)}")
# Reduce batch size and workers on error
current_batch_size = max(5, current_batch_size - 5)
max_workers = max(5, max_workers - 2)
logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
# Process remaining batch
if batch:
logger.info(f"Processing final batch of {len(batch)} documents")
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
futures.append(future)
# Process remaining futures
if futures:
try:
completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
processed += completed
except Exception as e:
logger.error(f"Error processing final futures: {str(e)}")
logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
return processed
|