GeminiFan207 commited on
Commit
ced2d25
·
verified ·
1 Parent(s): ec8f409

Update model.safetensors

Browse files
Files changed (1) hide show
  1. model.safetensors +221 -113
model.safetensors CHANGED
@@ -1,29 +1,35 @@
1
  import torch
2
  from safetensors.torch import save_file, load_file
3
- from typing import Dict, Optional, Tuple, List
4
  import logging
5
  import time
6
  import json
 
 
7
  from pathlib import Path
8
  import sys
9
- import yaml
10
- from dataclasses import dataclass
11
  import numpy as np
12
  from tqdm import tqdm
 
 
 
 
13
 
14
- # Configure logging with file output
15
  logging.basicConfig(
16
  level=logging.INFO,
17
- format="%(asctime)s - %(levelname)s - %(message)s",
18
  handlers=[
19
  logging.StreamHandler(sys.stdout),
20
- logging.FileHandler("transformer_builder.log")
21
  ]
22
  )
23
 
24
  @dataclass
25
  class ModelConfig:
26
- """Configuration class for transformer model parameters."""
27
  num_layers: int = 48
28
  hidden_size: int = 8192
29
  heads: int = 64
@@ -31,83 +37,106 @@ class ModelConfig:
31
  vocab_size: int = 50000
32
  dtype: str = "float16"
33
  ffn_multiplier: int = 4
34
- save_path: str = "charm15_large.safetensors"
 
35
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
36
  seed: Optional[int] = 42
 
 
 
 
 
 
37
 
38
- class TransformerModelBuilder:
39
- """Advanced class to build, validate, and save transformer model weights."""
40
-
41
  def __init__(self, config: Optional[ModelConfig] = None):
42
- """Initialize with optional configuration."""
43
  self.config = config or ModelConfig()
44
  self.dtype = getattr(torch, self.config.dtype)
45
  self.device = torch.device(self.config.device)
46
- self.weights: Dict[str, torch.Tensor] = {}
47
- self.metadata: Dict[str, any] = {}
 
48
 
49
  self._validate_config()
50
  self._setup_environment()
 
51
 
52
  def _validate_config(self) -> None:
53
  """Validate configuration parameters."""
