|
import torch |
|
import safetensors.torch |
|
import concurrent.futures |
|
import zlib |
|
import logging |
|
from typing import Dict, Tuple |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
handlers=[logging.StreamHandler()] |
|
) |
|
|
|
class AdvancedModelParameters: |
|
def __init__(self, num_shards=2089, base_filename="charm15", hidden_size=16384, layers_per_shard=100): |
|
"""Initialize model parameters for a massive transformer model.""" |
|
self.num_shards = num_shards |
|
self.base_filename = base_filename |
|
self.hidden_size = hidden_size |
|
self.layers_per_shard = layers_per_shard |
|
self.ffn_multiplier = 4 |
|
self.shape = (hidden_size, hidden_size) |
|
self.dtype = torch.float16 |
|
self.base_path = Path("model_shards") |
|
self.base_path.mkdir(parents=True, exist_ok=True) |
|
|
|
def generate_layer_parameters(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
|
"""Generate parameters for a single transformer layer.""" |
|
params = {} |
|
prefix = f"layer_{layer_idx}" |
|
|
|
|
|
for name in ["query_weight", "key_weight", "value_weight", "output_weight"]: |
|
params[f"{prefix}.attention.{name}"] = torch.randn( |
|
self.shape, dtype=self.dtype |
|
) * (1.0 / self.hidden_size ** 0.5) |
|
|
|
|
|
intermediate_size = self.hidden_size * self.ffn_multiplier |
|
params[f"{prefix}.ffn.intermediate_weight"] = torch.randn( |
|
self.hidden_size, intermediate_size, dtype=self.dtype |
|
) * (1.0 / self.hidden_size ** 0.5) |
|
params[f"{prefix}.ffn.output_weight"] = torch.randn( |
|
intermediate_size, self.hidden_size, dtype=self.dtype |
|
) * (1.0 / intermediate_size ** 0.5) |
|
|
|
return params |
|
|
|
def generate_shard_parameters(self, shard_index: int) -> Dict[str, torch.Tensor]: |
|
"""Generate parameters for a single shard.""" |
|
params = {} |
|
start_layer = (shard_index - 1) * self.layers_per_shard |
|
end_layer = start_layer + self.layers_per_shard |
|
|
|
|
|
for layer_idx in range(start_layer, end_layer): |
|
params.update(self.generate_layer_parameters(layer_idx)) |
|
|
|
|
|
if shard_index == 1: |
|
params["embedding.word_embeddings"] = torch.randn( |
|
50000, self.hidden_size, dtype=self.dtype |
|
) * (1.0 / self.hidden_size ** 0.5) |
|
params["embedding.position_embeddings"] = torch.randn( |
|
4096, self.hidden_size, dtype=self.dtype |
|
) * (1.0 / self.hidden_size ** 0.5) |
|
params["output_layer"] = torch.randn( |
|
self.hidden_size, 50000, dtype=self.dtype |
|
) * (1.0 / self.hidden_size ** 0.5) |
|
|
|
return params |
|
|
|
def compress_tensor(self, tensor: torch.Tensor) -> bytes: |
|
"""Apply zlib compression to tensor data.""" |
|
tensor_bytes = tensor.numpy().tobytes() |
|
return zlib.compress(tensor_bytes, level=9) |
|
|
|
def save_single_shard(self, shard_index: int) -> None: |
|
"""Save a single model shard with compression.""" |
|
params = self.generate_shard_parameters(shard_index) |
|
filename = self.base_path / f"{self.base_filename}_{shard_index}_of_{self.num_shards}.safetensors" |
|
|
|
|
|
compressed_data = {key: self.compress_tensor(value) for key, value in params.items()} |
|
|
|
|
|
metadata = { |
|
"shard_index": shard_index, |
|
"total_shards": self.num_shards, |
|
"layers": self.layers_per_shard, |
|
"hidden_size": self.hidden_size |
|
} |
|
safetensors.torch.save_file(compressed_data, str(filename), metadata=metadata) |
|
logging.info(f"[✔] Shard {shard_index}/{self.num_shards} saved: {filename}") |
|
|
|
def save_sharded_parameters(self) -> None: |
|
"""Save all shards in parallel.""" |
|
logging.info(f"Starting to save {self.num_shards} shards...") |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
executor.map(self.save_single_shard, range(1, self.num_shards + 1)) |
|
logging.info("All shards saved successfully.") |
|
|
|
def estimate_parameters(self) -> Tuple[int, float]: |
|
"""Estimate total parameters and memory usage.""" |
|
params_per_layer = ( |
|
4 * (self.hidden_size * self.hidden_size) + |
|
self.hidden_size * (self.hidden_size * self.ffn_multiplier) + |
|
(self.hidden_size * self.ffn_multiplier) * self.hidden_size |
|
) |
|
params_per_shard = params_per_layer * self.layers_per_shard |
|
total_params = params_per_shard * self.num_shards |
|
|
|
|
|
total_params += ( |
|
50000 * self.hidden_size + |
|
4096 * self.hidden_size + |
|
self.hidden_size * 50000 |
|
) |
|
|
|
memory_gb = (total_params * 2) / 1024**3 |
|
return total_params, memory_gb |
|
|
|
def main(): |
|
"""Main execution flow.""" |
|
model_storage = AdvancedModelParameters( |
|
num_shards=2089, |
|
base_filename="charm15", |
|
hidden_size=16384, |
|
layers_per_shard=100 |
|
) |
|
|
|
|
|
total_params, memory_gb = model_storage.estimate_parameters() |
|
logging.info(f"Estimated total parameters: {total_params:,}") |
|
logging.info(f"Estimated memory usage: {memory_gb:.2f} GB") |
|
|
|
|
|
model_storage.save_sharded_parameters() |
|
|
|
if __name__ == "__main__": |
|
main() |