File size: 5,860 Bytes
0fcb291 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import safetensors.torch
import concurrent.futures
import zlib
import logging
from typing import Dict, Tuple
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()]
)
class AdvancedModelParameters:
def __init__(self, num_shards=2089, base_filename="charm15", hidden_size=16384, layers_per_shard=100):
"""Initialize model parameters for a massive transformer model."""
self.num_shards = num_shards
self.base_filename = base_filename
self.hidden_size = hidden_size
self.layers_per_shard = layers_per_shard
self.ffn_multiplier = 4
self.shape = (hidden_size, hidden_size)
self.dtype = torch.float16
self.base_path = Path("model_shards")
self.base_path.mkdir(parents=True, exist_ok=True)
def generate_layer_parameters(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Generate parameters for a single transformer layer."""
params = {}
prefix = f"layer_{layer_idx}"
# Attention weights (Q, K, V, O)
for name in ["query_weight", "key_weight", "value_weight", "output_weight"]:
params[f"{prefix}.attention.{name}"] = torch.randn(
self.shape, dtype=self.dtype
) * (1.0 / self.hidden_size ** 0.5)
# FFN weights
intermediate_size = self.hidden_size * self.ffn_multiplier
params[f"{prefix}.ffn.intermediate_weight"] = torch.randn(
self.hidden_size, intermediate_size, dtype=self.dtype
) * (1.0 / self.hidden_size ** 0.5)
params[f"{prefix}.ffn.output_weight"] = torch.randn(
intermediate_size, self.hidden_size, dtype=self.dtype
) * (1.0 / intermediate_size ** 0.5)
return params
def generate_shard_parameters(self, shard_index: int) -> Dict[str, torch.Tensor]:
"""Generate parameters for a single shard."""
params = {}
start_layer = (shard_index - 1) * self.layers_per_shard
end_layer = start_layer + self.layers_per_shard
# Generate layers for this shard
for layer_idx in range(start_layer, end_layer):
params.update(self.generate_layer_parameters(layer_idx))
# Add embeddings and output layer to the first shard
if shard_index == 1:
params["embedding.word_embeddings"] = torch.randn(
50000, self.hidden_size, dtype=self.dtype
) * (1.0 / self.hidden_size ** 0.5)
params["embedding.position_embeddings"] = torch.randn(
4096, self.hidden_size, dtype=self.dtype
) * (1.0 / self.hidden_size ** 0.5)
params["output_layer"] = torch.randn(
self.hidden_size, 50000, dtype=self.dtype
) * (1.0 / self.hidden_size ** 0.5)
return params
def compress_tensor(self, tensor: torch.Tensor) -> bytes:
"""Apply zlib compression to tensor data."""
tensor_bytes = tensor.numpy().tobytes()
return zlib.compress(tensor_bytes, level=9)
def save_single_shard(self, shard_index: int) -> None:
"""Save a single model shard with compression."""
params = self.generate_shard_parameters(shard_index)
filename = self.base_path / f"{self.base_filename}_{shard_index}_of_{self.num_shards}.safetensors"
# Compress tensors
compressed_data = {key: self.compress_tensor(value) for key, value in params.items()}
# Save with metadata
metadata = {
"shard_index": shard_index,
"total_shards": self.num_shards,
"layers": self.layers_per_shard,
"hidden_size": self.hidden_size
}
safetensors.torch.save_file(compressed_data, str(filename), metadata=metadata)
logging.info(f"[✔] Shard {shard_index}/{self.num_shards} saved: {filename}")
def save_sharded_parameters(self) -> None:
"""Save all shards in parallel."""
logging.info(f"Starting to save {self.num_shards} shards...")
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(self.save_single_shard, range(1, self.num_shards + 1))
logging.info("All shards saved successfully.")
def estimate_parameters(self) -> Tuple[int, float]:
"""Estimate total parameters and memory usage."""
params_per_layer = (
4 * (self.hidden_size * self.hidden_size) + # Attention weights
self.hidden_size * (self.hidden_size * self.ffn_multiplier) + # FFN intermediate
(self.hidden_size * self.ffn_multiplier) * self.hidden_size # FFN output
)
params_per_shard = params_per_layer * self.layers_per_shard
total_params = params_per_shard * self.num_shards
# Add embedding and output layer from first shard
total_params += (
50000 * self.hidden_size + # word_embeddings
4096 * self.hidden_size + # position_embeddings
self.hidden_size * 50000 # output_layer
)
memory_gb = (total_params * 2) / 1024**3 # 2 bytes per float16
return total_params, memory_gb
def main():
"""Main execution flow."""
model_storage = AdvancedModelParameters(
num_shards=2089,
base_filename="charm15",
hidden_size=16384,
layers_per_shard=100
)
# Estimate parameters
total_params, memory_gb = model_storage.estimate_parameters()
logging.info(f"Estimated total parameters: {total_params:,}")
logging.info(f"Estimated memory usage: {memory_gb:.2f} GB")
# Save shards
model_storage.save_sharded_parameters()
if __name__ == "__main__":
main() |