Spaces:
Running
Running
| # ============================================================================= | |
| # system/inference.py | |
| # ============================================================================= | |
| import torch | |
| from typing import Dict, List, Optional, Union | |
| import time | |
| class MambaInferenceEngine: | |
| """Optimized inference engine for Mamba swarm""" | |
| def __init__(self, swarm_engine): | |
| self.swarm_engine = swarm_engine | |
| self.config = swarm_engine.config | |
| # Inference optimizations | |
| self.use_half_precision = True | |
| self.use_torch_compile = hasattr(torch, 'compile') | |
| # Apply optimizations | |
| self._optimize_models() | |
| def _optimize_models(self): | |
| """Apply inference optimizations""" | |
| if self.use_half_precision and self.config.device != 'cpu': | |
| # Convert to half precision for faster inference | |
| for specialist in self.swarm_engine.tlm_manager.specialists.values(): | |
| specialist.model = specialist.model.half() | |
| self.swarm_engine.aggregator = self.swarm_engine.aggregator.half() | |
| if self.use_torch_compile: | |
| try: | |
| # Compile models for faster inference (PyTorch 2.0+) | |
| for specialist in self.swarm_engine.tlm_manager.specialists.values(): | |
| specialist.model = torch.compile(specialist.model) | |
| self.swarm_engine.aggregator = torch.compile(self.swarm_engine.aggregator) | |
| print("Models compiled for faster inference") | |
| except Exception as e: | |
| print(f"Could not compile models: {e}") | |
| def generate(self, prompt: str, max_tokens: int = 100, | |
| temperature: float = 0.7, top_k: int = 50) -> Dict: | |
| """ | |
| Generate text response with advanced sampling | |
| Args: | |
| prompt: Input text prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| top_k: Top-k sampling parameter | |
| Returns: | |
| Dict with generated text and metadata | |
| """ | |
| start_time = time.time() | |
| # Process through swarm | |
| result = self.swarm_engine.process_request(prompt, max_tokens) | |
| if not result['success']: | |
| return result | |
| # Add inference metadata | |
| result.update({ | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'inference_time': time.time() - start_time, | |
| 'tokens_per_second': max_tokens / (time.time() - start_time) | |
| }) | |
| return result | |
| def stream_generate(self, prompt: str, max_tokens: int = 100): | |
| """ | |
| Stream generation token by token (placeholder implementation) | |
| """ | |
| # This would implement streaming generation | |
| # For now, return the full response | |
| result = self.generate(prompt, max_tokens) | |
| yield result['response'] | |
| def chat_completion(self, messages: List[Dict], max_tokens: int = 100) -> Dict: | |
| """ | |
| Chat completion interface similar to OpenAI API | |
| Args: | |
| messages: List of message dicts with 'role' and 'content' | |
| max_tokens: Maximum tokens to generate | |
| Returns: | |
| Chat completion response | |
| """ | |
| # Convert messages to single prompt | |
| prompt = self._format_chat_prompt(messages) | |
| # Generate response | |
| result = self.generate(prompt, max_tokens) | |
| if result['success']: | |
| # Format as chat completion | |
| return { | |
| 'choices': [{ | |
| 'message': { | |
| 'role': 'assistant', | |
| 'content': result['response'] | |
| }, | |
| 'finish_reason': 'stop' | |
| }], | |
| 'usage': { | |
| 'prompt_tokens': len(prompt.split()), | |
| 'completion_tokens': len(result['response'].split()), | |
| 'total_tokens': len(prompt.split()) + len(result['response'].split()) | |
| }, | |
| 'model': 'mamba-swarm-70m', | |
| 'inference_time': result.get('inference_time', 0) | |
| } | |
| else: | |
| return { | |
| 'error': result.get('error', 'Unknown error'), | |
| 'success': False | |
| } | |
| def _format_chat_prompt(self, messages: List[Dict]) -> str: | |
| """Format chat messages into a single prompt""" | |
| formatted = "" | |
| for message in messages: | |
| role = message.get('role', 'user') | |
| content = message.get('content', '') | |
| if role == 'system': | |
| formatted += f"System: {content}\n" | |
| elif role == 'user': | |
| formatted += f"User: {content}\n" | |
| elif role == 'assistant': | |
| formatted += f"Assistant: {content}\n" | |
| formatted += "Assistant: " | |
| return formatted |