|
from typing import List, Tuple |
|
from concurrent.futures import ThreadPoolExecutor |
|
from pymongo import UpdateOne |
|
from pymongo.collection import Collection |
|
import math |
|
|
|
def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]: |
|
"""Get embeddings for given text using OpenAI API""" |
|
text = text.replace("\n", " ") |
|
resp = openai_client.embeddings.create( |
|
input=[text], |
|
model=model |
|
) |
|
return resp.data[0].embedding |
|
|
|
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: |
|
"""Process a batch of documents to generate embeddings""" |
|
results = [] |
|
for doc in docs: |
|
|
|
if embedding_field in doc: |
|
continue |
|
|
|
text = doc[field_name] |
|
if isinstance(text, str): |
|
embedding = get_embedding(text, openai_client) |
|
results.append((doc["_id"], embedding)) |
|
return results |
|
|
|
def parallel_generate_embeddings( |
|
collection: Collection, |
|
cursor, |
|
field_name: str, |
|
embedding_field: str, |
|
openai_client, |
|
total_docs: int, |
|
batch_size: int = 20, |
|
callback=None |
|
) -> int: |
|
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching |
|
|
|
Args: |
|
collection: MongoDB collection |
|
cursor: MongoDB cursor for document iteration |
|
field_name: Field containing text to embed |
|
embedding_field: Field to store embeddings |
|
openai_client: OpenAI client instance |
|
total_docs: Total number of documents to process |
|
batch_size: Size of batches for parallel processing |
|
callback: Optional callback function for progress updates |
|
|
|
Returns: |
|
Number of documents processed |
|
""" |
|
if total_docs == 0: |
|
return 0 |
|
|
|
processed = 0 |
|
|
|
|
|
if callback: |
|
callback(0, 0, total_docs) |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=20) as executor: |
|
batch = [] |
|
futures = [] |
|
|
|
|
|
for doc in cursor: |
|
batch.append(doc) |
|
|
|
if len(batch) >= batch_size: |
|
|
|
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) |
|
futures.append(future) |
|
batch = [] |
|
|
|
|
|
completed_futures = [f for f in futures if f.done()] |
|
for future in completed_futures: |
|
results = future.result() |
|
if results: |
|
|
|
bulk_ops = [ |
|
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) |
|
for doc_id, embedding in results |
|
] |
|
if bulk_ops: |
|
collection.bulk_write(bulk_ops) |
|
processed += len(bulk_ops) |
|
|
|
|
|
if callback: |
|
progress = (processed / total_docs) * 100 |
|
callback(progress, processed, total_docs) |
|
|
|
futures = [f for f in futures if not f.done()] |
|
|
|
|
|
if batch: |
|
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) |
|
futures.append(future) |
|
|
|
|
|
for future in futures: |
|
results = future.result() |
|
if results: |
|
bulk_ops = [ |
|
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) |
|
for doc_id, embedding in results |
|
] |
|
if bulk_ops: |
|
collection.bulk_write(bulk_ops) |
|
processed += len(bulk_ops) |
|
|
|
|
|
if callback: |
|
progress = (processed / total_docs) * 100 |
|
callback(progress, processed, total_docs) |
|
|
|
return processed |
|
|