GeminiFan207 commited on
Commit
8d2806a
·
verified ·
1 Parent(s): cff74af

Create model.safetensors

Browse files
Files changed (1) hide show
  1. model.safetensors +238 -0
model.safetensors ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import save_file, load_file
3
+ from typing import Dict, Optional, Tuple, List
4
+ import logging
5
+ import time
6
+ import json
7
+ from pathlib import Path
8
+ import sys
9
+ import yaml
10
+ from dataclasses import dataclass
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ # Configure logging with file output
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format="%(asctime)s - %(levelname)s - %(message)s",
18
+ handlers=[
19
+ logging.StreamHandler(sys.stdout),
20
+ logging.FileHandler("transformer_builder.log")
21
+ ]
22
+ )
23
+
24
+ @dataclass
25
+ class ModelConfig:
26
+ """Configuration class for transformer model parameters."""
27
+ num_layers: int = 48
28
+ hidden_size: int = 8192
29
+ heads: int = 64
30
+ seq_length: int = 4096
31
+ vocab_size: int = 50000
32
+ dtype: str = "float16"
33
+ ffn_multiplier: int = 4
34
+ save_path: str = "charm15_large.safetensors"
35
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
36
+ seed: Optional[int] = 42
37
+
38
+ class TransformerModelBuilder:
39
+ """Advanced class to build, validate, and save transformer model weights."""
40
+
41
+ def __init__(self, config: Optional[ModelConfig] = None):
42
+ """Initialize with optional configuration."""
43
+ self.config = config or ModelConfig()
44
+ self.dtype = getattr(torch, self.config.dtype)
45
+ self.device = torch.device(self.config.device)
46
+ self.weights: Dict[str, torch.Tensor] = {}
47
+ self.metadata: Dict[str, any] = {}
48
+
49
+ self._validate_config()
50
+ self._setup_environment()
51
+
52
+ def _validate_config(self) -> None:
53
+ """Validate configuration parameters."""
54
+ checks = [
55
+ (self.config.num_layers > 0, "Number of layers must be positive"),
56
+ (self.config.hidden_size % self.config.heads == 0,
57
+ "Hidden size must be divisible by number of heads"),
58
+ (self.config.seq_length > 0, "Sequence length must be positive"),
59
+ (self.config.vocab_size > 0, "Vocab size must be positive"),
60
+ (self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1")
61
+ ]
62
+
63
+ for condition, message in checks:
64
+ if not condition:
65
+ raise ValueError(message)
66
+
67
+ def _setup_environment(self) -> None:
68
+ """Setup random seed and device environment."""
69
+ if self.config.seed is not None:
70
+ torch.manual_seed(self.config.seed)
71
+ np.random.seed(self.config.seed)
72
+ logging.info(f"Using device: {self.device}")
73
+ if str(self.device) == "cuda":
74
+ logging.info(f"GPU Memory Available: {torch.cuda.memory_available() / 1024**3:.2f} GB")
75
+
76
+ def _scaled_init(self, *shape) -> torch.Tensor:
77
+ """Create scaled random tensor for initialization."""
78
+ tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
79
+ fan_in = shape[-2] if len(shape) > 1 else shape[-1]
80
+ return tensor * (1.0 / fan_in ** 0.5)
81
+
82
+ def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
83
+ """Create attention mechanism weights for a layer."""
84
+ weights = {}
85
+ prefix = f"layer_{layer_idx}.attention"
86
+ head_dim = self.config.hidden_size // self.config.heads
87
+
88
+ weights[f"{prefix}.query_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
89
+ weights[f"{prefix}.key_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
90
+ weights[f"{prefix}.value_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
91
+ weights[f"{prefix}.output_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
92
+ weights[f"{prefix}.head_bias"] = torch.zeros(self.config.heads, head_dim, dtype=self.dtype, device=self.device)
93
+
94
+ return weights
95
+
96
+ def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
97
+ """Create feed-forward network weights for a layer."""
98
+ weights = {}
99
+ prefix = f"layer_{layer_idx}.ffn"
100
+ intermediate_size = self.config.hidden_size * self.config.ffn_multiplier
101
+
102
+ weights[f"{prefix}.intermediate_weight"] = self._scaled_init(self.config.hidden_size, intermediate_size)
103
+ weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device)
104
+ weights[f"{prefix}.output_weight"] = self._scaled_init(intermediate_size, self.config.hidden_size)
105
+ weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
106
+
107
+ return weights
108
+
109
+ def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
110
+ """Create normalization layer weights."""
111
+ prefix = f"layer_{layer_idx}"
112
+ return {
113
+ f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
114
+ f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device),
115
+ f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
116
+ f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
117
+ }
118
+
119
+ def build_model(self) -> Dict[str, torch.Tensor]:
120
+ """Build complete transformer model weights."""
121
+ start_time = time.time()
122
+ self.weights.clear()
123
+
124
+ try:
125
+ # Build transformer layers with progress bar
126
+ for i in tqdm(range(self.config.num_layers), desc="Building layers"):
127
+ self.weights.update(self._create_attention_block(i))
128
+ self.weights.update(self._create_ffn_block(i))
129
+ self.weights.update(self._create_norm_block(i))
130
+
131
+ # Embedding and output layers
132
+ logging.info("Building embedding and output layers")
133
+ self.weights["embedding.word_embeddings"] = self._scaled_init(
134
+ self.config.vocab_size, self.config.hidden_size
135
+ )
136
+ self.weights["embedding.position_embeddings"] = self._scaled_init(
137
+ self.config.seq_length, self.config.hidden_size
138
+ )
139
+ self.weights["embedding.token_type_embeddings"] = self._scaled_init(
140
+ self.config.seq_length, self.config.hidden_size
141
+ )
142
+ self.weights["output_layer.weight"] = self._scaled_init(
143
+ self.config.hidden_size, self.config.vocab_size
144
+ )
145
+ self.weights["output_layer.bias"] = torch.zeros(
146
+ self.config.vocab_size, dtype=self.dtype, device=self.device
147
+ )
148
+
149
+ # Store metadata
150
+ self.metadata = {
151
+ "build_time": time.time() - start_time,
152
+ "num_parameters": sum(t.numel() for t in self.weights.values()),
153
+ "config": vars(self.config)
154
+ }
155
+ logging.info(f"Model built with {self.metadata['num_parameters']:,} parameters "
156
+ f"in {self.metadata['build_time']:.2f} seconds")
157
+ return self.weights
158
+
159
+ except Exception as e:
160
+ logging.error(f"Model building failed: {str(e)}")
161
+ raise RuntimeError(f"Failed to build model: {str(e)}") from e
162
+
163
+ def save_model(self, save_path: Optional[str | Path] = None) -> None:
164
+ """Save model weights and metadata to safetensors file."""
165
+ save_path = Path(save_path or self.config.save_path)
166
+ start_time = time.time()
167
+
168
+ try:
169
+ save_path.parent.mkdir(parents=True, exist_ok=True)
170
+ save_file(self.weights, str(save_path), metadata=self.metadata)
171
+
172
+ # Save config separately
173
+ config_path = save_path.with_suffix(".yaml")
174
+ with open(config_path, "w") as f:
175
+ yaml.dump(vars(self.config), f, default_flow_style=False)
176
+
177
+ elapsed = time.time() - start_time
178
+ logging.info(f"Model and config saved to {save_path} in {elapsed:.2f} seconds")
179
+ except Exception as e:
180
+ logging.error(f"Model saving failed: {str(e)}")
181
+ raise RuntimeError(f"Failed to save model: {str(e)}") from e
182
+
183
+ def validate_model(self, weights: Optional[Dict[str, torch.Tensor]] = None) -> bool:
184
+ """Validate model weights for consistency."""
185
+ weights = weights or self.weights
186
+ all_valid = True
187
+
188
+ for name, tensor in weights.items():
189
+ if torch.isnan(tensor).any() or torch.isinf(tensor).any():
190
+ logging.warning(f"Invalid values detected in {name}")
191
+ all_valid = False
192
+ logging.debug(f"Validated {name}: shape={tensor.shape}")
193
+
194
+ return all_valid
195
+
196
+ @classmethod
197
+ def from_config_file(cls, config_path: str | Path) -> "TransformerModelBuilder":
198
+ """Create builder from YAML config file."""
199
+ with open(config_path, "r") as f:
200
+ config_dict = yaml.safe_load(f)
201
+ return cls(ModelConfig(**config_dict))
202
+
203
+ def estimate_model_size(config: ModelConfig) -> Tuple[int, float]:
204
+ """Estimate model size in parameters and GB."""
205
+ builder = TransformerModelBuilder(config)
206
+ weights = builder.build_model()
207
+ num_params = sum(t.numel() for t in weights.values())
208
+ size_gb = sum(t.element_size() * t.numel() for t in weights.values()) / 1024**3
209
+ return num_params, size_gb
210
+
211
+ def main():
212
+ """Main execution flow with size estimation and validation."""
213
+ try:
214
+ # Default configuration
215
+ config = ModelConfig()
216
+ builder = TransformerModelBuilder(config)
217
+
218
+ # Estimate size
219
+ num_params, size_gb = estimate_model_size(config)
220
+ logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB")
221
+
222
+ # Build and save
223
+ weights = builder.build_model()
224
+ if builder.validate_model(weights):
225
+ logging.info("Model validation passed")
226
+ builder.save_model()
227
+ else:
228
+ logging.warning("Model validation failed")
229
+ return 1
230
+
231
+ return 0
232
+
233
+ except Exception as e:
234
+ logging.error(f"Execution failed: {str(e)}")
235
+ return 1
236
+
237
+ if __name__ == "__main__":
238
+ sys.exit(main())