PLA-Net / utils /model.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
4.07 kB
import os
import torch
import numpy as np
from tqdm import tqdm
from data.dataset import transform_molecule_pg
import pandas as pd
import torch.nn.functional as F
import torch
def load_model(model, fold, args):
model_name = os.path.join(args.target_checkpoint_path, f'Fold{fold}','Best_Model.pth')
pre_model = torch.load(model_name,
map_location=lambda storage, loc: storage)
model.load_state_dict(pre_model['model_state_dict'])
return model
@torch.no_grad()
def test_gcn(model, device, loader,args):
for batch in tqdm(loader, desc="Iteration"):
save_dict = {'target': [],
'smiles': [],
'interaction_probability': [],
'interaction_class': []}
save_dict_temp = {
'Folder 1': [],
'Folder 2': [],
'Folder 3': [],
'Folder 4': []}
if args.use_prot:
batch_mol = batch[0].to(device)
batch_prot = batch[1].to(device)
smiles = batch_mol['smiles']
smiles = [smi for smi in smiles]
else:
batch_mol = batch[0].to(device)
smiles = batch_mol['y']
smiles = [smi for smi in smiles]
if args.feature == 'full':
pass
elif args.feature == 'simple':
# only retain the top two node/edge features
num_features = args.num_features
batch_mol.x = batch_mol.x[:, :num_features]
batch_mol.edge_attr = batch_mol.edge_attr[:, :num_features]
if batch_mol.x.shape[0] == 1:
pass
else:
target = [args.target]*len(batch[0].y)
save_dict['target'].extend(target)
save_dict['smiles'].extend(smiles)
for fold in range(1,5):
model = load_model(model, fold, args)
model.eval()
with torch.set_grad_enabled(False):
if args.use_prot:
pred = model(batch_mol,batch_prot)
else:
pred = model(batch_mol)
pred = F.softmax(pred,dim=1)
save_dict_temp[f'Folder {fold}'].extend(pred.cpu().tolist())
for fold in range(1,5):
save_dict_temp[f'Folder {fold}'] = np.array(save_dict_temp[f'Folder {fold}'])
save_dict['interaction_probability'] = np.mean([save_dict_temp['Folder 1'], save_dict_temp['Folder 2'], save_dict_temp['Folder 3'], save_dict_temp['Folder 4']], axis = 0).tolist()
save_dict['interaction_class'] = [int(np.argmax(i)) for i in save_dict['interaction_probability']]
save_dict['interaction_probability'] = [x[1] for x in save_dict['interaction_probability']]
for fold in range(1,5):
save_dict_temp[f'Folder {fold}'] = save_dict_temp[f'Folder {fold}'].tolist()
save_df = pd.DataFrame(save_dict)
save_path = os.path.join(args.output_file)
print("Saving results to csv file: ", save_path)
save_df.to_csv(save_path, mode='a', header=True, index= False)
def get_dataset_inference(
dataset, use_prot=False, target=None, args=None, advs=False, saliency=False
):
DEFAULT_LABEL = 0
total_dataset = []
if use_prot:
prot_graph = transform_molecule_pg(
target["Fasta"].item(), label=None, is_prot=use_prot
)
for mol, label in tqdm(
zip(dataset["Smiles"], [DEFAULT_LABEL]*len(dataset["Smiles"])), total=len(dataset["Smiles"])
):
if use_prot:
total_dataset.append(
[
transform_molecule_pg(mol, label, args, advs, saliency=saliency),
prot_graph,
]
)
else:
total_dataset.append(
transform_molecule_pg(mol, label, args, advs, saliency=saliency)
)
return total_dataset