File size: 1,374 Bytes
5c67222 590e9a6 5c67222 590e9a6 5c67222 590e9a6 5c67222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from safetensors.torch import load_file, save_file
def reduce_key_size(input_file, output_file, reduction_factor=0.50):
# Load the model
model_data = load_file(input_file)
# Iterate through all the tensors and reduce their size
for key in model_data.keys():
original_tensor = model_data[key]
# Calculate the new size
new_size = int(original_tensor.size(0) * (1 - reduction_factor))
# Resize the tensor (this could vary depending on your requirements)
if new_size > 0: # Ensure new size is positive
reduced_tensor = original_tensor[:new_size]
# Convert to FP8 (assuming your environment supports FP8)
# Note: PyTorch does not have built-in FP8 support; you may need to use a custom implementation
# Here's an example of converting a tensor to float16, then quantizing it
fp8_tensor = torch.quantize_per_tensor(reduced_tensor.to(torch.float16), scale=1.0, zero_point=0, dtype=torch.qint8)
model_data[key] = fp8_tensor
# Save the modified model
save_file(model_data, output_file)
# Usage example
input_file = 'merged_model_16.safetensors' # Replace with your input model file
output_file = 'merged_model_8.safetensors' # Desired output file name
reduce_key_size(input_file, output_file) |