Debito commited on
Commit
2ee6fe0
·
verified ·
1 Parent(s): 43c0029

Upload 3 files

Browse files
Files changed (3) hide show
  1. routing/aggregator.py +134 -0
  2. routing/router.py +157 -0
  3. routing/tlm_manager.py +244 -0
routing/aggregator.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # routing/aggregator.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Dict, List, Tuple
8
+ from core.config import MambaConfig
9
+
10
+ class AttentionAggregator(nn.Module):
11
+ """Attention-based aggregator for combining specialist outputs"""
12
+
13
+ def __init__(self, config: MambaConfig):
14
+ super().__init__()
15
+ self.config = config
16
+ self.d_model = config.d_model
17
+ self.num_specialists = config.num_specialists
18
+
19
+ # Attention mechanism for combining specialist outputs
20
+ self.specialist_attention = nn.MultiheadAttention(
21
+ embed_dim=self.d_model,
22
+ num_heads=8,
23
+ dropout=0.1,
24
+ batch_first=True
25
+ )
26
+
27
+ # Project specialist confidence scores
28
+ self.confidence_proj = nn.Linear(1, self.d_model)
29
+
30
+ # Output layers
31
+ self.output_layers = nn.Sequential(
32
+ nn.Linear(self.d_model, self.d_model * 2),
33
+ nn.ReLU(),
34
+ nn.Dropout(0.1),
35
+ nn.Linear(self.d_model * 2, self.d_model),
36
+ nn.LayerNorm(self.d_model)
37
+ )
38
+
39
+ # Final language modeling head
40
+ self.lm_head = nn.Linear(self.d_model, config.vocab_size, bias=False)
41
+
42
+ def forward(self, specialist_outputs: Dict[int, List[Dict]]) -> torch.Tensor:
43
+ """
44
+ Aggregate specialist outputs into final representation
45
+
46
+ Args:
47
+ specialist_outputs: Dict mapping chunk_id to list of specialist results
48
+
49
+ Returns:
50
+ aggregated_logits: [batch, seq_len, vocab_size]
51
+ """
52
+ batch_outputs = []
53
+
54
+ for chunk_id in sorted(specialist_outputs.keys()):
55
+ chunk_results = specialist_outputs[chunk_id]
56
+
57
+ if not chunk_results:
58
+ continue
59
+
60
+ # Stack specialist encodings
61
+ encodings = []
62
+ confidences = []
63
+
64
+ for result in chunk_results:
65
+ if result is not None:
66
+ encodings.append(result['encoding'])
67
+ confidences.append(result['confidence'])
68
+
69
+ if not encodings:
70
+ continue
71
+
72
+ # Stack tensors
73
+ specialist_encodings = torch.stack(encodings) # [num_specialists, d_model]
74
+ confidence_scores = torch.tensor(confidences, device=encodings[0].device)
75
+
76
+ # Project confidence scores
77
+ confidence_embeddings = self.confidence_proj(
78
+ confidence_scores.unsqueeze(-1)
79
+ ) # [num_specialists, d_model]
80
+
81
+ # Add confidence information to encodings
82
+ enhanced_encodings = specialist_encodings + confidence_embeddings
83
+
84
+ # Apply attention to combine specialist outputs
85
+ # Use self-attention to let specialists communicate
86
+ aggregated, _ = self.specialist_attention(
87
+ enhanced_encodings.unsqueeze(0), # [1, num_specialists, d_model]
88
+ enhanced_encodings.unsqueeze(0),
89
+ enhanced_encodings.unsqueeze(0)
90
+ )
91
+
92
+ # Pool the attended representations
93
+ chunk_representation = aggregated.mean(dim=1) # [1, d_model]
94
+
95
+ # Apply output layers
96
+ chunk_output = self.output_layers(chunk_representation)
97
+ batch_outputs.append(chunk_output)
98
+
99
+ if not batch_outputs:
100
+ # Return dummy output if no valid results
101
+ return torch.zeros(1, 1, self.config.vocab_size)
102
+
103
+ # Concatenate chunk outputs
104
+ final_representation = torch.cat(batch_outputs, dim=0) # [num_chunks, d_model]
105
+
106
+ # Generate logits
107
+ logits = self.lm_head(final_representation) # [num_chunks, vocab_size]
108
+
109
+ return logits.unsqueeze(0) # [1, num_chunks, vocab_size]
110
+
111
+ def generate_response(self, specialist_outputs: Dict[int, List[Dict]],
112
+ max_tokens: int = 100) -> str:
113
+ """Generate text response from specialist outputs"""
114
+ # Get aggregated logits
115
+ logits = self.forward(specialist_outputs)
116
+
117
+ # Simple greedy decoding (can be improved with better generation)
118
+ generated_ids = []
119
+ current_logits = logits[0, -1, :] # Use last chunk's logits
120
+
121
+ for _ in range(max_tokens):
122
+ # Get next token
123
+ next_token = torch.argmax(current_logits, dim=-1)
124
+ generated_ids.append(next_token.item())
125
+
126
+ # Break on EOS token (assuming token 0 is EOS)
127
+ if next_token.item() == 0:
128
+ break
129
+
130
+ # Convert to text (placeholder - should use proper tokenizer)
131
+ # This is simplified - integrate with actual tokenizer for real text
132
+ response = f"Generated response with {len(generated_ids)} tokens"
133
+
134
+ return response
routing/router.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # routing/router.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from typing import List, Dict, Tuple, Optional
8
+ from collections import defaultdict
9
+ import re
10
+ from utils.domain_configs import DomainConfigs
11
+
12
+ class TopicRouter(nn.Module):
13
+ def __init__(self, config, domain_configs: List[Dict]):
14
+ super().__init__()
15
+ self.config = config
16
+ self.domain_configs = domain_configs
17
+ self.num_specialists = len(domain_configs)
18
+
19
+ # Build keyword mappings
20
+ self.keyword_to_domains = defaultdict(list)
21
+ self.domain_keywords = {}
22
+
23
+ for domain in domain_configs:
24
+ domain_id = domain["id"]
25
+ keywords = domain["keywords"]
26
+ self.domain_keywords[domain_id] = keywords
27
+
28
+ for keyword in keywords:
29
+ self.keyword_to_domains[keyword.lower()].append(domain_id)
30
+
31
+ # Neural router for complex routing decisions
32
+ self.neural_router = nn.Sequential(
33
+ nn.Linear(config.d_model, 512),
34
+ nn.ReLU(),
35
+ nn.Dropout(0.1),
36
+ nn.Linear(512, 256),
37
+ nn.ReLU(),
38
+ nn.Linear(256, self.num_specialists)
39
+ )
40
+
41
+ # Text similarity threshold
42
+ self.similarity_threshold = 0.1
43
+
44
+ def keyword_based_routing(self, text: str) -> Dict[int, float]:
45
+ """Route based on keyword matching"""
46
+ text_lower = text.lower()
47
+ domain_scores = defaultdict(float)
48
+
49
+ # Count keyword matches for each domain
50
+ for domain_id, keywords in self.domain_keywords.items():
51
+ for keyword in keywords:
52
+ if keyword in text_lower:
53
+ # Weight by keyword frequency and length
54
+ count = text_lower.count(keyword)
55
+ weight = len(keyword) / 10.0 # Longer keywords get higher weight
56
+ domain_scores[domain_id] += count * weight
57
+
58
+ # Normalize scores
59
+ total_score = sum(domain_scores.values())
60
+ if total_score > 0:
61
+ domain_scores = {k: v/total_score for k, v in domain_scores.items()}
62
+
63
+ return dict(domain_scores)
64
+
65
+ def neural_routing(self, embeddings: torch.Tensor) -> torch.Tensor:
66
+ """Neural network based routing"""
67
+ # Use mean pooling of embeddings
68
+ pooled = embeddings.mean(dim=1) # [batch, d_model]
69
+ scores = self.neural_router(pooled) # [batch, num_specialists]
70
+ return torch.softmax(scores, dim=-1)
71
+
72
+ def route_text(self, text: str, embeddings: torch.Tensor = None,
73
+ max_specialists: int = 10) -> List[Tuple[int, float]]:
74
+ """
75
+ Route text to appropriate specialists
76
+
77
+ Args:
78
+ text: Input text to route
79
+ embeddings: Text embeddings [1, seq_len, d_model]
80
+ max_specialists: Maximum number of specialists to activate
81
+
82
+ Returns:
83
+ List of (specialist_id, confidence) tuples
84
+ """
85
+ # Keyword-based routing
86
+ keyword_scores = self.keyword_based_routing(text)
87
+
88
+ # Neural routing (if embeddings provided)
89
+ neural_scores = {}
90
+ if embeddings is not None:
91
+ neural_weights = self.neural_routing(embeddings)
92
+ neural_scores = {i: float(neural_weights[0, i])
93
+ for i in range(self.num_specialists)}
94
+
95
+ # Combine scores
96
+ final_scores = {}
97
+ for i in range(self.num_specialists):
98
+ keyword_score = keyword_scores.get(i, 0.0)
99
+ neural_score = neural_scores.get(i, 0.0)
100
+
101
+ # Weighted combination
102
+ final_scores[i] = 0.7 * keyword_score + 0.3 * neural_score
103
+
104
+ # Sort by score and take top specialists
105
+ sorted_specialists = sorted(final_scores.items(),
106
+ key=lambda x: x[1],
107
+ reverse=True)
108
+
109
+ # Filter by threshold and limit
110
+ active_specialists = []
111
+ for specialist_id, score in sorted_specialists:
112
+ if score > self.similarity_threshold and len(active_specialists) < max_specialists:
113
+ active_specialists.append((specialist_id, score))
114
+
115
+ # Ensure at least one specialist is active
116
+ if not active_specialists and sorted_specialists:
117
+ active_specialists = [sorted_specialists[0]]
118
+
119
+ return active_specialists
120
+
121
+ def chunk_and_route(self, text: str, chunk_size: int = 512) -> List[Dict]:
122
+ """
123
+ Split text into chunks and route each chunk
124
+
125
+ Returns:
126
+ List of dicts with 'text', 'specialists', 'chunk_id'
127
+ """
128
+ # Simple sentence-based chunking
129
+ sentences = re.split(r'[.!?]+', text)
130
+ chunks = []
131
+ current_chunk = ""
132
+ chunk_id = 0
133
+
134
+ for sentence in sentences:
135
+ if len(current_chunk) + len(sentence) > chunk_size and current_chunk:
136
+ # Route current chunk
137
+ specialists = self.route_text(current_chunk)
138
+ chunks.append({
139
+ 'text': current_chunk.strip(),
140
+ 'specialists': specialists,
141
+ 'chunk_id': chunk_id
142
+ })
143
+ current_chunk = sentence
144
+ chunk_id += 1
145
+ else:
146
+ current_chunk += sentence + ". "
147
+
148
+ # Handle last chunk
149
+ if current_chunk.strip():
150
+ specialists = self.route_text(current_chunk)
151
+ chunks.append({
152
+ 'text': current_chunk.strip(),
153
+ 'specialists': specialists,
154
+ 'chunk_id': chunk_id
155
+ })
156
+
157
+ return chunks
routing/tlm_manager.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # routing/tlm_manager.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import List, Dict, Tuple, Optional
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ import asyncio
9
+ from core.model import MambaModel
10
+ from core.config import MambaConfig
11
+ from utils.domain_configs import DomainConfigs
12
+
13
+ class SpecialistTLM:
14
+ """Individual Specialist Mamba TLM"""
15
+ def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict):
16
+ self.specialist_id = specialist_id
17
+ self.config = config
18
+ self.domain_info = domain_info
19
+ self.model = MambaModel(config)
20
+ self.device = config.device
21
+
22
+ # Move to device
23
+ self.model.to(self.device)
24
+
25
+ def encode(self, input_ids: torch.Tensor) -> torch.Tensor:
26
+ """Encode input and return hidden states"""
27
+ self.model.eval()
28
+ with torch.no_grad():
29
+ # Get embeddings
30
+ x = self.model.embedding(input_ids)
31
+
32
+ # Pass through Mamba layers
33
+ for layer in self.model.layers:
34
+ x = layer(x)
35
+
36
+ # Apply final norm
37
+ x = self.model.norm_f(x)
38
+
39
+ # Return pooled representation
40
+ return x.mean(dim=1) # [batch, d_model]
41
+
42
+ def get_memory_usage(self) -> int:
43
+ """Get model memory usage in bytes"""
44
+ return sum(p.numel() * p.element_size() for p in self.model.parameters())
45
+
46
+ class TLMManager:
47
+ """Manages 100 specialist Mamba TLMs"""
48
+
49
+ def __init__(self, config: MambaConfig):
50
+ self.config = config
51
+ self.device = config.device
52
+
53
+ # Create domain configurations
54
+ self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists)
55
+
56
+ # Initialize specialists
57
+ self.specialists = {}
58
+ self._initialize_specialists()
59
+
60
+ # Shared components
61
+ self.shared_embedding = None
62
+ if config.shared_embedding:
63
+ self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model)
64
+ self.shared_embedding.to(self.device)
65
+
66
+ # Thread pool for parallel processing
67
+ self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists))
68
+
69
+ def _initialize_specialists(self):
70
+ """Initialize all specialist TLMs"""
71
+ print("Initializing 100 specialist TLMs...")
72
+
73
+ for domain_config in self.domain_configs:
74
+ specialist_id = domain_config["id"]
75
+
76
+ # Create specialist-specific config
77
+ specialist_config = DomainConfigs.create_specialist_config(
78
+ self.config, specialist_id
79
+ )
80
+
81
+ # Create specialist TLM
82
+ specialist = SpecialistTLM(
83
+ specialist_id=specialist_id,
84
+ config=specialist_config,
85
+ domain_info=domain_config
86
+ )
87
+
88
+ self.specialists[specialist_id] = specialist
89
+
90
+ if specialist_id % 10 == 0:
91
+ print(f"Initialized {specialist_id + 1}/100 specialists")
92
+
93
+ print("All specialists initialized!")
94
+
95
+ # Apply weight sharing if enabled
96
+ if self.config.hierarchical_sharing:
97
+ self._apply_weight_sharing()
98
+
99
+ def _apply_weight_sharing(self):
100
+ """Apply hierarchical weight sharing between specialists"""
101
+ print("Applying hierarchical weight sharing...")
102
+
103
+ # Share embedding layers
104
+ if self.shared_embedding is not None:
105
+ for specialist in self.specialists.values():
106
+ specialist.model.embedding.token_embedding = self.shared_embedding
107
+
108
+ # Group specialists by domain similarity and share lower layers
109
+ domain_groups = self._group_domains_by_similarity()
110
+
111
+ for group in domain_groups:
112
+ if len(group) > 1:
113
+ # Use first specialist's weights as shared weights for the group
114
+ reference_specialist = self.specialists[group[0]]
115
+ shared_layers = reference_specialist.model.layers[:self.config.n_layers//2]
116
+
117
+ for specialist_id in group[1:]:
118
+ specialist = self.specialists[specialist_id]
119
+ for i, layer in enumerate(shared_layers):
120
+ specialist.model.layers[i] = layer
121
+
122
+ def _group_domains_by_similarity(self) -> List[List[int]]:
123
+ """Group domains by similarity for weight sharing"""
124
+ # Simple grouping based on domain categories
125
+ groups = {
126
+ 'stem': [],
127
+ 'programming': [],
128
+ 'language': [],
129
+ 'business': [],
130
+ 'other': []
131
+ }
132
+
133
+ for domain_config in self.domain_configs:
134
+ domain_name = domain_config["name"].lower()
135
+ specialist_id = domain_config["id"]
136
+
137
+ if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
138
+ groups['stem'].append(specialist_id)
139
+ elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
140
+ groups['programming'].append(specialist_id)
141
+ elif any(x in domain_name for x in ['writing', 'translation']):
142
+ groups['language'].append(specialist_id)
143
+ elif any(x in domain_name for x in ['business', 'legal']):
144
+ groups['business'].append(specialist_id)
145
+ else:
146
+ groups['other'].append(specialist_id)
147
+
148
+ return [group for group in groups.values() if len(group) > 1]
149
+
150
+ def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]:
151
+ """
152
+ Encode chunks in parallel using appropriate specialists
153
+
154
+ Args:
155
+ routing_results: List of routing results from router
156
+
157
+ Returns:
158
+ List of encoded results with specialist outputs
159
+ """
160
+ futures = []
161
+
162
+ for chunk_info in routing_results:
163
+ chunk_text = chunk_info['text']
164
+ specialists = chunk_info['specialists']
165
+ chunk_id = chunk_info['chunk_id']
166
+
167
+ # Create encoding task for each relevant specialist
168
+ for specialist_id, confidence in specialists:
169
+ if specialist_id in self.specialists:
170
+ future = self.executor.submit(
171
+ self._encode_chunk,
172
+ chunk_text,
173
+ specialist_id,
174
+ confidence,
175
+ chunk_id
176
+ )
177
+ futures.append(future)
178
+
179
+ # Collect results
180
+ encoded_results = []
181
+ for future in as_completed(futures):
182
+ try:
183
+ result = future.result()
184
+ encoded_results.append(result)
185
+ except Exception as e:
186
+ print(f"Error in specialist encoding: {e}")
187
+
188
+ # Group results by chunk_id
189
+ grouped_results = {}
190
+ for result in encoded_results:
191
+ chunk_id = result['chunk_id']
192
+ if chunk_id not in grouped_results:
193
+ grouped_results[chunk_id] = []
194
+ grouped_results[chunk_id].append(result)
195
+
196
+ return grouped_results
197
+
198
+ def _encode_chunk(self, text: str, specialist_id: int, confidence: float,
199
+ chunk_id: int) -> Dict:
200
+ """Encode a single chunk with a specific specialist"""
201
+ try:
202
+ specialist = self.specialists[specialist_id]
203
+
204
+ # Tokenize text (simplified - should use proper tokenizer)
205
+ # This is a placeholder - integrate with actual tokenizer
206
+ input_ids = torch.randint(0, 1000, (1, 100)).to(self.device)
207
+
208
+ # Encode with specialist
209
+ encoding = specialist.encode(input_ids)
210
+
211
+ return {
212
+ 'chunk_id': chunk_id,
213
+ 'specialist_id': specialist_id,
214
+ 'confidence': confidence,
215
+ 'encoding': encoding,
216
+ 'domain': specialist.domain_info['name']
217
+ }
218
+
219
+ except Exception as e:
220
+ print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}")
221
+ return None
222
+
223
+ def get_active_specialists(self) -> List[int]:
224
+ """Get list of currently active specialist IDs"""
225
+ return list(self.specialists.keys())
226
+
227
+ def get_specialist_info(self, specialist_id: int) -> Dict:
228
+ """Get information about a specific specialist"""
229
+ if specialist_id in self.specialists:
230
+ specialist = self.specialists[specialist_id]
231
+ return {
232
+ 'id': specialist_id,
233
+ 'domain': specialist.domain_info,
234
+ 'params': specialist.model.get_num_params(),
235
+ 'memory': specialist.get_memory_usage()
236
+ }
237
+ return None
238
+
239
+ def get_total_parameters(self) -> int:
240
+ """Get total parameters across all specialists"""
241
+ total = 0
242
+ for specialist in self.specialists.values():
243
+ total += specialist.model.get_num_params()
244
+ return total