Spaces:
Sleeping
Sleeping
File size: 5,209 Bytes
fcf0a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# =============================================================================
# system/inference.py
# =============================================================================
import torch
from typing import Dict, List, Optional, Union
import time
class MambaInferenceEngine:
"""Optimized inference engine for Mamba swarm"""
def __init__(self, swarm_engine):
self.swarm_engine = swarm_engine
self.config = swarm_engine.config
# Inference optimizations
self.use_half_precision = True
self.use_torch_compile = hasattr(torch, 'compile')
# Apply optimizations
self._optimize_models()
def _optimize_models(self):
"""Apply inference optimizations"""
if self.use_half_precision and self.config.device != 'cpu':
# Convert to half precision for faster inference
for specialist in self.swarm_engine.tlm_manager.specialists.values():
specialist.model = specialist.model.half()
self.swarm_engine.aggregator = self.swarm_engine.aggregator.half()
if self.use_torch_compile:
try:
# Compile models for faster inference (PyTorch 2.0+)
for specialist in self.swarm_engine.tlm_manager.specialists.values():
specialist.model = torch.compile(specialist.model)
self.swarm_engine.aggregator = torch.compile(self.swarm_engine.aggregator)
print("Models compiled for faster inference")
except Exception as e:
print(f"Could not compile models: {e}")
def generate(self, prompt: str, max_tokens: int = 100,
temperature: float = 0.7, top_k: int = 50) -> Dict:
"""
Generate text response with advanced sampling
Args:
prompt: Input text prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling parameter
Returns:
Dict with generated text and metadata
"""
start_time = time.time()
# Process through swarm
result = self.swarm_engine.process_request(prompt, max_tokens)
if not result['success']:
return result
# Add inference metadata
result.update({
'temperature': temperature,
'top_k': top_k,
'inference_time': time.time() - start_time,
'tokens_per_second': max_tokens / (time.time() - start_time)
})
return result
def stream_generate(self, prompt: str, max_tokens: int = 100):
"""
Stream generation token by token (placeholder implementation)
"""
# This would implement streaming generation
# For now, return the full response
result = self.generate(prompt, max_tokens)
yield result['response']
def chat_completion(self, messages: List[Dict], max_tokens: int = 100) -> Dict:
"""
Chat completion interface similar to OpenAI API
Args:
messages: List of message dicts with 'role' and 'content'
max_tokens: Maximum tokens to generate
Returns:
Chat completion response
"""
# Convert messages to single prompt
prompt = self._format_chat_prompt(messages)
# Generate response
result = self.generate(prompt, max_tokens)
if result['success']:
# Format as chat completion
return {
'choices': [{
'message': {
'role': 'assistant',
'content': result['response']
},
'finish_reason': 'stop'
}],
'usage': {
'prompt_tokens': len(prompt.split()),
'completion_tokens': len(result['response'].split()),
'total_tokens': len(prompt.split()) + len(result['response'].split())
},
'model': 'mamba-swarm-70m',
'inference_time': result.get('inference_time', 0)
}
else:
return {
'error': result.get('error', 'Unknown error'),
'success': False
}
def _format_chat_prompt(self, messages: List[Dict]) -> str:
"""Format chat messages into a single prompt"""
formatted = ""
for message in messages:
role = message.get('role', 'user')
content = message.get('content', '')
if role == 'system':
formatted += f"System: {content}\n"
elif role == 'user':
formatted += f"User: {content}\n"
elif role == 'assistant':
formatted += f"Assistant: {content}\n"
formatted += "Assistant: "
return formatted |