mtasic85 commited on
Commit
2b6c108
·
1 Parent(s): 295b3ab
scripts/convert_pth_to_safetensors.py CHANGED
@@ -1,18 +1,11 @@
1
  import os
2
 
3
  import torch
4
- from transformers import AutoModel
 
5
 
6
  checkpoint_dir = '../out/pretrain-core-3/hf'
7
  output_dir = '../out/pretrain-core-3/hf'
8
 
9
- # Load model
10
  state_dict = torch.load(os.path.join(checkpoint_dir, 'model.pth'))
11
-
12
- model = AutoModel.from_pretrained(
13
- checkpoint_dir,
14
- state_dict=state_dict,
15
- )
16
-
17
- # Save .safetensors files
18
- model.save_pretrained(output_dir)
 
1
  import os
2
 
3
  import torch
4
+ from safetensors.torch import save_file
5
+
6
 
7
  checkpoint_dir = '../out/pretrain-core-3/hf'
8
  output_dir = '../out/pretrain-core-3/hf'
9
 
 
10
  state_dict = torch.load(os.path.join(checkpoint_dir, 'model.pth'))
11
+ save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))