Debito commited on
Commit
4372b35
·
verified ·
1 Parent(s): fcf0a07

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +773 -0
utils/utils.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # utils/utils.py - Utility Functions for Mamba Encoder Swarm Architecture
3
+ # =============================================================================
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import time
10
+ import json
11
+ import logging
12
+ import os
13
+ import psutil
14
+ import gc
15
+ from typing import Dict, List, Tuple, Optional, Union, Any
16
+ from collections import defaultdict, deque
17
+ from datetime import datetime, timedelta
18
+ import threading
19
+ import warnings
20
+ from functools import wraps, lru_cache
21
+ import hashlib
22
+ import pickle
23
+
24
+ # Setup logging
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # =============================================================================
28
+ # PERFORMANCE MONITORING UTILITIES
29
+ # =============================================================================
30
+
31
+ class PerformanceMonitor:
32
+ """Monitor and track performance metrics for the swarm architecture"""
33
+
34
+ def __init__(self, max_history: int = 1000):
35
+ self.metrics = defaultdict(list)
36
+ self.max_history = max_history
37
+ self.start_times = {}
38
+ self.counters = defaultdict(int)
39
+ self.lock = threading.Lock()
40
+
41
+ def start_timer(self, name: str) -> None:
42
+ """Start timing an operation"""
43
+ with self.lock:
44
+ self.start_times[name] = time.time()
45
+
46
+ def end_timer(self, name: str) -> float:
47
+ """End timing and record duration"""
48
+ with self.lock:
49
+ if name in self.start_times:
50
+ duration = time.time() - self.start_times[name]
51
+ self.record_metric(f"{name}_duration", duration)
52
+ del self.start_times[name]
53
+ return duration
54
+ return 0.0
55
+
56
+ def record_metric(self, name: str, value: float) -> None:
57
+ """Record a metric value"""
58
+ with self.lock:
59
+ self.metrics[name].append({
60
+ 'value': value,
61
+ 'timestamp': time.time()
62
+ })
63
+ # Keep only recent history
64
+ if len(self.metrics[name]) > self.max_history:
65
+ self.metrics[name] = self.metrics[name][-self.max_history:]
66
+
67
+ def increment_counter(self, name: str, amount: int = 1) -> None:
68
+ """Increment a counter"""
69
+ with self.lock:
70
+ self.counters[name] += amount
71
+
72
+ def get_stats(self, name: str) -> Dict[str, float]:
73
+ """Get statistics for a metric"""
74
+ with self.lock:
75
+ if name not in self.metrics or not self.metrics[name]:
76
+ return {}
77
+
78
+ values = [m['value'] for m in self.metrics[name]]
79
+ return {
80
+ 'count': len(values),
81
+ 'mean': np.mean(values),
82
+ 'std': np.std(values),
83
+ 'min': np.min(values),
84
+ 'max': np.max(values),
85
+ 'median': np.median(values),
86
+ 'recent': values[-10:] if len(values) >= 10 else values
87
+ }
88
+
89
+ def get_summary(self) -> Dict[str, Any]:
90
+ """Get complete performance summary"""
91
+ with self.lock:
92
+ summary = {
93
+ 'metrics': {name: self.get_stats(name) for name in self.metrics},
94
+ 'counters': dict(self.counters),
95
+ 'active_timers': list(self.start_times.keys()),
96
+ 'timestamp': datetime.now().isoformat()
97
+ }
98
+ return summary
99
+
100
+ # Global performance monitor instance
101
+ perf_monitor = PerformanceMonitor()
102
+
103
+ def monitor_performance(func_name: str = None):
104
+ """Decorator to monitor function performance"""
105
+ def decorator(func):
106
+ name = func_name or f"{func.__module__}.{func.__name__}"
107
+
108
+ @wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ perf_monitor.start_timer(name)
111
+ perf_monitor.increment_counter(f"{name}_calls")
112
+ try:
113
+ result = func(*args, **kwargs)
114
+ perf_monitor.increment_counter(f"{name}_success")
115
+ return result
116
+ except Exception as e:
117
+ perf_monitor.increment_counter(f"{name}_errors")
118
+ raise
119
+ finally:
120
+ perf_monitor.end_timer(name)
121
+
122
+ return wrapper
123
+ return decorator
124
+
125
+ # =============================================================================
126
+ # MEMORY MANAGEMENT UTILITIES
127
+ # =============================================================================
128
+
129
+ class MemoryTracker:
130
+ """Track memory usage across the swarm system"""
131
+
132
+ @staticmethod
133
+ def get_memory_info() -> Dict[str, float]:
134
+ """Get current memory information"""
135
+ process = psutil.Process()
136
+ memory_info = process.memory_info()
137
+ virtual_memory = psutil.virtual_memory()
138
+
139
+ gpu_memory = {}
140
+ if torch.cuda.is_available():
141
+ for i in range(torch.cuda.device_count()):
142
+ gpu_memory[f'gpu_{i}'] = {
143
+ 'allocated': torch.cuda.memory_allocated(i) / 1024**3,
144
+ 'cached': torch.cuda.memory_reserved(i) / 1024**3,
145
+ 'max_allocated': torch.cuda.max_memory_allocated(i) / 1024**3
146
+ }
147
+
148
+ return {
149
+ 'process_memory_gb': memory_info.rss / 1024**3,
150
+ 'system_memory_percent': virtual_memory.percent,
151
+ 'system_memory_available_gb': virtual_memory.available / 1024**3,
152
+ 'gpu_memory': gpu_memory
153
+ }
154
+
155
+ @staticmethod
156
+ def clear_gpu_cache():
157
+ """Clear GPU memory cache"""
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+
162
+ @staticmethod
163
+ def optimize_memory():
164
+ """Perform memory optimization"""
165
+ gc.collect()
166
+ if torch.cuda.is_available():
167
+ torch.cuda.empty_cache()
168
+
169
+ def memory_efficient(clear_cache: bool = True):
170
+ """Decorator for memory-efficient functions"""
171
+ def decorator(func):
172
+ @wraps(func)
173
+ def wrapper(*args, **kwargs):
174
+ if clear_cache:
175
+ MemoryTracker.clear_gpu_cache()
176
+
177
+ try:
178
+ result = func(*args, **kwargs)
179
+ return result
180
+ finally:
181
+ if clear_cache:
182
+ MemoryTracker.clear_gpu_cache()
183
+
184
+ return wrapper
185
+ return decorator
186
+
187
+ # =============================================================================
188
+ # TENSOR UTILITIES
189
+ # =============================================================================
190
+
191
+ class TensorUtils:
192
+ """Utility functions for tensor operations"""
193
+
194
+ @staticmethod
195
+ def safe_tensor_to_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
196
+ """Safely move tensor to device with error handling"""
197
+ try:
198
+ if tensor.device != device:
199
+ return tensor.to(device)
200
+ return tensor
201
+ except RuntimeError as e:
202
+ logger.warning(f"Failed to move tensor to {device}: {e}")
203
+ return tensor
204
+
205
+ @staticmethod
206
+ def get_tensor_info(tensor: torch.Tensor) -> Dict[str, Any]:
207
+ """Get comprehensive tensor information"""
208
+ return {
209
+ 'shape': list(tensor.shape),
210
+ 'dtype': str(tensor.dtype),
211
+ 'device': str(tensor.device),
212
+ 'requires_grad': tensor.requires_grad,
213
+ 'memory_mb': tensor.numel() * tensor.element_size() / 1024**2,
214
+ 'is_contiguous': tensor.is_contiguous(),
215
+ 'stride': tensor.stride() if tensor.dim() > 0 else []
216
+ }
217
+
218
+ @staticmethod
219
+ def batch_tensors(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor:
220
+ """Batch tensors with padding to same length"""
221
+ if not tensors:
222
+ return torch.empty(0)
223
+
224
+ max_len = max(t.size(-1) for t in tensors)
225
+ batch_size = len(tensors)
226
+
227
+ if len(tensors[0].shape) == 1:
228
+ batched = torch.full((batch_size, max_len), pad_value, dtype=tensors[0].dtype, device=tensors[0].device)
229
+ else:
230
+ feature_dim = tensors[0].size(-2)
231
+ batched = torch.full((batch_size, feature_dim, max_len), pad_value, dtype=tensors[0].dtype, device=tensors[0].device)
232
+
233
+ for i, tensor in enumerate(tensors):
234
+ if len(tensor.shape) == 1:
235
+ batched[i, :tensor.size(0)] = tensor
236
+ else:
237
+ batched[i, :, :tensor.size(-1)] = tensor
238
+
239
+ return batched
240
+
241
+ @staticmethod
242
+ def split_tensor_by_chunks(tensor: torch.Tensor, chunk_size: int) -> List[torch.Tensor]:
243
+ """Split tensor into chunks of specified size"""
244
+ if tensor.size(0) <= chunk_size:
245
+ return [tensor]
246
+
247
+ return [tensor[i:i + chunk_size] for i in range(0, tensor.size(0), chunk_size)]
248
+
249
+ # =============================================================================
250
+ # ROUTING UTILITIES
251
+ # =============================================================================
252
+
253
+ class RoutingUtils:
254
+ """Utilities for encoder routing and load balancing"""
255
+
256
+ @staticmethod
257
+ def calculate_load_balance_loss(routing_weights: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor:
258
+ """Calculate load balance loss to encourage equal encoder usage"""
259
+ # routing_weights: [batch_size, seq_len, num_encoders]
260
+ avg_routing = routing_weights.mean(dim=[0, 1]) # [num_encoders]
261
+
262
+ # Variance penalty to encourage uniform distribution
263
+ target_load = 1.0 / routing_weights.size(-1)
264
+ load_balance_loss = torch.var(avg_routing) / (target_load ** 2 + epsilon)
265
+
266
+ return load_balance_loss
267
+
268
+ @staticmethod
269
+ def apply_top_k_routing(logits: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
270
+ """Apply top-k routing with Gumbel softmax"""
271
+ # Get top-k indices
272
+ top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
273
+
274
+ # Create mask for top-k
275
+ mask = torch.zeros_like(logits)
276
+ mask.scatter_(-1, top_k_indices, 1.0)
277
+
278
+ # Apply Gumbel softmax to top-k
279
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(top_k_logits) + 1e-8) + 1e-8)
280
+ top_k_weights = F.softmax((top_k_logits + gumbel_noise) / 1.0, dim=-1)
281
+
282
+ # Reconstruct full weights
283
+ weights = torch.zeros_like(logits)
284
+ weights.scatter_(-1, top_k_indices, top_k_weights)
285
+
286
+ return weights, mask
287
+
288
+ @staticmethod
289
+ def entropy_regularization(routing_weights: torch.Tensor) -> torch.Tensor:
290
+ """Add entropy regularization to encourage exploration"""
291
+ # Avoid log(0)
292
+ routing_weights = torch.clamp(routing_weights, min=1e-8)
293
+ entropy = -torch.sum(routing_weights * torch.log(routing_weights), dim=-1)
294
+ return -entropy.mean() # Negative because we want to maximize entropy
295
+
296
+ # =============================================================================
297
+ # TEXT PROCESSING UTILITIES
298
+ # =============================================================================
299
+
300
+ class TextUtils:
301
+ """Utilities for text processing and analysis"""
302
+
303
+ @staticmethod
304
+ def chunk_text(text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
305
+ """Split text into overlapping chunks"""
306
+ words = text.split()
307
+ if len(words) <= chunk_size:
308
+ return [text]
309
+
310
+ chunks = []
311
+ start = 0
312
+
313
+ while start < len(words):
314
+ end = min(start + chunk_size, len(words))
315
+ chunk = ' '.join(words[start:end])
316
+ chunks.append(chunk)
317
+
318
+ if end >= len(words):
319
+ break
320
+
321
+ start = end - overlap
322
+
323
+ return chunks
324
+
325
+ @staticmethod
326
+ def estimate_tokens(text: str, chars_per_token: float = 4.0) -> int:
327
+ """Estimate number of tokens in text"""
328
+ return max(1, int(len(text) / chars_per_token))
329
+
330
+ @staticmethod
331
+ def clean_text(text: str) -> str:
332
+ """Clean and normalize text"""
333
+ if not text:
334
+ return ""
335
+
336
+ # Remove excessive whitespace
337
+ text = ' '.join(text.split())
338
+
339
+ # Remove control characters
340
+ text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t')
341
+
342
+ return text.strip()
343
+
344
+ @staticmethod
345
+ def detect_language(text: str) -> str:
346
+ """Simple language detection based on character patterns"""
347
+ # This is a simplified version - for production, use langdetect library
348
+ if not text:
349
+ return "unknown"
350
+
351
+ # Count character types
352
+ ascii_count = sum(1 for c in text if ord(c) < 128)
353
+ total_chars = len(text)
354
+
355
+ if total_chars == 0:
356
+ return "unknown"
357
+
358
+ ascii_ratio = ascii_count / total_chars
359
+
360
+ if ascii_ratio > 0.9:
361
+ return "en" # Likely English
362
+ elif ascii_ratio > 0.7:
363
+ return "mixed"
364
+ else:
365
+ return "non-latin"
366
+
367
+ # =============================================================================
368
+ # CONFIGURATION UTILITIES
369
+ # =============================================================================
370
+
371
+ class ConfigUtils:
372
+ """Utilities for configuration management"""
373
+
374
+ @staticmethod
375
+ def load_config(config_path: str) -> Dict[str, Any]:
376
+ """Load configuration from JSON file"""
377
+ try:
378
+ with open(config_path, 'r', encoding='utf-8') as f:
379
+ config = json.load(f)
380
+ logger.info(f"Loaded configuration from {config_path}")
381
+ return config
382
+ except Exception as e:
383
+ logger.error(f"Failed to load config from {config_path}: {e}")
384
+ return {}
385
+
386
+ @staticmethod
387
+ def save_config(config: Dict[str, Any], config_path: str) -> bool:
388
+ """Save configuration to JSON file"""
389
+ try:
390
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
391
+ with open(config_path, 'w', encoding='utf-8') as f:
392
+ json.dump(config, f, indent=2, ensure_ascii=False)
393
+ logger.info(f"Saved configuration to {config_path}")
394
+ return True
395
+ except Exception as e:
396
+ logger.error(f"Failed to save config to {config_path}: {e}")
397
+ return False
398
+
399
+ @staticmethod
400
+ def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]:
401
+ """Merge two configuration dictionaries"""
402
+ merged = base_config.copy()
403
+
404
+ for key, value in override_config.items():
405
+ if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
406
+ merged[key] = ConfigUtils.merge_configs(merged[key], value)
407
+ else:
408
+ merged[key] = value
409
+
410
+ return merged
411
+
412
+ @staticmethod
413
+ def validate_config(config: Dict[str, Any], required_keys: List[str]) -> List[str]:
414
+ """Validate configuration has required keys"""
415
+ missing_keys = []
416
+
417
+ for key in required_keys:
418
+ if '.' in key:
419
+ # Handle nested keys
420
+ keys = key.split('.')
421
+ current = config
422
+ for k in keys:
423
+ if not isinstance(current, dict) or k not in current:
424
+ missing_keys.append(key)
425
+ break
426
+ current = current[k]
427
+ elif key not in config:
428
+ missing_keys.append(key)
429
+
430
+ return missing_keys
431
+
432
+ # =============================================================================
433
+ # CACHING UTILITIES
434
+ # =============================================================================
435
+
436
+ class CacheManager:
437
+ """Intelligent caching for model outputs and computations"""
438
+
439
+ def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600):
440
+ self.max_size = max_size
441
+ self.ttl_seconds = ttl_seconds
442
+ self.cache = {}
443
+ self.access_times = {}
444
+ self.lock = threading.Lock()
445
+
446
+ def _generate_key(self, *args, **kwargs) -> str:
447
+ """Generate cache key from arguments"""
448
+ key_data = {
449
+ 'args': args,
450
+ 'kwargs': sorted(kwargs.items())
451
+ }
452
+ return hashlib.md5(pickle.dumps(key_data)).hexdigest()
453
+
454
+ def get(self, key: str) -> Optional[Any]:
455
+ """Get item from cache"""
456
+ with self.lock:
457
+ if key not in self.cache:
458
+ return None
459
+
460
+ # Check TTL
461
+ if time.time() - self.cache[key]['timestamp'] > self.ttl_seconds:
462
+ self._remove_key(key)
463
+ return None
464
+
465
+ self.access_times[key] = time.time()
466
+ return self.cache[key]['value']
467
+
468
+ def put(self, key: str, value: Any) -> None:
469
+ """Put item in cache"""
470
+ with self.lock:
471
+ # Clean up if cache is full
472
+ if len(self.cache) >= self.max_size:
473
+ self._evict_lru()
474
+
475
+ self.cache[key] = {
476
+ 'value': value,
477
+ 'timestamp': time.time()
478
+ }
479
+ self.access_times[key] = time.time()
480
+
481
+ def _remove_key(self, key: str) -> None:
482
+ """Remove key from cache"""
483
+ if key in self.cache:
484
+ del self.cache[key]
485
+ if key in self.access_times:
486
+ del self.access_times[key]
487
+
488
+ def _evict_lru(self) -> None:
489
+ """Evict least recently used item"""
490
+ if not self.access_times:
491
+ return
492
+
493
+ lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
494
+ self._remove_key(lru_key)
495
+
496
+ def clear(self) -> None:
497
+ """Clear all cached items"""
498
+ with self.lock:
499
+ self.cache.clear()
500
+ self.access_times.clear()
501
+
502
+ def stats(self) -> Dict[str, Any]:
503
+ """Get cache statistics"""
504
+ with self.lock:
505
+ return {
506
+ 'size': len(self.cache),
507
+ 'max_size': self.max_size,
508
+ 'hit_ratio': getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1),
509
+ 'ttl_seconds': self.ttl_seconds
510
+ }
511
+
512
+ # Global cache manager
513
+ cache_manager = CacheManager()
514
+
515
+ def cached(ttl_seconds: int = 3600):
516
+ """Decorator for caching function results"""
517
+ def decorator(func):
518
+ @wraps(func)
519
+ def wrapper(*args, **kwargs):
520
+ cache_key = cache_manager._generate_key(func.__name__, *args, **kwargs)
521
+
522
+ # Try to get from cache
523
+ result = cache_manager.get(cache_key)
524
+ if result is not None:
525
+ return result
526
+
527
+ # Compute and cache
528
+ result = func(*args, **kwargs)
529
+ cache_manager.put(cache_key, result)
530
+
531
+ return result
532
+
533
+ return wrapper
534
+ return decorator
535
+
536
+ # =============================================================================
537
+ # DEBUGGING AND LOGGING UTILITIES
538
+ # =============================================================================
539
+
540
+ class DebugUtils:
541
+ """Utilities for debugging the swarm architecture"""
542
+
543
+ @staticmethod
544
+ def log_tensor_stats(tensor: torch.Tensor, name: str) -> None:
545
+ """Log comprehensive tensor statistics"""
546
+ if not tensor.numel():
547
+ logger.debug(f"{name}: Empty tensor")
548
+ return
549
+
550
+ stats = {
551
+ 'shape': list(tensor.shape),
552
+ 'dtype': str(tensor.dtype),
553
+ 'device': str(tensor.device),
554
+ 'mean': tensor.float().mean().item(),
555
+ 'std': tensor.float().std().item(),
556
+ 'min': tensor.min().item(),
557
+ 'max': tensor.max().item(),
558
+ 'has_nan': torch.isnan(tensor).any().item(),
559
+ 'has_inf': torch.isinf(tensor).any().item()
560
+ }
561
+
562
+ logger.debug(f"{name} stats: {stats}")
563
+
564
+ @staticmethod
565
+ def validate_tensor(tensor: torch.Tensor, name: str, check_finite: bool = True) -> bool:
566
+ """Validate tensor for common issues"""
567
+ if not isinstance(tensor, torch.Tensor):
568
+ logger.error(f"{name}: Not a tensor, got {type(tensor)}")
569
+ return False
570
+
571
+ if tensor.numel() == 0:
572
+ logger.warning(f"{name}: Empty tensor")
573
+ return False
574
+
575
+ if check_finite:
576
+ if torch.isnan(tensor).any():
577
+ logger.error(f"{name}: Contains NaN values")
578
+ return False
579
+
580
+ if torch.isinf(tensor).any():
581
+ logger.error(f"{name}: Contains infinite values")
582
+ return False
583
+
584
+ return True
585
+
586
+ @staticmethod
587
+ def trace_function_calls(func):
588
+ """Decorator to trace function calls"""
589
+ @wraps(func)
590
+ def wrapper(*args, **kwargs):
591
+ logger.debug(f"Calling {func.__name__} with args: {len(args)}, kwargs: {list(kwargs.keys())}")
592
+ start_time = time.time()
593
+
594
+ try:
595
+ result = func(*args, **kwargs)
596
+ duration = time.time() - start_time
597
+ logger.debug(f"{func.__name__} completed in {duration:.4f}s")
598
+ return result
599
+ except Exception as e:
600
+ duration = time.time() - start_time
601
+ logger.error(f"{func.__name__} failed after {duration:.4f}s: {e}")
602
+ raise
603
+
604
+ return wrapper
605
+
606
+ # =============================================================================
607
+ # SYSTEM UTILITIES
608
+ # =============================================================================
609
+
610
+ class SystemUtils:
611
+ """System-level utilities"""
612
+
613
+ @staticmethod
614
+ def get_system_info() -> Dict[str, Any]:
615
+ """Get comprehensive system information"""
616
+ cpu_info = {
617
+ 'cpu_count': psutil.cpu_count(),
618
+ 'cpu_percent': psutil.cpu_percent(interval=1),
619
+ 'load_average': os.getloadavg() if hasattr(os, 'getloadavg') else None
620
+ }
621
+
622
+ memory_info = psutil.virtual_memory()._asdict()
623
+
624
+ gpu_info = {}
625
+ if torch.cuda.is_available():
626
+ gpu_info = {
627
+ 'device_count': torch.cuda.device_count(),
628
+ 'current_device': torch.cuda.current_device(),
629
+ 'devices': [
630
+ {
631
+ 'name': torch.cuda.get_device_name(i),
632
+ 'memory_total': torch.cuda.get_device_properties(i).total_memory,
633
+ 'memory_allocated': torch.cuda.memory_allocated(i),
634
+ 'memory_cached': torch.cuda.memory_reserved(i)
635
+ }
636
+ for i in range(torch.cuda.device_count())
637
+ ]
638
+ }
639
+
640
+ return {
641
+ 'cpu': cpu_info,
642
+ 'memory': memory_info,
643
+ 'gpu': gpu_info,
644
+ 'python_version': f"{__import__('sys').version_info.major}.{__import__('sys').version_info.minor}",
645
+ 'torch_version': torch.__version__,
646
+ 'timestamp': datetime.now().isoformat()
647
+ }
648
+
649
+ @staticmethod
650
+ def ensure_directory(path: str) -> None:
651
+ """Ensure directory exists"""
652
+ os.makedirs(path, exist_ok=True)
653
+
654
+ @staticmethod
655
+ def safe_file_write(content: str, filepath: str, backup: bool = True) -> bool:
656
+ """Safely write content to file with backup"""
657
+ try:
658
+ # Create directory if needed
659
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
660
+
661
+ # Create backup if file exists
662
+ if backup and os.path.exists(filepath):
663
+ backup_path = f"{filepath}.backup"
664
+ import shutil
665
+ shutil.copy2(filepath, backup_path)
666
+
667
+ # Write content
668
+ with open(filepath, 'w', encoding='utf-8') as f:
669
+ f.write(content)
670
+
671
+ return True
672
+ except Exception as e:
673
+ logger.error(f"Failed to write file {filepath}: {e}")
674
+ return False
675
+
676
+ # =============================================================================
677
+ # EXPORT UTILITIES
678
+ # =============================================================================
679
+
680
+ def format_model_size(num_params: int) -> str:
681
+ """Format model size in human-readable format"""
682
+ for unit in ['', 'K', 'M', 'B', 'T']:
683
+ if num_params < 1000:
684
+ return f"{num_params:.1f}{unit}"
685
+ num_params /= 1000
686
+ return f"{num_params:.1f}P"
687
+
688
+ def format_memory_size(bytes_size: int) -> str:
689
+ """Format memory size in human-readable format"""
690
+ for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
691
+ if bytes_size < 1024:
692
+ return f"{bytes_size:.1f}{unit}"
693
+ bytes_size /= 1024
694
+ return f"{bytes_size:.1f}PB"
695
+
696
+ def format_duration(seconds: float) -> str:
697
+ """Format duration in human-readable format"""
698
+ if seconds < 1:
699
+ return f"{seconds*1000:.1f}ms"
700
+ elif seconds < 60:
701
+ return f"{seconds:.1f}s"
702
+ elif seconds < 3600:
703
+ minutes = seconds / 60
704
+ return f"{minutes:.1f}m"
705
+ else:
706
+ hours = seconds / 3600
707
+ return f"{hours:.1f}h"
708
+
709
+ # =============================================================================
710
+ # INITIALIZATION
711
+ # =============================================================================
712
+
713
+ def initialize_logging(log_level: str = "INFO", log_file: Optional[str] = None) -> None:
714
+ """Initialize logging configuration"""
715
+ level = getattr(logging, log_level.upper(), logging.INFO)
716
+
717
+ handlers = [logging.StreamHandler()]
718
+ if log_file:
719
+ handlers.append(logging.FileHandler(log_file))
720
+
721
+ logging.basicConfig(
722
+ level=level,
723
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
724
+ handlers=handlers
725
+ )
726
+
727
+ def setup_warnings() -> None:
728
+ """Setup warning filters"""
729
+ # Filter out common warnings that don't affect functionality
730
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch")
731
+ warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
732
+
733
+ # Initialize on import
734
+ setup_warnings()
735
+
736
+ # =============================================================================
737
+ # MAIN UTILITIES EXPORT
738
+ # =============================================================================
739
+
740
+ __all__ = [
741
+ # Performance monitoring
742
+ 'PerformanceMonitor', 'perf_monitor', 'monitor_performance',
743
+
744
+ # Memory management
745
+ 'MemoryTracker', 'memory_efficient',
746
+
747
+ # Tensor utilities
748
+ 'TensorUtils',
749
+
750
+ # Routing utilities
751
+ 'RoutingUtils',
752
+
753
+ # Text processing
754
+ 'TextUtils',
755
+
756
+ # Configuration
757
+ 'ConfigUtils',
758
+
759
+ # Caching
760
+ 'CacheManager', 'cache_manager', 'cached',
761
+
762
+ # Debugging
763
+ 'DebugUtils',
764
+
765
+ # System utilities
766
+ 'SystemUtils',
767
+
768
+ # Formatting utilities
769
+ 'format_model_size', 'format_memory_size', 'format_duration',
770
+
771
+ # Initialization
772
+ 'initialize_logging', 'setup_warnings'
773
+ ]