PLA-Net / utils /ckpt_util.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
14.3 kB
import os
import copy
import torch
import shutil
from collections import OrderedDict
import logging
import numpy as np
def save_ckpt(
model,
optimizer,
train_epoch_loss,
val_epoch_loss,
train_epoch_nap,
val_epoch_nap,
epoch,
save_path,
name_pre,
name_post="best",
):
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
state = {
"epoch": epoch,
"model_state_dict": model_cpu,
"optimizer_state_dict": optimizer.state_dict(),
"train_loss": train_epoch_loss,
"val_loss": val_epoch_loss,
"train_map": train_epoch_nap,
"val_map": val_epoch_nap,
}
if not os.path.exists(save_path):
os.mkdir(save_path)
print("Directory ", save_path, " is created.")
filename = "{}/{}_{}.pth".format(save_path, name_pre, name_post)
torch.save(state, filename)
print("model has been saved as {}".format(filename))
def load_pretrained_models(
model, pretrained_model, phase, ismax=True
): # ismax means max best
if ismax:
best_value = -np.inf
else:
best_value = np.inf
epoch = -1
if pretrained_model:
if os.path.isfile(pretrained_model):
logging.info("===> Loading checkpoint '{}'".format(pretrained_model))
checkpoint = torch.load(pretrained_model)
try:
best_value = checkpoint["best_value"]
if best_value == -np.inf or best_value == np.inf:
show_best_value = False
else:
show_best_value = True
except:
best_value = best_value
show_best_value = False
model_dict = model.state_dict()
ckpt_model_state_dict = checkpoint["state_dict"]
# rename ckpt (avoid name is not same because of multi-gpus)
is_model_multi_gpus = True if list(model_dict)[0][0][0] == "m" else False
is_ckpt_multi_gpus = (
True if list(ckpt_model_state_dict)[0][0] == "m" else False
)
if not (is_model_multi_gpus == is_ckpt_multi_gpus):
temp_dict = OrderedDict()
for k, v in ckpt_model_state_dict.items():
if is_ckpt_multi_gpus:
name = k[7:] # remove 'module.'
else:
name = "module." + k # add 'module'
temp_dict[name] = v
# load params
ckpt_model_state_dict = temp_dict
model_dict.update(ckpt_model_state_dict)
model.load_state_dict(ckpt_model_state_dict)
if show_best_value:
logging.info(
"The pretrained_model is at checkpoint {}. \t "
"Best value: {}".format(checkpoint["epoch"], best_value)
)
else:
logging.info(
"The pretrained_model is at checkpoint {}.".format(
checkpoint["epoch"]
)
)
if phase == "train":
epoch = checkpoint["epoch"]
else:
epoch = -1
else:
raise ImportError(
"===> No checkpoint found at '{}'".format(pretrained_model)
)
else:
logging.info("===> No pre-trained model")
return model, best_value, epoch
def load_pretrained_optimizer(
pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True
):
if pretrained_model:
if os.path.isfile(pretrained_model):
checkpoint = torch.load(pretrained_model)
if "optimizer_state_dict" in checkpoint.keys():
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
if "scheduler_state_dict" in checkpoint.keys():
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if use_ckpt_lr:
try:
lr = scheduler.get_lr()[0]
except:
lr = lr
return optimizer, scheduler, lr
def save_checkpoint(state, is_best, save_path, postname):
filename = "{}/{}_{}.pth".format(save_path, postname, int(state["epoch"]))
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "{}/{}_best.pth".format(save_path, postname))
def change_ckpt_dict(model, optimizer, scheduler, opt):
for _ in range(opt.epoch):
scheduler.step()
is_best = opt.test_value < opt.best_value
opt.best_value = min(opt.test_value, opt.best_value)
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
# optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()}
save_checkpoint(
{
"epoch": opt.epoch,
"state_dict": model_cpu,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_value": opt.best_value,
},
is_best,
opt.save_path,
opt.post,
)
def load_models(model, device):
print("------Copying model 1---------")
prop_predictor1 = copy.deepcopy(model)
print("------Copying model 2---------")
prop_predictor2 = copy.deepcopy(model)
print("------Copying model 3---------")
prop_predictor3 = copy.deepcopy(model)
print("------Copying model 4---------")
prop_predictor4 = copy.deepcopy(model)
test_model_path = "./PLA-Net/pretrained-models/BINARY_ada"
test_model_path1 = test_model_path + "/Fold1/Best_Model.pth"
test_model_path2 = test_model_path + "/Fold2/Best_Model.pth"
test_model_path3 = test_model_path + "/Fold3/Best_Model.pth"
test_model_path4 = test_model_path + "/Fold4/Best_Model.pth"
# LOAD MODELS
print("------- Loading weights----------")
ckpt1 = torch.load(test_model_path1, map_location=lambda storage, loc: storage)
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
].t()
ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
] = ckpt1["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
].t()
prop_predictor1.load_state_dict(ckpt1["model_state_dict"])
prop_predictor1.to(device)
ckpt2 = torch.load(test_model_path2, map_location=lambda storage, loc: storage)
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
].t()
ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
] = ckpt2["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
].t()
prop_predictor2.load_state_dict(ckpt2["model_state_dict"])
prop_predictor2.to(device)
ckpt3 = torch.load(test_model_path3, map_location=lambda storage, loc: storage)
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
].t()
ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
] = ckpt3["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
].t()
prop_predictor3.load_state_dict(ckpt3["model_state_dict"])
prop_predictor3.to(device)
ckpt4 = torch.load(test_model_path4, map_location=lambda storage, loc: storage)
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight"
].t()
ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
] = ckpt4["model_state_dict"][
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight"
].t()
prop_predictor4.load_state_dict(ckpt4["model_state_dict"])
prop_predictor4.to(device)
return prop_predictor1, prop_predictor2, prop_predictor3, prop_predictor4