Debito commited on
Commit
fcf0a07
·
verified ·
1 Parent(s): 687ec98

Upload 4 files

Browse files
system/inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # system/inference.py
3
+ # =============================================================================
4
+ import torch
5
+ from typing import Dict, List, Optional, Union
6
+ import time
7
+
8
+ class MambaInferenceEngine:
9
+ """Optimized inference engine for Mamba swarm"""
10
+
11
+ def __init__(self, swarm_engine):
12
+ self.swarm_engine = swarm_engine
13
+ self.config = swarm_engine.config
14
+
15
+ # Inference optimizations
16
+ self.use_half_precision = True
17
+ self.use_torch_compile = hasattr(torch, 'compile')
18
+
19
+ # Apply optimizations
20
+ self._optimize_models()
21
+
22
+ def _optimize_models(self):
23
+ """Apply inference optimizations"""
24
+ if self.use_half_precision and self.config.device != 'cpu':
25
+ # Convert to half precision for faster inference
26
+ for specialist in self.swarm_engine.tlm_manager.specialists.values():
27
+ specialist.model = specialist.model.half()
28
+ self.swarm_engine.aggregator = self.swarm_engine.aggregator.half()
29
+
30
+ if self.use_torch_compile:
31
+ try:
32
+ # Compile models for faster inference (PyTorch 2.0+)
33
+ for specialist in self.swarm_engine.tlm_manager.specialists.values():
34
+ specialist.model = torch.compile(specialist.model)
35
+ self.swarm_engine.aggregator = torch.compile(self.swarm_engine.aggregator)
36
+ print("Models compiled for faster inference")
37
+ except Exception as e:
38
+ print(f"Could not compile models: {e}")
39
+
40
+ def generate(self, prompt: str, max_tokens: int = 100,
41
+ temperature: float = 0.7, top_k: int = 50) -> Dict:
42
+ """
43
+ Generate text response with advanced sampling
44
+
45
+ Args:
46
+ prompt: Input text prompt
47
+ max_tokens: Maximum tokens to generate
48
+ temperature: Sampling temperature
49
+ top_k: Top-k sampling parameter
50
+
51
+ Returns:
52
+ Dict with generated text and metadata
53
+ """
54
+ start_time = time.time()
55
+
56
+ # Process through swarm
57
+ result = self.swarm_engine.process_request(prompt, max_tokens)
58
+
59
+ if not result['success']:
60
+ return result
61
+
62
+ # Add inference metadata
63
+ result.update({
64
+ 'temperature': temperature,
65
+ 'top_k': top_k,
66
+ 'inference_time': time.time() - start_time,
67
+ 'tokens_per_second': max_tokens / (time.time() - start_time)
68
+ })
69
+
70
+ return result
71
+
72
+ def stream_generate(self, prompt: str, max_tokens: int = 100):
73
+ """
74
+ Stream generation token by token (placeholder implementation)
75
+ """
76
+ # This would implement streaming generation
77
+ # For now, return the full response
78
+ result = self.generate(prompt, max_tokens)
79
+ yield result['response']
80
+
81
+ def chat_completion(self, messages: List[Dict], max_tokens: int = 100) -> Dict:
82
+ """
83
+ Chat completion interface similar to OpenAI API
84
+
85
+ Args:
86
+ messages: List of message dicts with 'role' and 'content'
87
+ max_tokens: Maximum tokens to generate
88
+
89
+ Returns:
90
+ Chat completion response
91
+ """
92
+ # Convert messages to single prompt
93
+ prompt = self._format_chat_prompt(messages)
94
+
95
+ # Generate response
96
+ result = self.generate(prompt, max_tokens)
97
+
98
+ if result['success']:
99
+ # Format as chat completion
100
+ return {
101
+ 'choices': [{
102
+ 'message': {
103
+ 'role': 'assistant',
104
+ 'content': result['response']
105
+ },
106
+ 'finish_reason': 'stop'
107
+ }],
108
+ 'usage': {
109
+ 'prompt_tokens': len(prompt.split()),
110
+ 'completion_tokens': len(result['response'].split()),
111
+ 'total_tokens': len(prompt.split()) + len(result['response'].split())
112
+ },
113
+ 'model': 'mamba-swarm-70m',
114
+ 'inference_time': result.get('inference_time', 0)
115
+ }
116
+ else:
117
+ return {
118
+ 'error': result.get('error', 'Unknown error'),
119
+ 'success': False
120
+ }
121
+
122
+ def _format_chat_prompt(self, messages: List[Dict]) -> str:
123
+ """Format chat messages into a single prompt"""
124
+ formatted = ""
125
+
126
+ for message in messages:
127
+ role = message.get('role', 'user')
128
+ content = message.get('content', '')
129
+
130
+ if role == 'system':
131
+ formatted += f"System: {content}\n"
132
+ elif role == 'user':
133
+ formatted += f"User: {content}\n"
134
+ elif role == 'assistant':
135
+ formatted += f"Assistant: {content}\n"
136
+
137
+ formatted += "Assistant: "
138
+ return formatted
system/mambaSwarm.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # system/mambaSwarm.py - Unified Scalable Mamba Swarm Engine
3
+ # =============================================================================
4
+ import torch
5
+ import time
6
+ import os
7
+ import asyncio
8
+ from typing import Dict, List, Tuple, Optional, Union
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ # Core imports
13
+ from core.config import MambaConfig, MambaSwarmConfig, auto_detect_tier
14
+ from core.tokenizer import MambaTokenizer
15
+ from core.preprocess import TextPreprocessor
16
+ from core.model import MambaModel
17
+ from core.mamba_swarm_integration import MambaEncoderSwarmModel, create_swarm_from_existing_config
18
+
19
+ # Routing imports
20
+ from routing.router import TopicRouter, ContentBasedRouter
21
+ from routing.tlm_manager import TLMManager
22
+ from routing.aggregator import AttentionAggregator, WeightedAggregator
23
+ from utils.domain_configs import DomainConfigs
24
+
25
+
26
+ class UnifiedMambaSwarm:
27
+ """
28
+ Unified Mamba Swarm Engine combining the best of both architectures:
29
+ - Scalable tier-based system with auto-detection
30
+ - Production-ready async processing and monitoring
31
+ - Graceful fallback to simulation mode
32
+ - Support for both custom and pre-trained models
33
+ """
34
+
35
+ def __init__(self,
36
+ tier: Optional[str] = None,
37
+ config: Optional[Union[MambaConfig, MambaSwarmConfig]] = None,
38
+ use_pretrained: bool = True,
39
+ config_override: Optional[Dict] = None):
40
+ """
41
+ Initialize the unified swarm engine
42
+
43
+ Args:
44
+ tier: Scaling tier (demo/small/medium/large/full) or None for auto-detect
45
+ config: Either MambaConfig for custom models or MambaSwarmConfig for scaling
46
+ use_pretrained: Whether to use HuggingFace pretrained models
47
+ config_override: Dictionary to override config settings
48
+ """
49
+ # Auto-detect tier if not specified
50
+ if tier is None:
51
+ tier = auto_detect_tier()
52
+ print(f"Auto-detected tier: {tier}")
53
+
54
+ self.tier = tier
55
+ self.use_pretrained = use_pretrained
56
+
57
+ # Initialize configuration
58
+ if config is None:
59
+ if use_pretrained:
60
+ self.swarm_config = MambaSwarmConfig(tier=tier)
61
+ if config_override:
62
+ self.swarm_config.config.update(config_override)
63
+ self.config = self._create_legacy_config()
64
+ else:
65
+ # Use custom config for legacy components
66
+ self.config = MambaConfig() # Default config
67
+ self.swarm_config = None
68
+ else:
69
+ if isinstance(config, MambaSwarmConfig):
70
+ self.swarm_config = config
71
+ self.config = self._create_legacy_config()
72
+ else:
73
+ self.config = config
74
+ self.swarm_config = None
75
+
76
+ self.device = getattr(self.config, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
77
+
78
+ # System properties
79
+ if self.swarm_config:
80
+ self.num_encoders = self.swarm_config.config["num_encoders"]
81
+ self.encoder_size = self.swarm_config.config["encoder_size"]
82
+ else:
83
+ self.num_encoders = getattr(self.config, 'num_specialists', 5)
84
+ self.encoder_size = "130M"
85
+
86
+ # Initialize components
87
+ self.encoders = []
88
+ self.tokenizer = None
89
+ self.preprocessor = None
90
+ self.router = None
91
+ self.aggregator = None
92
+ self.tlm_manager = None
93
+
94
+ # Performance tracking
95
+ self.stats = {
96
+ 'total_requests': 0,
97
+ 'total_tokens_processed': 0,
98
+ 'avg_response_time': 0.0,
99
+ 'specialist_usage': {i: 0 for i in range(self.num_encoders)},
100
+ 'simulation_mode': False,
101
+ 'model_load_errors': 0
102
+ }
103
+
104
+ # Initialize system
105
+ self._initialize_system()
106
+
107
+ print(f"✅ Unified Mamba Swarm initialized: {self.tier} tier, {self.num_encoders} encoders")
108
+
109
+ def _create_legacy_config(self) -> MambaConfig:
110
+ """Create legacy MambaConfig from SwarmConfig for compatibility"""
111
+ legacy_config = MambaConfig()
112
+ if self.swarm_config:
113
+ legacy_config.num_specialists = self.swarm_config.config["num_encoders"]
114
+ legacy_config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
115
+ return legacy_config
116
+
117
+ def _initialize_system(self):
118
+ """Initialize the complete swarm system"""
119
+ try:
120
+ # Initialize tokenizer and preprocessor
121
+ self._initialize_tokenizer()
122
+ self._initialize_preprocessor()
123
+
124
+ # Initialize encoders/specialists
125
+ if self.use_pretrained:
126
+ self._initialize_pretrained_encoders()
127
+ else:
128
+ self._initialize_custom_specialists()
129
+
130
+ # Initialize routing system
131
+ self._initialize_routing()
132
+
133
+ # Initialize aggregation system
134
+ self._initialize_aggregation()
135
+
136
+ print(f"🚀 System initialization complete!")
137
+
138
+ except Exception as e:
139
+ print(f"⚠️ Error during initialization: {e}")
140
+ self._fallback_to_simulation()
141
+
142
+ def _initialize_tokenizer(self):
143
+ """Initialize tokenizer based on mode"""
144
+ if self.use_pretrained:
145
+ base_model_name = self._get_base_model_name()
146
+ try:
147
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
148
+ if self.tokenizer.pad_token is None:
149
+ self.tokenizer.pad_token = self.tokenizer.eos_token
150
+ print(f"📝 Loaded HuggingFace tokenizer: {base_model_name}")
151
+ except:
152
+ print("⚠️ HuggingFace tokenizer failed, using custom tokenizer")
153
+ self.tokenizer = MambaTokenizer(self.config)
154
+ else:
155
+ self.tokenizer = MambaTokenizer(self.config)
156
+
157
+ def _initialize_preprocessor(self):
158
+ """Initialize text preprocessor"""
159
+ self.preprocessor = TextPreprocessor(self.config)
160
+
161
+ def _get_base_model_name(self):
162
+ """Get the appropriate base model for current tier"""
163
+ model_mapping = {
164
+ "130M": "state-spaces/mamba-130m",
165
+ "370M": "state-spaces/mamba-370m",
166
+ "790M": "state-spaces/mamba-790m",
167
+ "1.4B": "state-spaces/mamba-1.4b",
168
+ "2.8B": "state-spaces/mamba-2.8b"
169
+ }
170
+ return model_mapping.get(self.encoder_size, "state-spaces/mamba-130m")
171
+
172
+ def _initialize_pretrained_encoders(self):
173
+ """Initialize pretrained encoder swarm"""
174
+ print(f"🔄 Loading {self.num_encoders} pretrained encoders...")
175
+
176
+ base_model_name = self._get_base_model_name()
177
+
178
+ try:
179
+ # Load base model
180
+ base_model = AutoModelForCausalLM.from_pretrained(
181
+ base_model_name,
182
+ torch_dtype=torch.float16 if self.num_encoders > 5 else torch.float32,
183
+ device_map="auto" if torch.cuda.is_available() else "cpu"
184
+ )
185
+
186
+ # Create encoder instances
187
+ for i in range(self.num_encoders):
188
+ domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else {
189
+ "domain": f"general_{i}", "specialty": "general"
190
+ }
191
+
192
+ if self.tier == "demo" or self.num_encoders <= 5:
193
+ # Share model instance for smaller configurations
194
+ encoder = {
195
+ "id": i,
196
+ "model": base_model,
197
+ "domain": domain_info["domain"],
198
+ "specialty": domain_info["specialty"],
199
+ "shared": True
200
+ }
201
+ else:
202
+ # Separate instances for larger configurations
203
+ encoder = {
204
+ "id": i,
205
+ "model": AutoModelForCausalLM.from_pretrained(
206
+ base_model_name,
207
+ torch_dtype=torch.float16,
208
+ device_map="auto"
209
+ ),
210
+ "domain": domain_info["domain"],
211
+ "specialty": domain_info["specialty"],
212
+ "shared": False
213
+ }
214
+
215
+ self.encoders.append(encoder)
216
+ print(f" ✓ Encoder {i}: {encoder['domain']} specialist")
217
+
218
+ except Exception as e:
219
+ print(f"❌ Failed to load pretrained models: {e}")
220
+ self.stats['model_load_errors'] += 1
221
+ self._create_simulated_encoders()
222
+
223
+ def _initialize_custom_specialists(self):
224
+ """Initialize custom TLM specialists or native Mamba swarm"""
225
+ try:
226
+ if hasattr(self, 'use_native_swarm') and self.use_native_swarm:
227
+ # Use the native Mamba swarm integration
228
+ self.native_swarm_model = create_swarm_from_existing_config(
229
+ self.config, num_encoders=self.num_encoders
230
+ )
231
+ print(f"✓ Initialized native Mamba swarm with {self.num_encoders} encoders")
232
+ else:
233
+ # Use TLM manager (legacy approach)
234
+ self.tlm_manager = TLMManager(self.config)
235
+ print(f"✓ Initialized {self.num_encoders} custom specialists")
236
+ except Exception as e:
237
+ print(f"⚠️ Custom specialists failed: {e}")
238
+ self._create_simulated_encoders()
239
+
240
+ def _create_simulated_encoders(self):
241
+ """Create simulated encoders for demonstration/fallback"""
242
+ print("🎭 Creating simulated encoders...")
243
+ self.stats['simulation_mode'] = True
244
+
245
+ for i in range(self.num_encoders):
246
+ domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else {
247
+ "domain": f"general_{i}", "specialty": "general"
248
+ }
249
+
250
+ encoder = {
251
+ "id": i,
252
+ "model": None,
253
+ "domain": domain_info["domain"],
254
+ "specialty": domain_info["specialty"],
255
+ "simulated": True
256
+ }
257
+ self.encoders.append(encoder)
258
+
259
+ def _initialize_routing(self):
260
+ """Initialize routing system"""
261
+ try:
262
+ if self.use_pretrained and self.swarm_config:
263
+ # Use content-based router for pretrained models
264
+ router_config = self.swarm_config.get_router_config()
265
+ self.router = ContentBasedRouter(
266
+ num_encoders=self.num_encoders,
267
+ domain_assignments=self.swarm_config.domain_assignments,
268
+ config=router_config
269
+ )
270
+ else:
271
+ # Use topic router for custom models
272
+ domain_configs = DomainConfigs.get_domain_configs(self.num_encoders)
273
+ self.router = TopicRouter(self.config, domain_configs)
274
+ if hasattr(self.router, 'to'):
275
+ self.router.to(self.device)
276
+
277
+ print("🧭 Router initialized")
278
+
279
+ except Exception as e:
280
+ print(f"⚠️ Router initialization failed: {e}")
281
+ # Create basic fallback router
282
+ self.router = self._create_fallback_router()
283
+
284
+ def _initialize_aggregation(self):
285
+ """Initialize aggregation system"""
286
+ try:
287
+ if self.use_pretrained:
288
+ self.aggregator = WeightedAggregator(
289
+ num_encoders=self.num_encoders,
290
+ hidden_dim=768
291
+ )
292
+ else:
293
+ self.aggregator = AttentionAggregator(self.config)
294
+ if hasattr(self.aggregator, 'to'):
295
+ self.aggregator.to(self.device)
296
+
297
+ print("🔄 Aggregator initialized")
298
+
299
+ except Exception as e:
300
+ print(f"⚠️ Aggregator initialization failed: {e}")
301
+ self.aggregator = None
302
+
303
+ def _create_fallback_router(self):
304
+ """Create a simple fallback router"""
305
+ class FallbackRouter:
306
+ def __init__(self, num_encoders):
307
+ self.num_encoders = num_encoders
308
+
309
+ def route(self, text):
310
+ # Simple round-robin routing
311
+ import random
312
+ num_selected = min(3, self.num_encoders)
313
+ return {
314
+ "selected_encoders": random.sample(range(self.num_encoders), num_selected)
315
+ }
316
+
317
+ def chunk_and_route(self, text):
318
+ return [{"specialists": [(0, 1.0)], "chunk": text}]
319
+
320
+ return FallbackRouter(self.num_encoders)
321
+
322
+ def _fallback_to_simulation(self):
323
+ """Complete fallback to simulation mode"""
324
+ print("🎭 Entering full simulation mode")
325
+ self.stats['simulation_mode'] = True
326
+ self._create_simulated_encoders()
327
+ if not self.router:
328
+ self.router = self._create_fallback_router()
329
+
330
+ # =============================================================================
331
+ # MAIN PROCESSING METHODS
332
+ # =============================================================================
333
+
334
+ def generate(self, prompt: str, max_length: int = 100, temperature: float = 0.7,
335
+ show_routing: bool = True) -> Dict:
336
+ """
337
+ Generate response using the swarm (from swarmEngine2 style)
338
+
339
+ Args:
340
+ prompt: Input text prompt
341
+ max_length: Maximum tokens to generate
342
+ temperature: Sampling temperature
343
+ show_routing: Whether to display routing information
344
+
345
+ Returns:
346
+ Dict with response and metadata
347
+ """
348
+ start_time = time.time()
349
+
350
+ try:
351
+ # Route to appropriate encoders
352
+ if hasattr(self.router, 'route'):
353
+ routing_decision = self.router.route(prompt)
354
+ selected_encoders = routing_decision.get("selected_encoders", [0])
355
+ else:
356
+ # Fallback routing
357
+ selected_encoders = [0]
358
+
359
+ if show_routing:
360
+ print(f"🔀 Routing: Selected {len(selected_encoders)} encoders")
361
+ for enc_id in selected_encoders[:3]:
362
+ if enc_id < len(self.encoders):
363
+ domain = self.encoders[enc_id]["domain"]
364
+ print(f" Encoder {enc_id}: {domain}")
365
+
366
+ # Generate response
367
+ if self.stats['simulation_mode'] or any(enc.get("simulated") for enc in self.encoders):
368
+ response = self._simulate_generation(prompt, selected_encoders, max_length)
369
+ else:
370
+ response = self._real_generation(prompt, selected_encoders, max_length, temperature)
371
+
372
+ # Update statistics
373
+ processing_time = time.time() - start_time
374
+ self._update_stats_simple(prompt, selected_encoders, processing_time)
375
+
376
+ return {
377
+ "response": response,
378
+ "processing_time": processing_time,
379
+ "routing_info": {
380
+ "selected_encoders": selected_encoders,
381
+ "num_active": len(selected_encoders),
382
+ "total_encoders": self.num_encoders,
383
+ "domains": [self.encoders[i]["domain"] for i in selected_encoders
384
+ if i < len(self.encoders)]
385
+ },
386
+ "success": True
387
+ }
388
+
389
+ except Exception as e:
390
+ return {
391
+ "response": f"Error generating response: {str(e)}",
392
+ "processing_time": time.time() - start_time,
393
+ "success": False,
394
+ "error": str(e)
395
+ }
396
+
397
+ def process_request(self, text: str, max_new_tokens: int = 100) -> Dict:
398
+ """
399
+ Process request using traditional pipeline (from swarm_engine style)
400
+
401
+ Args:
402
+ text: Input text to process
403
+ max_new_tokens: Maximum tokens to generate
404
+
405
+ Returns:
406
+ Dict with response and metadata
407
+ """
408
+ start_time = time.time()
409
+
410
+ try:
411
+ # Step 1: Preprocess input
412
+ if self.preprocessor:
413
+ clean_text = self.preprocessor.clean_text(text)
414
+ else:
415
+ clean_text = text
416
+
417
+ # Step 2: Route to specialists
418
+ if hasattr(self.router, 'chunk_and_route'):
419
+ routing_results = self.router.chunk_and_route(clean_text)
420
+ else:
421
+ # Fallback for content-based router
422
+ routing_decision = self.router.route(clean_text)
423
+ routing_results = [{"specialists": [(enc_id, 1.0) for enc_id in routing_decision["selected_encoders"]],
424
+ "chunk": clean_text}]
425
+
426
+ # Step 3: Process chunks
427
+ if self.tlm_manager and not self.stats['simulation_mode']:
428
+ specialist_outputs = self.tlm_manager.encode_parallel(routing_results)
429
+ else:
430
+ # Simulate processing
431
+ specialist_outputs = [{"response": f"Processed chunk: {res['chunk'][:50]}..."}
432
+ for res in routing_results]
433
+
434
+ # Step 4: Aggregate results
435
+ if self.aggregator and not self.stats['simulation_mode']:
436
+ response = self.aggregator.generate_response(specialist_outputs, max_new_tokens)
437
+ else:
438
+ # Simple aggregation fallback
439
+ response = " ".join([out.get("response", "") for out in specialist_outputs])
440
+
441
+ # Update stats
442
+ processing_time = time.time() - start_time
443
+ self._update_stats(text, routing_results, processing_time)
444
+
445
+ return {
446
+ 'response': response,
447
+ 'processing_time': processing_time,
448
+ 'chunks_processed': len(routing_results),
449
+ 'specialists_used': self._get_specialists_used(routing_results),
450
+ 'success': True
451
+ }
452
+
453
+ except Exception as e:
454
+ return {
455
+ 'response': f"Error processing request: {str(e)}",
456
+ 'processing_time': time.time() - start_time,
457
+ 'success': False,
458
+ 'error': str(e)
459
+ }
460
+
461
+ # =============================================================================
462
+ # ASYNC AND BATCH PROCESSING
463
+ # =============================================================================
464
+
465
+ async def process_request_async(self, text: str, max_new_tokens: int = 100) -> Dict:
466
+ """Async version of process_request"""
467
+ loop = asyncio.get_event_loop()
468
+
469
+ with ThreadPoolExecutor() as executor:
470
+ result = await loop.run_in_executor(
471
+ executor, self.process_request, text, max_new_tokens
472
+ )
473
+
474
+ return result
475
+
476
+ async def generate_async(self, prompt: str, max_length: int = 100,
477
+ temperature: float = 0.7) -> Dict:
478
+ """Async version of generate"""
479
+ loop = asyncio.get_event_loop()
480
+
481
+ with ThreadPoolExecutor() as executor:
482
+ result = await loop.run_in_executor(
483
+ executor, self.generate, prompt, max_length, temperature, False
484
+ )
485
+
486
+ return result
487
+
488
+ def batch_process(self, texts: List[str], max_new_tokens: int = 100,
489
+ method: str = "process") -> List[Dict]:
490
+ """
491
+ Process multiple texts in batch
492
+
493
+ Args:
494
+ texts: List of input texts
495
+ max_new_tokens: Maximum tokens to generate
496
+ method: "process" or "generate" for processing method
497
+ """
498
+ results = []
499
+
500
+ for text in texts:
501
+ if method == "generate":
502
+ result = self.generate(text, max_new_tokens, show_routing=False)
503
+ else:
504
+ result = self.process_request(text, max_new_tokens)
505
+ results.append(result)
506
+
507
+ return results
508
+
509
+ # =============================================================================
510
+ # GENERATION METHODS
511
+ # =============================================================================
512
+
513
+ def _simulate_generation(self, prompt: str, selected_encoders: List[int], max_length: int) -> str:
514
+ """Simulate generation for demo/fallback purposes"""
515
+ import random
516
+
517
+ # Determine response type based on selected encoder domains
518
+ domains = [self.encoders[i]["domain"] for i in selected_encoders if i < len(self.encoders)]
519
+
520
+ if any("code" in domain.lower() for domain in domains):
521
+ return f"Here's a solution for '{prompt[:30]}...':\n\n```python\ndef solution():\n # Implementation here\n return result\n```"
522
+ elif any("medical" in domain.lower() for domain in domains):
523
+ return f"Regarding '{prompt[:30]}...': This medical topic requires careful consideration. Please consult healthcare professionals."
524
+ elif any("science" in domain.lower() for domain in domains):
525
+ return f"From a scientific perspective on '{prompt[:30]}...': Current research indicates several key factors..."
526
+ else:
527
+ return f"Thank you for asking about '{prompt[:30]}...'. Based on expertise from {len(selected_encoders)} specialized domains, here's a comprehensive response..."
528
+
529
+ def _real_generation(self, prompt: str, selected_encoders: List[int],
530
+ max_length: int, temperature: float) -> str:
531
+ """Real generation using loaded models"""
532
+ if not selected_encoders or selected_encoders[0] >= len(self.encoders):
533
+ return "No valid encoders available for generation."
534
+
535
+ try:
536
+ # Use primary encoder for generation
537
+ primary_encoder = self.encoders[selected_encoders[0]]
538
+
539
+ if primary_encoder.get("simulated") or not primary_encoder["model"]:
540
+ return self._simulate_generation(prompt, selected_encoders, max_length)
541
+
542
+ # Tokenize input
543
+ if hasattr(self.tokenizer, 'encode'):
544
+ inputs = self.tokenizer(prompt, return_tensors="pt")
545
+ else:
546
+ # Fallback tokenization
547
+ return self._simulate_generation(prompt, selected_encoders, max_length)
548
+
549
+ # Generate with model
550
+ with torch.no_grad():
551
+ outputs = primary_encoder["model"].generate(
552
+ **inputs,
553
+ max_length=max_length,
554
+ temperature=temperature,
555
+ do_sample=True,
556
+ pad_token_id=self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 0
557
+ )
558
+
559
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
560
+ # Remove original prompt from response
561
+ response = response[len(prompt):].strip()
562
+
563
+ return response if response else "Generated response was empty."
564
+
565
+ except Exception as e:
566
+ print(f"⚠️ Real generation failed: {e}")
567
+ return self._simulate_generation(prompt, selected_encoders, max_length)
568
+
569
+ # =============================================================================
570
+ # UTILITY METHODS
571
+ # =============================================================================
572
+
573
+ def _get_specialists_used(self, routing_results: List[Dict]) -> List[int]:
574
+ """Extract specialist IDs used in routing"""
575
+ specialists_used = set()
576
+
577
+ for chunk_info in routing_results:
578
+ if 'specialists' in chunk_info:
579
+ for specialist_id, _ in chunk_info['specialists']:
580
+ specialists_used.add(specialist_id)
581
+
582
+ return list(specialists_used)
583
+
584
+ def _update_stats(self, text: str, routing_results: List[Dict], processing_time: float):
585
+ """Update detailed performance statistics"""
586
+ self.stats['total_requests'] += 1
587
+ self.stats['total_tokens_processed'] += len(text.split())
588
+
589
+ # Update average response time
590
+ prev_avg = self.stats['avg_response_time']
591
+ n = self.stats['total_requests']
592
+ self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n
593
+
594
+ # Update specialist usage
595
+ specialists_used = self._get_specialists_used(routing_results)
596
+ for specialist_id in specialists_used:
597
+ if specialist_id in self.stats['specialist_usage']:
598
+ self.stats['specialist_usage'][specialist_id] += 1
599
+
600
+ def _update_stats_simple(self, text: str, selected_encoders: List[int], processing_time: float):
601
+ """Update simple statistics for generate method"""
602
+ self.stats['total_requests'] += 1
603
+ self.stats['total_tokens_processed'] += len(text.split())
604
+
605
+ # Update average response time
606
+ prev_avg = self.stats['avg_response_time']
607
+ n = self.stats['total_requests']
608
+ self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n
609
+
610
+ # Update encoder usage
611
+ for enc_id in selected_encoders:
612
+ if enc_id in self.stats['specialist_usage']:
613
+ self.stats['specialist_usage'][enc_id] += 1
614
+
615
+ # =============================================================================
616
+ # SCALING AND MANAGEMENT
617
+ # =============================================================================
618
+
619
+ def scale_up(self, new_tier: str):
620
+ """Scale up to a higher tier"""
621
+ if new_tier not in ["demo", "small", "medium", "large", "full"]:
622
+ raise ValueError(f"Invalid tier: {new_tier}")
623
+
624
+ print(f"🚀 Scaling from {self.tier} to {new_tier}")
625
+
626
+ # Preserve current stats
627
+ old_stats = self.stats.copy()
628
+
629
+ # Reinitialize with new tier
630
+ self.__init__(tier=new_tier, use_pretrained=self.use_pretrained)
631
+
632
+ # Restore relevant stats
633
+ self.stats['total_requests'] = old_stats['total_requests']
634
+ self.stats['total_tokens_processed'] = old_stats['total_tokens_processed']
635
+ self.stats['avg_response_time'] = old_stats['avg_response_time']
636
+
637
+ def get_system_info(self) -> Dict:
638
+ """Get comprehensive system information"""
639
+ info = {
640
+ "tier": self.tier,
641
+ "num_encoders": self.num_encoders,
642
+ "encoder_size": self.encoder_size,
643
+ "use_pretrained": self.use_pretrained,
644
+ "simulation_mode": self.stats['simulation_mode'],
645
+ "device": self.device,
646
+ "domains": list(set(enc["domain"] for enc in self.encoders)),
647
+ }
648
+
649
+ if self.swarm_config:
650
+ info.update({
651
+ "total_parameters": self.swarm_config.config["total_params"],
652
+ "memory_estimate": self.swarm_config.config["memory_estimate"],
653
+ "hardware_recommendation": self.swarm_config.config["hardware"]
654
+ })
655
+
656
+ return info
657
+
658
+ def get_stats(self) -> Dict:
659
+ """Get current performance statistics"""
660
+ return self.stats.copy()
661
+
662
+ def load_models(self, checkpoint_path: str):
663
+ """Load trained models from checkpoint"""
664
+ if not os.path.exists(checkpoint_path):
665
+ print(f"❌ Checkpoint not found: {checkpoint_path}")
666
+ return
667
+
668
+ try:
669
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
670
+
671
+ # Load aggregator
672
+ if self.aggregator and 'aggregator_state' in checkpoint:
673
+ self.aggregator.load_state_dict(checkpoint['aggregator_state'])
674
+
675
+ # Load specialists (if using custom models)
676
+ if self.tlm_manager and 'specialist_states' in checkpoint:
677
+ for specialist_id, state_dict in checkpoint['specialist_states'].items():
678
+ if specialist_id in self.tlm_manager.specialists:
679
+ self.tlm_manager.specialists[specialist_id].model.load_state_dict(state_dict)
680
+
681
+ print(f"✅ Models loaded from {checkpoint_path}")
682
+
683
+ except Exception as e:
684
+ print(f"❌ Error loading models: {e}")
685
+
686
+ def set_eval_mode(self):
687
+ """Set all models to evaluation mode"""
688
+ if self.tlm_manager:
689
+ for specialist in self.tlm_manager.specialists.values():
690
+ if hasattr(specialist, 'model'):
691
+ specialist.model.eval()
692
+
693
+ if self.aggregator and hasattr(self.aggregator, 'eval'):
694
+ self.aggregator.eval()
695
+
696
+ if self.router and hasattr(self.router, 'eval'):
697
+ self.router.eval()
698
+
699
+ # Set pretrained encoders to eval mode
700
+ for encoder in self.encoders:
701
+ if encoder.get("model") and hasattr(encoder["model"], 'eval'):
702
+ encoder["model"].eval()
703
+
704
+ def set_train_mode(self):
705
+ """Set all models to training mode"""
706
+ if self.tlm_manager:
707
+ for specialist in self.tlm_manager.specialists.values():
708
+ if hasattr(specialist, 'model'):
709
+ specialist.model.train()
710
+
711
+ if self.aggregator and hasattr(self.aggregator, 'train'):
712
+ self.aggregator.train()
713
+
714
+ if self.router and hasattr(self.router, 'train'):
715
+ self.router.train()
716
+
717
+
718
+ # =============================================================================
719
+ # FACTORY FUNCTIONS
720
+ # =============================================================================
721
+
722
+ def create_mamba_swarm(tier: str = "auto", use_pretrained: bool = True,
723
+ config_override: Optional[Dict] = None) -> UnifiedMambaSwarm:
724
+ """
725
+ Factory function to create appropriately configured swarm
726
+
727
+ Args:
728
+ tier: Scaling tier or "auto" for auto-detection
729
+ use_pretrained: Whether to use pretrained HuggingFace models
730
+ config_override: Dictionary to override default config
731
+
732
+ Returns:
733
+ Configured UnifiedMambaSwarm instance
734
+ """
735
+ if tier == "auto":
736
+ tier = auto_detect_tier()
737
+
738
+ return UnifiedMambaSwarm(
739
+ tier=tier,
740
+ use_pretrained=use_pretrained,
741
+ config_override=config_override
742
+ )
743
+
744
+
745
+ def create_production_swarm(tier: str = "medium") -> UnifiedMambaSwarm:
746
+ """Create production-ready swarm with optimal settings"""
747
+ return UnifiedMambaSwarm(
748
+ tier=tier,
749
+ use_pretrained=True,
750
+ config_override={
751
+ "batch_size": 32,
752
+ "max_sequence_length": 2048
753
+ }
754
+ )
755
+
756
+
757
+ def create_development_swarm() -> UnifiedMambaSwarm:
758
+ """Create development swarm with simulation fallback"""
759
+ return UnifiedMambaSwarm(
760
+ tier="demo",
761
+ use_pretrained=True,
762
+ config_override={
763
+ "simulation_fallback": True
764
+ }
765
+ )
766
+
767
+
768
+ # =============================================================================
769
+ # MAIN EXECUTION
770
+ # =============================================================================
771
+
772
+ if __name__ == "__main__":
773
+ print("🧪 Testing Unified Mamba Swarm...")
774
+
775
+ # Create swarm instance
776
+ swarm = create_mamba_swarm(tier="demo")
777
+
778
+ # Display system info
779
+ print("\n📊 System Information:")
780
+ info = swarm.get_system_info()
781
+ for key, value in info.items():
782
+ print(f" {key}: {value}")
783
+
784
+ # Test both processing methods
785
+ test_prompts = [
786
+ "Write a Python function to calculate fibonacci numbers",
787
+ "Explain the process of photosynthesis",
788
+ "What are the symptoms of diabetes?"
789
+ ]
790
+
791
+ print("\n🧪 Testing generate method:")
792
+ for prompt in test_prompts[:2]:
793
+ result = swarm.generate(prompt, max_length=150)
794
+ print(f"\nPrompt: {prompt}")
795
+ print(f"Response: {result['response'][:100]}...")
796
+ print(f"Processing time: {result['processing_time']:.3f}s")
797
+ print(f"Routing: {result['routing_info']['domains']}")
798
+
799
+ print("\n🧪 Testing process_request method:")
800
+ result = swarm.process_request(test_prompts[2])
801
+ print(f"Response: {result['response'][:100]}...")
802
+ print(f"Success: {result['success']}")
803
+
804
+ # Test batch processing
805
+ print("\n🧪 Testing batch processing:")
806
+ batch_results = swarm.batch_process(test_prompts, method="generate")
807
+ print(f"Processed {len(batch_results)} requests in batch")
808
+
809
+ # Display final stats
810
+ print("\n📈 Final Statistics:")
811
+ stats = swarm.get_stats()
812
+ for key, value in stats.items():
813
+ if key != 'specialist_usage':
814
+ print(f" {key}: {value}")
815
+
816
+ print("\n✅ Testing complete!")
system/memory_manager.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Manager for Mamba Swarm
3
+ Handles memory optimization, caching, and distributed memory management
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import gc
9
+ import psutil
10
+ import threading
11
+ from typing import Dict, Any, Optional, List, Tuple
12
+ from dataclasses import dataclass
13
+ from collections import OrderedDict
14
+ import numpy as np
15
+ import logging
16
+
17
+ @dataclass
18
+ class MemoryStats:
19
+ total_memory: float
20
+ used_memory: float
21
+ free_memory: float
22
+ gpu_memory: float
23
+ gpu_free: float
24
+ cache_size: float
25
+
26
+ class LRUCache:
27
+ """Least Recently Used cache for model states and activations"""
28
+
29
+ def __init__(self, max_size: int = 1000):
30
+ self.max_size = max_size
31
+ self.cache = OrderedDict()
32
+ self.lock = threading.Lock()
33
+
34
+ def get(self, key: str) -> Optional[torch.Tensor]:
35
+ with self.lock:
36
+ if key in self.cache:
37
+ # Move to end (most recently used)
38
+ value = self.cache.pop(key)
39
+ self.cache[key] = value
40
+ return value
41
+ return None
42
+
43
+ def put(self, key: str, value: torch.Tensor):
44
+ with self.lock:
45
+ if key in self.cache:
46
+ self.cache.pop(key)
47
+ elif len(self.cache) >= self.max_size:
48
+ # Remove least recently used
49
+ oldest_key = next(iter(self.cache))
50
+ old_value = self.cache.pop(oldest_key)
51
+ del old_value
52
+
53
+ self.cache[key] = value.clone() if isinstance(value, torch.Tensor) else value
54
+
55
+ def clear(self):
56
+ with self.lock:
57
+ self.cache.clear()
58
+ gc.collect()
59
+
60
+ class GradientAccumulator:
61
+ """Manages gradient accumulation across multiple steps"""
62
+
63
+ def __init__(self, accumulation_steps: int = 8):
64
+ self.accumulation_steps = accumulation_steps
65
+ self.current_step = 0
66
+ self.accumulated_gradients = {}
67
+
68
+ def accumulate(self, model: nn.Module):
69
+ """Accumulate gradients from current backward pass"""
70
+ for name, param in model.named_parameters():
71
+ if param.grad is not None:
72
+ if name not in self.accumulated_gradients:
73
+ self.accumulated_gradients[name] = param.grad.clone()
74
+ else:
75
+ self.accumulated_gradients[name] += param.grad
76
+
77
+ self.current_step += 1
78
+
79
+ def should_update(self) -> bool:
80
+ """Check if we should perform optimizer step"""
81
+ return self.current_step % self.accumulation_steps == 0
82
+
83
+ def get_averaged_gradients(self) -> Dict[str, torch.Tensor]:
84
+ """Get accumulated gradients averaged over accumulation steps"""
85
+ averaged = {}
86
+ for name, grad in self.accumulated_gradients.items():
87
+ averaged[name] = grad / self.accumulation_steps
88
+ return averaged
89
+
90
+ def reset(self):
91
+ """Reset accumulator"""
92
+ self.accumulated_gradients.clear()
93
+ self.current_step = 0
94
+
95
+ class MemoryManager:
96
+ """Comprehensive memory management for Mamba Swarm"""
97
+
98
+ def __init__(self,
99
+ max_cache_size: int = 2000,
100
+ gradient_accumulation_steps: int = 8,
101
+ auto_cleanup: bool = True,
102
+ memory_threshold: float = 0.85):
103
+
104
+ self.logger = logging.getLogger(__name__)
105
+ self.max_cache_size = max_cache_size
106
+ self.gradient_accumulation_steps = gradient_accumulation_steps
107
+ self.auto_cleanup = auto_cleanup
108
+ self.memory_threshold = memory_threshold
109
+
110
+ # Initialize components
111
+ self.activation_cache = LRUCache(max_cache_size)
112
+ self.state_cache = LRUCache(max_cache_size // 2)
113
+ self.gradient_accumulator = GradientAccumulator(gradient_accumulation_steps)
114
+
115
+ # Memory tracking
116
+ self.peak_memory_usage = 0.0
117
+ self.memory_history = []
118
+ self.cleanup_threshold = memory_threshold
119
+
120
+ # Device management
121
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ self.setup_memory_optimization()
123
+
124
+ def setup_memory_optimization(self):
125
+ """Setup memory optimization settings"""
126
+ if torch.cuda.is_available():
127
+ # Enable memory mapping for large tensors
128
+ torch.backends.cuda.matmul.allow_tf32 = True
129
+ torch.backends.cudnn.allow_tf32 = True
130
+
131
+ # Set memory fraction
132
+ if hasattr(torch.cuda, 'set_per_process_memory_fraction'):
133
+ torch.cuda.set_per_process_memory_fraction(0.9)
134
+
135
+ def get_memory_stats(self) -> MemoryStats:
136
+ """Get current memory statistics"""
137
+ # System memory
138
+ memory = psutil.virtual_memory()
139
+ total_memory = memory.total / (1024**3) # GB
140
+ used_memory = memory.used / (1024**3)
141
+ free_memory = memory.available / (1024**3)
142
+
143
+ # GPU memory
144
+ gpu_memory = 0.0
145
+ gpu_free = 0.0
146
+ if torch.cuda.is_available():
147
+ gpu_memory = torch.cuda.memory_allocated() / (1024**3)
148
+ gpu_free = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / (1024**3)
149
+
150
+ # Cache size estimation
151
+ cache_size = (len(self.activation_cache.cache) + len(self.state_cache.cache)) * 0.001 # Rough estimate
152
+
153
+ stats = MemoryStats(
154
+ total_memory=total_memory,
155
+ used_memory=used_memory,
156
+ free_memory=free_memory,
157
+ gpu_memory=gpu_memory,
158
+ gpu_free=gpu_free,
159
+ cache_size=cache_size
160
+ )
161
+
162
+ # Update peak usage
163
+ current_usage = used_memory + gpu_memory
164
+ if current_usage > self.peak_memory_usage:
165
+ self.peak_memory_usage = current_usage
166
+
167
+ return stats
168
+
169
+ def check_memory_pressure(self) -> bool:
170
+ """Check if system is under memory pressure"""
171
+ stats = self.get_memory_stats()
172
+ memory_usage_ratio = stats.used_memory / stats.total_memory
173
+
174
+ if torch.cuda.is_available():
175
+ gpu_usage_ratio = stats.gpu_memory / (stats.gpu_memory + stats.gpu_free + 1e-6)
176
+ return memory_usage_ratio > self.cleanup_threshold or gpu_usage_ratio > self.cleanup_threshold
177
+
178
+ return memory_usage_ratio > self.cleanup_threshold
179
+
180
+ def cleanup_memory(self, aggressive: bool = False):
181
+ """Perform memory cleanup"""
182
+ if aggressive:
183
+ self.activation_cache.clear()
184
+ self.state_cache.clear()
185
+ self.gradient_accumulator.reset()
186
+
187
+ # Python garbage collection
188
+ gc.collect()
189
+
190
+ # GPU memory cleanup
191
+ if torch.cuda.is_available():
192
+ torch.cuda.empty_cache()
193
+ torch.cuda.synchronize()
194
+
195
+ self.logger.info(f"Memory cleanup completed. Aggressive: {aggressive}")
196
+
197
+ def cache_activation(self, key: str, activation: torch.Tensor):
198
+ """Cache activation with memory pressure check"""
199
+ if self.auto_cleanup and self.check_memory_pressure():
200
+ self.cleanup_memory()
201
+
202
+ self.activation_cache.put(key, activation)
203
+
204
+ def get_cached_activation(self, key: str) -> Optional[torch.Tensor]:
205
+ """Retrieve cached activation"""
206
+ return self.activation_cache.get(key)
207
+
208
+ def cache_hidden_state(self, key: str, state: torch.Tensor):
209
+ """Cache hidden state"""
210
+ self.state_cache.put(key, state)
211
+
212
+ def get_cached_state(self, key: str) -> Optional[torch.Tensor]:
213
+ """Retrieve cached hidden state"""
214
+ return self.state_cache.get(key)
215
+
216
+ def manage_gradient_accumulation(self, model: nn.Module) -> bool:
217
+ """Manage gradient accumulation and return if optimizer step should be taken"""
218
+ self.gradient_accumulator.accumulate(model)
219
+
220
+ if self.gradient_accumulator.should_update():
221
+ # Apply accumulated gradients
222
+ averaged_grads = self.gradient_accumulator.get_averaged_gradients()
223
+
224
+ for name, param in model.named_parameters():
225
+ if name in averaged_grads:
226
+ param.grad = averaged_grads[name]
227
+
228
+ self.gradient_accumulator.reset()
229
+ return True
230
+
231
+ return False
232
+
233
+ def optimize_model_memory(self, model: nn.Module):
234
+ """Optimize model memory usage"""
235
+ # Enable gradient checkpointing for large models
236
+ for module in model.modules():
237
+ if hasattr(module, 'gradient_checkpointing'):
238
+ module.gradient_checkpointing = True
239
+
240
+ # Convert to half precision if possible
241
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7:
242
+ model = model.half()
243
+
244
+ return model
245
+
246
+ def create_memory_efficient_dataloader(self, dataset, batch_size: int, **kwargs):
247
+ """Create memory-efficient dataloader"""
248
+ # Adjust batch size based on available memory
249
+ stats = self.get_memory_stats()
250
+
251
+ if stats.free_memory < 2.0: # Less than 2GB free
252
+ batch_size = max(1, batch_size // 2)
253
+ self.logger.warning(f"Reduced batch size to {batch_size} due to low memory")
254
+
255
+ return torch.utils.data.DataLoader(
256
+ dataset,
257
+ batch_size=batch_size,
258
+ num_workers=min(4, psutil.cpu_count()),
259
+ pin_memory=torch.cuda.is_available(),
260
+ prefetch_factor=2,
261
+ **kwargs
262
+ )
263
+
264
+ def monitor_memory_usage(self):
265
+ """Monitor and log memory usage"""
266
+ stats = self.get_memory_stats()
267
+ self.memory_history.append({
268
+ 'timestamp': torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None,
269
+ 'stats': stats
270
+ })
271
+
272
+ # Keep only recent history
273
+ if len(self.memory_history) > 100:
274
+ self.memory_history = self.memory_history[-50:]
275
+
276
+ self.logger.debug(f"Memory - System: {stats.used_memory:.2f}GB/{stats.total_memory:.2f}GB, "
277
+ f"GPU: {stats.gpu_memory:.2f}GB, Cache: {stats.cache_size:.2f}GB")
278
+
279
+ def get_memory_report(self) -> Dict[str, Any]:
280
+ """Generate comprehensive memory report"""
281
+ stats = self.get_memory_stats()
282
+
283
+ return {
284
+ 'current_stats': stats.__dict__,
285
+ 'peak_usage': self.peak_memory_usage,
286
+ 'cache_stats': {
287
+ 'activation_cache_size': len(self.activation_cache.cache),
288
+ 'state_cache_size': len(self.state_cache.cache),
289
+ 'max_cache_size': self.max_cache_size
290
+ },
291
+ 'gradient_accumulation': {
292
+ 'current_step': self.gradient_accumulator.current_step,
293
+ 'accumulation_steps': self.gradient_accumulation_steps,
294
+ 'accumulated_params': len(self.gradient_accumulator.accumulated_gradients)
295
+ },
296
+ 'memory_pressure': self.check_memory_pressure(),
297
+ 'device': str(self.device)
298
+ }
299
+
300
+ def __enter__(self):
301
+ """Context manager entry"""
302
+ return self
303
+
304
+ def __exit__(self, exc_type, exc_val, exc_tb):
305
+ """Context manager exit with cleanup"""
306
+ self.cleanup_memory(aggressive=True)
system/weight_manager.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # system/weight_manager.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Dict, List, Optional
7
+ import os
8
+ from pathlib import Path
9
+
10
+ class WeightManager:
11
+ """Manages hierarchical weight sharing and loading/saving"""
12
+
13
+ def __init__(self, config, tlm_manager):
14
+ self.config = config
15
+ self.tlm_manager = tlm_manager
16
+
17
+ # Track shared weights
18
+ self.shared_embeddings = None
19
+ self.shared_foundation_layers = {}
20
+
21
+ def setup_hierarchical_sharing(self):
22
+ """Setup hierarchical weight sharing between specialists"""
23
+ print("Setting up hierarchical weight sharing...")
24
+
25
+ # Create shared embedding if enabled
26
+ if self.config.shared_embedding:
27
+ self.shared_embeddings = nn.Embedding(
28
+ self.config.vocab_size,
29
+ self.config.d_model
30
+ ).to(self.config.device)
31
+
32
+ # Share embedding across all specialists
33
+ for specialist in self.tlm_manager.specialists.values():
34
+ specialist.model.embedding.token_embedding = self.shared_embeddings
35
+
36
+ # Setup foundation layer sharing
37
+ self._setup_foundation_sharing()
38
+
39
+ print("Hierarchical weight sharing setup complete!")
40
+
41
+ def _setup_foundation_sharing(self):
42
+ """Setup sharing of foundation layers"""
43
+ num_shared_layers = self.config.n_layers // 2
44
+
45
+ # Group specialists by domain similarity
46
+ domain_groups = self._group_specialists_by_domain()
47
+
48
+ for group_name, specialist_ids in domain_groups.items():
49
+ if len(specialist_ids) > 1:
50
+ # Create shared foundation layers for this group
51
+ reference_specialist = self.tlm_manager.specialists[specialist_ids[0]]
52
+ shared_layers = reference_specialist.model.layers[:num_shared_layers]
53
+
54
+ # Share with other specialists in the group
55
+ for specialist_id in specialist_ids[1:]:
56
+ specialist = self.tlm_manager.specialists[specialist_id]
57
+ for i in range(num_shared_layers):
58
+ specialist.model.layers[i] = shared_layers[i]
59
+
60
+ self.shared_foundation_layers[group_name] = shared_layers
61
+
62
+ def _group_specialists_by_domain(self) -> Dict[str, List[int]]:
63
+ """Group specialists by domain for weight sharing"""
64
+ domain_groups = {
65
+ 'stem': [],
66
+ 'programming': [],
67
+ 'language': [],
68
+ 'business': [],
69
+ 'general': []
70
+ }
71
+
72
+ for specialist_id, specialist in self.tlm_manager.specialists.items():
73
+ domain_name = specialist.domain_info['name'].lower()
74
+
75
+ if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
76
+ domain_groups['stem'].append(specialist_id)
77
+ elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
78
+ domain_groups['programming'].append(specialist_id)
79
+ elif any(x in domain_name for x in ['writing', 'translation']):
80
+ domain_groups['language'].append(specialist_id)
81
+ elif any(x in domain_name for x in ['business', 'legal']):
82
+ domain_groups['business'].append(specialist_id)
83
+ else:
84
+ domain_groups['general'].append(specialist_id)
85
+
86
+ return {k: v for k, v in domain_groups.items() if len(v) > 1}
87
+
88
+ def save_weights(self, save_path: str):
89
+ """Save all weights with hierarchical structure"""
90
+ save_path = Path(save_path)
91
+ save_path.mkdir(parents=True, exist_ok=True)
92
+
93
+ # Save shared embeddings
94
+ if self.shared_embeddings is not None:
95
+ torch.save(
96
+ self.shared_embeddings.state_dict(),
97
+ save_path / "shared_embeddings.pt"
98
+ )
99
+
100
+ # Save shared foundation layers
101
+ for group_name, layers in self.shared_foundation_layers.items():
102
+ group_state = {}
103
+ for i, layer in enumerate(layers):
104
+ group_state[f"layer_{i}"] = layer.state_dict()
105
+ torch.save(group_state, save_path / f"shared_foundation_{group_name}.pt")
106
+
107
+ # Save specialist-specific weights
108
+ specialists_path = save_path / "specialists"
109
+ specialists_path.mkdir(exist_ok=True)
110
+
111
+ for specialist_id, specialist in self.tlm_manager.specialists.items():
112
+ torch.save(
113
+ specialist.model.state_dict(),
114
+ specialists_path / f"specialist_{specialist_id}.pt"
115
+ )
116
+
117
+ print(f"Weights saved to {save_path}")
118
+
119
+ def load_weights(self, load_path: str):
120
+ """Load weights with hierarchical structure"""
121
+ load_path = Path(load_path)
122
+
123
+ if not load_path.exists():
124
+ raise FileNotFoundError(f"Weight path {load_path} not found")
125
+
126
+ # Load shared embeddings
127
+ embeddings_path = load_path / "shared_embeddings.pt"
128
+ if embeddings_path.exists() and self.shared_embeddings is not None:
129
+ self.shared_embeddings.load_state_dict(torch.load(embeddings_path))
130
+
131
+ # Load shared foundation layers
132
+ for group_name in self.shared_foundation_layers.keys():
133
+ foundation_path = load_path / f"shared_foundation_{group_name}.pt"
134
+ if foundation_path.exists():
135
+ group_state = torch.load(foundation_path)
136
+ for i, layer in enumerate(self.shared_foundation_layers[group_name]):
137
+ if f"layer_{i}" in group_state:
138
+ layer.load_state_dict(group_state[f"layer_{i}"])
139
+
140
+ # Load specialist weights
141
+ specialists_path = load_path / "specialists"
142
+ if specialists_path.exists():
143
+ for specialist_id, specialist in self.tlm_manager.specialists.items():
144
+ specialist_path = specialists_path / f"specialist_{specialist_id}.pt"
145
+ if specialist_path.exists():
146
+ specialist.model.load_state_dict(torch.load(specialist_path))
147
+
148
+ print(f"Weights loaded from {load_path}")
149
+
150
+ def get_memory_usage(self) -> Dict[str, int]:
151
+ """Get memory usage breakdown"""
152
+ usage = {}
153
+
154
+ # Shared embedding memory
155
+ if self.shared_embeddings is not None:
156
+ usage['shared_embeddings'] = sum(
157
+ p.numel() * p.element_size()
158
+ for p in self.shared_embeddings.parameters()
159
+ )
160
+
161
+ # Shared foundation layer memory
162
+ total_foundation = 0
163
+ for layers in self.shared_foundation_layers.values():
164
+ for layer in layers:
165
+ total_foundation += sum(
166
+ p.numel() * p.element_size()
167
+ for p in layer.parameters()
168
+ )