Spaces:
Sleeping
Sleeping
File size: 9,681 Bytes
2ee6fe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
# =============================================================================
# routing/tlm_manager.py
# =============================================================================
import torch
import torch.nn as nn
from typing import List, Dict, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from core.model import MambaModel
from core.config import MambaConfig
from utils.domain_configs import DomainConfigs
class SpecialistTLM:
"""Individual Specialist Mamba TLM"""
def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict):
self.specialist_id = specialist_id
self.config = config
self.domain_info = domain_info
self.model = MambaModel(config)
self.device = config.device
# Move to device
self.model.to(self.device)
def encode(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Encode input and return hidden states"""
self.model.eval()
with torch.no_grad():
# Get embeddings
x = self.model.embedding(input_ids)
# Pass through Mamba layers
for layer in self.model.layers:
x = layer(x)
# Apply final norm
x = self.model.norm_f(x)
# Return pooled representation
return x.mean(dim=1) # [batch, d_model]
def get_memory_usage(self) -> int:
"""Get model memory usage in bytes"""
return sum(p.numel() * p.element_size() for p in self.model.parameters())
class TLMManager:
"""Manages 100 specialist Mamba TLMs"""
def __init__(self, config: MambaConfig):
self.config = config
self.device = config.device
# Create domain configurations
self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists)
# Initialize specialists
self.specialists = {}
self._initialize_specialists()
# Shared components
self.shared_embedding = None
if config.shared_embedding:
self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.shared_embedding.to(self.device)
# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists))
def _initialize_specialists(self):
"""Initialize all specialist TLMs"""
print("Initializing 100 specialist TLMs...")
for domain_config in self.domain_configs:
specialist_id = domain_config["id"]
# Create specialist-specific config
specialist_config = DomainConfigs.create_specialist_config(
self.config, specialist_id
)
# Create specialist TLM
specialist = SpecialistTLM(
specialist_id=specialist_id,
config=specialist_config,
domain_info=domain_config
)
self.specialists[specialist_id] = specialist
if specialist_id % 10 == 0:
print(f"Initialized {specialist_id + 1}/100 specialists")
print("All specialists initialized!")
# Apply weight sharing if enabled
if self.config.hierarchical_sharing:
self._apply_weight_sharing()
def _apply_weight_sharing(self):
"""Apply hierarchical weight sharing between specialists"""
print("Applying hierarchical weight sharing...")
# Share embedding layers
if self.shared_embedding is not None:
for specialist in self.specialists.values():
specialist.model.embedding.token_embedding = self.shared_embedding
# Group specialists by domain similarity and share lower layers
domain_groups = self._group_domains_by_similarity()
for group in domain_groups:
if len(group) > 1:
# Use first specialist's weights as shared weights for the group
reference_specialist = self.specialists[group[0]]
shared_layers = reference_specialist.model.layers[:self.config.n_layers//2]
for specialist_id in group[1:]:
specialist = self.specialists[specialist_id]
for i, layer in enumerate(shared_layers):
specialist.model.layers[i] = layer
def _group_domains_by_similarity(self) -> List[List[int]]:
"""Group domains by similarity for weight sharing"""
# Simple grouping based on domain categories
groups = {
'stem': [],
'programming': [],
'language': [],
'business': [],
'other': []
}
for domain_config in self.domain_configs:
domain_name = domain_config["name"].lower()
specialist_id = domain_config["id"]
if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
groups['stem'].append(specialist_id)
elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
groups['programming'].append(specialist_id)
elif any(x in domain_name for x in ['writing', 'translation']):
groups['language'].append(specialist_id)
elif any(x in domain_name for x in ['business', 'legal']):
groups['business'].append(specialist_id)
else:
groups['other'].append(specialist_id)
return [group for group in groups.values() if len(group) > 1]
def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]:
"""
Encode chunks in parallel using appropriate specialists
Args:
routing_results: List of routing results from router
Returns:
List of encoded results with specialist outputs
"""
futures = []
for chunk_info in routing_results:
chunk_text = chunk_info['text']
specialists = chunk_info['specialists']
chunk_id = chunk_info['chunk_id']
# Create encoding task for each relevant specialist
for specialist_id, confidence in specialists:
if specialist_id in self.specialists:
future = self.executor.submit(
self._encode_chunk,
chunk_text,
specialist_id,
confidence,
chunk_id
)
futures.append(future)
# Collect results
encoded_results = []
for future in as_completed(futures):
try:
result = future.result()
encoded_results.append(result)
except Exception as e:
print(f"Error in specialist encoding: {e}")
# Group results by chunk_id
grouped_results = {}
for result in encoded_results:
chunk_id = result['chunk_id']
if chunk_id not in grouped_results:
grouped_results[chunk_id] = []
grouped_results[chunk_id].append(result)
return grouped_results
def _encode_chunk(self, text: str, specialist_id: int, confidence: float,
chunk_id: int) -> Dict:
"""Encode a single chunk with a specific specialist"""
try:
specialist = self.specialists[specialist_id]
# Tokenize text (simplified - should use proper tokenizer)
# This is a placeholder - integrate with actual tokenizer
input_ids = torch.randint(0, 1000, (1, 100)).to(self.device)
# Encode with specialist
encoding = specialist.encode(input_ids)
return {
'chunk_id': chunk_id,
'specialist_id': specialist_id,
'confidence': confidence,
'encoding': encoding,
'domain': specialist.domain_info['name']
}
except Exception as e:
print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}")
return None
def get_active_specialists(self) -> List[int]:
"""Get list of currently active specialist IDs"""
return list(self.specialists.keys())
def get_specialist_info(self, specialist_id: int) -> Dict:
"""Get information about a specific specialist"""
if specialist_id in self.specialists:
specialist = self.specialists[specialist_id]
return {
'id': specialist_id,
'domain': specialist.domain_info,
'params': specialist.model.get_num_params(),
'memory': specialist.get_memory_usage()
}
return None
def get_total_parameters(self) -> int:
"""Get total parameters across all specialists"""
total = 0
for specialist in self.specialists.values():
total += specialist.model.get_num_params()
return total |