File size: 9,681 Bytes
2ee6fe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# =============================================================================
# routing/tlm_manager.py
# =============================================================================
import torch
import torch.nn as nn
from typing import List, Dict, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from core.model import MambaModel
from core.config import MambaConfig
from utils.domain_configs import DomainConfigs

class SpecialistTLM:
    """Individual Specialist Mamba TLM"""
    def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict):
        self.specialist_id = specialist_id
        self.config = config
        self.domain_info = domain_info
        self.model = MambaModel(config)
        self.device = config.device
        
        # Move to device
        self.model.to(self.device)
        
    def encode(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Encode input and return hidden states"""
        self.model.eval()
        with torch.no_grad():
            # Get embeddings
            x = self.model.embedding(input_ids)
            
            # Pass through Mamba layers
            for layer in self.model.layers:
                x = layer(x)
            
            # Apply final norm
            x = self.model.norm_f(x)
            
            # Return pooled representation
            return x.mean(dim=1)  # [batch, d_model]
    
    def get_memory_usage(self) -> int:
        """Get model memory usage in bytes"""
        return sum(p.numel() * p.element_size() for p in self.model.parameters())

class TLMManager:
    """Manages 100 specialist Mamba TLMs"""
    
    def __init__(self, config: MambaConfig):
        self.config = config
        self.device = config.device
        
        # Create domain configurations
        self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists)
        
        # Initialize specialists
        self.specialists = {}
        self._initialize_specialists()
        
        # Shared components
        self.shared_embedding = None
        if config.shared_embedding:
            self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model)
            self.shared_embedding.to(self.device)
        
        # Thread pool for parallel processing
        self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists))
        
    def _initialize_specialists(self):
        """Initialize all specialist TLMs"""
        print("Initializing 100 specialist TLMs...")
        
        for domain_config in self.domain_configs:
            specialist_id = domain_config["id"]
            
            # Create specialist-specific config
            specialist_config = DomainConfigs.create_specialist_config(
                self.config, specialist_id
            )
            
            # Create specialist TLM
            specialist = SpecialistTLM(
                specialist_id=specialist_id,
                config=specialist_config,
                domain_info=domain_config
            )
            
            self.specialists[specialist_id] = specialist
            
            if specialist_id % 10 == 0:
                print(f"Initialized {specialist_id + 1}/100 specialists")
        
        print("All specialists initialized!")
        
        # Apply weight sharing if enabled
        if self.config.hierarchical_sharing:
            self._apply_weight_sharing()
    
    def _apply_weight_sharing(self):
        """Apply hierarchical weight sharing between specialists"""
        print("Applying hierarchical weight sharing...")
        
        # Share embedding layers
        if self.shared_embedding is not None:
            for specialist in self.specialists.values():
                specialist.model.embedding.token_embedding = self.shared_embedding
        
        # Group specialists by domain similarity and share lower layers
        domain_groups = self._group_domains_by_similarity()
        
        for group in domain_groups:
            if len(group) > 1:
                # Use first specialist's weights as shared weights for the group
                reference_specialist = self.specialists[group[0]]
                shared_layers = reference_specialist.model.layers[:self.config.n_layers//2]
                
                for specialist_id in group[1:]:
                    specialist = self.specialists[specialist_id]
                    for i, layer in enumerate(shared_layers):
                        specialist.model.layers[i] = layer
    
    def _group_domains_by_similarity(self) -> List[List[int]]:
        """Group domains by similarity for weight sharing"""
        # Simple grouping based on domain categories
        groups = {
            'stem': [],
            'programming': [],
            'language': [],
            'business': [],
            'other': []
        }
        
        for domain_config in self.domain_configs:
            domain_name = domain_config["name"].lower()
            specialist_id = domain_config["id"]
            
            if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
                groups['stem'].append(specialist_id)
            elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
                groups['programming'].append(specialist_id)
            elif any(x in domain_name for x in ['writing', 'translation']):
                groups['language'].append(specialist_id)
            elif any(x in domain_name for x in ['business', 'legal']):
                groups['business'].append(specialist_id)
            else:
                groups['other'].append(specialist_id)
        
        return [group for group in groups.values() if len(group) > 1]
    
    def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]:
        """

        Encode chunks in parallel using appropriate specialists

        

        Args:

            routing_results: List of routing results from router

            

        Returns:

            List of encoded results with specialist outputs

        """
        futures = []
        
        for chunk_info in routing_results:
            chunk_text = chunk_info['text']
            specialists = chunk_info['specialists']
            chunk_id = chunk_info['chunk_id']
            
            # Create encoding task for each relevant specialist
            for specialist_id, confidence in specialists:
                if specialist_id in self.specialists:
                    future = self.executor.submit(
                        self._encode_chunk,
                        chunk_text,
                        specialist_id,
                        confidence,
                        chunk_id
                    )
                    futures.append(future)
        
        # Collect results
        encoded_results = []
        for future in as_completed(futures):
            try:
                result = future.result()
                encoded_results.append(result)
            except Exception as e:
                print(f"Error in specialist encoding: {e}")
        
        # Group results by chunk_id
        grouped_results = {}
        for result in encoded_results:
            chunk_id = result['chunk_id']
            if chunk_id not in grouped_results:
                grouped_results[chunk_id] = []
            grouped_results[chunk_id].append(result)
        
        return grouped_results
    
    def _encode_chunk(self, text: str, specialist_id: int, confidence: float, 

                     chunk_id: int) -> Dict:
        """Encode a single chunk with a specific specialist"""
        try:
            specialist = self.specialists[specialist_id]
            
            # Tokenize text (simplified - should use proper tokenizer)
            # This is a placeholder - integrate with actual tokenizer
            input_ids = torch.randint(0, 1000, (1, 100)).to(self.device)
            
            # Encode with specialist
            encoding = specialist.encode(input_ids)
            
            return {
                'chunk_id': chunk_id,
                'specialist_id': specialist_id,
                'confidence': confidence,
                'encoding': encoding,
                'domain': specialist.domain_info['name']
            }
            
        except Exception as e:
            print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}")
            return None
    
    def get_active_specialists(self) -> List[int]:
        """Get list of currently active specialist IDs"""
        return list(self.specialists.keys())
    
    def get_specialist_info(self, specialist_id: int) -> Dict:
        """Get information about a specific specialist"""
        if specialist_id in self.specialists:
            specialist = self.specialists[specialist_id]
            return {
                'id': specialist_id,
                'domain': specialist.domain_info,
                'params': specialist.model.get_num_params(),
                'memory': specialist.get_memory_usage()
            }
        return None
    
    def get_total_parameters(self) -> int:
        """Get total parameters across all specialists"""
        total = 0
        for specialist in self.specialists.values():
            total += specialist.model.get_num_params()
        return total