# ============================================================================= # system/mambaSwarm.py - Unified Scalable Mamba Swarm Engine # ============================================================================= import torch import time import os import asyncio from typing import Dict, List, Tuple, Optional, Union from concurrent.futures import ThreadPoolExecutor from transformers import AutoModelForCausalLM, AutoTokenizer # Core imports from core.config import MambaConfig, MambaSwarmConfig, auto_detect_tier from core.tokenizer import MambaTokenizer from core.preprocess import TextPreprocessor from core.model import MambaModel from core.mamba_swarm_integration import MambaEncoderSwarmModel, create_swarm_from_existing_config # Routing imports from routing.router import TopicRouter, ContentBasedRouter from routing.tlm_manager import TLMManager from routing.aggregator import AttentionAggregator, WeightedAggregator from utils.domain_configs import DomainConfigs class UnifiedMambaSwarm: """ Unified Mamba Swarm Engine combining the best of both architectures: - Scalable tier-based system with auto-detection - Production-ready async processing and monitoring - Graceful fallback to simulation mode - Support for both custom and pre-trained models """ def __init__(self, tier: Optional[str] = None, config: Optional[Union[MambaConfig, MambaSwarmConfig]] = None, use_pretrained: bool = True, config_override: Optional[Dict] = None): """ Initialize the unified swarm engine Args: tier: Scaling tier (demo/small/medium/large/full) or None for auto-detect config: Either MambaConfig for custom models or MambaSwarmConfig for scaling use_pretrained: Whether to use HuggingFace pretrained models config_override: Dictionary to override config settings """ # Auto-detect tier if not specified if tier is None: tier = auto_detect_tier() print(f"Auto-detected tier: {tier}") self.tier = tier self.use_pretrained = use_pretrained # Initialize configuration if config is None: if use_pretrained: self.swarm_config = MambaSwarmConfig(tier=tier) if config_override: self.swarm_config.config.update(config_override) self.config = self._create_legacy_config() else: # Use custom config for legacy components self.config = MambaConfig() # Default config self.swarm_config = None else: if isinstance(config, MambaSwarmConfig): self.swarm_config = config self.config = self._create_legacy_config() else: self.config = config self.swarm_config = None self.device = getattr(self.config, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') # System properties if self.swarm_config: self.num_encoders = self.swarm_config.config["num_encoders"] self.encoder_size = self.swarm_config.config["encoder_size"] else: self.num_encoders = getattr(self.config, 'num_specialists', 5) self.encoder_size = "130M" # Initialize components self.encoders = [] self.tokenizer = None self.preprocessor = None self.router = None self.aggregator = None self.tlm_manager = None # Performance tracking self.stats = { 'total_requests': 0, 'total_tokens_processed': 0, 'avg_response_time': 0.0, 'specialist_usage': {i: 0 for i in range(self.num_encoders)}, 'simulation_mode': False, 'model_load_errors': 0 } # Initialize system self._initialize_system() print(f"โœ… Unified Mamba Swarm initialized: {self.tier} tier, {self.num_encoders} encoders") def _create_legacy_config(self) -> MambaConfig: """Create legacy MambaConfig from SwarmConfig for compatibility""" legacy_config = MambaConfig() if self.swarm_config: legacy_config.num_specialists = self.swarm_config.config["num_encoders"] legacy_config.device = 'cuda' if torch.cuda.is_available() else 'cpu' return legacy_config def _initialize_system(self): """Initialize the complete swarm system""" try: # Initialize tokenizer and preprocessor self._initialize_tokenizer() self._initialize_preprocessor() # Initialize encoders/specialists if self.use_pretrained: self._initialize_pretrained_encoders() else: self._initialize_custom_specialists() # Initialize routing system self._initialize_routing() # Initialize aggregation system self._initialize_aggregation() print(f"๐Ÿš€ System initialization complete!") except Exception as e: print(f"โš ๏ธ Error during initialization: {e}") self._fallback_to_simulation() def _initialize_tokenizer(self): """Initialize tokenizer based on mode""" if self.use_pretrained: base_model_name = self._get_base_model_name() try: self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"๐Ÿ“ Loaded HuggingFace tokenizer: {base_model_name}") except: print("โš ๏ธ HuggingFace tokenizer failed, using custom tokenizer") self.tokenizer = MambaTokenizer(self.config) else: self.tokenizer = MambaTokenizer(self.config) def _initialize_preprocessor(self): """Initialize text preprocessor""" self.preprocessor = TextPreprocessor(self.config) def _get_base_model_name(self): """Get the appropriate base model for current tier""" model_mapping = { "130M": "state-spaces/mamba-130m", "370M": "state-spaces/mamba-370m", "790M": "state-spaces/mamba-790m", "1.4B": "state-spaces/mamba-1.4b", "2.8B": "state-spaces/mamba-2.8b" } return model_mapping.get(self.encoder_size, "state-spaces/mamba-130m") def _initialize_pretrained_encoders(self): """Initialize pretrained encoder swarm""" print(f"๐Ÿ”„ Loading {self.num_encoders} pretrained encoders...") base_model_name = self._get_base_model_name() try: # Load base model base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16 if self.num_encoders > 5 else torch.float32, device_map="auto" if torch.cuda.is_available() else "cpu" ) # Create encoder instances for i in range(self.num_encoders): domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else { "domain": f"general_{i}", "specialty": "general" } if self.tier == "demo" or self.num_encoders <= 5: # Share model instance for smaller configurations encoder = { "id": i, "model": base_model, "domain": domain_info["domain"], "specialty": domain_info["specialty"], "shared": True } else: # Separate instances for larger configurations encoder = { "id": i, "model": AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, device_map="auto" ), "domain": domain_info["domain"], "specialty": domain_info["specialty"], "shared": False } self.encoders.append(encoder) print(f" โœ“ Encoder {i}: {encoder['domain']} specialist") except Exception as e: print(f"โŒ Failed to load pretrained models: {e}") self.stats['model_load_errors'] += 1 self._create_simulated_encoders() def _initialize_custom_specialists(self): """Initialize custom TLM specialists or native Mamba swarm""" try: if hasattr(self, 'use_native_swarm') and self.use_native_swarm: # Use the native Mamba swarm integration self.native_swarm_model = create_swarm_from_existing_config( self.config, num_encoders=self.num_encoders ) print(f"โœ“ Initialized native Mamba swarm with {self.num_encoders} encoders") else: # Use TLM manager (legacy approach) self.tlm_manager = TLMManager(self.config) print(f"โœ“ Initialized {self.num_encoders} custom specialists") except Exception as e: print(f"โš ๏ธ Custom specialists failed: {e}") self._create_simulated_encoders() def _create_simulated_encoders(self): """Create simulated encoders for demonstration/fallback""" print("๐ŸŽญ Creating simulated encoders...") self.stats['simulation_mode'] = True for i in range(self.num_encoders): domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else { "domain": f"general_{i}", "specialty": "general" } encoder = { "id": i, "model": None, "domain": domain_info["domain"], "specialty": domain_info["specialty"], "simulated": True } self.encoders.append(encoder) def _initialize_routing(self): """Initialize routing system""" try: if self.use_pretrained and self.swarm_config: # Use content-based router for pretrained models router_config = self.swarm_config.get_router_config() self.router = ContentBasedRouter( num_encoders=self.num_encoders, domain_assignments=self.swarm_config.domain_assignments, config=router_config ) else: # Use topic router for custom models domain_configs = DomainConfigs.get_domain_configs(self.num_encoders) self.router = TopicRouter(self.config, domain_configs) if hasattr(self.router, 'to'): self.router.to(self.device) print("๐Ÿงญ Router initialized") except Exception as e: print(f"โš ๏ธ Router initialization failed: {e}") # Create basic fallback router self.router = self._create_fallback_router() def _initialize_aggregation(self): """Initialize aggregation system""" try: if self.use_pretrained: self.aggregator = WeightedAggregator( num_encoders=self.num_encoders, hidden_dim=768 ) else: self.aggregator = AttentionAggregator(self.config) if hasattr(self.aggregator, 'to'): self.aggregator.to(self.device) print("๐Ÿ”„ Aggregator initialized") except Exception as e: print(f"โš ๏ธ Aggregator initialization failed: {e}") self.aggregator = None def _create_fallback_router(self): """Create a simple fallback router""" class FallbackRouter: def __init__(self, num_encoders): self.num_encoders = num_encoders def route(self, text): # Simple round-robin routing import random num_selected = min(3, self.num_encoders) return { "selected_encoders": random.sample(range(self.num_encoders), num_selected) } def chunk_and_route(self, text): return [{"specialists": [(0, 1.0)], "chunk": text}] return FallbackRouter(self.num_encoders) def _fallback_to_simulation(self): """Complete fallback to simulation mode""" print("๐ŸŽญ Entering full simulation mode") self.stats['simulation_mode'] = True self._create_simulated_encoders() if not self.router: self.router = self._create_fallback_router() # ============================================================================= # MAIN PROCESSING METHODS # ============================================================================= def generate(self, prompt: str, max_length: int = 100, temperature: float = 0.7, show_routing: bool = True) -> Dict: """ Generate response using the swarm (from swarmEngine2 style) Args: prompt: Input text prompt max_length: Maximum tokens to generate temperature: Sampling temperature show_routing: Whether to display routing information Returns: Dict with response and metadata """ start_time = time.time() try: # Route to appropriate encoders if hasattr(self.router, 'route'): routing_decision = self.router.route(prompt) selected_encoders = routing_decision.get("selected_encoders", [0]) else: # Fallback routing selected_encoders = [0] if show_routing: print(f"๐Ÿ”€ Routing: Selected {len(selected_encoders)} encoders") for enc_id in selected_encoders[:3]: if enc_id < len(self.encoders): domain = self.encoders[enc_id]["domain"] print(f" Encoder {enc_id}: {domain}") # Generate response if self.stats['simulation_mode'] or any(enc.get("simulated") for enc in self.encoders): response = self._simulate_generation(prompt, selected_encoders, max_length) else: response = self._real_generation(prompt, selected_encoders, max_length, temperature) # Update statistics processing_time = time.time() - start_time self._update_stats_simple(prompt, selected_encoders, processing_time) return { "response": response, "processing_time": processing_time, "routing_info": { "selected_encoders": selected_encoders, "num_active": len(selected_encoders), "total_encoders": self.num_encoders, "domains": [self.encoders[i]["domain"] for i in selected_encoders if i < len(self.encoders)] }, "success": True } except Exception as e: return { "response": f"Error generating response: {str(e)}", "processing_time": time.time() - start_time, "success": False, "error": str(e) } def process_request(self, text: str, max_new_tokens: int = 100) -> Dict: """ Process request using traditional pipeline (from swarm_engine style) Args: text: Input text to process max_new_tokens: Maximum tokens to generate Returns: Dict with response and metadata """ start_time = time.time() try: # Step 1: Preprocess input if self.preprocessor: clean_text = self.preprocessor.clean_text(text) else: clean_text = text # Step 2: Route to specialists if hasattr(self.router, 'chunk_and_route'): routing_results = self.router.chunk_and_route(clean_text) else: # Fallback for content-based router routing_decision = self.router.route(clean_text) routing_results = [{"specialists": [(enc_id, 1.0) for enc_id in routing_decision["selected_encoders"]], "chunk": clean_text}] # Step 3: Process chunks if self.tlm_manager and not self.stats['simulation_mode']: specialist_outputs = self.tlm_manager.encode_parallel(routing_results) else: # Simulate processing specialist_outputs = [{"response": f"Processed chunk: {res['chunk'][:50]}..."} for res in routing_results] # Step 4: Aggregate results if self.aggregator and not self.stats['simulation_mode']: response = self.aggregator.generate_response(specialist_outputs, max_new_tokens) else: # Simple aggregation fallback response = " ".join([out.get("response", "") for out in specialist_outputs]) # Update stats processing_time = time.time() - start_time self._update_stats(text, routing_results, processing_time) return { 'response': response, 'processing_time': processing_time, 'chunks_processed': len(routing_results), 'specialists_used': self._get_specialists_used(routing_results), 'success': True } except Exception as e: return { 'response': f"Error processing request: {str(e)}", 'processing_time': time.time() - start_time, 'success': False, 'error': str(e) } # ============================================================================= # ASYNC AND BATCH PROCESSING # ============================================================================= async def process_request_async(self, text: str, max_new_tokens: int = 100) -> Dict: """Async version of process_request""" loop = asyncio.get_event_loop() with ThreadPoolExecutor() as executor: result = await loop.run_in_executor( executor, self.process_request, text, max_new_tokens ) return result async def generate_async(self, prompt: str, max_length: int = 100, temperature: float = 0.7) -> Dict: """Async version of generate""" loop = asyncio.get_event_loop() with ThreadPoolExecutor() as executor: result = await loop.run_in_executor( executor, self.generate, prompt, max_length, temperature, False ) return result def batch_process(self, texts: List[str], max_new_tokens: int = 100, method: str = "process") -> List[Dict]: """ Process multiple texts in batch Args: texts: List of input texts max_new_tokens: Maximum tokens to generate method: "process" or "generate" for processing method """ results = [] for text in texts: if method == "generate": result = self.generate(text, max_new_tokens, show_routing=False) else: result = self.process_request(text, max_new_tokens) results.append(result) return results # ============================================================================= # GENERATION METHODS # ============================================================================= def _simulate_generation(self, prompt: str, selected_encoders: List[int], max_length: int) -> str: """Simulate generation for demo/fallback purposes""" import random # Determine response type based on selected encoder domains domains = [self.encoders[i]["domain"] for i in selected_encoders if i < len(self.encoders)] if any("code" in domain.lower() for domain in domains): return f"Here's a solution for '{prompt[:30]}...':\n\n```python\ndef solution():\n # Implementation here\n return result\n```" elif any("medical" in domain.lower() for domain in domains): return f"Regarding '{prompt[:30]}...': This medical topic requires careful consideration. Please consult healthcare professionals." elif any("science" in domain.lower() for domain in domains): return f"From a scientific perspective on '{prompt[:30]}...': Current research indicates several key factors..." else: return f"Thank you for asking about '{prompt[:30]}...'. Based on expertise from {len(selected_encoders)} specialized domains, here's a comprehensive response..." def _real_generation(self, prompt: str, selected_encoders: List[int], max_length: int, temperature: float) -> str: """Real generation using loaded models""" if not selected_encoders or selected_encoders[0] >= len(self.encoders): return "No valid encoders available for generation." try: # Use primary encoder for generation primary_encoder = self.encoders[selected_encoders[0]] if primary_encoder.get("simulated") or not primary_encoder["model"]: return self._simulate_generation(prompt, selected_encoders, max_length) # Tokenize input if hasattr(self.tokenizer, 'encode'): inputs = self.tokenizer(prompt, return_tensors="pt") else: # Fallback tokenization return self._simulate_generation(prompt, selected_encoders, max_length) # Generate with model with torch.no_grad(): outputs = primary_encoder["model"].generate( **inputs, max_length=max_length, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 0 ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove original prompt from response response = response[len(prompt):].strip() return response if response else "Generated response was empty." except Exception as e: print(f"โš ๏ธ Real generation failed: {e}") return self._simulate_generation(prompt, selected_encoders, max_length) # ============================================================================= # UTILITY METHODS # ============================================================================= def _get_specialists_used(self, routing_results: List[Dict]) -> List[int]: """Extract specialist IDs used in routing""" specialists_used = set() for chunk_info in routing_results: if 'specialists' in chunk_info: for specialist_id, _ in chunk_info['specialists']: specialists_used.add(specialist_id) return list(specialists_used) def _update_stats(self, text: str, routing_results: List[Dict], processing_time: float): """Update detailed performance statistics""" self.stats['total_requests'] += 1 self.stats['total_tokens_processed'] += len(text.split()) # Update average response time prev_avg = self.stats['avg_response_time'] n = self.stats['total_requests'] self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n # Update specialist usage specialists_used = self._get_specialists_used(routing_results) for specialist_id in specialists_used: if specialist_id in self.stats['specialist_usage']: self.stats['specialist_usage'][specialist_id] += 1 def _update_stats_simple(self, text: str, selected_encoders: List[int], processing_time: float): """Update simple statistics for generate method""" self.stats['total_requests'] += 1 self.stats['total_tokens_processed'] += len(text.split()) # Update average response time prev_avg = self.stats['avg_response_time'] n = self.stats['total_requests'] self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n # Update encoder usage for enc_id in selected_encoders: if enc_id in self.stats['specialist_usage']: self.stats['specialist_usage'][enc_id] += 1 # ============================================================================= # SCALING AND MANAGEMENT # ============================================================================= def scale_up(self, new_tier: str): """Scale up to a higher tier""" if new_tier not in ["demo", "small", "medium", "large", "full"]: raise ValueError(f"Invalid tier: {new_tier}") print(f"๐Ÿš€ Scaling from {self.tier} to {new_tier}") # Preserve current stats old_stats = self.stats.copy() # Reinitialize with new tier self.__init__(tier=new_tier, use_pretrained=self.use_pretrained) # Restore relevant stats self.stats['total_requests'] = old_stats['total_requests'] self.stats['total_tokens_processed'] = old_stats['total_tokens_processed'] self.stats['avg_response_time'] = old_stats['avg_response_time'] def get_system_info(self) -> Dict: """Get comprehensive system information""" info = { "tier": self.tier, "num_encoders": self.num_encoders, "encoder_size": self.encoder_size, "use_pretrained": self.use_pretrained, "simulation_mode": self.stats['simulation_mode'], "device": self.device, "domains": list(set(enc["domain"] for enc in self.encoders)), } if self.swarm_config: info.update({ "total_parameters": self.swarm_config.config["total_params"], "memory_estimate": self.swarm_config.config["memory_estimate"], "hardware_recommendation": self.swarm_config.config["hardware"] }) return info def get_stats(self) -> Dict: """Get current performance statistics""" return self.stats.copy() def load_models(self, checkpoint_path: str): """Load trained models from checkpoint""" if not os.path.exists(checkpoint_path): print(f"โŒ Checkpoint not found: {checkpoint_path}") return try: checkpoint = torch.load(checkpoint_path, map_location=self.device) # Load aggregator if self.aggregator and 'aggregator_state' in checkpoint: self.aggregator.load_state_dict(checkpoint['aggregator_state']) # Load specialists (if using custom models) if self.tlm_manager and 'specialist_states' in checkpoint: for specialist_id, state_dict in checkpoint['specialist_states'].items(): if specialist_id in self.tlm_manager.specialists: self.tlm_manager.specialists[specialist_id].model.load_state_dict(state_dict) print(f"โœ… Models loaded from {checkpoint_path}") except Exception as e: print(f"โŒ Error loading models: {e}") def set_eval_mode(self): """Set all models to evaluation mode""" if self.tlm_manager: for specialist in self.tlm_manager.specialists.values(): if hasattr(specialist, 'model'): specialist.model.eval() if self.aggregator and hasattr(self.aggregator, 'eval'): self.aggregator.eval() if self.router and hasattr(self.router, 'eval'): self.router.eval() # Set pretrained encoders to eval mode for encoder in self.encoders: if encoder.get("model") and hasattr(encoder["model"], 'eval'): encoder["model"].eval() def set_train_mode(self): """Set all models to training mode""" if self.tlm_manager: for specialist in self.tlm_manager.specialists.values(): if hasattr(specialist, 'model'): specialist.model.train() if self.aggregator and hasattr(self.aggregator, 'train'): self.aggregator.train() if self.router and hasattr(self.router, 'train'): self.router.train() # ============================================================================= # FACTORY FUNCTIONS # ============================================================================= def create_mamba_swarm(tier: str = "auto", use_pretrained: bool = True, config_override: Optional[Dict] = None) -> UnifiedMambaSwarm: """ Factory function to create appropriately configured swarm Args: tier: Scaling tier or "auto" for auto-detection use_pretrained: Whether to use pretrained HuggingFace models config_override: Dictionary to override default config Returns: Configured UnifiedMambaSwarm instance """ if tier == "auto": tier = auto_detect_tier() return UnifiedMambaSwarm( tier=tier, use_pretrained=use_pretrained, config_override=config_override ) def create_production_swarm(tier: str = "medium") -> UnifiedMambaSwarm: """Create production-ready swarm with optimal settings""" return UnifiedMambaSwarm( tier=tier, use_pretrained=True, config_override={ "batch_size": 32, "max_sequence_length": 2048 } ) def create_development_swarm() -> UnifiedMambaSwarm: """Create development swarm with simulation fallback""" return UnifiedMambaSwarm( tier="demo", use_pretrained=True, config_override={ "simulation_fallback": True } ) # ============================================================================= # MAIN EXECUTION # ============================================================================= if __name__ == "__main__": print("๐Ÿงช Testing Unified Mamba Swarm...") # Create swarm instance swarm = create_mamba_swarm(tier="demo") # Display system info print("\n๐Ÿ“Š System Information:") info = swarm.get_system_info() for key, value in info.items(): print(f" {key}: {value}") # Test both processing methods test_prompts = [ "Write a Python function to calculate fibonacci numbers", "Explain the process of photosynthesis", "What are the symptoms of diabetes?" ] print("\n๐Ÿงช Testing generate method:") for prompt in test_prompts[:2]: result = swarm.generate(prompt, max_length=150) print(f"\nPrompt: {prompt}") print(f"Response: {result['response'][:100]}...") print(f"Processing time: {result['processing_time']:.3f}s") print(f"Routing: {result['routing_info']['domains']}") print("\n๐Ÿงช Testing process_request method:") result = swarm.process_request(test_prompts[2]) print(f"Response: {result['response'][:100]}...") print(f"Success: {result['success']}") # Test batch processing print("\n๐Ÿงช Testing batch processing:") batch_results = swarm.batch_process(test_prompts, method="generate") print(f"Processed {len(batch_results)} requests in batch") # Display final stats print("\n๐Ÿ“ˆ Final Statistics:") stats = swarm.get_stats() for key, value in stats.items(): if key != 'specialist_usage': print(f" {key}: {value}") print("\nโœ… Testing complete!")