AbstractPhil commited on
Commit
c71cf49
Β·
verified Β·
1 Parent(s): 85441c2

Create bert_handler.py

Browse files
Files changed (1) hide show
  1. bert_handler.py +558 -0
bert_handler.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+ from pathlib import Path
5
+ import json
6
+ import re
7
+ import gc
8
+
9
+
10
+ class BERTHandler:
11
+ """
12
+ VRAM-safe BERT model handler for loading, tokenization, and saving
13
+ Handles all token management and checkpoint operations with proper cleanup
14
+ """
15
+
16
+ def __init__(self, symbolic_tokens=None):
17
+ # Default symbolic tokens
18
+ self.symbolic_tokens = symbolic_tokens or [
19
+ "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
20
+ "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
21
+ "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
22
+ "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
23
+ "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
24
+ "<fabric>", "<jewelry>"
25
+ ]
26
+
27
+ # Generate shunt tokens
28
+ self.shunt_tokens = [f"[SHUNT_{1000000 + i}]" for i in range(len(self.symbolic_tokens))]
29
+ self.all_special_tokens = self.symbolic_tokens + self.shunt_tokens
30
+
31
+ # Model components
32
+ self.tokenizer = None
33
+ self.model = None
34
+ self.current_step = 0
35
+ self.current_epoch = 1
36
+
37
+ print(f"🎯 BERTHandler initialized with {len(self.all_special_tokens)} special tokens")
38
+
39
+ def __del__(self):
40
+ """Destructor to ensure cleanup when object is deleted"""
41
+ self._cleanup_model()
42
+
43
+ def _cleanup_model(self):
44
+ """
45
+ CRITICAL: Comprehensive model cleanup to free VRAM
46
+ This is the core method that prevents VRAM accumulation
47
+ """
48
+ if hasattr(self, 'model') and self.model is not None:
49
+ print("🧹 Cleaning up existing model from VRAM...")
50
+
51
+ # Move model to CPU first to free GPU memory
52
+ if torch.cuda.is_available() and next(self.model.parameters(), None) is not None:
53
+ if next(self.model.parameters()).is_cuda:
54
+ self.model = self.model.cpu()
55
+
56
+ # Delete the model
57
+ del self.model
58
+ self.model = None
59
+
60
+ # Force garbage collection
61
+ gc.collect()
62
+
63
+ # Clear CUDA cache
64
+ if torch.cuda.is_available():
65
+ torch.cuda.empty_cache()
66
+ torch.cuda.synchronize() # Ensure all CUDA operations complete
67
+
68
+ print("βœ… Model cleanup complete")
69
+
70
+ def _print_vram_usage(self, prefix=""):
71
+ """Print current VRAM usage for monitoring"""
72
+ if torch.cuda.is_available():
73
+ allocated = torch.cuda.memory_allocated() / 1e9
74
+ reserved = torch.cuda.memory_reserved() / 1e9
75
+ print(f"🎯 {prefix}VRAM: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
76
+ else:
77
+ print(f"🎯 {prefix}CUDA not available")
78
+
79
+ def load_fresh_model(self, model_name="nomic-ai/nomic-bert-2048"):
80
+ """Load fresh model and add special tokens with proper VRAM management"""
81
+ print(f"πŸ†• Loading fresh model: {model_name}")
82
+ self._print_vram_usage("Before cleanup: ")
83
+
84
+ # CRITICAL: Clean up existing model first
85
+ self._cleanup_model()
86
+ self._print_vram_usage("After cleanup: ")
87
+
88
+ try:
89
+ # Load base model and tokenizer
90
+ print("πŸ“₯ Loading base tokenizer...")
91
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
92
+
93
+ print("πŸ“₯ Loading base model...")
94
+ self.model = AutoModelForMaskedLM.from_pretrained(
95
+ model_name,
96
+ trust_remote_code=True,
97
+ torch_dtype=torch.float32 # Explicit dtype for consistency
98
+ )
99
+
100
+ # Add special tokens (ONLY for fresh models)
101
+ original_size = len(self.tokenizer)
102
+ special_tokens_dict = {"additional_special_tokens": self.all_special_tokens}
103
+ num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
104
+
105
+ print(f" - Original vocab size: {original_size}")
106
+ print(f" - Added {num_added} special tokens")
107
+ print(f" - New vocab size: {len(self.tokenizer)}")
108
+
109
+ # Resize model embeddings (ONLY for fresh models)
110
+ if num_added > 0:
111
+ self._resize_embeddings()
112
+
113
+ # Reset training state
114
+ self.current_step = 0
115
+ self.current_epoch = 1
116
+
117
+ print("βœ… Fresh model loaded successfully")
118
+ self._print_vram_usage("After loading: ")
119
+ return self.model, self.tokenizer
120
+
121
+ except Exception as e:
122
+ print(f"❌ Failed to load fresh model: {e}")
123
+ # Clean up on failure
124
+ self._cleanup_model()
125
+ raise
126
+
127
+ def load_checkpoint(self, checkpoint_path):
128
+ """Load model from checkpoint - use saved tokenizer as-is, no modifications"""
129
+ print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
130
+ self._print_vram_usage("Before cleanup: ")
131
+
132
+ # CRITICAL: Clean up existing model first
133
+ self._cleanup_model()
134
+ self._print_vram_usage("After cleanup: ")
135
+
136
+ try:
137
+ # Load saved tokenizer AS-IS (already contains special tokens)
138
+ print("πŸ“₯ Loading saved tokenizer...")
139
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
140
+ print(f" - Tokenizer loaded: {len(self.tokenizer)} tokens (already includes special tokens)")
141
+
142
+ # Load saved model AS-IS (already matches tokenizer)
143
+ print("πŸ“₯ Loading saved model...")
144
+ self.model = AutoModelForMaskedLM.from_pretrained(
145
+ checkpoint_path,
146
+ trust_remote_code=True,
147
+ torch_dtype=torch.float32,
148
+ )
149
+
150
+ print(f"βœ… Model loaded successfully")
151
+ print(f" - Model vocab size: {self.model.config.vocab_size}")
152
+ print(f" - Embedding size: {self.model.bert.embeddings.word_embeddings.weight.shape[0]}")
153
+ print(f" - Tokenizer size: {len(self.tokenizer)}")
154
+
155
+ # DO NOT MODIFY ANYTHING - checkpoint is self-consistent
156
+
157
+ # Load training state
158
+ self._load_training_state(checkpoint_path)
159
+
160
+ print(f"βœ… Checkpoint loaded - Step: {self.current_step}, Epoch: {self.current_epoch}")
161
+ self._print_vram_usage("After loading: ")
162
+ return self.model, self.tokenizer
163
+
164
+ except Exception as e:
165
+ print(f"❌ Failed to load checkpoint: {e}")
166
+ # Clean up on failure
167
+ self._cleanup_model()
168
+ raise
169
+
170
+ def save_checkpoint(self, save_path, step=None, epoch=None):
171
+ """Save model checkpoint with consistency verification"""
172
+ if self.model is None or self.tokenizer is None:
173
+ raise RuntimeError("No model loaded to save")
174
+
175
+ step = step or self.current_step
176
+ epoch = epoch or self.current_epoch
177
+
178
+ # CRITICAL: Verify consistency before saving
179
+ tokenizer_size = len(self.tokenizer)
180
+ model_vocab_size = self.model.config.vocab_size
181
+ embedding_size = self.model.bert.embeddings.word_embeddings.weight.shape[0]
182
+
183
+ if not (tokenizer_size == model_vocab_size == embedding_size):
184
+ print(f"⚠️ CONSISTENCY CHECK FAILED before saving:")
185
+ print(f" - Tokenizer size: {tokenizer_size}")
186
+ print(f" - Model config vocab_size: {model_vocab_size}")
187
+ print(f" - Embedding size: {embedding_size}")
188
+
189
+ # Force consistency before saving
190
+ print(f"πŸ”§ Forcing consistency to tokenizer size: {tokenizer_size}")
191
+ self.model.config.vocab_size = tokenizer_size
192
+
193
+ # Resize embeddings if needed
194
+ if embedding_size != tokenizer_size:
195
+ print(f"πŸ”§ Resizing embeddings to match tokenizer: {embedding_size} β†’ {tokenizer_size}")
196
+ self._resize_embeddings()
197
+
198
+ # Create checkpoint directory
199
+ checkpoint_dir = Path(save_path) / f"symbolic_bert_step{step}_epoch{epoch}"
200
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
201
+
202
+ print(f"πŸ’Ύ Saving checkpoint: {checkpoint_dir}")
203
+
204
+ try:
205
+ # Save model and tokenizer
206
+ print("πŸ’Ύ Saving model...")
207
+ self.model.save_pretrained(checkpoint_dir)
208
+
209
+ print("πŸ’Ύ Saving tokenizer...")
210
+ self.tokenizer.save_pretrained(checkpoint_dir)
211
+
212
+ # Save training state with consistency info
213
+ training_state = {
214
+ "step": step,
215
+ "epoch": epoch,
216
+ "vocab_size": len(self.tokenizer),
217
+ "model_vocab_size": self.model.config.vocab_size,
218
+ "embedding_size": self.model.bert.embeddings.word_embeddings.weight.shape[0],
219
+ "consistency_verified": True,
220
+ "special_tokens_count": len(self.all_special_tokens)
221
+ }
222
+
223
+ with open(checkpoint_dir / "training_config.json", "w") as f:
224
+ json.dump(training_state, f, indent=2)
225
+
226
+ # Save token mappings
227
+ self._save_token_mappings(checkpoint_dir)
228
+
229
+ # VERIFICATION: Load and check consistency
230
+ print("πŸ” Verifying saved checkpoint consistency...")
231
+ test_tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
232
+ test_config_path = checkpoint_dir / "config.json"
233
+
234
+ with open(test_config_path) as f:
235
+ test_config = json.load(f)
236
+
237
+ saved_tokenizer_size = len(test_tokenizer)
238
+ saved_model_vocab = test_config["vocab_size"]
239
+
240
+ if saved_tokenizer_size != saved_model_vocab:
241
+ raise RuntimeError(
242
+ f"CHECKPOINT SAVE FAILED! Inconsistency detected:\n"
243
+ f" Saved tokenizer size: {saved_tokenizer_size}\n"
244
+ f" Saved model vocab: {saved_model_vocab}"
245
+ )
246
+
247
+ # Update internal state
248
+ self.current_step = step
249
+ self.current_epoch = epoch
250
+
251
+ print(f"βœ… Checkpoint saved and verified successfully")
252
+ print(f" - Consistent vocab size: {saved_tokenizer_size}")
253
+ return checkpoint_dir
254
+
255
+ except Exception as e:
256
+ print(f"❌ Failed to save checkpoint: {e}")
257
+ raise
258
+
259
+ def find_latest_checkpoint(self, base_path, pattern="symbolic_bert"):
260
+ """Find latest checkpoint in directory"""
261
+ path = Path(base_path)
262
+ if not path.exists():
263
+ print(f"⚠️ Checkpoint directory does not exist: {base_path}")
264
+ return None
265
+
266
+ # Find checkpoints
267
+ checkpoints = list(path.glob(f"{pattern}_step*_epoch*"))
268
+ if not checkpoints:
269
+ print(f"⚠️ No checkpoints found in {base_path}")
270
+ return None
271
+
272
+ # Sort by step number (more reliable than modification time)
273
+ def extract_step(checkpoint_path):
274
+ match = re.search(r"step(\d+)", checkpoint_path.name)
275
+ return int(match.group(1)) if match else 0
276
+
277
+ checkpoints.sort(key=extract_step, reverse=True)
278
+ latest = checkpoints[0]
279
+
280
+ print(f"πŸ“‚ Found latest checkpoint: {latest}")
281
+ return latest
282
+
283
+ def get_token_mappings(self):
284
+ """Get token ID mappings"""
285
+ if self.tokenizer is None:
286
+ return {}, {}
287
+
288
+ symbolic_ids = {}
289
+ shunt_ids = {}
290
+
291
+ for token in self.symbolic_tokens:
292
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
293
+ if token_id != self.tokenizer.unk_token_id:
294
+ symbolic_ids[token] = token_id
295
+
296
+ for token in self.shunt_tokens:
297
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
298
+ if token_id != self.tokenizer.unk_token_id:
299
+ shunt_ids[token] = token_id
300
+
301
+ return symbolic_ids, shunt_ids
302
+
303
+ def to_device(self, device):
304
+ """Move model to device with VRAM monitoring"""
305
+ if self.model is not None:
306
+ print(f"πŸ“± Moving model to {device}...")
307
+ self._print_vram_usage("Before device move: ")
308
+
309
+ self.model = self.model.to(device)
310
+
311
+ # Clear cache after moving to device
312
+ if torch.cuda.is_available():
313
+ torch.cuda.empty_cache()
314
+
315
+ print(f"βœ… Model moved to {device}")
316
+ self._print_vram_usage("After device move: ")
317
+ else:
318
+ print(f"⚠️ No model loaded to move to {device}")
319
+ return self
320
+
321
+ def _resize_embeddings(self):
322
+ """Resize model embeddings to match tokenizer (handles both expansion and shrinking)"""
323
+ if self.model is None:
324
+ raise RuntimeError("No model loaded")
325
+
326
+ old_embeddings = self.model.bert.embeddings.word_embeddings
327
+ old_size, embedding_dim = old_embeddings.weight.shape
328
+ new_size = len(self.tokenizer)
329
+
330
+ if old_size == new_size:
331
+ print(f"βœ… Embeddings already correct size: {new_size}")
332
+ return
333
+
334
+ print(f"πŸ”„ Resizing embeddings: {old_size} β†’ {new_size}")
335
+
336
+ try:
337
+ # Create new embeddings
338
+ new_embeddings = nn.Embedding(new_size, embedding_dim)
339
+
340
+ # Copy existing embeddings (handle both expansion and shrinking)
341
+ with torch.no_grad():
342
+ # Copy the minimum of old_size and new_size
343
+ copy_size = min(old_size, new_size)
344
+ new_embeddings.weight.data[:copy_size] = old_embeddings.weight.data[:copy_size].clone()
345
+
346
+ # If expanding, initialize new token embeddings
347
+ if new_size > old_size:
348
+ num_added = new_size - old_size
349
+ # Use small random initialization for new tokens
350
+ new_embeddings.weight.data[old_size:] = torch.randn(
351
+ num_added, embedding_dim, device=old_embeddings.weight.device
352
+ ) * 0.02
353
+ print(f" - Added {num_added} new token embeddings")
354
+ elif new_size < old_size:
355
+ num_removed = old_size - new_size
356
+ print(f" - Removed {num_removed} token embeddings")
357
+
358
+ # Replace embeddings
359
+ self.model.bert.embeddings.word_embeddings = new_embeddings
360
+
361
+ # Resize decoder if it exists
362
+ if hasattr(self.model.cls.predictions, "decoder"):
363
+ old_decoder = self.model.cls.predictions.decoder
364
+ new_decoder = nn.Linear(embedding_dim, new_size, bias=True)
365
+
366
+ with torch.no_grad():
367
+ # Copy existing weights (handle both expansion and shrinking)
368
+ copy_size = min(old_decoder.weight.shape[0], new_size)
369
+ new_decoder.weight.data[:copy_size] = old_decoder.weight.data[:copy_size].clone()
370
+
371
+ # Handle bias
372
+ if old_decoder.bias is not None:
373
+ new_decoder.bias.data[:copy_size] = old_decoder.bias.data[:copy_size].clone()
374
+
375
+ # If expanding, tie new decoder weights to new embeddings and init bias
376
+ if new_size > old_decoder.weight.shape[0]:
377
+ start_idx = old_decoder.weight.shape[0]
378
+ new_decoder.weight.data[start_idx:] = new_embeddings.weight.data[start_idx:].clone()
379
+ if old_decoder.bias is not None:
380
+ new_decoder.bias.data[start_idx:] = torch.zeros(
381
+ new_size - start_idx, device=old_decoder.bias.device
382
+ )
383
+
384
+ self.model.cls.predictions.decoder = new_decoder
385
+
386
+ # Update config
387
+ self.model.config.vocab_size = new_size
388
+
389
+ print(f"βœ… Embeddings resized successfully")
390
+
391
+ except Exception as e:
392
+ print(f"❌ Failed to resize embeddings: {e}")
393
+ raise
394
+
395
+ def _load_training_state(self, checkpoint_path):
396
+ """Load training state from checkpoint"""
397
+ # Try training_config.json first
398
+ config_path = Path(checkpoint_path) / "training_config.json"
399
+ if config_path.exists():
400
+ try:
401
+ with open(config_path) as f:
402
+ config = json.load(f)
403
+ self.current_step = config.get("step", 0)
404
+ self.current_epoch = config.get("epoch", 1)
405
+ print(f"πŸ“Š Loaded training state: step {self.current_step}, epoch {self.current_epoch}")
406
+ return
407
+ except Exception as e:
408
+ print(f"⚠️ Failed to load training_config.json: {e}")
409
+
410
+ # Fallback: extract from path name
411
+ match = re.search(r"step(\d+)_epoch(\d+)", str(checkpoint_path))
412
+ if match:
413
+ self.current_step = int(match.group(1))
414
+ self.current_epoch = int(match.group(2))
415
+ print(f"πŸ“Š Extracted training state from path: step {self.current_step}, epoch {self.current_epoch}")
416
+ else:
417
+ self.current_step = 0
418
+ self.current_epoch = 1
419
+ print(f"⚠️ Could not determine training state, using defaults: step 0, epoch 1")
420
+
421
+ def _save_token_mappings(self, checkpoint_dir):
422
+ """Save token ID mappings"""
423
+ try:
424
+ symbolic_ids, shunt_ids = self.get_token_mappings()
425
+
426
+ token_mappings = {
427
+ "symbolic_token_ids": symbolic_ids,
428
+ "shunt_token_ids": shunt_ids,
429
+ "symbolic_tokens": self.symbolic_tokens,
430
+ "shunt_tokens": self.shunt_tokens,
431
+ "total_special_tokens": len(self.all_special_tokens)
432
+ }
433
+
434
+ with open(checkpoint_dir / "special_token_ids.json", "w") as f:
435
+ json.dump(token_mappings, f, indent=2)
436
+
437
+ print(f"πŸ’Ύ Saved {len(symbolic_ids)} symbolic and {len(shunt_ids)} shunt token mappings")
438
+
439
+ except Exception as e:
440
+ print(f"⚠️ Failed to save token mappings: {e}")
441
+
442
+ def summary(self):
443
+ """Print comprehensive handler summary"""
444
+ print(f"\nπŸ“‹ BERT HANDLER SUMMARY:")
445
+
446
+ if self.model is None:
447
+ print("❌ No model loaded")
448
+ return
449
+
450
+ symbolic_ids, shunt_ids = self.get_token_mappings()
451
+
452
+ print(f" πŸ“š Tokenizer:")
453
+ print(f" - Size: {len(self.tokenizer)}")
454
+ print(f" - Special tokens: {len(self.tokenizer.additional_special_tokens or [])}")
455
+
456
+ print(f" πŸ€– Model:")
457
+ print(f" - Config vocab size: {self.model.config.vocab_size}")
458
+ print(f" - Embedding vocab size: {self.model.bert.embeddings.word_embeddings.weight.shape[0]}")
459
+ print(f" - Embedding dim: {self.model.bert.embeddings.word_embeddings.weight.shape[1]}")
460
+
461
+ if hasattr(self.model.cls.predictions, "decoder"):
462
+ decoder = self.model.cls.predictions.decoder
463
+ print(f" - Decoder output size: {decoder.weight.shape[0]}")
464
+
465
+ print(f" 🎯 Special Tokens:")
466
+ print(f" - Symbolic tokens mapped: {len(symbolic_ids)}")
467
+ print(f" - Shunt tokens mapped: {len(shunt_ids)}")
468
+ print(f" - Total defined: {len(self.all_special_tokens)}")
469
+
470
+ print(f" πŸ“Š Training State:")
471
+ print(f" - Current step: {self.current_step}")
472
+ print(f" - Current epoch: {self.current_epoch}")
473
+
474
+ # VRAM usage
475
+ self._print_vram_usage(" 🎯 ")
476
+
477
+ # Check for vocab consistency
478
+ tokenizer_size = len(self.tokenizer)
479
+ model_config_size = self.model.config.vocab_size
480
+ embedding_size = self.model.bert.embeddings.word_embeddings.weight.shape[0]
481
+
482
+ if tokenizer_size == model_config_size == embedding_size:
483
+ print(f" βœ… All vocab sizes consistent: {tokenizer_size}")
484
+ else:
485
+ print(f" ⚠️ Vocab size mismatch detected:")
486
+ print(f" - Tokenizer: {tokenizer_size}")
487
+ print(f" - Model config: {model_config_size}")
488
+ print(f" - Embeddings: {embedding_size}")
489
+
490
+ def clear_vram(self):
491
+ """Explicit method to clear VRAM for debugging"""
492
+ print("🧹 Explicit VRAM cleanup requested...")
493
+ self._cleanup_model()
494
+ self._print_vram_usage("After cleanup: ")
495
+
496
+
497
+ # Utility functions for safe usage patterns
498
+
499
+ def create_handler_with_fresh_model(model_name="nomic-ai/nomic-bert-2048", symbolic_tokens=None):
500
+ """Factory function to create handler and load fresh model safely"""
501
+ print("πŸ”„ Creating new BERTHandler with fresh model...")
502
+ handler = BERTHandler(symbolic_tokens=symbolic_tokens)
503
+ model, tokenizer = handler.load_fresh_model(model_name)
504
+ return handler, model, tokenizer
505
+
506
+
507
+ def create_handler_from_checkpoint(checkpoint_path, symbolic_tokens=None):
508
+ """Factory function to create handler and load from checkpoint safely"""
509
+ print("πŸ”„ Creating new BERTHandler from checkpoint...")
510
+ handler = BERTHandler(symbolic_tokens=symbolic_tokens)
511
+ model, tokenizer = handler.load_checkpoint(checkpoint_path)
512
+ return handler, model, tokenizer
513
+
514
+
515
+ # Usage examples and testing
516
+ if __name__ == "__main__":
517
+ # Example usage with comprehensive error handling
518
+
519
+ def test_vram_safety():
520
+ """Test VRAM safety by loading multiple models"""
521
+ print("πŸ§ͺ Testing VRAM safety...")
522
+
523
+ handler = BERTHandler()
524
+
525
+ # Load model 1
526
+ print("\n--- Loading Model 1 ---")
527
+ handler.load_fresh_model("bert-base-uncased")
528
+ handler.summary()
529
+
530
+ # Load model 2 (should clean up model 1)
531
+ print("\n--- Loading Model 2 (should cleanup Model 1) ---")
532
+ handler.load_fresh_model("distilbert-base-uncased")
533
+ handler.summary()
534
+
535
+ # Explicit cleanup
536
+ print("\n--- Explicit Cleanup ---")
537
+ handler.clear_vram()
538
+
539
+ print("βœ… VRAM safety test complete")
540
+
541
+ # Uncomment to run test
542
+ # test_vram_safety()
543
+
544
+ """
545
+ USAGE EXAMPLES:
546
+
547
+ # Safe way to work with fresh models:
548
+ handler, model, tokenizer = create_handler_with_fresh_model("nomic-ai/nomic-bert-2048")
549
+
550
+ # Safe way to work with checkpoints:
551
+ handler, model, tokenizer = create_handler_from_checkpoint("/path/to/checkpoint")
552
+
553
+ # Manual cleanup when needed:
554
+ handler.clear_vram()
555
+
556
+ # Always check summary for consistency:
557
+ handler.summary()
558
+ """