airabbitX commited on
Commit
bbccab6
·
verified ·
1 Parent(s): 787933d

Upload 7 files

Browse files
Files changed (2) hide show
  1. embedding_utils.py +90 -69
  2. run.sh +1 -0
embedding_utils.py CHANGED
@@ -1,20 +1,38 @@
1
  from typing import List, Tuple
2
- from concurrent.futures import ThreadPoolExecutor
3
  from pymongo import UpdateOne
4
  from pymongo.collection import Collection
5
  import math
 
 
6
 
7
- def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]:
8
- """Get embeddings for given text using OpenAI API"""
 
 
 
 
9
  text = text.replace("\n", " ")
10
- resp = openai_client.embeddings.create(
11
- input=[text],
12
- model=model
13
- )
14
- return resp.data[0].embedding
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
17
  """Process a batch of documents to generate embeddings"""
 
18
  results = []
19
  for doc in docs:
20
  # Skip if embedding already exists
@@ -27,6 +45,32 @@ def process_batch(docs: List[dict], field_name: str, embedding_field: str, opena
27
  results.append((doc["_id"], embedding))
28
  return results
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def parallel_generate_embeddings(
31
  collection: Collection,
32
  cursor,
@@ -34,89 +78,66 @@ def parallel_generate_embeddings(
34
  embedding_field: str,
35
  openai_client,
36
  total_docs: int,
37
- batch_size: int = 20,
38
  callback=None
39
  ) -> int:
40
- """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching
41
-
42
- Args:
43
- collection: MongoDB collection
44
- cursor: MongoDB cursor for document iteration
45
- field_name: Field containing text to embed
46
- embedding_field: Field to store embeddings
47
- openai_client: OpenAI client instance
48
- total_docs: Total number of documents to process
49
- batch_size: Size of batches for parallel processing
50
- callback: Optional callback function for progress updates
51
-
52
- Returns:
53
- Number of documents processed
54
- """
55
  if total_docs == 0:
56
  return 0
57
 
58
  processed = 0
 
 
59
 
60
- # Initial progress update
61
  if callback:
62
  callback(0, 0, total_docs)
63
 
64
- # Process documents in batches using cursor
65
- with ThreadPoolExecutor(max_workers=20) as executor:
66
  batch = []
67
  futures = []
68
 
69
- # Iterate through cursor and build batches
70
  for doc in cursor:
71
  batch.append(doc)
72
 
73
- if len(batch) >= batch_size:
74
- # Submit batch for processing
75
  future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
76
  futures.append(future)
77
- batch = [] # Clear batch for next round
78
 
79
- # Process completed futures to free up memory
80
- completed_futures = [f for f in futures if f.done()]
81
- for future in completed_futures:
82
- results = future.result()
83
- if results:
84
- # Batch update MongoDB
85
- bulk_ops = [
86
- UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
87
- for doc_id, embedding in results
88
- ]
89
- if bulk_ops:
90
- collection.bulk_write(bulk_ops)
91
- processed += len(bulk_ops)
92
-
93
- # Update progress
94
- if callback:
95
- progress = (processed / total_docs) * 100
96
- callback(progress, processed, total_docs)
97
-
98
- futures = [f for f in futures if not f.done()]
99
 
100
- # Process any remaining documents in the last batch
101
  if batch:
 
102
  future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
103
  futures.append(future)
104
 
105
- # Wait for remaining futures to complete
106
- for future in futures:
107
- results = future.result()
108
- if results:
109
- bulk_ops = [
110
- UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
111
- for doc_id, embedding in results
112
- ]
113
- if bulk_ops:
114
- collection.bulk_write(bulk_ops)
115
- processed += len(bulk_ops)
116
-
117
- # Final progress update
118
- if callback:
119
- progress = (processed / total_docs) * 100
120
- callback(progress, processed, total_docs)
121
-
122
  return processed
 
1
  from typing import List, Tuple
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
  from pymongo import UpdateOne
4
  from pymongo.collection import Collection
5
  import math
6
+ import time
7
+ import logging
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def get_embedding(text: str, openai_client, model="text-embedding-ada-002", max_retries=3) -> list[float]:
14
+ """Get embeddings for given text using OpenAI API with retry logic"""
15
  text = text.replace("\n", " ")
16
+
17
+ for attempt in range(max_retries):
18
+ try:
19
+ resp = openai_client.embeddings.create(
20
+ input=[text],
21
+ model=model
22
+ )
23
+ return resp.data[0].embedding
24
+ except Exception as e:
25
+ if attempt == max_retries - 1:
26
+ raise
27
+ error_details = f"{type(e).__name__}: {str(e)}"
28
+ if hasattr(e, 'response'):
29
+ error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
30
+ logger.warning(f"Embedding API error (attempt {attempt + 1}/{max_retries}):\n{error_details}")
31
+ time.sleep(2 ** attempt) # Exponential backoff
32
 
33
  def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]:
34
  """Process a batch of documents to generate embeddings"""
35
+ logger.info(f"Processing batch of {len(docs)} documents")
36
  results = []
37
  for doc in docs:
38
  # Skip if embedding already exists
 
45
  results.append((doc["_id"], embedding))
46
  return results
47
 
48
+ def process_futures(futures: List, collection: Collection, embedding_field: str, processed: int, total_docs: int, callback=None) -> int:
49
+ """Process completed futures and update progress"""
50
+ completed = 0
51
+ for future in as_completed(futures, timeout=30): # 30 second timeout
52
+ try:
53
+ results = future.result()
54
+ if results:
55
+ bulk_ops = [
56
+ UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}})
57
+ for doc_id, embedding in results
58
+ ]
59
+ if bulk_ops:
60
+ collection.bulk_write(bulk_ops)
61
+ completed += len(bulk_ops)
62
+
63
+ # Update progress
64
+ if callback:
65
+ progress = ((processed + completed) / total_docs) * 100
66
+ callback(progress, processed + completed, total_docs)
67
+ except Exception as e:
68
+ error_details = f"{type(e).__name__}: {str(e)}"
69
+ if hasattr(e, 'response'):
70
+ error_details += f"\nResponse: {e.response.text if hasattr(e.response, 'text') else 'No response text'}"
71
+ logger.error(f"Error processing future:\n{error_details}")
72
+ return completed
73
+
74
  def parallel_generate_embeddings(
75
  collection: Collection,
76
  cursor,
 
78
  embedding_field: str,
79
  openai_client,
80
  total_docs: int,
81
+ batch_size: int = 10, # Reduced initial batch size
82
  callback=None
83
  ) -> int:
