File size: 4,547 Bytes
8fb6e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor
from pymongo import UpdateOne
from pymongo.collection import Collection
import math

def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]:
    """Get embeddings for given text using OpenAI API"""
    text = text.replace("\n", " ")
    resp = openai_client.embeddings.create(
        input=[text],
        model=model
    )
    return resp.data[0].embedding

def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
    """Process a batch of documents to generate embeddings"""
    results = []
    for doc in docs:
        # Skip if embedding already exists
        if embedding_field in doc:
            continue
            
        text = doc[field_name]
        if isinstance(text, str):
            embedding = get_embedding(text, openai_client)
            results.append((doc["_id"], embedding))
    return results

def parallel_generate_embeddings(
    collection: Collection,
    cursor,
    field_name: str,
    embedding_field: str,
    openai_client,
    total_docs: int,
    batch_size: int = 20,
    callback=None
) -> int:
    """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching
    
    Args:
        collection: MongoDB collection
        cursor: MongoDB cursor for document iteration
        field_name: Field containing text to embed
        embedding_field: Field to store embeddings
        openai_client: OpenAI client instance
        total_docs: Total number of documents to process
        batch_size: Size of batches for parallel processing
        callback: Optional callback function for progress updates
        
    Returns:
        Number of documents processed
    """
    if total_docs == 0:
        return 0
        
    processed = 0
    
    # Initial progress update
    if callback:
        callback(0, 0, total_docs)
        
    # Process documents in batches using cursor
    with ThreadPoolExecutor(max_workers=20) as executor:
        batch = []
        futures = []
        
        # Iterate through cursor and build batches
        for doc in cursor:
            batch.append(doc)
            
            if len(batch) >= batch_size:
                # Submit batch for processing
                future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
                futures.append(future)
                batch = []  # Clear batch for next round
                
                # Process completed futures to free up memory
                completed_futures = [f for f in futures if f.done()]
                for future in completed_futures:
                    results = future.result()
                    if results:
                        # Batch update MongoDB
                        bulk_ops = [
                            UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
                            for doc_id, embedding in results
                        ]
                        if bulk_ops:
                            collection.bulk_write(bulk_ops)
                            processed += len(bulk_ops)
                            
                        # Update progress
                        if callback:
                            progress = (processed / total_docs) * 100
                            callback(progress, processed, total_docs)
                            
                futures = [f for f in futures if not f.done()]
        
        # Process any remaining documents in the last batch
        if batch:
            future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
            futures.append(future)
        
        # Wait for remaining futures to complete
        for future in futures:
            results = future.result()
            if results:
                bulk_ops = [
                    UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
                    for doc_id, embedding in results
                ]
                if bulk_ops:
                    collection.bulk_write(bulk_ops)
                    processed += len(bulk_ops)
                    
                    # Final progress update
                    if callback:
                        progress = (processed / total_docs) * 100
                        callback(progress, processed, total_docs)
                    
    return processed