pranavajay commited on
Commit
590e9a6
·
verified ·
1 Parent(s): 2d6b3e3

Update rp.py

Browse files
Files changed (1) hide show
  1. rp.py +9 -4
rp.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from safetensors.torch import load_file, save_file
3
 
4
- def reduce_key_size(input_file, output_file, reduction_factor=0.30):
5
  # Load the model
6
  model_data = load_file(input_file)
7
 
@@ -15,12 +15,17 @@ def reduce_key_size(input_file, output_file, reduction_factor=0.30):
15
  # Resize the tensor (this could vary depending on your requirements)
16
  if new_size > 0: # Ensure new size is positive
17
  reduced_tensor = original_tensor[:new_size]
18
- model_data[key] = reduced_tensor
 
 
 
 
 
19
 
20
  # Save the modified model
21
  save_file(model_data, output_file)
22
 
23
  # Usage example
24
- input_file = 'merged_model2.safetensors' # Replace with your input model file
25
- output_file = 'merged_model_216.safetensors' # Desired output file name
26
  reduce_key_size(input_file, output_file)
 
1
  import torch
2
  from safetensors.torch import load_file, save_file
3
 
4
+ def reduce_key_size(input_file, output_file, reduction_factor=0.50):
5
  # Load the model
6
  model_data = load_file(input_file)
7
 
 
15
  # Resize the tensor (this could vary depending on your requirements)
16
  if new_size > 0: # Ensure new size is positive
17
  reduced_tensor = original_tensor[:new_size]
18
+
19
+ # Convert to FP8 (assuming your environment supports FP8)
20
+ # Note: PyTorch does not have built-in FP8 support; you may need to use a custom implementation
21
+ # Here's an example of converting a tensor to float16, then quantizing it
22
+ fp8_tensor = torch.quantize_per_tensor(reduced_tensor.to(torch.float16), scale=1.0, zero_point=0, dtype=torch.qint8)
23
+ model_data[key] = fp8_tensor
24
 
25
  # Save the modified model
26
  save_file(model_data, output_file)
27
 
28
  # Usage example
29
+ input_file = 'merged_model_16.safetensors' # Replace with your input model file
30
+ output_file = 'merged_model_8.safetensors' # Desired output file name
31
  reduce_key_size(input_file, output_file)