Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# system/mambaSwarm.py - Unified Scalable Mamba Swarm Engine | |
# ============================================================================= | |
import torch | |
import time | |
import os | |
import asyncio | |
from typing import Dict, List, Tuple, Optional, Union | |
from concurrent.futures import ThreadPoolExecutor | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Core imports | |
from core.config import MambaConfig, MambaSwarmConfig, auto_detect_tier | |
from core.tokenizer import MambaTokenizer | |
from core.preprocess import TextPreprocessor | |
from core.model import MambaModel | |
from core.mamba_swarm_integration import MambaEncoderSwarmModel, create_swarm_from_existing_config | |
# Routing imports | |
from routing.router import TopicRouter, ContentBasedRouter | |
from routing.tlm_manager import TLMManager | |
from routing.aggregator import AttentionAggregator, WeightedAggregator | |
from utils.domain_configs import DomainConfigs | |
class UnifiedMambaSwarm: | |
""" | |
Unified Mamba Swarm Engine combining the best of both architectures: | |
- Scalable tier-based system with auto-detection | |
- Production-ready async processing and monitoring | |
- Graceful fallback to simulation mode | |
- Support for both custom and pre-trained models | |
""" | |
def __init__(self, | |
tier: Optional[str] = None, | |
config: Optional[Union[MambaConfig, MambaSwarmConfig]] = None, | |
use_pretrained: bool = True, | |
config_override: Optional[Dict] = None): | |
""" | |
Initialize the unified swarm engine | |
Args: | |
tier: Scaling tier (demo/small/medium/large/full) or None for auto-detect | |
config: Either MambaConfig for custom models or MambaSwarmConfig for scaling | |
use_pretrained: Whether to use HuggingFace pretrained models | |
config_override: Dictionary to override config settings | |
""" | |
# Auto-detect tier if not specified | |
if tier is None: | |
tier = auto_detect_tier() | |
print(f"Auto-detected tier: {tier}") | |
self.tier = tier | |
self.use_pretrained = use_pretrained | |
# Initialize configuration | |
if config is None: | |
if use_pretrained: | |
self.swarm_config = MambaSwarmConfig(tier=tier) | |
if config_override: | |
self.swarm_config.config.update(config_override) | |
self.config = self._create_legacy_config() | |
else: | |
# Use custom config for legacy components | |
self.config = MambaConfig() # Default config | |
self.swarm_config = None | |
else: | |
if isinstance(config, MambaSwarmConfig): | |
self.swarm_config = config | |
self.config = self._create_legacy_config() | |
else: | |
self.config = config | |
self.swarm_config = None | |
self.device = getattr(self.config, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') | |
# System properties | |
if self.swarm_config: | |
self.num_encoders = self.swarm_config.config["num_encoders"] | |
self.encoder_size = self.swarm_config.config["encoder_size"] | |
else: | |
self.num_encoders = getattr(self.config, 'num_specialists', 5) | |
self.encoder_size = "130M" | |
# Initialize components | |
self.encoders = [] | |
self.tokenizer = None | |
self.preprocessor = None | |
self.router = None | |
self.aggregator = None | |
self.tlm_manager = None | |
# Performance tracking | |
self.stats = { | |
'total_requests': 0, | |
'total_tokens_processed': 0, | |
'avg_response_time': 0.0, | |
'specialist_usage': {i: 0 for i in range(self.num_encoders)}, | |
'simulation_mode': False, | |
'model_load_errors': 0 | |
} | |
# Initialize system | |
self._initialize_system() | |
print(f"✅ Unified Mamba Swarm initialized: {self.tier} tier, {self.num_encoders} encoders") | |
def _create_legacy_config(self) -> MambaConfig: | |
"""Create legacy MambaConfig from SwarmConfig for compatibility""" | |
legacy_config = MambaConfig() | |
if self.swarm_config: | |
legacy_config.num_specialists = self.swarm_config.config["num_encoders"] | |
legacy_config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
return legacy_config | |
def _initialize_system(self): | |
"""Initialize the complete swarm system""" | |
try: | |
# Initialize tokenizer and preprocessor | |
self._initialize_tokenizer() | |
self._initialize_preprocessor() | |
# Initialize encoders/specialists | |
if self.use_pretrained: | |
self._initialize_pretrained_encoders() | |
else: | |
self._initialize_custom_specialists() | |
# Initialize routing system | |
self._initialize_routing() | |
# Initialize aggregation system | |
self._initialize_aggregation() | |
print(f"🚀 System initialization complete!") | |
except Exception as e: | |
print(f"⚠️ Error during initialization: {e}") | |
self._fallback_to_simulation() | |
def _initialize_tokenizer(self): | |
"""Initialize tokenizer based on mode""" | |
if self.use_pretrained: | |
base_model_name = self._get_base_model_name() | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
print(f"📝 Loaded HuggingFace tokenizer: {base_model_name}") | |
except: | |
print("⚠️ HuggingFace tokenizer failed, using custom tokenizer") | |
self.tokenizer = MambaTokenizer(self.config) | |
else: | |
self.tokenizer = MambaTokenizer(self.config) | |
def _initialize_preprocessor(self): | |
"""Initialize text preprocessor""" | |
self.preprocessor = TextPreprocessor(self.config) | |
def _get_base_model_name(self): | |
"""Get the appropriate base model for current tier""" | |
model_mapping = { | |
"130M": "state-spaces/mamba-130m", | |
"370M": "state-spaces/mamba-370m", | |
"790M": "state-spaces/mamba-790m", | |
"1.4B": "state-spaces/mamba-1.4b", | |
"2.8B": "state-spaces/mamba-2.8b" | |
} | |
return model_mapping.get(self.encoder_size, "state-spaces/mamba-130m") | |
def _initialize_pretrained_encoders(self): | |
"""Initialize pretrained encoder swarm""" | |
print(f"🔄 Loading {self.num_encoders} pretrained encoders...") | |
base_model_name = self._get_base_model_name() | |
try: | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16 if self.num_encoders > 5 else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else "cpu" | |
) | |
# Create encoder instances | |
for i in range(self.num_encoders): | |
domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else { | |
"domain": f"general_{i}", "specialty": "general" | |
} | |
if self.tier == "demo" or self.num_encoders <= 5: | |
# Share model instance for smaller configurations | |
encoder = { | |
"id": i, | |
"model": base_model, | |
"domain": domain_info["domain"], | |
"specialty": domain_info["specialty"], | |
"shared": True | |
} | |
else: | |
# Separate instances for larger configurations | |
encoder = { | |
"id": i, | |
"model": AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
), | |
"domain": domain_info["domain"], | |
"specialty": domain_info["specialty"], | |
"shared": False | |
} | |
self.encoders.append(encoder) | |
print(f" ✓ Encoder {i}: {encoder['domain']} specialist") | |
except Exception as e: | |
print(f"❌ Failed to load pretrained models: {e}") | |
self.stats['model_load_errors'] += 1 | |
self._create_simulated_encoders() | |
def _initialize_custom_specialists(self): | |
"""Initialize custom TLM specialists or native Mamba swarm""" | |
try: | |
if hasattr(self, 'use_native_swarm') and self.use_native_swarm: | |
# Use the native Mamba swarm integration | |
self.native_swarm_model = create_swarm_from_existing_config( | |
self.config, num_encoders=self.num_encoders | |
) | |
print(f"✓ Initialized native Mamba swarm with {self.num_encoders} encoders") | |
else: | |
# Use TLM manager (legacy approach) | |
self.tlm_manager = TLMManager(self.config) | |
print(f"✓ Initialized {self.num_encoders} custom specialists") | |
except Exception as e: | |
print(f"⚠️ Custom specialists failed: {e}") | |
self._create_simulated_encoders() | |
def _create_simulated_encoders(self): | |
"""Create simulated encoders for demonstration/fallback""" | |
print("🎭 Creating simulated encoders...") | |
self.stats['simulation_mode'] = True | |
for i in range(self.num_encoders): | |
domain_info = self.swarm_config.domain_assignments[i] if self.swarm_config else { | |
"domain": f"general_{i}", "specialty": "general" | |
} | |
encoder = { | |
"id": i, | |
"model": None, | |
"domain": domain_info["domain"], | |
"specialty": domain_info["specialty"], | |
"simulated": True | |
} | |
self.encoders.append(encoder) | |
def _initialize_routing(self): | |
"""Initialize routing system""" | |
try: | |
if self.use_pretrained and self.swarm_config: | |
# Use content-based router for pretrained models | |
router_config = self.swarm_config.get_router_config() | |
self.router = ContentBasedRouter( | |
num_encoders=self.num_encoders, | |
domain_assignments=self.swarm_config.domain_assignments, | |
config=router_config | |
) | |
else: | |
# Use topic router for custom models | |
domain_configs = DomainConfigs.get_domain_configs(self.num_encoders) | |
self.router = TopicRouter(self.config, domain_configs) | |
if hasattr(self.router, 'to'): | |
self.router.to(self.device) | |
print("🧭 Router initialized") | |
except Exception as e: | |
print(f"⚠️ Router initialization failed: {e}") | |
# Create basic fallback router | |
self.router = self._create_fallback_router() | |
def _initialize_aggregation(self): | |
"""Initialize aggregation system""" | |
try: | |
if self.use_pretrained: | |
self.aggregator = WeightedAggregator( | |
num_encoders=self.num_encoders, | |
hidden_dim=768 | |
) | |
else: | |
self.aggregator = AttentionAggregator(self.config) | |
if hasattr(self.aggregator, 'to'): | |
self.aggregator.to(self.device) | |
print("🔄 Aggregator initialized") | |
except Exception as e: | |
print(f"⚠️ Aggregator initialization failed: {e}") | |
self.aggregator = None | |
def _create_fallback_router(self): | |
"""Create a simple fallback router""" | |
class FallbackRouter: | |
def __init__(self, num_encoders): | |
self.num_encoders = num_encoders | |
def route(self, text): | |
# Simple round-robin routing | |
import random | |
num_selected = min(3, self.num_encoders) | |
return { | |
"selected_encoders": random.sample(range(self.num_encoders), num_selected) | |
} | |
def chunk_and_route(self, text): | |
return [{"specialists": [(0, 1.0)], "chunk": text}] | |
return FallbackRouter(self.num_encoders) | |
def _fallback_to_simulation(self): | |
"""Complete fallback to simulation mode""" | |
print("🎭 Entering full simulation mode") | |
self.stats['simulation_mode'] = True | |
self._create_simulated_encoders() | |
if not self.router: | |
self.router = self._create_fallback_router() | |
# ============================================================================= | |
# MAIN PROCESSING METHODS | |
# ============================================================================= | |
def generate(self, prompt: str, max_length: int = 100, temperature: float = 0.7, | |
show_routing: bool = True) -> Dict: | |
""" | |
Generate response using the swarm (from swarmEngine2 style) | |
Args: | |
prompt: Input text prompt | |
max_length: Maximum tokens to generate | |
temperature: Sampling temperature | |
show_routing: Whether to display routing information | |
Returns: | |
Dict with response and metadata | |
""" | |
start_time = time.time() | |
try: | |
# Route to appropriate encoders | |
if hasattr(self.router, 'route'): | |
routing_decision = self.router.route(prompt) | |
selected_encoders = routing_decision.get("selected_encoders", [0]) | |
else: | |
# Fallback routing | |
selected_encoders = [0] | |
if show_routing: | |
print(f"🔀 Routing: Selected {len(selected_encoders)} encoders") | |
for enc_id in selected_encoders[:3]: | |
if enc_id < len(self.encoders): | |
domain = self.encoders[enc_id]["domain"] | |
print(f" Encoder {enc_id}: {domain}") | |
# Generate response | |
if self.stats['simulation_mode'] or any(enc.get("simulated") for enc in self.encoders): | |
response = self._simulate_generation(prompt, selected_encoders, max_length) | |
else: | |
response = self._real_generation(prompt, selected_encoders, max_length, temperature) | |
# Update statistics | |
processing_time = time.time() - start_time | |
self._update_stats_simple(prompt, selected_encoders, processing_time) | |
return { | |
"response": response, | |
"processing_time": processing_time, | |
"routing_info": { | |
"selected_encoders": selected_encoders, | |
"num_active": len(selected_encoders), | |
"total_encoders": self.num_encoders, | |
"domains": [self.encoders[i]["domain"] for i in selected_encoders | |
if i < len(self.encoders)] | |
}, | |
"success": True | |
} | |
except Exception as e: | |
return { | |
"response": f"Error generating response: {str(e)}", | |
"processing_time": time.time() - start_time, | |
"success": False, | |
"error": str(e) | |
} | |
def process_request(self, text: str, max_new_tokens: int = 100) -> Dict: | |
""" | |
Process request using traditional pipeline (from swarm_engine style) | |
Args: | |
text: Input text to process | |
max_new_tokens: Maximum tokens to generate | |
Returns: | |
Dict with response and metadata | |
""" | |
start_time = time.time() | |
try: | |
# Step 1: Preprocess input | |
if self.preprocessor: | |
clean_text = self.preprocessor.clean_text(text) | |
else: | |
clean_text = text | |
# Step 2: Route to specialists | |
if hasattr(self.router, 'chunk_and_route'): | |
routing_results = self.router.chunk_and_route(clean_text) | |
else: | |
# Fallback for content-based router | |
routing_decision = self.router.route(clean_text) | |
routing_results = [{"specialists": [(enc_id, 1.0) for enc_id in routing_decision["selected_encoders"]], | |
"chunk": clean_text}] | |
# Step 3: Process chunks | |
if self.tlm_manager and not self.stats['simulation_mode']: | |
specialist_outputs = self.tlm_manager.encode_parallel(routing_results) | |
else: | |
# Simulate processing | |
specialist_outputs = [{"response": f"Processed chunk: {res['chunk'][:50]}..."} | |
for res in routing_results] | |
# Step 4: Aggregate results | |
if self.aggregator and not self.stats['simulation_mode']: | |
response = self.aggregator.generate_response(specialist_outputs, max_new_tokens) | |
else: | |
# Simple aggregation fallback | |
response = " ".join([out.get("response", "") for out in specialist_outputs]) | |
# Update stats | |
processing_time = time.time() - start_time | |
self._update_stats(text, routing_results, processing_time) | |
return { | |
'response': response, | |
'processing_time': processing_time, | |
'chunks_processed': len(routing_results), | |
'specialists_used': self._get_specialists_used(routing_results), | |
'success': True | |
} | |
except Exception as e: | |
return { | |
'response': f"Error processing request: {str(e)}", | |
'processing_time': time.time() - start_time, | |
'success': False, | |
'error': str(e) | |
} | |
# ============================================================================= | |
# ASYNC AND BATCH PROCESSING | |
# ============================================================================= | |
async def process_request_async(self, text: str, max_new_tokens: int = 100) -> Dict: | |
"""Async version of process_request""" | |
loop = asyncio.get_event_loop() | |
with ThreadPoolExecutor() as executor: | |
result = await loop.run_in_executor( | |
executor, self.process_request, text, max_new_tokens | |
) | |
return result | |
async def generate_async(self, prompt: str, max_length: int = 100, | |
temperature: float = 0.7) -> Dict: | |
"""Async version of generate""" | |
loop = asyncio.get_event_loop() | |
with ThreadPoolExecutor() as executor: | |
result = await loop.run_in_executor( | |
executor, self.generate, prompt, max_length, temperature, False | |
) | |
return result | |
def batch_process(self, texts: List[str], max_new_tokens: int = 100, | |
method: str = "process") -> List[Dict]: | |
""" | |
Process multiple texts in batch | |
Args: | |
texts: List of input texts | |
max_new_tokens: Maximum tokens to generate | |
method: "process" or "generate" for processing method | |
""" | |
results = [] | |
for text in texts: | |
if method == "generate": | |
result = self.generate(text, max_new_tokens, show_routing=False) | |
else: | |
result = self.process_request(text, max_new_tokens) | |
results.append(result) | |
return results | |
# ============================================================================= | |
# GENERATION METHODS | |
# ============================================================================= | |
def _simulate_generation(self, prompt: str, selected_encoders: List[int], max_length: int) -> str: | |
"""Simulate generation for demo/fallback purposes""" | |
import random | |
# Determine response type based on selected encoder domains | |
domains = [self.encoders[i]["domain"] for i in selected_encoders if i < len(self.encoders)] | |
if any("code" in domain.lower() for domain in domains): | |
return f"Here's a solution for '{prompt[:30]}...':\n\n```python\ndef solution():\n # Implementation here\n return result\n```" | |
elif any("medical" in domain.lower() for domain in domains): | |
return f"Regarding '{prompt[:30]}...': This medical topic requires careful consideration. Please consult healthcare professionals." | |
elif any("science" in domain.lower() for domain in domains): | |
return f"From a scientific perspective on '{prompt[:30]}...': Current research indicates several key factors..." | |
else: | |
return f"Thank you for asking about '{prompt[:30]}...'. Based on expertise from {len(selected_encoders)} specialized domains, here's a comprehensive response..." | |
def _real_generation(self, prompt: str, selected_encoders: List[int], | |
max_length: int, temperature: float) -> str: | |
"""Real generation using loaded models""" | |
if not selected_encoders or selected_encoders[0] >= len(self.encoders): | |
return "No valid encoders available for generation." | |
try: | |
# Use primary encoder for generation | |
primary_encoder = self.encoders[selected_encoders[0]] | |
if primary_encoder.get("simulated") or not primary_encoder["model"]: | |
return self._simulate_generation(prompt, selected_encoders, max_length) | |
# Tokenize input | |
if hasattr(self.tokenizer, 'encode'): | |
inputs = self.tokenizer(prompt, return_tensors="pt") | |
else: | |
# Fallback tokenization | |
return self._simulate_generation(prompt, selected_encoders, max_length) | |
# Generate with model | |
with torch.no_grad(): | |
outputs = primary_encoder["model"].generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 0 | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove original prompt from response | |
response = response[len(prompt):].strip() | |
return response if response else "Generated response was empty." | |
except Exception as e: | |
print(f"⚠️ Real generation failed: {e}") | |
return self._simulate_generation(prompt, selected_encoders, max_length) | |
# ============================================================================= | |
# UTILITY METHODS | |
# ============================================================================= | |
def _get_specialists_used(self, routing_results: List[Dict]) -> List[int]: | |
"""Extract specialist IDs used in routing""" | |
specialists_used = set() | |
for chunk_info in routing_results: | |
if 'specialists' in chunk_info: | |
for specialist_id, _ in chunk_info['specialists']: | |
specialists_used.add(specialist_id) | |
return list(specialists_used) | |
def _update_stats(self, text: str, routing_results: List[Dict], processing_time: float): | |
"""Update detailed performance statistics""" | |
self.stats['total_requests'] += 1 | |
self.stats['total_tokens_processed'] += len(text.split()) | |
# Update average response time | |
prev_avg = self.stats['avg_response_time'] | |
n = self.stats['total_requests'] | |
self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n | |
# Update specialist usage | |
specialists_used = self._get_specialists_used(routing_results) | |
for specialist_id in specialists_used: | |
if specialist_id in self.stats['specialist_usage']: | |
self.stats['specialist_usage'][specialist_id] += 1 | |
def _update_stats_simple(self, text: str, selected_encoders: List[int], processing_time: float): | |
"""Update simple statistics for generate method""" | |
self.stats['total_requests'] += 1 | |
self.stats['total_tokens_processed'] += len(text.split()) | |
# Update average response time | |
prev_avg = self.stats['avg_response_time'] | |
n = self.stats['total_requests'] | |
self.stats['avg_response_time'] = (prev_avg * (n-1) + processing_time) / n | |
# Update encoder usage | |
for enc_id in selected_encoders: | |
if enc_id in self.stats['specialist_usage']: | |
self.stats['specialist_usage'][enc_id] += 1 | |
# ============================================================================= | |
# SCALING AND MANAGEMENT | |
# ============================================================================= | |
def scale_up(self, new_tier: str): | |
"""Scale up to a higher tier""" | |
if new_tier not in ["demo", "small", "medium", "large", "full"]: | |
raise ValueError(f"Invalid tier: {new_tier}") | |
print(f"🚀 Scaling from {self.tier} to {new_tier}") | |
# Preserve current stats | |
old_stats = self.stats.copy() | |
# Reinitialize with new tier | |
self.__init__(tier=new_tier, use_pretrained=self.use_pretrained) | |
# Restore relevant stats | |
self.stats['total_requests'] = old_stats['total_requests'] | |
self.stats['total_tokens_processed'] = old_stats['total_tokens_processed'] | |
self.stats['avg_response_time'] = old_stats['avg_response_time'] | |
def get_system_info(self) -> Dict: | |
"""Get comprehensive system information""" | |
info = { | |
"tier": self.tier, | |
"num_encoders": self.num_encoders, | |
"encoder_size": self.encoder_size, | |
"use_pretrained": self.use_pretrained, | |
"simulation_mode": self.stats['simulation_mode'], | |
"device": self.device, | |
"domains": list(set(enc["domain"] for enc in self.encoders)), | |
} | |
if self.swarm_config: | |
info.update({ | |
"total_parameters": self.swarm_config.config["total_params"], | |
"memory_estimate": self.swarm_config.config["memory_estimate"], | |
"hardware_recommendation": self.swarm_config.config["hardware"] | |
}) | |
return info | |
def get_stats(self) -> Dict: | |
"""Get current performance statistics""" | |
return self.stats.copy() | |
def load_models(self, checkpoint_path: str): | |
"""Load trained models from checkpoint""" | |
if not os.path.exists(checkpoint_path): | |
print(f"❌ Checkpoint not found: {checkpoint_path}") | |
return | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
# Load aggregator | |
if self.aggregator and 'aggregator_state' in checkpoint: | |
self.aggregator.load_state_dict(checkpoint['aggregator_state']) | |
# Load specialists (if using custom models) | |
if self.tlm_manager and 'specialist_states' in checkpoint: | |
for specialist_id, state_dict in checkpoint['specialist_states'].items(): | |
if specialist_id in self.tlm_manager.specialists: | |
self.tlm_manager.specialists[specialist_id].model.load_state_dict(state_dict) | |
print(f"✅ Models loaded from {checkpoint_path}") | |
except Exception as e: | |
print(f"❌ Error loading models: {e}") | |
def set_eval_mode(self): | |
"""Set all models to evaluation mode""" | |
if self.tlm_manager: | |
for specialist in self.tlm_manager.specialists.values(): | |
if hasattr(specialist, 'model'): | |
specialist.model.eval() | |
if self.aggregator and hasattr(self.aggregator, 'eval'): | |
self.aggregator.eval() | |
if self.router and hasattr(self.router, 'eval'): | |
self.router.eval() | |
# Set pretrained encoders to eval mode | |
for encoder in self.encoders: | |
if encoder.get("model") and hasattr(encoder["model"], 'eval'): | |
encoder["model"].eval() | |
def set_train_mode(self): | |
"""Set all models to training mode""" | |
if self.tlm_manager: | |
for specialist in self.tlm_manager.specialists.values(): | |
if hasattr(specialist, 'model'): | |
specialist.model.train() | |
if self.aggregator and hasattr(self.aggregator, 'train'): | |
self.aggregator.train() | |
if self.router and hasattr(self.router, 'train'): | |
self.router.train() | |
# ============================================================================= | |
# FACTORY FUNCTIONS | |
# ============================================================================= | |
def create_mamba_swarm(tier: str = "auto", use_pretrained: bool = True, | |
config_override: Optional[Dict] = None) -> UnifiedMambaSwarm: | |
""" | |
Factory function to create appropriately configured swarm | |
Args: | |
tier: Scaling tier or "auto" for auto-detection | |
use_pretrained: Whether to use pretrained HuggingFace models | |
config_override: Dictionary to override default config | |
Returns: | |
Configured UnifiedMambaSwarm instance | |
""" | |
if tier == "auto": | |
tier = auto_detect_tier() | |
return UnifiedMambaSwarm( | |
tier=tier, | |
use_pretrained=use_pretrained, | |
config_override=config_override | |
) | |
def create_production_swarm(tier: str = "medium") -> UnifiedMambaSwarm: | |
"""Create production-ready swarm with optimal settings""" | |
return UnifiedMambaSwarm( | |
tier=tier, | |
use_pretrained=True, | |
config_override={ | |
"batch_size": 32, | |
"max_sequence_length": 2048 | |
} | |
) | |
def create_development_swarm() -> UnifiedMambaSwarm: | |
"""Create development swarm with simulation fallback""" | |
return UnifiedMambaSwarm( | |
tier="demo", | |
use_pretrained=True, | |
config_override={ | |
"simulation_fallback": True | |
} | |
) | |
# ============================================================================= | |
# MAIN EXECUTION | |
# ============================================================================= | |
if __name__ == "__main__": | |
print("🧪 Testing Unified Mamba Swarm...") | |
# Create swarm instance | |
swarm = create_mamba_swarm(tier="demo") | |
# Display system info | |
print("\n📊 System Information:") | |
info = swarm.get_system_info() | |
for key, value in info.items(): | |
print(f" {key}: {value}") | |
# Test both processing methods | |
test_prompts = [ | |
"Write a Python function to calculate fibonacci numbers", | |
"Explain the process of photosynthesis", | |
"What are the symptoms of diabetes?" | |
] | |
print("\n🧪 Testing generate method:") | |
for prompt in test_prompts[:2]: | |
result = swarm.generate(prompt, max_length=150) | |
print(f"\nPrompt: {prompt}") | |
print(f"Response: {result['response'][:100]}...") | |
print(f"Processing time: {result['processing_time']:.3f}s") | |
print(f"Routing: {result['routing_info']['domains']}") | |
print("\n🧪 Testing process_request method:") | |
result = swarm.process_request(test_prompts[2]) | |
print(f"Response: {result['response'][:100]}...") | |
print(f"Success: {result['success']}") | |
# Test batch processing | |
print("\n🧪 Testing batch processing:") | |
batch_results = swarm.batch_process(test_prompts, method="generate") | |
print(f"Processed {len(batch_results)} requests in batch") | |
# Display final stats | |
print("\n📈 Final Statistics:") | |
stats = swarm.get_stats() | |
for key, value in stats.items(): | |
if key != 'specialist_usage': | |
print(f" {key}: {value}") | |
print("\n✅ Testing complete!") | |