mamba-encoder-swarm_app / system /weight_manager.py
Debito's picture
Upload 4 files
fcf0a07 verified
# =============================================================================
# system/weight_manager.py
# =============================================================================
import torch
import torch.nn as nn
from typing import Dict, List, Optional
import os
from pathlib import Path
class WeightManager:
"""Manages hierarchical weight sharing and loading/saving"""
def __init__(self, config, tlm_manager):
self.config = config
self.tlm_manager = tlm_manager
# Track shared weights
self.shared_embeddings = None
self.shared_foundation_layers = {}
def setup_hierarchical_sharing(self):
"""Setup hierarchical weight sharing between specialists"""
print("Setting up hierarchical weight sharing...")
# Create shared embedding if enabled
if self.config.shared_embedding:
self.shared_embeddings = nn.Embedding(
self.config.vocab_size,
self.config.d_model
).to(self.config.device)
# Share embedding across all specialists
for specialist in self.tlm_manager.specialists.values():
specialist.model.embedding.token_embedding = self.shared_embeddings
# Setup foundation layer sharing
self._setup_foundation_sharing()
print("Hierarchical weight sharing setup complete!")
def _setup_foundation_sharing(self):
"""Setup sharing of foundation layers"""
num_shared_layers = self.config.n_layers // 2
# Group specialists by domain similarity
domain_groups = self._group_specialists_by_domain()
for group_name, specialist_ids in domain_groups.items():
if len(specialist_ids) > 1:
# Create shared foundation layers for this group
reference_specialist = self.tlm_manager.specialists[specialist_ids[0]]
shared_layers = reference_specialist.model.layers[:num_shared_layers]
# Share with other specialists in the group
for specialist_id in specialist_ids[1:]:
specialist = self.tlm_manager.specialists[specialist_id]
for i in range(num_shared_layers):
specialist.model.layers[i] = shared_layers[i]
self.shared_foundation_layers[group_name] = shared_layers
def _group_specialists_by_domain(self) -> Dict[str, List[int]]:
"""Group specialists by domain for weight sharing"""
domain_groups = {
'stem': [],
'programming': [],
'language': [],
'business': [],
'general': []
}
for specialist_id, specialist in self.tlm_manager.specialists.items():
domain_name = specialist.domain_info['name'].lower()
if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']):
domain_groups['stem'].append(specialist_id)
elif any(x in domain_name for x in ['python', 'javascript', 'systems']):
domain_groups['programming'].append(specialist_id)
elif any(x in domain_name for x in ['writing', 'translation']):
domain_groups['language'].append(specialist_id)
elif any(x in domain_name for x in ['business', 'legal']):
domain_groups['business'].append(specialist_id)
else:
domain_groups['general'].append(specialist_id)
return {k: v for k, v in domain_groups.items() if len(v) > 1}
def save_weights(self, save_path: str):
"""Save all weights with hierarchical structure"""
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
# Save shared embeddings
if self.shared_embeddings is not None:
torch.save(
self.shared_embeddings.state_dict(),
save_path / "shared_embeddings.pt"
)
# Save shared foundation layers
for group_name, layers in self.shared_foundation_layers.items():
group_state = {}
for i, layer in enumerate(layers):
group_state[f"layer_{i}"] = layer.state_dict()
torch.save(group_state, save_path / f"shared_foundation_{group_name}.pt")
# Save specialist-specific weights
specialists_path = save_path / "specialists"
specialists_path.mkdir(exist_ok=True)
for specialist_id, specialist in self.tlm_manager.specialists.items():
torch.save(
specialist.model.state_dict(),
specialists_path / f"specialist_{specialist_id}.pt"
)
print(f"Weights saved to {save_path}")
def load_weights(self, load_path: str):
"""Load weights with hierarchical structure"""
load_path = Path(load_path)
if not load_path.exists():
raise FileNotFoundError(f"Weight path {load_path} not found")
# Load shared embeddings
embeddings_path = load_path / "shared_embeddings.pt"
if embeddings_path.exists() and self.shared_embeddings is not None:
self.shared_embeddings.load_state_dict(torch.load(embeddings_path))
# Load shared foundation layers
for group_name in self.shared_foundation_layers.keys():
foundation_path = load_path / f"shared_foundation_{group_name}.pt"
if foundation_path.exists():
group_state = torch.load(foundation_path)
for i, layer in enumerate(self.shared_foundation_layers[group_name]):
if f"layer_{i}" in group_state:
layer.load_state_dict(group_state[f"layer_{i}"])
# Load specialist weights
specialists_path = load_path / "specialists"
if specialists_path.exists():
for specialist_id, specialist in self.tlm_manager.specialists.items():
specialist_path = specialists_path / f"specialist_{specialist_id}.pt"
if specialist_path.exists():
specialist.model.load_state_dict(torch.load(specialist_path))
print(f"Weights loaded from {load_path}")
def get_memory_usage(self) -> Dict[str, int]:
"""Get memory usage breakdown"""
usage = {}
# Shared embedding memory
if self.shared_embeddings is not None:
usage['shared_embeddings'] = sum(
p.numel() * p.element_size()
for p in self.shared_embeddings.parameters()
)
# Shared foundation layer memory
total_foundation = 0
for layers in self.shared_foundation_layers.values():
for layer in layers:
total_foundation += sum(
p.numel() * p.element_size()
for p in layer.parameters()
)