Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# routing/tlm_manager.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
from typing import List, Dict, Tuple, Optional | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import asyncio | |
from core.model import MambaModel | |
from core.config import MambaConfig | |
from utils.domain_configs import DomainConfigs | |
class SpecialistTLM: | |
"""Individual Specialist Mamba TLM""" | |
def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict): | |
self.specialist_id = specialist_id | |
self.config = config | |
self.domain_info = domain_info | |
self.model = MambaModel(config) | |
self.device = config.device | |
# Move to device | |
self.model.to(self.device) | |
def encode(self, input_ids: torch.Tensor) -> torch.Tensor: | |
"""Encode input and return hidden states""" | |
self.model.eval() | |
with torch.no_grad(): | |
# Get embeddings | |
x = self.model.embedding(input_ids) | |
# Pass through Mamba layers | |
for layer in self.model.layers: | |
x = layer(x) | |
# Apply final norm | |
x = self.model.norm_f(x) | |
# Return pooled representation | |
return x.mean(dim=1) # [batch, d_model] | |
def get_memory_usage(self) -> int: | |
"""Get model memory usage in bytes""" | |
return sum(p.numel() * p.element_size() for p in self.model.parameters()) | |
class TLMManager: | |
"""Manages 100 specialist Mamba TLMs""" | |
def __init__(self, config: MambaConfig): | |
self.config = config | |
self.device = config.device | |
# Create domain configurations | |
self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists) | |
# Initialize specialists | |
self.specialists = {} | |
self._initialize_specialists() | |
# Shared components | |
self.shared_embedding = None | |
if config.shared_embedding: | |
self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model) | |
self.shared_embedding.to(self.device) | |
# Thread pool for parallel processing | |
self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists)) | |
def _initialize_specialists(self): | |
"""Initialize all specialist TLMs""" | |
print("Initializing 100 specialist TLMs...") | |
for domain_config in self.domain_configs: | |
specialist_id = domain_config["id"] | |
# Create specialist-specific config | |
specialist_config = DomainConfigs.create_specialist_config( | |
self.config, specialist_id | |
) | |
# Create specialist TLM | |
specialist = SpecialistTLM( | |
specialist_id=specialist_id, | |
config=specialist_config, | |
domain_info=domain_config | |
) | |
self.specialists[specialist_id] = specialist | |
if specialist_id % 10 == 0: | |
print(f"Initialized {specialist_id + 1}/100 specialists") | |
print("All specialists initialized!") | |
# Apply weight sharing if enabled | |
if self.config.hierarchical_sharing: | |
self._apply_weight_sharing() | |
def _apply_weight_sharing(self): | |
"""Apply hierarchical weight sharing between specialists""" | |
print("Applying hierarchical weight sharing...") | |
# Share embedding layers | |
if self.shared_embedding is not None: | |
for specialist in self.specialists.values(): | |
specialist.model.embedding.token_embedding = self.shared_embedding | |
# Group specialists by domain similarity and share lower layers | |
domain_groups = self._group_domains_by_similarity() | |
for group in domain_groups: | |
if len(group) > 1: | |
# Use first specialist's weights as shared weights for the group | |
reference_specialist = self.specialists[group[0]] | |
shared_layers = reference_specialist.model.layers[:self.config.n_layers//2] | |
for specialist_id in group[1:]: | |
specialist = self.specialists[specialist_id] | |
for i, layer in enumerate(shared_layers): | |
specialist.model.layers[i] = layer | |
def _group_domains_by_similarity(self) -> List[List[int]]: | |
"""Group domains by similarity for weight sharing""" | |
# Simple grouping based on domain categories | |
groups = { | |
'stem': [], | |
'programming': [], | |
'language': [], | |
'business': [], | |
'other': [] | |
} | |
for domain_config in self.domain_configs: | |
domain_name = domain_config["name"].lower() | |
specialist_id = domain_config["id"] | |
if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']): | |
groups['stem'].append(specialist_id) | |
elif any(x in domain_name for x in ['python', 'javascript', 'systems']): | |
groups['programming'].append(specialist_id) | |
elif any(x in domain_name for x in ['writing', 'translation']): | |
groups['language'].append(specialist_id) | |
elif any(x in domain_name for x in ['business', 'legal']): | |
groups['business'].append(specialist_id) | |
else: | |
groups['other'].append(specialist_id) | |
return [group for group in groups.values() if len(group) > 1] | |
def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]: | |
""" | |
Encode chunks in parallel using appropriate specialists | |
Args: | |
routing_results: List of routing results from router | |
Returns: | |
List of encoded results with specialist outputs | |
""" | |
futures = [] | |
for chunk_info in routing_results: | |
chunk_text = chunk_info['text'] | |
specialists = chunk_info['specialists'] | |
chunk_id = chunk_info['chunk_id'] | |
# Create encoding task for each relevant specialist | |
for specialist_id, confidence in specialists: | |
if specialist_id in self.specialists: | |
future = self.executor.submit( | |
self._encode_chunk, | |
chunk_text, | |
specialist_id, | |
confidence, | |
chunk_id | |
) | |
futures.append(future) | |
# Collect results | |
encoded_results = [] | |
for future in as_completed(futures): | |
try: | |
result = future.result() | |
encoded_results.append(result) | |
except Exception as e: | |
print(f"Error in specialist encoding: {e}") | |
# Group results by chunk_id | |
grouped_results = {} | |
for result in encoded_results: | |
chunk_id = result['chunk_id'] | |
if chunk_id not in grouped_results: | |
grouped_results[chunk_id] = [] | |
grouped_results[chunk_id].append(result) | |
return grouped_results | |
def _encode_chunk(self, text: str, specialist_id: int, confidence: float, | |
chunk_id: int) -> Dict: | |
"""Encode a single chunk with a specific specialist""" | |
try: | |
specialist = self.specialists[specialist_id] | |
# Tokenize text (simplified - should use proper tokenizer) | |
# This is a placeholder - integrate with actual tokenizer | |
input_ids = torch.randint(0, 1000, (1, 100)).to(self.device) | |
# Encode with specialist | |
encoding = specialist.encode(input_ids) | |
return { | |
'chunk_id': chunk_id, | |
'specialist_id': specialist_id, | |
'confidence': confidence, | |
'encoding': encoding, | |
'domain': specialist.domain_info['name'] | |
} | |
except Exception as e: | |
print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}") | |
return None | |
def get_active_specialists(self) -> List[int]: | |
"""Get list of currently active specialist IDs""" | |
return list(self.specialists.keys()) | |
def get_specialist_info(self, specialist_id: int) -> Dict: | |
"""Get information about a specific specialist""" | |
if specialist_id in self.specialists: | |
specialist = self.specialists[specialist_id] | |
return { | |
'id': specialist_id, | |
'domain': specialist.domain_info, | |
'params': specialist.model.get_num_params(), | |
'memory': specialist.get_memory_usage() | |
} | |
return None | |
def get_total_parameters(self) -> int: | |
"""Get total parameters across all specialists""" | |
total = 0 | |
for specialist in self.specialists.values(): | |
total += specialist.model.get_num_params() | |
return total |