mongo-vector-search-util / embedding_utils.py
airabbitX's picture
Upload 9 files
8fb6e2f verified
raw
history blame
4.55 kB
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor
from pymongo import UpdateOne
from pymongo.collection import Collection
import math
def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]:
"""Get embeddings for given text using OpenAI API"""
text = text.replace("\n", " ")
resp = openai_client.embeddings.create(
input=[text],
model=model
)
return resp.data[0].embedding
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
"""Process a batch of documents to generate embeddings"""
results = []
for doc in docs:
# Skip if embedding already exists
if embedding_field in doc:
continue
text = doc[field_name]
if isinstance(text, str):
embedding = get_embedding(text, openai_client)
results.append((doc["_id"], embedding))
return results
def parallel_generate_embeddings(
collection: Collection,
cursor,
field_name: str,
embedding_field: str,
openai_client,
total_docs: int,
batch_size: int = 20,
callback=None
) -> int:
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching
Args:
collection: MongoDB collection
cursor: MongoDB cursor for document iteration
field_name: Field containing text to embed
embedding_field: Field to store embeddings
openai_client: OpenAI client instance
total_docs: Total number of documents to process
batch_size: Size of batches for parallel processing
callback: Optional callback function for progress updates
Returns:
Number of documents processed
"""
if total_docs == 0:
return 0
processed = 0
# Initial progress update
if callback:
callback(0, 0, total_docs)
# Process documents in batches using cursor
with ThreadPoolExecutor(max_workers=20) as executor:
batch = []
futures = []
# Iterate through cursor and build batches
for doc in cursor:
batch.append(doc)
if len(batch) >= batch_size:
# Submit batch for processing
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
futures.append(future)
batch = [] # Clear batch for next round
# Process completed futures to free up memory
completed_futures = [f for f in futures if f.done()]
for future in completed_futures:
results = future.result()
if results:
# Batch update MongoDB
bulk_ops = [
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
for doc_id, embedding in results
]
if bulk_ops:
collection.bulk_write(bulk_ops)
processed += len(bulk_ops)
# Update progress
if callback:
progress = (processed / total_docs) * 100
callback(progress, processed, total_docs)
futures = [f for f in futures if not f.done()]
# Process any remaining documents in the last batch
if batch:
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
futures.append(future)
# Wait for remaining futures to complete
for future in futures:
results = future.result()
if results:
bulk_ops = [
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
for doc_id, embedding in results
]
if bulk_ops:
collection.bulk_write(bulk_ops)
processed += len(bulk_ops)
# Final progress update
if callback:
progress = (processed / total_docs) * 100
callback(progress, processed, total_docs)
return processed