Debito's picture
Upload 3 files
2ee6fe0 verified
# =============================================================================
# routing/tlm_manager.py
# =============================================================================
import torch
import torch.nn as nn
from typing import List, Dict, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from core.model import MambaModel
from core.config import MambaConfig
from utils.domain_configs import DomainConfigs
class SpecialistTLM:
"""Individual Specialist Mamba TLM"""
def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict):
self.specialist_id = specialist_id
self.config = config
self.domain_info = domain_info
self.model = MambaModel(config)
self.device = config.device
# Move to device
self.model.to(self.device)
def encode(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Encode input and return hidden states"""
self.model.eval()
with torch.no_grad():
# Get embeddings
x = self.model.embedding(input_ids)
# Pass through Mamba layers
for layer in self.model.layers:
x = layer(x)
# Apply final norm
x = self.model.norm_f(x)
# Return pooled representation
return x.mean(dim=1) # [batch, d_model]
def get_memory_usage(self) -> int:
"""Get model memory usage in bytes"""
return sum(p.numel() * p.element_size() for p in self.model.parameters())
class TLMManager:
"""Manages 100 specialist Mamba TLMs"""
def __init__(self, config: MambaConfig):
self.config = config
self.device = config.device
# Create domain configurations
self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists)
# Initialize specialists
self.specialists = {}
self._initialize_specialists()
# Shared components
self.shared_embedding = None
if config.shared_embedding:
self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.shared_embedding.to(self.device)
# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists))
def _initialize_specialists(self):
"""Initialize all specialist TLMs"""
print("Initializing 100 specialist TLMs...")
for domain_config in self.domain_configs:
specialist_id = domain_config["id"]
# Create specialist-specific config
specialist_config = DomainConfigs.create_specialist_config(
self.config, specialist_id
)
# Create specialist TLM
specialist = SpecialistTLM(
specialist_id=specialist_id,
config=specialist_config,
domain_info=domain_config
)
self.specialists[specialist_id] = specialist
if specialist_id % 10 == 0:
print(f"Initialized {specialist_id + 1}/100 specialists")
print("All specialists initialized!")
# Apply weight sharing if enabled
if self.config.hierarchical_sharing:
self._apply_weight_sharing()
def _apply_weight_sharing(self):
"""Apply hierarchical weight sharing between specialists"""
print("Applying hierarchical weight sharing...")
# Share embedding layers
if self.shared_embedding is not None:
for specialist in self.specialists.values():
specialist.model.embedding.token_embedding = self.shared_embedding
# Group specialists by domain similarity and share lower layers
domain_groups = self._group_domains_by_similarity()
for group in domain_groups:
if len(group) > 1:
# Use first specialist's weights as shared weights for the group
reference_specialist = self.specialists[group[0]]
shared_layers = reference_specialist.model.layers[:self.config.n_layers//2]
for specialist_id in group[1:]:
specialist = self.specialists[specialist_id]
for i, layer in enumerate(shared_layers):
specialist.model.layers[i] = layer
def _group_domains_by_similarity(self) -> List[List[int]]:
"""Group domains by similarity for weight sharing"""
# Simple grouping based on domain categories
groups = {
'stem': [],
'programming': [],
'language': [],
'business': [],
'other': []
}
for domain_config in self.domain_configs:
domain_name = domain_config["name"].lower()
specialist_id = domain_config["id"]
if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
groups['stem'].append(specialist_id)
elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
groups['programming'].append(specialist_id)
elif any(x in domain_name for x in ['writing', 'translation']):
groups['language'].append(specialist_id)
elif any(x in domain_name for x in ['business', 'legal']):
groups['business'].append(specialist_id)
else:
groups['other'].append(specialist_id)
return [group for group in groups.values() if len(group) > 1]
def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]:
"""
Encode chunks in parallel using appropriate specialists
Args:
routing_results: List of routing results from router
Returns:
List of encoded results with specialist outputs
"""
futures = []
for chunk_info in routing_results:
chunk_text = chunk_info['text']
specialists = chunk_info['specialists']
chunk_id = chunk_info['chunk_id']
# Create encoding task for each relevant specialist
for specialist_id, confidence in specialists:
if specialist_id in self.specialists:
future = self.executor.submit(
self._encode_chunk,
chunk_text,
specialist_id,
confidence,
chunk_id
)
futures.append(future)
# Collect results
encoded_results = []
for future in as_completed(futures):
try:
result = future.result()
encoded_results.append(result)
except Exception as e:
print(f"Error in specialist encoding: {e}")
# Group results by chunk_id
grouped_results = {}
for result in encoded_results:
chunk_id = result['chunk_id']
if chunk_id not in grouped_results:
grouped_results[chunk_id] = []
grouped_results[chunk_id].append(result)
return grouped_results
def _encode_chunk(self, text: str, specialist_id: int, confidence: float,
chunk_id: int) -> Dict:
"""Encode a single chunk with a specific specialist"""
try:
specialist = self.specialists[specialist_id]
# Tokenize text (simplified - should use proper tokenizer)
# This is a placeholder - integrate with actual tokenizer
input_ids = torch.randint(0, 1000, (1, 100)).to(self.device)
# Encode with specialist
encoding = specialist.encode(input_ids)
return {
'chunk_id': chunk_id,
'specialist_id': specialist_id,
'confidence': confidence,
'encoding': encoding,
'domain': specialist.domain_info['name']
}
except Exception as e:
print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}")
return None
def get_active_specialists(self) -> List[int]:
"""Get list of currently active specialist IDs"""
return list(self.specialists.keys())
def get_specialist_info(self, specialist_id: int) -> Dict:
"""Get information about a specific specialist"""
if specialist_id in self.specialists:
specialist = self.specialists[specialist_id]
return {
'id': specialist_id,
'domain': specialist.domain_info,
'params': specialist.model.get_num_params(),
'memory': specialist.get_memory_usage()
}
return None
def get_total_parameters(self) -> int:
"""Get total parameters across all specialists"""
total = 0
for specialist in self.specialists.values():
total += specialist.model.get_num_params()
return total