54
  checks = [
55
  (self.config.num_layers > 0, "Number of layers must be positive"),
56
- (self.config.hidden_size % self.config.heads == 0,
57
- "Hidden size must be divisible by number of heads"),
58
  (self.config.seq_length > 0, "Sequence length must be positive"),
59
  (self.config.vocab_size > 0, "Vocab size must be positive"),
60
- (self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1")
 
 
61
  ]
62
-
63
  for condition, message in checks:
64
  if not condition:
65
  raise ValueError(message)
 
 
66
 
67
  def _setup_environment(self) -> None:
68
- """Setup random seed and device environment."""
69
  if self.config.seed is not None:
70
  torch.manual_seed(self.config.seed)
71
  np.random.seed(self.config.seed)
72
- logging.info(f"Using device: {self.device}")
73
- if str(self.device) == "cuda":
74
- logging.info(f"GPU Memory Available: {torch.cuda.memory_available() / 1024**3:.2f} GB")
 
 
 
 
 
 
 
75
 
76
- def _scaled_init(self, *shape) -> torch.Tensor:
77
- """Create scaled random tensor for initialization."""
78
- tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
79
- fan_in = shape[-2] if len(shape) > 1 else shape[-1]
80
- return tensor * (1.0 / fan_in ** 0.5)
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
83
- """Create attention mechanism weights for a layer."""
84
  weights = {}
85
  prefix = f"layer_{layer_idx}.attention"
86
  head_dim = self.config.hidden_size // self.config.heads
87
 
88
- weights[f"{prefix}.query_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
89
- weights[f"{prefix}.key_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
90
- weights[f"{prefix}.value_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
91
- weights[f"{prefix}.output_weight"] = self._scaled_init(self.config.hidden_size, self.config.hidden_size)
92
- weights[f"{prefix}.head_bias"] = torch.zeros(self.config.heads, head_dim, dtype=self.dtype, device=self.device)
93
-
94
  return weights
95
 
96
  def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
97
- """Create feed-forward network weights for a layer."""
98
  weights = {}
99
  prefix = f"layer_{layer_idx}.ffn"
100
  intermediate_size = self.config.hidden_size * self.config.ffn_multiplier
101
 
102
- weights[f"{prefix}.intermediate_weight"] = self._scaled_init(self.config.hidden_size, intermediate_size)
103
  weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device)
104
- weights[f"{prefix}.output_weight"] = self._scaled_init(intermediate_size, self.config.hidden_size)
105
  weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
106
-
107
  return weights
108
 
109
  def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
110
- """Create normalization layer weights."""
111
  prefix = f"layer_{layer_idx}"
112
  return {
113
  f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
@@ -116,98 +145,177 @@ class TransformerModelBuilder:
116
  f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
117
  }
118
 
119
- def build_model(self) -> Dict[str, torch.Tensor]:
120
- """Build complete transformer model weights."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  start_time = time.time()
122
- self.weights.clear()
123
 
124
  try:
125
- # Build transformer layers with progress bar
126
- for i in tqdm(range(self.config.num_layers), desc="Building layers"):
127
- self.weights.update(self._create_attention_block(i))
128
- self.weights.update(self._create_ffn_block(i))
129
- self.weights.update(self._create_norm_block(i))
130
-
131
- # Embedding and output layers
132
- logging.info("Building embedding and output layers")
133
- self.weights["embedding.word_embeddings"] = self._scaled_init(
134
- self.config.vocab_size, self.config.hidden_size
135
- )
136
- self.weights["embedding.position_embeddings"] = self._scaled_init(
137
- self.config.seq_length, self.config.hidden_size
138
- )
139
- self.weights["embedding.token_type_embeddings"] = self._scaled_init(
140
- self.config.seq_length, self.config.hidden_size
141
- )
142
- self.weights["output_layer.weight"] = self._scaled_init(
143
- self.config.hidden_size, self.config.vocab_size
144
- )
145
- self.weights["output_layer.bias"] = torch.zeros(
146
- self.config.vocab_size, dtype=self.dtype, device=self.device
147
- )
148
-
149
- # Store metadata
150
- self.metadata = {
151
- "build_time": time.time() - start_time,
152
- "num_parameters": sum(t.numel() for t in self.weights.values()),
153
- "config": vars(self.config)
154
  }
155
- logging.info(f"Model built with {self.metadata['num_parameters']:,} parameters "
156
- f"in {self.metadata['build_time']:.2f} seconds")
157
- return self.weights
158
-
159
  except Exception as e:
160
- logging.error(f"Model building failed: {str(e)}")
161
- raise RuntimeError(f"Failed to build model: {str(e)}") from e
162
 
163
- def save_model(self, save_path: Optional[str | Path] = None) -> None:
164
- """Save model weights and metadata to safetensors file."""
165
- save_path = Path(save_path or self.config.save_path)
166
  start_time = time.time()
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  try:
169
- save_path.parent.mkdir(parents=True, exist_ok=True)
170
- save_file(self.weights, str(save_path), metadata=self.metadata)
171
-
172
- # Save config separately
173
- config_path = save_path.with_suffix(".yaml")
174
- with open(config_path, "w") as f:
175
- yaml.dump(vars(self.config), f, default_flow_style=False)
176
-
177
- elapsed = time.time() - start_time
178
- logging.info(f"Model and config saved to {save_path} in {elapsed:.2f} seconds")
179
  except Exception as e:
180
- logging.error(f"Model saving failed: {str(e)}")
181
- raise RuntimeError(f"Failed to save model: {str(e)}") from e
182
 
183
- def validate_model(self, weights: Optional[Dict[str, torch.Tensor]] = None) -> bool:
184
- """Validate model weights for consistency."""
185
- weights = weights or self.weights
186
- all_valid = True
187
-
188
- for name, tensor in weights.items():
189
- if torch.isnan(tensor).any() or torch.isinf(tensor).any():
190
- logging.warning(f"Invalid values detected in {name}")
191
- all_valid = False
192
- logging.debug(f"Validated {name}: shape={tensor.shape}")
193
-
194
- return all_valid
 
 
 
195
 
196
  @classmethod
197
- def from_config_file(cls, config_path: str | Path) -> "TransformerModelBuilder":
198
- """Create builder from YAML config file."""
199
- with open(config_path, "r") as f:
200
  config_dict = yaml.safe_load(f)
201
  return cls(ModelConfig(**config_dict))
202
 
203
  def estimate_model_size(config: ModelConfig) -> Tuple[int, float]:
204
- """Estimate model size in parameters and GB."""
205
- builder = TransformerModelBuilder(config)
206
- weights = builder.build_model()
207
- num_params = sum(t.numel() for t in weights.values())
208
- size_gb = sum(t.element_size() * t.numel() for t in weights.values()) / 1024**3
209
- return num_params, size_gb
 
 
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def main():
212
  """Main execution flow with size estimation and validation."""
213
  try:
 
1
  import torch
2
  from safetensors.torch import save_file, load_file
3
+ from typing import Dict, Optional, Tuple, List, Union, Any
4
  import logging
5
  import time
6
  import json
7
+ import yaml
8
+ import os
9
  from pathlib import Path
10
  import sys
11
+ import shutil
12
+ from dataclasses import dataclass, asdict
13
  import numpy as np
14
  from tqdm import tqdm
15
+ import multiprocessing as mp
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ import hashlib
18
+ from torch.nn.init import xavier_uniform_, kaiming_uniform_
19
 
20
+ # Configure logging with rotation and detailed output
21
  logging.basicConfig(
22
  level=logging.INFO,
23
+ format="%(asctime)s - %(levelname)s - [%(processName)s:%(threadName)s] - %(message)s",
24
  handlers=[
25
  logging.StreamHandler(sys.stdout),
26
+ logging.FileHandler("transformer_shard_builder.log", mode="a")
27
  ]
28
  )
29
 
30
  @dataclass
31
  class ModelConfig:
32
+ """Configuration for transformer model parameters and sharding."""
33
  num_layers: int = 48
34
  hidden_size: int = 8192
35
  heads: int = 64
 
37
  vocab_size: int = 50000
38
  dtype: str = "float16"
39
  ffn_multiplier: int = 4
40
+ total_shards: int = 278
41
+ base_path: str = "model_shards"
42
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
43
  seed: Optional[int] = 42
44
+ init_method: str = "xavier" # Options: "xavier", "kaiming", "normal"
45
+ shard_compression: bool = True
46
+ validation_threshold: float = 1e-5
47
+
48
+ class TransformerShardBuilder:
49
+ """Advanced class to build, shard, validate, and save a large transformer model."""
50
 
 
 
 
51
  def __init__(self, config: Optional[ModelConfig] = None):
52
+ """Initialize with configuration and setup environment."""
53
  self.config = config or ModelConfig()
54
  self.dtype = getattr(torch, self.config.dtype)
55
  self.device = torch.device(self.config.device)
56
+ self.base_path = Path(self.config.base_path)
57
+ self.weights: Dict[int, Dict[str, torch.Tensor]] = {} # Shard-indexed weights
58
+ self.metadata: Dict[str, Any] = {}
59
 
60
  self._validate_config()
61
  self._setup_environment()
62
+ self._calculate_sharding()
63
 
64
  def _validate_config(self) -> None:
65
  """Validate configuration parameters."""
66
  checks = [
67
  (self.config.num_layers > 0, "Number of layers must be positive"),
68
+ (self.config.hidden_size % self.config.heads == 0, "Hidden size must be divisible by heads"),
 
69
  (self.config.seq_length > 0, "Sequence length must be positive"),
70
  (self.config.vocab_size > 0, "Vocab size must be positive"),
71
+ (self.config.total_shards > 0, "Total shards must be positive"),
72
+ (self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1"),
73
+ (self.config.init_method in ["xavier", "kaiming", "normal"], "Invalid initialization method")
74
  ]
 
75
  for condition, message in checks:
76
  if not condition:
77
  raise ValueError(message)
78
+ if self.config.num_layers < self.config.total_shards:
79
+ raise ValueError("Number of layers must be >= total shards")
80
 
81
  def _setup_environment(self) -> None:
82
+ """Setup random seed, device, and directories."""
83
  if self.config.seed is not None:
84
  torch.manual_seed(self.config.seed)
85
  np.random.seed(self.config.seed)
86
+ self.base_path.mkdir(parents=True, exist_ok=True)
87
+ logging.info(f"Environment setup: device={self.device}, base_path={self.base_path}")
88
+ if self.device.type == "cuda":
89
+ logging.info(f"CUDA Memory: {torch.cuda.memory_available() / 1024**3:.2f} GB free")
90
+
91
+ def _calculate_sharding(self) -> None:
92
+ """Calculate layer distribution across shards."""
93
+ self.layers_per_shard = self.config.num_layers // self.config.total_shards
94
+ self.remaining_layers = self.config.num_layers % self.config.total_shards
95
+ logging.info(f"Sharding: {self.layers_per_shard} layers/shard, {self.remaining_layers} extra")
96
 
97
+ def _initialize_tensor(self, *shape) -> torch.Tensor:
98
+ """Initialize tensor based on configured method."""
99
+ tensor = torch.empty(*shape, dtype=self.dtype, device=self.device)
100
+ if self.config.init_method == "xavier":
101
+ if len(shape) > 1:
102
+ xavier_uniform_(tensor)
103
+ else:
104
+ tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
105
+ elif self.config.init_method == "kaiming":
106
+ if len(shape) > 1:
107
+ kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="relu")
108
+ else:
109
+ tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
110
+ else: # normal
111
+ tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
112
+ return tensor
113
 
114
  def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
115
+ """Create attention weights for a layer."""
116
  weights = {}
117
  prefix = f"layer_{layer_idx}.attention"
118
  head_dim = self.config.hidden_size // self.config.heads
119
 
120
+ for name in ["query_weight", "key_weight", "value_weight", "output_weight"]:
121
+ weights[f"{prefix}.{name}"] = self._initialize_tensor(self.config.hidden_size, self.config.hidden_size)
122
+ weights[f"{prefix}.{name}_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
123
+ weights[f"{prefix}.head_scale"] = torch.ones(self.config.heads, head_dim, dtype=self.dtype, device=self.device)
 
 
124
  return weights
125
 
126
  def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
127
+ """Create FFN weights for a layer."""
128
  weights = {}
129
  prefix = f"layer_{layer_idx}.ffn"
130
  intermediate_size = self.config.hidden_size * self.config.ffn_multiplier
131
 
132
+ weights[f"{prefix}.intermediate_weight"] = self._initialize_tensor(self.config.hidden_size, intermediate_size)
133
  weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device)
134
+ weights[f"{prefix}.output_weight"] = self._initialize_tensor(intermediate_size, self.config.hidden_size)
135
  weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
 
136
  return weights
137
 
138
  def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
139
+ """Create normalization weights."""
140
  prefix = f"layer_{layer_idx}"
141
  return {
142
  f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
 
145
  f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
146
  }
147
 
148
+ def _create_embedding_output(self) -> Dict[str, torch.Tensor]:
149
+ """Create embedding and output layers for first shard."""
150
+ weights = {
151
+ "embedding.word_embeddings": self._initialize_tensor(self.config.vocab_size, self.config.hidden_size),
152
+ "embedding.position_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size),
153
+ "embedding.token_type_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size),
154
+ "output_layer.weight": self._initialize_tensor(self.config.hidden_size, self.config.vocab_size),
155
+ "output_layer.bias": torch.zeros(self.config.vocab_size, dtype=self.dtype, device=self.device)
156
+ }
157
+ return weights
158
+
159
+ def build_shard(self, shard_idx: int) -> Dict[str, torch.Tensor]:
160
+ """Build weights for a specific shard."""
161
+ weights = {}
162
+ start_time = time.time()
163
+
164
+ start_layer = (shard_idx - 1) * self.layers_per_shard
165
+ end_layer = start_layer + self.layers_per_shard
166
+ if shard_idx == self.config.total_shards:
167
+ end_layer += self.remaining_layers
168
+
169
+ for i in tqdm(range(start_layer, end_layer), desc=f"Shard {shard_idx} layers"):
170
+ weights.update(self._create_attention_block(i))
171
+ weights.update(self._create_ffn_block(i))
172
+ weights.update(self._create_norm_block(i))
173
+
174
+ if shard_idx == 1:
175
+ weights.update(self._create_embedding_output())
176
+
177
+ elapsed = time.time() - start_time
178
+ self.metadata[f"shard_{shard_idx}"] = {"build_time": elapsed, "num_layers": end_layer - start_layer}
179
+ logging.info(f"Shard {shard_idx} built with {len(weights)} tensors in {elapsed:.2f}s")
180
+ return weights
181
+
182
+ def save_shard(self, shard_idx: int, weights: Dict[str, torch.Tensor]) -> None:
183
+ """Save a single shard with metadata."""
184
+ shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
185
  start_time = time.time()
 
186
 
187
  try:
188
+ shard_metadata = {
189
+ "shard_idx": shard_idx,
190
+ "total_shards": self.config.total_shards,
191
+ "config": asdict(self.config),
192
+ **self.metadata.get(f"shard_{shard_idx}", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  }
194
+ save_file(weights, str(shard_path), metadata=shard_metadata)
195
+ elapsed = time.time() - start_time
196
+ logging.info(f"Shard {shard_idx} saved to {shard_path} in {elapsed:.2f}s")
 
197
  except Exception as e:
198
+ logging.error(f"Shard {shard_idx} save failed: {str(e)}")
199
+ raise RuntimeError(f"Failed to save shard {shard_idx}: {str(e)}") from e
200
 
201
+ def build_and_save_all_shards(self, parallel: bool = True) -> None:
202
+ """Build and save all shards, optionally in parallel."""
 
203
  start_time = time.time()
204
 
205
+ if parallel and mp.cpu_count() > 1:
206
+ with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), self.config.total_shards)) as executor:
207
+ futures = {
208
+ executor.submit(self.build_shard, i): i
209
+ for i in range(1, self.config.total_shards + 1)
210
+ }
211
+ for future in as_completed(futures):
212
+ shard_idx = futures[future]
213
+ try:
214
+ weights = future.result()
215
+ self.save_shard(shard_idx, weights)
216
+ except Exception as e:
217
+ logging.error(f"Parallel shard {shard_idx} failed: {str(e)}")
218
+ else:
219
+ for shard_idx in tqdm(range(1, self.config.total_shards + 1), desc="Building shards"):
220
+ weights = self.build_shard(shard_idx)
221
+ self.save_shard(shard_idx, weights)
222
+
223
+ total_time = time.time() - start_time
224
+ self.metadata["total_build_time"] = total_time
225
+ logging.info(f"All {self.config.total_shards} shards completed in {total_time:.2f}s")
226
+
227
+ def validate_shard(self, shard_idx: int) -> bool:
228
+ """Validate a shard's weights after loading."""
229
+ shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
230
  try:
231
+ weights = load_file(str(shard_path), device="cpu") # Load to CPU for validation
232
+ all_valid = True
233
+ for name, tensor in weights.items():
234
+ if torch.isnan(tensor).any() or torch.isinf(tensor).any():
235
+ logging.warning(f"Invalid values in {name} (shard {shard_idx})")
236
+ all_valid = False
237
+ elif torch.max(torch.abs(tensor)) > self.config.validation_threshold:
238
+ logging.warning(f"Large values in {name} (shard {shard_idx})")
239
+ return all_valid
 
240
  except Exception as e:
241
+ logging.error(f"Validation failed for shard {shard_idx}: {str(e)}")
242
+ return False
243
 
244
+ def compute_checksum(self, shard_idx: int) -> str:
245
+ """Compute SHA256 checksum of a shard file."""
246
+ shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
247
+ sha256 = hashlib.sha256()
248
+ with open(shard_path, "rb") as f:
249
+ for chunk in iter(lambda: f.read(4096), b""):
250
+ sha256.update(chunk)
251
+ return sha256.hexdigest()
252
+
253
+ def export_metadata(self, output_path: str | Path = "model_metadata.json") -> None:
254
+ """Export metadata to JSON file."""
255
+ output_path = Path(output_path)
256
+ with open(output_path, "w") as f:
257
+ json.dump(self.metadata, f, indent=2)
258
+ logging.info(f"Metadata exported to {output_path}")
259
 
260
  @classmethod
261
+ def from_yaml(cls, yaml_path: str | Path) -> "TransformerShardBuilder":
262
+ """Initialize from YAML config file."""
263
+ with open(yaml_path, "r") as f:
264
  config_dict = yaml.safe_load(f)
265
  return cls(ModelConfig(**config_dict))
266
 
267
  def estimate_model_size(config: ModelConfig) -> Tuple[int, float]:
268
+ """Estimate total model size in parameters and GB."""
269
+ builder = TransformerShardBuilder(config)
270
+ params = 0
271
+ bytes_size = 0
272
+ for shard in range(1, config.total_shards + 1):
273
+ weights = builder.build_shard(shard)
274
+ params += sum(t.numel() for t in weights.values())
275
+ bytes_size += sum(t.element_size() * t.numel() for t in weights.values())
276
+ return params, bytes_size / 1024**3
277
 
278
+ def main():
279
+ """Main execution flow with comprehensive functionality."""
280
+ try:
281
+ # Custom configuration
282
+ config = ModelConfig(
283
+ num_layers=48,
284
+ hidden_size=8192,
285
+ heads=64,
286
+ seq_length=4096,
287
+ vocab_size=50000,
288
+ total_shards=278,
289
+ base_path="model_shards_large"
290
+ )
291
+ builder = TransformerShardBuilder(config)
292
+
293
+ # Size estimation
294
+ num_params, size_gb = estimate_model_size(config)
295
+ logging.info(f"Estimated size: {num_params:,} parameters, {size_gb:.2f} GB")
296
+
297
+ # Build and save all shards
298
+ builder.build_and_save_all_shards(parallel=True)
299
+
300
+ # Validate all shards
301
+ logging.info("Validating shards...")
302
+ for shard in tqdm(range(1, config.total_shards + 1), desc="Validating"):
303
+ if builder.validate_shard(shard):
304
+ checksum = builder.compute_checksum(shard)
305
+ logging.info(f"Shard {shard} validated, checksum: {checksum[:8]}...")
306
+ else:
307
+ logging.warning(f"Shard {shard} validation failed")
308
+
309
+ # Export metadata
310
+ builder.export_metadata()
311
+
312
+ return 0
313
+ except Exception as e:
314
+ logging.error(f"Execution failed: {str(e)}")
315
+ return 1
316
+
317
+ if __name__ == "__main__":
318
+ sys.exit(main())
319
  def main():
320
  """Main execution flow with size estimation and validation."""
321
  try: