Charm_15 / model.safetensors
GeminiFan207's picture
Create model.safetensors
8d2806a verified
raw
history blame
10.3 kB
import torch
from safetensors.torch import save_file, load_file
from typing import Dict, Optional, Tuple, List
import logging
import time
import json
from pathlib import Path
import sys
import yaml
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm
# Configure logging with file output
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("transformer_builder.log")
]
)
@dataclass
class ModelConfig:
"""Configuration class for transformer model parameters."""
num_layers: int = 48
hidden_size: int = 8192
heads: int = 64
seq_length: int = 4096
vocab_size: int = 50000
dtype: str = "float16"
ffn_multiplier: int = 4
save_path: str = "charm15_large.safetensors"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
seed: Optional[int] = 42
class TransformerModelBuilder:
"""Advanced class to build, validate, and save transformer model weights."""
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize with optional configuration."""
self.config = config or ModelConfig()
self.dtype = getattr(torch, self.config.dtype)
self.device = torch.device(self.config.device)
self.weights: Dict[str, torch.Tensor] = {}
self.metadata: Dict[str, any] = {}
self._validate_config()
self._setup_environment()
def _validate_config(self) -> None:
"""Validate configuration parameters."""
checks = [
(self.config.num_layers > 0, "Number of layers must be positive"),
(self.config.hidden_size % self.config.heads == 0,
"Hidden size must be divisible by number of heads"),
(self.config.seq_length > 0, "Sequence length must be positive"),
(self.config.vocab_size > 0, "Vocab size must be positive"),
(self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1")
]
for condition, message in checks:
if not condition:
raise ValueError(message)
def _setup_environment(self) -> None:
"""Setup random seed and device environment."""
if self.config.seed is not None:
torch.manual_seed(self.config.seed)
np.random.seed(self.config.seed)
logging.info(f"Using device: {self.device}")
if str(self.device) == "cuda":
logging.info(f"GPU Memory Available: {torch.cuda.memory_available() / 1024**3:.2f} GB")
def _scaled_init(self, *shape) -> torch.Tensor:
"""Create scaled random tensor for initialization."""
tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
fan_in = shape[-2] if len(shape) > 1 else shape[-1]
return tensor * (1.0 / fan_in ** 0.5)
def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create attention mechanism weights for a layer."""
weights = {}
prefix = f"layer_{layer_idx}.attention"
head_dim = self.config.hidden_size // self.config.heads
weights[f"{prefix}.query_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
weights[f"{prefix}.key_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
weights[f"{prefix}.value_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
weights[f"{prefix}.output_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
weights[f"{prefix}.head_bias"] = torch.zeros(self.config.heads, head_dim, dtype=self.dtype, device=self.device)
return weights
def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create feed-forward network weights for a layer."""
weights = {}
prefix = f"layer_{layer_idx}.ffn"
intermediate_size = self.config.hidden_size * self.config.ffn_multiplier
weights[f"{prefix}.intermediate_weight"] = self._scaled_init(self.config.hidden_size, intermediate_size)
weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device)
weights[f"{prefix}.output_weight"] = self._scaled_init(intermediate_size, self.config.hidden_size)
weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
return weights
def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create normalization layer weights."""
prefix = f"layer_{layer_idx}"
return {
f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
}
def build_model(self) -> Dict[str, torch.Tensor]:
"""Build complete transformer model weights."""
start_time = time.time()
self.weights.clear()
try:
# Build transformer layers with progress bar
for i in tqdm(range(self.config.num_layers), desc="Building layers"):
self.weights.update(self._create_attention_block(i))
self.weights.update(self._create_ffn_block(i))
self.weights.update(self._create_norm_block(i))
# Embedding and output layers
logging.info("Building embedding and output layers")
self.weights["embedding.word_embeddings"] = self._scaled_init(
self.config.vocab_size, self.config.hidden_size
)
self.weights["embedding.position_embeddings"] = self._scaled_init(
self.config.seq_length, self.config.hidden_size
)
self.weights["embedding.token_type_embeddings"] = self._scaled_init(
self.config.seq_length, self.config.hidden_size
)
self.weights["output_layer.weight"] = self._scaled_init(
self.config.hidden_size, self.config.vocab_size
)
self.weights["output_layer.bias"] = torch.zeros(
self.config.vocab_size, dtype=self.dtype, device=self.device
)
# Store metadata
self.metadata = {
"build_time": time.time() - start_time,
"num_parameters": sum(t.numel() for t in self.weights.values()),
"config": vars(self.config)
}
logging.info(f"Model built with {self.metadata['num_parameters']:,} parameters "
f"in {self.metadata['build_time']:.2f} seconds")
return self.weights
except Exception as e:
logging.error(f"Model building failed: {str(e)}")
raise RuntimeError(f"Failed to build model: {str(e)}") from e
def save_model(self, save_path: Optional[str | Path] = None) -> None:
"""Save model weights and metadata to safetensors file."""
save_path = Path(save_path or self.config.save_path)
start_time = time.time()
try:
save_path.parent.mkdir(parents=True, exist_ok=True)
save_file(self.weights, str(save_path), metadata=self.metadata)
# Save config separately
config_path = save_path.with_suffix(".yaml")
with open(config_path, "w") as f:
yaml.dump(vars(self.config), f, default_flow_style=False)
elapsed = time.time() - start_time
logging.info(f"Model and config saved to {save_path} in {elapsed:.2f} seconds")
except Exception as e:
logging.error(f"Model saving failed: {str(e)}")
raise RuntimeError(f"Failed to save model: {str(e)}") from e
def validate_model(self, weights: Optional[Dict[str, torch.Tensor]] = None) -> bool:
"""Validate model weights for consistency."""
weights = weights or self.weights
all_valid = True
for name, tensor in weights.items():
if torch.isnan(tensor).any() or torch.isinf(tensor).any():
logging.warning(f"Invalid values detected in {name}")
all_valid = False
logging.debug(f"Validated {name}: shape={tensor.shape}")
return all_valid
@classmethod
def from_config_file(cls, config_path: str | Path) -> "TransformerModelBuilder":
"""Create builder from YAML config file."""
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(ModelConfig(**config_dict))
def estimate_model_size(config: ModelConfig) -> Tuple[int, float]:
"""Estimate model size in parameters and GB."""
builder = TransformerModelBuilder(config)
weights = builder.build_model()
num_params = sum(t.numel() for t in weights.values())
size_gb = sum(t.element_size() * t.numel() for t in weights.values()) / 1024**3
return num_params, size_gb
def main():
"""Main execution flow with size estimation and validation."""
try:
# Default configuration
config = ModelConfig()
builder = TransformerModelBuilder(config)
# Estimate size
num_params, size_gb = estimate_model_size(config)
logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB")
# Build and save
weights = builder.build_model()
if builder.validate_model(weights):
logging.info("Model validation passed")
builder.save_model()
else:
logging.warning("Model validation failed")
return 1
return 0
except Exception as e:
logging.error(f"Execution failed: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())