File size: 7,297 Bytes
fcf0a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# =============================================================================
# system/weight_manager.py
# =============================================================================
import torch
import torch.nn as nn
from typing import Dict, List, Optional
import os
from pathlib import Path

class WeightManager:
    """Manages hierarchical weight sharing and loading/saving"""
    
    def __init__(self, config, tlm_manager):
        self.config = config
        self.tlm_manager = tlm_manager
        
        # Track shared weights
        self.shared_embeddings = None
        self.shared_foundation_layers = {}
        
    def setup_hierarchical_sharing(self):
        """Setup hierarchical weight sharing between specialists"""
        print("Setting up hierarchical weight sharing...")
        
        # Create shared embedding if enabled
        if self.config.shared_embedding:
            self.shared_embeddings = nn.Embedding(
                self.config.vocab_size, 
                self.config.d_model
            ).to(self.config.device)
            
            # Share embedding across all specialists
            for specialist in self.tlm_manager.specialists.values():
                specialist.model.embedding.token_embedding = self.shared_embeddings
        
        # Setup foundation layer sharing
        self._setup_foundation_sharing()
        
        print("Hierarchical weight sharing setup complete!")
    
    def _setup_foundation_sharing(self):
        """Setup sharing of foundation layers"""
        num_shared_layers = self.config.n_layers // 2
        
        # Group specialists by domain similarity
        domain_groups = self._group_specialists_by_domain()
        
        for group_name, specialist_ids in domain_groups.items():
            if len(specialist_ids) > 1:
                # Create shared foundation layers for this group
                reference_specialist = self.tlm_manager.specialists[specialist_ids[0]]
                shared_layers = reference_specialist.model.layers[:num_shared_layers]
                
                # Share with other specialists in the group
                for specialist_id in specialist_ids[1:]:
                    specialist = self.tlm_manager.specialists[specialist_id]
                    for i in range(num_shared_layers):
                        specialist.model.layers[i] = shared_layers[i]
                
                self.shared_foundation_layers[group_name] = shared_layers
    
    def _group_specialists_by_domain(self) -> Dict[str, List[int]]:
        """Group specialists by domain for weight sharing"""
        domain_groups = {
            'stem': [],
            'programming': [],
            'language': [],
            'business': [],
            'general': []
        }
        
        for specialist_id, specialist in self.tlm_manager.specialists.items():
            domain_name = specialist.domain_info['name'].lower()
            
            if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
                domain_groups['stem'].append(specialist_id)
            elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
                domain_groups['programming'].append(specialist_id)
            elif any(x in domain_name for x in ['writing', 'translation']):
                domain_groups['language'].append(specialist_id)
            elif any(x in domain_name for x in ['business', 'legal']):
                domain_groups['business'].append(specialist_id)
            else:
                domain_groups['general'].append(specialist_id)
        
        return {k: v for k, v in domain_groups.items() if len(v) > 1}
    
    def save_weights(self, save_path: str):
        """Save all weights with hierarchical structure"""
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save shared embeddings
        if self.shared_embeddings is not None:
            torch.save(
                self.shared_embeddings.state_dict(),
                save_path / "shared_embeddings.pt"
            )
        
        # Save shared foundation layers
        for group_name, layers in self.shared_foundation_layers.items():
            group_state = {}
            for i, layer in enumerate(layers):
                group_state[f"layer_{i}"] = layer.state_dict()
            torch.save(group_state, save_path / f"shared_foundation_{group_name}.pt")
        
        # Save specialist-specific weights
        specialists_path = save_path / "specialists"
        specialists_path.mkdir(exist_ok=True)
        
        for specialist_id, specialist in self.tlm_manager.specialists.items():
            torch.save(
                specialist.model.state_dict(),
                specialists_path / f"specialist_{specialist_id}.pt"
            )
        
        print(f"Weights saved to {save_path}")
    
    def load_weights(self, load_path: str):
        """Load weights with hierarchical structure"""
        load_path = Path(load_path)
        
        if not load_path.exists():
            raise FileNotFoundError(f"Weight path {load_path} not found")
        
        # Load shared embeddings
        embeddings_path = load_path / "shared_embeddings.pt"
        if embeddings_path.exists() and self.shared_embeddings is not None:
            self.shared_embeddings.load_state_dict(torch.load(embeddings_path))
        
        # Load shared foundation layers
        for group_name in self.shared_foundation_layers.keys():
            foundation_path = load_path / f"shared_foundation_{group_name}.pt"
            if foundation_path.exists():
                group_state = torch.load(foundation_path)
                for i, layer in enumerate(self.shared_foundation_layers[group_name]):
                    if f"layer_{i}" in group_state:
                        layer.load_state_dict(group_state[f"layer_{i}"])
        
        # Load specialist weights
        specialists_path = load_path / "specialists"
        if specialists_path.exists():
            for specialist_id, specialist in self.tlm_manager.specialists.items():
                specialist_path = specialists_path / f"specialist_{specialist_id}.pt"
                if specialist_path.exists():
                    specialist.model.load_state_dict(torch.load(specialist_path))
        
        print(f"Weights loaded from {load_path}")
    
    def get_memory_usage(self) -> Dict[str, int]:
        """Get memory usage breakdown"""
        usage = {}
        
        # Shared embedding memory
        if self.shared_embeddings is not None:
            usage['shared_embeddings'] = sum(
                p.numel() * p.element_size() 
                for p in self.shared_embeddings.parameters()
            )
        
        # Shared foundation layer memory
        total_foundation = 0
        for layers in self.shared_foundation_layers.values():
            for layer in layers:
                total_foundation += sum(
                    p.numel() * p.element_size()
                    for p in layer.parameters()
                )