import os | |
import torch | |
from transformers import AutoModel | |
checkpoint_dir = '../out/pretrain-core-3/hf' | |
output_dir = '../out/pretrain-core-3/hf' | |
# Load model | |
state_dict = torch.load(os.path.join(checkpoint_dir, 'model.pth')) | |
model = AutoModel.from_pretrained( | |
checkpoint_dir, | |
state_dict=state_dict, | |
) | |
# Save .safetensors files | |
model.save_pretrained(output_dir) | |