Spaces:
Runtime error
Runtime error
import time | |
from builtins import print | |
import argparse | |
import torch | |
# os.environ["CUDA_VISIBLE_DEVICES"] = '3' | |
def get_time_str(): | |
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |
def main(): | |
total_parser = argparse.ArgumentParser("Pretrain Unsupervise.") | |
total_parser.add_argument('--ckpt_path', default=None, type=str) | |
total_parser.add_argument('--bin_path', default=None, type=str) | |
total_parser.add_argument('--rm_prefix', default=None, type=str) | |
# * Args for base model | |
args = total_parser.parse_args() | |
print('Argument parse success.') | |
state_dict = torch.load(args.ckpt_path)['module'] | |
new_state_dict = {} | |
if args.rm_prefix is not None: | |
prefix_len = len(args.rm_prefix) | |
for k, v in state_dict.items(): | |
if k[:prefix_len] == args.rm_prefix: | |
new_state_dict[k[prefix_len:]] = v | |
else: | |
new_state_dict[k] = v | |
else: | |
new_state_dict = state_dict | |
torch.save(new_state_dict, args.bin_path) | |
if __name__ == '__main__': | |
main() | |