mtasic85 commited on
Commit
0e08ab8
·
1 Parent(s): 4999d38

convert pth to safetensors

Browse files
scripts/convert_pth_to_safetensors.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from transformers import AutoModel
5
+
6
+ checkpoint_dir = '../out/pretrain-core-3/hf'
7
+ output_dir = '../out/pretrain-core-3/hf'
8
+
9
+ # Load model
10
+ state_dict = torch.load(os.path.join(checkpoint_dir, 'model.pth'))
11
+
12
+ model = AutoModel.from_pretrained(
13
+ checkpoint_dir,
14
+ state_dict=state_dict,
15
+ )
16
+
17
+ # Save .safetensors files
18
+ model.save_pretrained(output_dir)