File size: 6,263 Bytes
8fb6e2f
bbccab6
8fb6e2f
 
 
bbccab6
 
8fb6e2f
bbccab6
 
 
 
 
 
8fb6e2f
bbccab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb6e2f
 
 
bbccab6
8fb6e2f
 
 
 
 
 
 
 
 
 
 
 
bbccab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb6e2f
 
 
 
 
 
 
bbccab6
8fb6e2f
 
bbccab6
8fb6e2f
 
 
 
bbccab6
7301668
8fb6e2f
bbccab6
8fb6e2f
 
 
bbccab6
8fb6e2f
 
 
 
 
 
bbccab6
 
8fb6e2f
 
bbccab6
8fb6e2f
bbccab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7301668
bbccab6
8fb6e2f
bbccab6
8fb6e2f
bbccab6
8fb6e2f
 
 
bbccab6
 
 
 
 
 
 
 
 
8fb6e2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pymongo import UpdateOne
from pymongo.collection import Collection
import math
import time
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
    """Get embeddings for given text using OpenAI API with retry logic"""
    text = text.replace("\n", " ")
    
    for attempt in range(max_retries):
        try:
            resp = openai_client.embeddings.create(
                input=[text],
                model=model
            )
            return resp.data[0].embedding
        except Exception as e:
            if attempt == max_retries - 1:
                raise
            error_details = f"{type(e).__name__}: {str(e)}"
            if hasattr(e, 'response'):
                error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
            logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
            time.sleep(2 ** attempt)  # Exponential backoff

def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
    """Process a batch of documents to generate embeddings"""
    logger.info(f"Processing batch of {len(docs)} documents")
    results = []
    for doc in docs:
        # Skip if embedding already exists
        if embedding_field in doc:
            continue
            
        text = doc[field_name]
        if isinstance(text, str):
            embedding = get_embedding(text, openai_client)
            results.append((doc["_id"], embedding))
    return results

def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
    """Process completed futures and update progress"""
    completed = 0
    for future in as_completed(futures, timeout=30):  # 30 second timeout
        try:
            results = future.result()
            if results:
                bulk_ops = [
                    UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
                    for doc_id, embedding in results
                ]
                if bulk_ops:
                    collection.bulk_write(bulk_ops)
                    completed += len(bulk_ops)
                    
                    # Update progress
                    if callback:
                        progress = ((processed + completed) / total_docs) * 100
                        callback(progress, processed + completed, total_docs)
        except Exception as e:
            error_details = f"{type(e).__name__}: {str(e)}"
            if hasattr(e, 'response'):
                error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
            logger.error(f"Error processing future:\n{error_details}")
    return completed

def parallel_generate_embeddings(
    collection: Collection,
    cursor,
    field_name: str,
    embedding_field: str,
    openai_client,
    total_docs: int,
    batch_size: int = 10,  # Reduced initial batch size
    callback=None
) -> int:
    """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
    if total_docs == 0:
        return 0
        
    processed = 0
    current_batch_size = batch_size
    max_workers = 10  # Start with fewer workers
    
    logger.info(f"Starting embedding generation for {total_docs} documents")
    if callback:
        callback(0, 0, total_docs)
        
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        batch = []
        futures = []
        
        for doc in cursor:
            batch.append(doc)
            
            if len(batch) >= current_batch_size:
                logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
                future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
                futures.append(future)
                batch = []
                
                # Process completed futures more frequently
                if len(futures) >= max_workers:
                    try:
                        completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
                        processed += completed
                        futures = []  # Clear processed futures
                        
                        # Gradually increase batch size and workers if processing is successful
                        if completed > 0:
                            current_batch_size = min(current_batch_size + 5, 30)
                            max_workers = min(max_workers + 2, 20)
                            logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
                    except Exception as e:
                        logger.error(f"Error processing futures: {str(e)}")
                        # Reduce batch size and workers on error
                        current_batch_size = max(5, current_batch_size - 5)
                        max_workers = max(5, max_workers - 2)
                        logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
        
        # Process remaining batch
        if batch:
            logger.info(f"Processing final batch of {len(batch)} documents")
            future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
            futures.append(future)
        
        # Process remaining futures
        if futures:
            try:
                completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
                processed += completed
            except Exception as e:
                logger.error(f"Error processing final futures: {str(e)}")
                
    logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
    return processed