# ============================================================================= # system/weight_manager.py # ============================================================================= import torch import torch.nn as nn from typing import Dict, List, Optional import os from pathlib import Path class WeightManager: """Manages hierarchical weight sharing and loading/saving""" def __init__(self, config, tlm_manager): self.config = config self.tlm_manager = tlm_manager # Track shared weights self.shared_embeddings = None self.shared_foundation_layers = {} def setup_hierarchical_sharing(self): """Setup hierarchical weight sharing between specialists""" print("Setting up hierarchical weight sharing...") # Create shared embedding if enabled if self.config.shared_embedding: self.shared_embeddings = nn.Embedding( self.config.vocab_size, self.config.d_model ).to(self.config.device) # Share embedding across all specialists for specialist in self.tlm_manager.specialists.values(): specialist.model.embedding.token_embedding = self.shared_embeddings # Setup foundation layer sharing self._setup_foundation_sharing() print("Hierarchical weight sharing setup complete!") def _setup_foundation_sharing(self): """Setup sharing of foundation layers""" num_shared_layers = self.config.n_layers // 2 # Group specialists by domain similarity domain_groups = self._group_specialists_by_domain() for group_name, specialist_ids in domain_groups.items(): if len(specialist_ids) > 1: # Create shared foundation layers for this group reference_specialist = self.tlm_manager.specialists[specialist_ids[0]] shared_layers = reference_specialist.model.layers[:num_shared_layers] # Share with other specialists in the group for specialist_id in specialist_ids[1:]: specialist = self.tlm_manager.specialists[specialist_id] for i in range(num_shared_layers): specialist.model.layers[i] = shared_layers[i] self.shared_foundation_layers[group_name] = shared_layers def _group_specialists_by_domain(self) -> Dict[str, List[int]]: """Group specialists by domain for weight sharing""" domain_groups = { 'stem': [], 'programming': [], 'language': [], 'business': [], 'general': [] } for specialist_id, specialist in self.tlm_manager.specialists.items(): domain_name = specialist.domain_info['name'].lower() if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']): domain_groups['stem'].append(specialist_id) elif any(x in domain_name for x in ['python', 'javascript', 'systems']): domain_groups['programming'].append(specialist_id) elif any(x in domain_name for x in ['writing', 'translation']): domain_groups['language'].append(specialist_id) elif any(x in domain_name for x in ['business', 'legal']): domain_groups['business'].append(specialist_id) else: domain_groups['general'].append(specialist_id) return {k: v for k, v in domain_groups.items() if len(v) > 1} def save_weights(self, save_path: str): """Save all weights with hierarchical structure""" save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) # Save shared embeddings if self.shared_embeddings is not None: torch.save( self.shared_embeddings.state_dict(), save_path / "shared_embeddings.pt" ) # Save shared foundation layers for group_name, layers in self.shared_foundation_layers.items(): group_state = {} for i, layer in enumerate(layers): group_state[f"layer_{i}"] = layer.state_dict() torch.save(group_state, save_path / f"shared_foundation_{group_name}.pt") # Save specialist-specific weights specialists_path = save_path / "specialists" specialists_path.mkdir(exist_ok=True) for specialist_id, specialist in self.tlm_manager.specialists.items(): torch.save( specialist.model.state_dict(), specialists_path / f"specialist_{specialist_id}.pt" ) print(f"Weights saved to {save_path}") def load_weights(self, load_path: str): """Load weights with hierarchical structure""" load_path = Path(load_path) if not load_path.exists(): raise FileNotFoundError(f"Weight path {load_path} not found") # Load shared embeddings embeddings_path = load_path / "shared_embeddings.pt" if embeddings_path.exists() and self.shared_embeddings is not None: self.shared_embeddings.load_state_dict(torch.load(embeddings_path)) # Load shared foundation layers for group_name in self.shared_foundation_layers.keys(): foundation_path = load_path / f"shared_foundation_{group_name}.pt" if foundation_path.exists(): group_state = torch.load(foundation_path) for i, layer in enumerate(self.shared_foundation_layers[group_name]): if f"layer_{i}" in group_state: layer.load_state_dict(group_state[f"layer_{i}"]) # Load specialist weights specialists_path = load_path / "specialists" if specialists_path.exists(): for specialist_id, specialist in self.tlm_manager.specialists.items(): specialist_path = specialists_path / f"specialist_{specialist_id}.pt" if specialist_path.exists(): specialist.model.load_state_dict(torch.load(specialist_path)) print(f"Weights loaded from {load_path}") def get_memory_usage(self) -> Dict[str, int]: """Get memory usage breakdown""" usage = {} # Shared embedding memory if self.shared_embeddings is not None: usage['shared_embeddings'] = sum( p.numel() * p.element_size() for p in self.shared_embeddings.parameters() ) # Shared foundation layer memory total_foundation = 0 for layers in self.shared_foundation_layers.values(): for layer in layers: total_foundation += sum( p.numel() * p.element_size() for p in layer.parameters() )