Spaces:
Sleeping
Sleeping
import os | |
import copy | |
import torch | |
import shutil | |
from collections import OrderedDict | |
import logging | |
import numpy as np | |
def save_ckpt( | |
model, | |
optimizer, | |
train_epoch_loss, | |
val_epoch_loss, | |
train_epoch_nap, | |
val_epoch_nap, | |
epoch, | |
save_path, | |
name_pre, | |
name_post="best", | |
): | |
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} | |
state = { | |
"epoch": epoch, | |
"model_state_dict": model_cpu, | |
"optimizer_state_dict": optimizer.state_dict(), | |
"train_loss": train_epoch_loss, | |
"val_loss": val_epoch_loss, | |
"train_map": train_epoch_nap, | |
"val_map": val_epoch_nap, | |
} | |
if not os.path.exists(save_path): | |
os.mkdir(save_path) | |
print("Directory ", save_path, " is created.") | |
filename = "{}/{}_{}.pth".format(save_path, name_pre, name_post) | |
torch.save(state, filename) | |
print("model has been saved as {}".format(filename)) | |
def load_pretrained_models( | |
model, pretrained_model, phase, ismax=True | |
): # ismax means max best | |
if ismax: | |
best_value = -np.inf | |
else: | |
best_value = np.inf | |
epoch = -1 | |
if pretrained_model: | |
if os.path.isfile(pretrained_model): | |
logging.info("===> Loading checkpoint '{}'".format(pretrained_model)) | |
checkpoint = torch.load(pretrained_model) | |
try: | |
best_value = checkpoint["best_value"] | |
if best_value == -np.inf or best_value == np.inf: | |
show_best_value = False | |
else: | |
show_best_value = True | |
except: | |
best_value = best_value | |
show_best_value = False | |
model_dict = model.state_dict() | |
ckpt_model_state_dict = checkpoint["state_dict"] | |
# rename ckpt (avoid name is not same because of multi-gpus) | |
is_model_multi_gpus = True if list(model_dict)[0][0][0] == "m" else False | |
is_ckpt_multi_gpus = ( | |
True if list(ckpt_model_state_dict)[0][0] == "m" else False | |
) | |
if not (is_model_multi_gpus == is_ckpt_multi_gpus): | |
temp_dict = OrderedDict() | |
for k, v in ckpt_model_state_dict.items(): | |
if is_ckpt_multi_gpus: | |
name = k[7:] # remove 'module.' | |
else: | |
name = "module." + k # add 'module' | |
temp_dict[name] = v | |
# load params | |
ckpt_model_state_dict = temp_dict | |
model_dict.update(ckpt_model_state_dict) | |
model.load_state_dict(ckpt_model_state_dict) | |
if show_best_value: | |
logging.info( | |
"The pretrained_model is at checkpoint {}. \t " | |
"Best value: {}".format(checkpoint["epoch"], best_value) | |
) | |
else: | |
logging.info( | |
"The pretrained_model is at checkpoint {}.".format( | |
checkpoint["epoch"] | |
) | |
) | |
if phase == "train": | |
epoch = checkpoint["epoch"] | |
else: | |
epoch = -1 | |
else: | |
raise ImportError( | |
"===> No checkpoint found at '{}'".format(pretrained_model) | |
) | |
else: | |
logging.info("===> No pre-trained model") | |
return model, best_value, epoch | |
def load_pretrained_optimizer( | |
pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True | |
): | |
if pretrained_model: | |
if os.path.isfile(pretrained_model): | |
checkpoint = torch.load(pretrained_model) | |
if "optimizer_state_dict" in checkpoint.keys(): | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
for state in optimizer.state.values(): | |
for k, v in state.items(): | |
if torch.is_tensor(v): | |
state[k] = v.cuda() | |
if "scheduler_state_dict" in checkpoint.keys(): | |
scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
if use_ckpt_lr: | |
try: | |
lr = scheduler.get_lr()[0] | |
except: | |
lr = lr | |
return optimizer, scheduler, lr | |
def save_checkpoint(state, is_best, save_path, postname): | |
filename = "{}/{}_{}.pth".format(save_path, postname, int(state["epoch"])) | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, "{}/{}_best.pth".format(save_path, postname)) | |
def change_ckpt_dict(model, optimizer, scheduler, opt): | |
for _ in range(opt.epoch): | |
scheduler.step() | |
is_best = opt.test_value < opt.best_value | |
opt.best_value = min(opt.test_value, opt.best_value) | |
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} | |
# optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} | |
save_checkpoint( | |
{ | |
"epoch": opt.epoch, | |
"state_dict": model_cpu, | |
"optimizer_state_dict": optimizer.state_dict(), | |
"scheduler_state_dict": scheduler.state_dict(), | |
"best_value": opt.best_value, | |
}, | |
is_best, | |
opt.save_path, | |
opt.post, | |
) | |
def load_models(model, device): | |
print("------Copying model 1---------") | |
prop_predictor1 = copy.deepcopy(model) | |
print("------Copying model 2---------") | |
prop_predictor2 = copy.deepcopy(model) | |
print("------Copying model 3---------") | |
prop_predictor3 = copy.deepcopy(model) | |
print("------Copying model 4---------") | |
prop_predictor4 = copy.deepcopy(model) | |
test_model_path = "./PLA-Net/pretrained-models/BINARY_ada" | |
test_model_path1 = test_model_path + "/Fold1/Best_Model.pth" | |
test_model_path2 = test_model_path + "/Fold2/Best_Model.pth" | |
test_model_path3 = test_model_path + "/Fold3/Best_Model.pth" | |
test_model_path4 = test_model_path + "/Fold4/Best_Model.pth" | |
# LOAD MODELS | |
print("------- Loading weights----------") | |
ckpt1 = torch.load(test_model_path1, map_location=lambda storage, loc: storage) | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
].t() | |
ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
] = ckpt1["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
].t() | |
prop_predictor1.load_state_dict(ckpt1["model_state_dict"]) | |
prop_predictor1.to(device) | |
ckpt2 = torch.load(test_model_path2, map_location=lambda storage, loc: storage) | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
].t() | |
ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
] = ckpt2["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
].t() | |
prop_predictor2.load_state_dict(ckpt2["model_state_dict"]) | |
prop_predictor2.to(device) | |
ckpt3 = torch.load(test_model_path3, map_location=lambda storage, loc: storage) | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
].t() | |
ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
] = ckpt3["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
].t() | |
prop_predictor3.load_state_dict(ckpt3["model_state_dict"]) | |
prop_predictor3.to(device) | |
ckpt4 = torch.load(test_model_path4, map_location=lambda storage, loc: storage) | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
].t() | |
ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
] = ckpt4["model_state_dict"][ | |
"molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
].t() | |
prop_predictor4.load_state_dict(ckpt4["model_state_dict"]) | |
prop_predictor4.to(device) | |
return prop_predictor1, prop_predictor2, prop_predictor3, prop_predictor4 | |