Debito's picture
Upload 4 files
fcf0a07 verified
# =============================================================================
# system/inference.py
# =============================================================================
import torch
from typing import Dict, List, Optional, Union
import time
class MambaInferenceEngine:
"""Optimized inference engine for Mamba swarm"""
def __init__(self, swarm_engine):
self.swarm_engine = swarm_engine
self.config = swarm_engine.config
# Inference optimizations
self.use_half_precision = True
self.use_torch_compile = hasattr(torch, 'compile')
# Apply optimizations
self._optimize_models()
def _optimize_models(self):
"""Apply inference optimizations"""
if self.use_half_precision and self.config.device != 'cpu':
# Convert to half precision for faster inference
for specialist in self.swarm_engine.tlm_manager.specialists.values():
specialist.model = specialist.model.half()
self.swarm_engine.aggregator = self.swarm_engine.aggregator.half()
if self.use_torch_compile:
try:
# Compile models for faster inference (PyTorch 2.0+)
for specialist in self.swarm_engine.tlm_manager.specialists.values():
specialist.model = torch.compile(specialist.model)
self.swarm_engine.aggregator = torch.compile(self.swarm_engine.aggregator)
print("Models compiled for faster inference")
except Exception as e:
print(f"Could not compile models: {e}")
def generate(self, prompt: str, max_tokens: int = 100,
temperature: float = 0.7, top_k: int = 50) -> Dict:
"""
Generate text response with advanced sampling
Args:
prompt: Input text prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling parameter
Returns:
Dict with generated text and metadata
"""
start_time = time.time()
# Process through swarm
result = self.swarm_engine.process_request(prompt, max_tokens)
if not result['success']:
return result
# Add inference metadata
result.update({
'temperature': temperature,
'top_k': top_k,
'inference_time': time.time() - start_time,
'tokens_per_second': max_tokens / (time.time() - start_time)
})
return result
def stream_generate(self, prompt: str, max_tokens: int = 100):
"""
Stream generation token by token (placeholder implementation)
"""
# This would implement streaming generation
# For now, return the full response
result = self.generate(prompt, max_tokens)
yield result['response']
def chat_completion(self, messages: List[Dict], max_tokens: int = 100) -> Dict:
"""
Chat completion interface similar to OpenAI API
Args:
messages: List of message dicts with 'role' and 'content'
max_tokens: Maximum tokens to generate
Returns:
Chat completion response
"""
# Convert messages to single prompt
prompt = self._format_chat_prompt(messages)
# Generate response
result = self.generate(prompt, max_tokens)
if result['success']:
# Format as chat completion
return {
'choices': [{
'message': {
'role': 'assistant',
'content': result['response']
},
'finish_reason': 'stop'
}],
'usage': {
'prompt_tokens': len(prompt.split()),
'completion_tokens': len(result['response'].split()),
'total_tokens': len(prompt.split()) + len(result['response'].split())
},
'model': 'mamba-swarm-70m',
'inference_time': result.get('inference_time', 0)
}
else:
return {
'error': result.get('error', 'Unknown error'),
'success': False
}
def _format_chat_prompt(self, messages: List[Dict]) -> str:
"""Format chat messages into a single prompt"""
formatted = ""
for message in messages:
role = message.get('role', 'user')
content = message.get('content', '')
if role == 'system':
formatted += f"System: {content}\n"
elif role == 'user':
formatted += f"User: {content}\n"
elif role == 'assistant':
formatted += f"Assistant: {content}\n"
formatted += "Assistant: "
return formatted