# ============================================================================= # system/inference.py # ============================================================================= import torch from typing import Dict, List, Optional, Union import time class MambaInferenceEngine: """Optimized inference engine for Mamba swarm""" def __init__(self, swarm_engine): self.swarm_engine = swarm_engine self.config = swarm_engine.config # Inference optimizations self.use_half_precision = True self.use_torch_compile = hasattr(torch, 'compile') # Apply optimizations self._optimize_models() def _optimize_models(self): """Apply inference optimizations""" if self.use_half_precision and self.config.device != 'cpu': # Convert to half precision for faster inference for specialist in self.swarm_engine.tlm_manager.specialists.values(): specialist.model = specialist.model.half() self.swarm_engine.aggregator = self.swarm_engine.aggregator.half() if self.use_torch_compile: try: # Compile models for faster inference (PyTorch 2.0+) for specialist in self.swarm_engine.tlm_manager.specialists.values(): specialist.model = torch.compile(specialist.model) self.swarm_engine.aggregator = torch.compile(self.swarm_engine.aggregator) print("Models compiled for faster inference") except Exception as e: print(f"Could not compile models: {e}") def generate(self, prompt: str, max_tokens: int = 100, temperature: float = 0.7, top_k: int = 50) -> Dict: """ Generate text response with advanced sampling Args: prompt: Input text prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k sampling parameter Returns: Dict with generated text and metadata """ start_time = time.time() # Process through swarm result = self.swarm_engine.process_request(prompt, max_tokens) if not result['success']: return result # Add inference metadata result.update({ 'temperature': temperature, 'top_k': top_k, 'inference_time': time.time() - start_time, 'tokens_per_second': max_tokens / (time.time() - start_time) }) return result def stream_generate(self, prompt: str, max_tokens: int = 100): """ Stream generation token by token (placeholder implementation) """ # This would implement streaming generation # For now, return the full response result = self.generate(prompt, max_tokens) yield result['response'] def chat_completion(self, messages: List[Dict], max_tokens: int = 100) -> Dict: """ Chat completion interface similar to OpenAI API Args: messages: List of message dicts with 'role' and 'content' max_tokens: Maximum tokens to generate Returns: Chat completion response """ # Convert messages to single prompt prompt = self._format_chat_prompt(messages) # Generate response result = self.generate(prompt, max_tokens) if result['success']: # Format as chat completion return { 'choices': [{ 'message': { 'role': 'assistant', 'content': result['response'] }, 'finish_reason': 'stop' }], 'usage': { 'prompt_tokens': len(prompt.split()), 'completion_tokens': len(result['response'].split()), 'total_tokens': len(prompt.split()) + len(result['response'].split()) }, 'model': 'mamba-swarm-70m', 'inference_time': result.get('inference_time', 0) } else: return { 'error': result.get('error', 'Unknown error'), 'success': False } def _format_chat_prompt(self, messages: List[Dict]) -> str: """Format chat messages into a single prompt""" formatted = "" for message in messages: role = message.get('role', 'user') content = message.get('content', '') if role == 'system': formatted += f"System: {content}\n" elif role == 'user': formatted += f"User: {content}\n" elif role == 'assistant': formatted += f"Assistant: {content}\n" formatted += "Assistant: " return formatted