|
import torch |
|
from safetensors.torch import save_file, load_file |
|
from typing import Dict, Optional, Tuple, List |
|
import logging |
|
import time |
|
import json |
|
from pathlib import Path |
|
import sys |
|
import yaml |
|
from dataclasses import dataclass |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler("transformer_builder.log") |
|
] |
|
) |
|
|
|
@dataclass |
|
class ModelConfig: |
|
"""Configuration class for transformer model parameters.""" |
|
num_layers: int = 48 |
|
hidden_size: int = 8192 |
|
heads: int = 64 |
|
seq_length: int = 4096 |
|
vocab_size: int = 50000 |
|
dtype: str = "float16" |
|
ffn_multiplier: int = 4 |
|
save_path: str = "charm15_large.safetensors" |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
seed: Optional[int] = 42 |
|
|
|
class TransformerModelBuilder: |
|
"""Advanced class to build, validate, and save transformer model weights.""" |
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
"""Initialize with optional configuration.""" |
|
self.config = config or ModelConfig() |
|
self.dtype = getattr(torch, self.config.dtype) |
|
self.device = torch.device(self.config.device) |
|
self.weights: Dict[str, torch.Tensor] = {} |
|
self.metadata: Dict[str, any] = {} |
|
|
|
self._validate_config() |
|
self._setup_environment() |
|
|
|
def _validate_config(self) -> None: |
|
"""Validate configuration parameters.""" |
|
checks = [ |
|
(self.config.num_layers > 0, "Number of layers must be positive"), |
|
(self.config.hidden_size % self.config.heads == 0, |
|
"Hidden size must be divisible by number of heads"), |
|
(self.config.seq_length > 0, "Sequence length must be positive"), |
|
(self.config.vocab_size > 0, "Vocab size must be positive"), |
|
(self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1") |
|
] |
|
|
|
for condition, message in checks: |
|
if not condition: |
|
raise ValueError(message) |
|
|
|
def _setup_environment(self) -> None: |
|
"""Setup random seed and device environment.""" |
|
if self.config.seed is not None: |
|
torch.manual_seed(self.config.seed) |
|
np.random.seed(self.config.seed) |
|
logging.info(f"Using device: {self.device}") |
|
if str(self.device) == "cuda": |
|
logging.info(f"GPU Memory Available: {torch.cuda.memory_available() / 1024**3:.2f} GB") |
|
|
|
def _scaled_init(self, *shape) -> torch.Tensor: |
|
"""Create scaled random tensor for initialization.""" |
|
tensor = torch.randn(*shape, dtype=self.dtype, device=self.device) |
|
fan_in = shape[-2] if len(shape) > 1 else shape[-1] |
|
return tensor * (1.0 / fan_in ** 0.5) |
|
|
|
def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
|
"""Create attention mechanism weights for a layer.""" |
|
weights = {} |
|
prefix = f"layer_{layer_idx}.attention" |
|
head_dim = self.config.hidden_size // self.config.heads |
|
|
|
weights[f"{prefix}.query_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size) |
|
weights[f"{prefix}.key_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size) |
|
weights[f"{prefix}.value_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size) |
|
weights[f"{prefix}.output_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size) |
|
weights[f"{prefix}.head_bias"] = torch.zeros(self.config.heads, head_dim, dtype=self.dtype, device=self.device) |
|
|
|
return weights |
|
|
|
def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
|
"""Create feed-forward network weights for a layer.""" |
|
weights = {} |
|
prefix = f"layer_{layer_idx}.ffn" |
|
intermediate_size = self.config.hidden_size * self.config.ffn_multiplier |
|
|
|
weights[f"{prefix}.intermediate_weight"] = self._scaled_init(self.config.hidden_size, intermediate_size) |
|
weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device) |
|
weights[f"{prefix}.output_weight"] = self._scaled_init(intermediate_size, self.config.hidden_size) |
|
weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) |
|
|
|
return weights |
|
|
|
def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
|
"""Create normalization layer weights.""" |
|
prefix = f"layer_{layer_idx}" |
|
return { |
|
f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), |
|
f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device), |
|
f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), |
|
f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) |
|
} |
|
|
|
def build_model(self) -> Dict[str, torch.Tensor]: |
|
"""Build complete transformer model weights.""" |
|
start_time = time.time() |
|
self.weights.clear() |
|
|
|
try: |
|
|
|
for i in tqdm(range(self.config.num_layers), desc="Building layers"): |
|
self.weights.update(self._create_attention_block(i)) |
|
self.weights.update(self._create_ffn_block(i)) |
|
self.weights.update(self._create_norm_block(i)) |
|
|
|
|
|
logging.info("Building embedding and output layers") |
|
self.weights["embedding.word_embeddings"] = self._scaled_init( |
|
self.config.vocab_size, self.config.hidden_size |
|
) |
|
self.weights["embedding.position_embeddings"] = self._scaled_init( |
|
self.config.seq_length, self.config.hidden_size |
|
) |
|
self.weights["embedding.token_type_embeddings"] = self._scaled_init( |
|
self.config.seq_length, self.config.hidden_size |
|
) |
|
self.weights["output_layer.weight"] = self._scaled_init( |
|
self.config.hidden_size, self.config.vocab_size |
|
) |
|
self.weights["output_layer.bias"] = torch.zeros( |
|
self.config.vocab_size, dtype=self.dtype, device=self.device |
|
) |
|
|
|
|
|
self.metadata = { |
|
"build_time": time.time() - start_time, |
|
"num_parameters": sum(t.numel() for t in self.weights.values()), |
|
"config": vars(self.config) |
|
} |
|
logging.info(f"Model built with {self.metadata['num_parameters']:,} parameters " |
|
f"in {self.metadata['build_time']:.2f} seconds") |
|
return self.weights |
|
|
|
except Exception as e: |
|
logging.error(f"Model building failed: {str(e)}") |
|
raise RuntimeError(f"Failed to build model: {str(e)}") from e |
|
|
|
def save_model(self, save_path: Optional[str | Path] = None) -> None: |
|
"""Save model weights and metadata to safetensors file.""" |
|
save_path = Path(save_path or self.config.save_path) |
|
start_time = time.time() |
|
|
|
try: |
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
save_file(self.weights, str(save_path), metadata=self.metadata) |
|
|
|
|
|
config_path = save_path.with_suffix(".yaml") |
|
with open(config_path, "w") as f: |
|
yaml.dump(vars(self.config), f, default_flow_style=False) |
|
|
|
elapsed = time.time() - start_time |
|
logging.info(f"Model and config saved to {save_path} in {elapsed:.2f} seconds") |
|
except Exception as e: |
|
logging.error(f"Model saving failed: {str(e)}") |
|
raise RuntimeError(f"Failed to save model: {str(e)}") from e |
|
|
|
def validate_model(self, weights: Optional[Dict[str, torch.Tensor]] = None) -> bool: |
|
"""Validate model weights for consistency.""" |
|
weights = weights or self.weights |
|
all_valid = True |
|
|
|
for name, tensor in weights.items(): |
|
if torch.isnan(tensor).any() or torch.isinf(tensor).any(): |
|
logging.warning(f"Invalid values detected in {name}") |
|
all_valid = False |
|
logging.debug(f"Validated {name}: shape={tensor.shape}") |
|
|
|
return all_valid |
|
|
|
@classmethod |
|
def from_config_file(cls, config_path: str | Path) -> "TransformerModelBuilder": |
|
"""Create builder from YAML config file.""" |
|
with open(config_path, "r") as f: |
|
config_dict = yaml.safe_load(f) |
|
return cls(ModelConfig(**config_dict)) |
|
|
|
def estimate_model_size(config: ModelConfig) -> Tuple[int, float]: |
|
"""Estimate model size in parameters and GB.""" |
|
builder = TransformerModelBuilder(config) |
|
weights = builder.build_model() |
|
num_params = sum(t.numel() for t in weights.values()) |
|
size_gb = sum(t.element_size() * t.numel() for t in weights.values()) / 1024**3 |
|
return num_params, size_gb |
|
|
|
def main(): |
|
"""Main execution flow with size estimation and validation.""" |
|
try: |
|
|
|
config = ModelConfig() |
|
builder = TransformerModelBuilder(config) |
|
|
|
|
|
num_params, size_gb = estimate_model_size(config) |
|
logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB") |
|
|
|
|
|
weights = builder.build_model() |
|
if builder.validate_model(weights): |
|
logging.info("Model validation passed") |
|
builder.save_model() |
|
else: |
|
logging.warning("Model validation failed") |
|
return 1 |
|
|
|
return 0 |
|
|
|
except Exception as e: |
|
logging.error(f"Execution failed: {str(e)}") |
|
return 1 |
|
|
|
if __name__ == "__main__": |
|
sys.exit(main()) |