84
+ """Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching and dynamic processing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  if total_docs == 0:
86
  return 0
87
 
88
  processed = 0
89
+ current_batch_size = batch_size
90
+ max_workers = 5 # Start with fewer workers
91
 
92
+ logger.info(f"Starting embedding generation for {total_docs} documents")
93
  if callback:
94
  callback(0, 0, total_docs)
95
 
96
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
97
  batch = []
98
  futures = []
99
 
 
100
  for doc in cursor:
101
  batch.append(doc)
102
 
103
+ if len(batch) >= current_batch_size:
104
+ logger.info(f"Submitting batch of {len(batch)} documents (batch size: {current_batch_size})")
105
  future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client)
106
  futures.append(future)
107
+ batch = []
108
 
109
+ # Process completed futures more frequently
110
+ if len(futures) >= max_workers:
111
+ try:
112
+ completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
113
+ processed += completed
114
+ futures = [] # Clear processed futures
115
+
116
+ # Gradually increase batch size and workers if processing is successful
117
+ if completed > 0:
118
+ current_batch_size = min(current_batch_size + 5, 30)
119
+ max_workers = min(max_workers + 2, 20)
120
+ logger.info(f"Increased batch size to {current_batch_size}, workers to {max_workers}")
121
+ except Exception as e:
122
+ logger.error(f"Error processing futures: {str(e)}")
123
+ # Reduce batch size and workers on error
124
+ current_batch_size = max(5, current_batch_size - 5)
125
+ max_workers = max(3, max_workers - 2)
126
+ logger.info(f"Reduced batch size to {current_batch_size}, workers to {max_workers}")
 
 
127
 
128
+ # Process remaining batch
129
  if batch:
130
+ logger.info(f"Processing final batch of {len(batch)} documents")
131
  future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client)
132
  futures.append(future)
133
 
134
+ # Process remaining futures
135
+ if futures:
136
+ try:
137
+ completed = process_futures(futures, collection, embedding_field, processed, total_docs, callback)
138
+ processed += completed
139
+ except Exception as e:
140
+ logger.error(f"Error processing final futures: {str(e)}")
141
+
142
+ logger.info(f"Completed embedding generation. Processed {processed}/{total_docs} documents")
 
 
 
 
 
 
 
 
143
  return processed
run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python app.py