gakim commited on
Commit
a6fb103
·
1 Parent(s): 3f835a5

embedding migration script

Browse files
src/know_lang_bot/utils/migration/embedding_migrations.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ from pathlib import Path
4
+ import chromadb
5
+ from chromadb.errors import InvalidCollectionException
6
+ from rich.progress import Progress
7
+ from rich.console import Console
8
+ from typing import List, Dict, Any, Optional
9
+ import openai
10
+ from openai import OpenAI
11
+ from datetime import datetime
12
+ from know_lang_bot.config import AppConfig
13
+ from know_lang_bot.utils.fancy_log import FancyLogger
14
+
15
+ LOG = FancyLogger(__name__)
16
+ console = Console()
17
+
18
+ BATCH_SIZE = 2000 # Max items per batch
19
+ MAX_CHARS_PER_CHUNK = 10000 # Approximate 8k tokens limit (very rough estimate)
20
+
21
+ class BatchState:
22
+ """Class to track batch processing state"""
23
+ def __init__(self, root_dir: Path):
24
+ self.root_dir = root_dir
25
+ self.batch_dir = root_dir / "batches"
26
+ self.results_dir = root_dir / "results"
27
+ self.metadata_dir = root_dir / "metadata"
28
+
29
+ # Create directories
30
+ for dir in [self.batch_dir, self.results_dir, self.metadata_dir]:
31
+ dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ def save_batch_metadata(self, batch_id: str, metadata: Dict):
34
+ """Save batch processing metadata"""
35
+ with open(self.metadata_dir / f"{batch_id}.json", "w") as f:
36
+ json.dump(metadata, f, indent=2)
37
+
38
+ def truncate_chunk(text: str, max_chars: int = MAX_CHARS_PER_CHUNK) -> str:
39
+ """Truncate text to approximate token limit while preserving structure"""
40
+ if len(text) <= max_chars:
41
+ return text
42
+
43
+ # Split into CODE and SUMMARY sections
44
+ parts = text.split("\nSUMMARY:\n")
45
+ if len(parts) != 2:
46
+ # If structure not found, just truncate
47
+ return text[:max_chars]
48
+
49
+ code, summary = parts
50
+
51
+ # Calculate available space for each section (proportionally)
52
+ total_len = len(code) + len(summary)
53
+ code_ratio = len(code) / total_len
54
+
55
+ # Allocate characters proportionally
56
+ code_chars = int(max_chars * code_ratio)
57
+ summary_chars = max_chars - code_chars
58
+
59
+ truncated_code = code[:code_chars]
60
+ truncated_summary = summary[:summary_chars]
61
+
62
+ return f"{truncated_code}\nSUMMARY:\n{truncated_summary}"
63
+
64
+ async def prepare_batches(config: AppConfig, batch_state: BatchState) -> List[str]:
65
+ """Prepare batch files from ChromaDB and return batch IDs"""
66
+ source_client = chromadb.PersistentClient(path=str(config.db.persist_directory))
67
+ source_collection = source_client.get_collection(name=config.db.collection_name)
68
+
69
+ # Get all documents
70
+ results = source_collection.get(include=['documents', 'metadatas' ])
71
+
72
+ if not results['ids']:
73
+ console.print("[red]No documents found in source collection!")
74
+ return
75
+ total_documents = len(results['ids'])
76
+
77
+ batch_ids = []
78
+ with Progress() as progress:
79
+ task = progress.add_task("Preparing batches...", total=total_documents)
80
+
81
+ current_batch = []
82
+ current_batch_ids = []
83
+ current_batch_num = 0
84
+
85
+ for i, (doc_id, doc, metadata) in enumerate(zip(
86
+ results['ids'],
87
+ results['documents'],
88
+ results['metadatas']
89
+ )):
90
+ # Truncate document if needed
91
+ truncated_doc = truncate_chunk(doc)
92
+
93
+ current_batch.append((doc_id, truncated_doc))
94
+ current_batch_ids.append(doc_id)
95
+
96
+ # Create batch file when size limit reached or at end
97
+ if len(current_batch) >= BATCH_SIZE or i == total_documents - 1:
98
+ batch_file = batch_state.batch_dir / f"batch_{current_batch_num}.jsonl"
99
+
100
+ with open(batch_file, 'w') as f:
101
+ for bid, bdoc in current_batch:
102
+ request = {
103
+ "custom_id": bid,
104
+ "method": "POST",
105
+ "url": "/v1/embeddings",
106
+ "body": {
107
+ "model": config.embedding.model_name,
108
+ "input": bdoc
109
+ }
110
+ }
111
+ f.write(json.dumps(request) + '\n')
112
+
113
+ # Save batch metadata
114
+ batch_metadata = {
115
+ "batch_id": f"batch_{current_batch_num}",
116
+ "created_at": datetime.now().isoformat(),
117
+ "document_ids": current_batch_ids,
118
+ "size": len(current_batch),
119
+ "status": "prepared"
120
+ }
121
+ batch_state.save_batch_metadata(f"batch_{current_batch_num}", batch_metadata)
122
+
123
+ batch_ids.append(f"batch_{current_batch_num}")
124
+ current_batch = []
125
+ current_batch_ids = []
126
+ current_batch_num += 1
127
+
128
+ progress.advance(task)
129
+
130
+ return batch_ids
131
+
132
+ async def submit_batches(batch_state: BatchState, batch_ids: List[str]):
133
+ """Submit prepared batches to OpenAI"""
134
+ client = OpenAI()
135
+
136
+ with Progress() as progress:
137
+ task = progress.add_task("Submitting batches...", total=len(batch_ids))
138
+
139
+ for batch_id in batch_ids:
140
+ batch_file = batch_state.batch_dir / f"{batch_id}.jsonl"
141
+
142
+ # Upload batch file
143
+ file = client.files.create(
144
+ file=open(batch_file, "rb"),
145
+ purpose="batch"
146
+ )
147
+
148
+ # Create batch job
149
+ batch = client.batches.create(
150
+ input_file_id=file.id,
151
+ endpoint="/v1/embeddings",
152
+ completion_window="24h"
153
+ )
154
+
155
+ # Update metadata
156
+ with open(batch_state.metadata_dir / f"{batch_id}.json", "r") as f:
157
+ metadata = json.load(f)
158
+
159
+ metadata.update({
160
+ "openai_batch_id": batch.id,
161
+ "file_id": file.id,
162
+ "status": "submitted",
163
+ "submitted_at": datetime.now().isoformat()
164
+ })
165
+
166
+ batch_state.save_batch_metadata(batch_id, metadata)
167
+ progress.advance(task)
168
+
169
+ async def process_batch_results(
170
+ batch_state: BatchState,
171
+ config: AppConfig,
172
+ batch_ids: Optional[List[str]] = None
173
+ ):
174
+ """Process completed batches and store in new ChromaDB"""
175
+ client = OpenAI()
176
+
177
+ # Initialize target DB
178
+ target_path = Path(config.db.persist_directory).parent / "batch_embeddings_db"
179
+ target_path.mkdir(exist_ok=True)
180
+ target_client = chromadb.PersistentClient(path=str(target_path))
181
+
182
+ # Create or get collection
183
+ new_collection_name = f"{config.db.collection_name}_batch"
184
+ try:
185
+ target_collection = target_client.get_collection(name=new_collection_name)
186
+ console.print(f"[yellow]Collection {new_collection_name} exists, appending...")
187
+ except InvalidCollectionException:
188
+ target_collection = target_client.create_collection(
189
+ name=new_collection_name,
190
+ metadata={"hnsw:space": "cosine"}
191
+ )
192
+
193
+ # Process each batch
194
+ if batch_ids is None:
195
+ batch_ids = [f.stem for f in batch_state.metadata_dir.glob("*.json")]
196
+
197
+ with Progress() as progress:
198
+ task = progress.add_task("Processing results...", total=len(batch_ids))
199
+
200
+ for batch_id in batch_ids:
201
+ # Load batch metadata
202
+ with open(batch_state.metadata_dir / f"{batch_id}.json", "r") as f:
203
+ metadata = json.load(f)
204
+
205
+ if metadata["status"] != "submitted":
206
+ console.print(f"[yellow]Skipping {batch_id} - not submitted")
207
+ progress.advance(task)
208
+ continue
209
+
210
+ # Check batch status
211
+ batch_status = client.batches.retrieve(metadata["openai_batch_id"])
212
+ if batch_status.status != "completed":
213
+ console.print(f"[yellow]Batch {batch_id} not complete, status: {batch_status.status}")
214
+ progress.advance(task)
215
+ continue
216
+
217
+ # Download results
218
+ output_file = batch_state.results_dir / f"{batch_id}_output.jsonl"
219
+ response = client.files.content(batch_status.output_file_id)
220
+ with open(output_file, "wb") as f:
221
+ f.write(response.read())
222
+
223
+ # Process embeddings
224
+ source_client = chromadb.PersistentClient(path=str(config.db.persist_directory))
225
+ source_collection = source_client.get_collection(name=config.db.collection_name)
226
+
227
+ # Get original documents and metadata
228
+ results = source_collection.get(
229
+ ids=metadata["document_ids"],
230
+ include=['documents', 'metadatas']
231
+ )
232
+
233
+ # Process results file
234
+ embeddings = []
235
+ processed_ids = []
236
+ processed_docs = []
237
+ processed_metadatas = []
238
+
239
+ with open(output_file) as f:
240
+ for line in f:
241
+ result = json.loads(line)
242
+ if result["response"]["status_code"] == 200:
243
+ doc_idx = metadata["document_ids"].index(result["custom_id"])
244
+
245
+ embeddings.append(result["response"]["body"]["data"][0]["embedding"])
246
+ processed_ids.append(result["custom_id"])
247
+ processed_docs.append(results["documents"][doc_idx])
248
+ processed_metadatas.append(results["metadatas"][doc_idx])
249
+
250
+ # Add to new collection
251
+ if processed_ids:
252
+ target_collection.add(
253
+ embeddings=embeddings,
254
+ documents=processed_docs,
255
+ metadatas=processed_metadatas,
256
+ ids=processed_ids
257
+ )
258
+
259
+ # Update metadata
260
+ metadata.update({
261
+ "status": "processed",
262
+ "processed_at": datetime.now().isoformat(),
263
+ "processed_count": len(processed_ids)
264
+ })
265
+ batch_state.save_batch_metadata(batch_id, metadata)
266
+
267
+ progress.advance(task)
268
+
269
+ async def main():
270
+ config = AppConfig()
271
+ batch_state = BatchState(Path("embedding_migration"))
272
+
273
+ # Step 1: Prepare batches
274
+ console.print("[green]Step 1: Preparing batches...")
275
+ batch_ids = await prepare_batches(config, batch_state)
276
+
277
+ # # Step 2: Submit batches
278
+ console.print("\n[green]Step 2: Submitting batches...")
279
+ await submit_batches(batch_state, batch_ids)
280
+
281
+ # Step 3: Process results
282
+ console.print("\n[green]Step 3: Processing results...")
283
+ await process_batch_results(batch_state, config)
284
+
285
+ if __name__ == "__main__":
286
+ asyncio.run(main())