File size: 1,374 Bytes
5c67222
 
 
590e9a6
5c67222
 
 
 
 
 
 
 
 
 
 
 
 
590e9a6
 
 
 
 
 
5c67222
 
 
 
 
590e9a6
 
5c67222
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from safetensors.torch import load_file, save_file

def reduce_key_size(input_file, output_file, reduction_factor=0.50):
    # Load the model
    model_data = load_file(input_file)

    # Iterate through all the tensors and reduce their size
    for key in model_data.keys():
        original_tensor = model_data[key]
        
        # Calculate the new size
        new_size = int(original_tensor.size(0) * (1 - reduction_factor))
        
        # Resize the tensor (this could vary depending on your requirements)
        if new_size > 0:  # Ensure new size is positive
            reduced_tensor = original_tensor[:new_size]
            
            # Convert to FP8 (assuming your environment supports FP8)
            # Note: PyTorch does not have built-in FP8 support; you may need to use a custom implementation
            # Here's an example of converting a tensor to float16, then quantizing it
            fp8_tensor = torch.quantize_per_tensor(reduced_tensor.to(torch.float16), scale=1.0, zero_point=0, dtype=torch.qint8)
            model_data[key] = fp8_tensor
    
    # Save the modified model
    save_file(model_data, output_file)

# Usage example
input_file = 'merged_model_16.safetensors'  # Replace with your input model file
output_file = 'merged_model_8.safetensors'  # Desired output file name
reduce_key_size(input_file, output_file)