import torch from safetensors.torch import save_file, load_file from typing import Dict, Optional, Tuple, List, Union, Any import logging import time import json import yaml import os from pathlib import Path import sys import shutil from dataclasses import dataclass, asdict import numpy as np from tqdm import tqdm import multiprocessing as mp from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib from torch.nn.init import xavier_uniform_, kaiming_uniform_ # Configure logging with rotation and detailed output logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(processName)s:%(threadName)s] - %(message)s", handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler("transformer_shard_builder.log", mode="a") ] ) @dataclass class ModelConfig: """Configuration for transformer model parameters and sharding.""" num_layers: int = 48 hidden_size: int = 8192 heads: int = 64 seq_length: int = 4096 vocab_size: int = 50000 dtype: str = "float16" ffn_multiplier: int = 4 total_shards: int = 278 base_path: str = "model_shards" device: str = "cuda" if torch.cuda.is_available() else "cpu" seed: Optional[int] = 42 init_method: str = "xavier" # Options: "xavier", "kaiming", "normal" shard_compression: bool = True validation_threshold: float = 1e-5 class TransformerShardBuilder: """Advanced class to build, shard, validate, and save a large transformer model.""" def __init__(self, config: Optional[ModelConfig] = None): """Initialize with configuration and setup environment.""" self.config = config or ModelConfig() self.dtype = getattr(torch, self.config.dtype) self.device = torch.device(self.config.device) self.base_path = Path(self.config.base_path) self.weights: Dict[int, Dict[str, torch.Tensor]] = {} # Shard-indexed weights self.metadata: Dict[str, Any] = {} self._validate_config() self._setup_environment() self._calculate_sharding() def _validate_config(self) -> None: """Validate configuration parameters.""" checks = [ (self.config.num_layers > 0, "Number of layers must be positive"), (self.config.hidden_size % self.config.heads == 0, "Hidden size must be divisible by heads"), (self.config.seq_length > 0, "Sequence length must be positive"), (self.config.vocab_size > 0, "Vocab size must be positive"), (self.config.total_shards > 0, "Total shards must be positive"), (self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1"), (self.config.init_method in ["xavier", "kaiming", "normal"], "Invalid initialization method") ] for condition, message in checks: if not condition: raise ValueError(message) if self.config.num_layers < self.config.total_shards: raise ValueError("Number of layers must be >= total shards") def _setup_environment(self) -> None: """Setup random seed, device, and directories.""" if self.config.seed is not None: torch.manual_seed(self.config.seed) np.random.seed(self.config.seed) self.base_path.mkdir(parents=True, exist_ok=True) logging.info(f"Environment setup: device={self.device}, base_path={self.base_path}") if self.device.type == "cuda": logging.info(f"CUDA Memory: {torch.cuda.memory_available() / 1024**3:.2f} GB free") def _calculate_sharding(self) -> None: """Calculate layer distribution across shards.""" self.layers_per_shard = self.config.num_layers // self.config.total_shards self.remaining_layers = self.config.num_layers % self.config.total_shards logging.info(f"Sharding: {self.layers_per_shard} layers/shard, {self.remaining_layers} extra") def _initialize_tensor(self, *shape) -> torch.Tensor: """Initialize tensor based on configured method.""" tensor = torch.empty(*shape, dtype=self.dtype, device=self.device) if self.config.init_method == "xavier": if len(shape) > 1: xavier_uniform_(tensor) else: tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) elif self.config.init_method == "kaiming": if len(shape) > 1: kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="relu") else: tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) else: # normal tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) return tensor def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: """Create attention weights for a layer.""" weights = {} prefix = f"layer_{layer_idx}.attention" head_dim = self.config.hidden_size // self.config.heads for name in ["query_weight", "key_weight", "value_weight", "output_weight"]: weights[f"{prefix}.{name}"] = self._initialize_tensor(self.config.hidden_size, self.config.hidden_size) weights[f"{prefix}.{name}_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) weights[f"{prefix}.head_scale"] = torch.ones(self.config.heads, head_dim, dtype=self.dtype, device=self.device) return weights def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: """Create FFN weights for a layer.""" weights = {} prefix = f"layer_{layer_idx}.ffn" intermediate_size = self.config.hidden_size * self.config.ffn_multiplier weights[f"{prefix}.intermediate_weight"] = self._initialize_tensor(self.config.hidden_size, intermediate_size) weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device) weights[f"{prefix}.output_weight"] = self._initialize_tensor(intermediate_size, self.config.hidden_size) weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) return weights def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: """Create normalization weights.""" prefix = f"layer_{layer_idx}" return { f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device), f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) } def _create_embedding_output(self) -> Dict[str, torch.Tensor]: """Create embedding and output layers for first shard.""" weights = { "embedding.word_embeddings": self._initialize_tensor(self.config.vocab_size, self.config.hidden_size), "embedding.position_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size), "embedding.token_type_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size), "output_layer.weight": self._initialize_tensor(self.config.hidden_size, self.config.vocab_size), "output_layer.bias": torch.zeros(self.config.vocab_size, dtype=self.dtype, device=self.device) } return weights def build_shard(self, shard_idx: int) -> Dict[str, torch.Tensor]: """Build weights for a specific shard.""" weights = {} start_time = time.time() start_layer = (shard_idx - 1) * self.layers_per_shard end_layer = start_layer + self.layers_per_shard if shard_idx == self.config.total_shards: end_layer += self.remaining_layers for i in tqdm(range(start_layer, end_layer), desc=f"Shard {shard_idx} layers"): weights.update(self._create_attention_block(i)) weights.update(self._create_ffn_block(i)) weights.update(self._create_norm_block(i)) if shard_idx == 1: weights.update(self._create_embedding_output()) elapsed = time.time() - start_time self.metadata[f"shard_{shard_idx}"] = {"build_time": elapsed, "num_layers": end_layer - start_layer} logging.info(f"Shard {shard_idx} built with {len(weights)} tensors in {elapsed:.2f}s") return weights def save_shard(self, shard_idx: int, weights: Dict[str, torch.Tensor]) -> None: """Save a single shard with metadata.""" shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" start_time = time.time() try: shard_metadata = { "shard_idx": shard_idx, "total_shards": self.config.total_shards, "config": asdict(self.config), **self.metadata.get(f"shard_{shard_idx}", {}) } save_file(weights, str(shard_path), metadata=shard_metadata) elapsed = time.time() - start_time logging.info(f"Shard {shard_idx} saved to {shard_path} in {elapsed:.2f}s") except Exception as e: logging.error(f"Shard {shard_idx} save failed: {str(e)}") raise RuntimeError(f"Failed to save shard {shard_idx}: {str(e)}") from e def build_and_save_all_shards(self, parallel: bool = True) -> None: """Build and save all shards, optionally in parallel.""" start_time = time.time() if parallel and mp.cpu_count() > 1: with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), self.config.total_shards)) as executor: futures = { executor.submit(self.build_shard, i): i for i in range(1, self.config.total_shards + 1) } for future in as_completed(futures): shard_idx = futures[future] try: weights = future.result() self.save_shard(shard_idx, weights) except Exception as e: logging.error(f"Parallel shard {shard_idx} failed: {str(e)}") else: for shard_idx in tqdm(range(1, self.config.total_shards + 1), desc="Building shards"): weights = self.build_shard(shard_idx) self.save_shard(shard_idx, weights) total_time = time.time() - start_time self.metadata["total_build_time"] = total_time logging.info(f"All {self.config.total_shards} shards completed in {total_time:.2f}s") def validate_shard(self, shard_idx: int) -> bool: """Validate a shard's weights after loading.""" shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" try: weights = load_file(str(shard_path), device="cpu") # Load to CPU for validation all_valid = True for name, tensor in weights.items(): if torch.isnan(tensor).any() or torch.isinf(tensor).any(): logging.warning(f"Invalid values in {name} (shard {shard_idx})") all_valid = False elif torch.max(torch.abs(tensor)) > self.config.validation_threshold: logging.warning(f"Large values in {name} (shard {shard_idx})") return all_valid except Exception as e: logging.error(f"Validation failed for shard {shard_idx}: {str(e)}") return False def compute_checksum(self, shard_idx: int) -> str: """Compute SHA256 checksum of a shard file.""" shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" sha256 = hashlib.sha256() with open(shard_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): sha256.update(chunk) return sha256.hexdigest() def export_metadata(self, output_path: str | Path = "model_metadata.json") -> None: """Export metadata to JSON file.""" output_path = Path(output_path) with open(output_path, "w") as f: json.dump(self.metadata, f, indent=2) logging.info(f"Metadata exported to {output_path}") @classmethod def from_yaml(cls, yaml_path: str | Path) -> "TransformerShardBuilder": """Initialize from YAML config file.""" with open(yaml_path, "r") as f: config_dict = yaml.safe_load(f) return cls(ModelConfig(**config_dict)) def estimate_model_size(config: ModelConfig) -> Tuple[int, float]: """Estimate total model size in parameters and GB.""" builder = TransformerShardBuilder(config) params = 0 bytes_size = 0 for shard in range(1, config.total_shards + 1): weights = builder.build_shard(shard) params += sum(t.numel() for t in weights.values()) bytes_size += sum(t.element_size() * t.numel() for t in weights.values()) return params, bytes_size / 1024**3 def main(): """Main execution flow with comprehensive functionality.""" try: # Custom configuration config = ModelConfig( num_layers=48, hidden_size=8192, heads=64, seq_length=4096, vocab_size=50000, total_shards=278, base_path="model_shards_large" ) builder = TransformerShardBuilder(config) # Size estimation num_params, size_gb = estimate_model_size(config) logging.info(f"Estimated size: {num_params:,} parameters, {size_gb:.2f} GB") # Build and save all shards builder.build_and_save_all_shards(parallel=True) # Validate all shards logging.info("Validating shards...") for shard in tqdm(range(1, config.total_shards + 1), desc="Validating"): if builder.validate_shard(shard): checksum = builder.compute_checksum(shard) logging.info(f"Shard {shard} validated, checksum: {checksum[:8]}...") else: logging.warning(f"Shard {shard} validation failed") # Export metadata builder.export_metadata() return 0 except Exception as e: logging.error(f"Execution failed: {str(e)}") return 1 if __name__ == "__main__": sys.exit(main()) def main(): """Main execution flow with size estimation and validation.""" try: # Default configuration config = ModelConfig() builder = TransformerModelBuilder(config) # Estimate size num_params, size_gb = estimate_model_size(config) logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB") # Build and save weights = builder.build_model() if builder.validate_model(weights): logging.info("Model validation passed") builder.save_model() else: logging.warning("Model validation failed") return 1 return 0 except Exception as e: logging.error(f"Execution failed: {str(e)}") return 1 if __name__ == "__main__": sys.exit(main())