import time from builtins import print import argparse import torch # os.environ["CUDA_VISIBLE_DEVICES"] = '3' def get_time_str(): return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) def main(): total_parser = argparse.ArgumentParser("Pretrain Unsupervise.") total_parser.add_argument('--ckpt_path', default=None, type=str) total_parser.add_argument('--bin_path', default=None, type=str) total_parser.add_argument('--rm_prefix', default=None, type=str) # * Args for base model args = total_parser.parse_args() print('Argument parse success.') state_dict = torch.load(args.ckpt_path)['module'] new_state_dict = {} if args.rm_prefix is not None: prefix_len = len(args.rm_prefix) for k, v in state_dict.items(): if k[:prefix_len] == args.rm_prefix: new_state_dict[k[prefix_len:]] = v else: new_state_dict[k] = v else: new_state_dict = state_dict torch.save(new_state_dict, args.bin_path) if __name__ == '__main__': main()