Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from data.dataset import transform_molecule_pg | |
import pandas as pd | |
import torch.nn.functional as F | |
import torch | |
def load_model(model, fold, args): | |
model_name = os.path.join(args.target_checkpoint_path, f'Fold{fold}','Best_Model.pth') | |
pre_model = torch.load(model_name, | |
map_location=lambda storage, loc: storage) | |
model.load_state_dict(pre_model['model_state_dict']) | |
return model | |
def test_gcn(model, device, loader,args): | |
for batch in tqdm(loader, desc="Iteration"): | |
save_dict = {'target': [], | |
'smiles': [], | |
'interaction_probability': [], | |
'interaction_class': []} | |
save_dict_temp = { | |
'Folder 1': [], | |
'Folder 2': [], | |
'Folder 3': [], | |
'Folder 4': []} | |
if args.use_prot: | |
batch_mol = batch[0].to(device) | |
batch_prot = batch[1].to(device) | |
smiles = batch_mol['smiles'] | |
smiles = [smi for smi in smiles] | |
else: | |
batch_mol = batch[0].to(device) | |
smiles = batch_mol['y'] | |
smiles = [smi for smi in smiles] | |
if args.feature == 'full': | |
pass | |
elif args.feature == 'simple': | |
# only retain the top two node/edge features | |
num_features = args.num_features | |
batch_mol.x = batch_mol.x[:, :num_features] | |
batch_mol.edge_attr = batch_mol.edge_attr[:, :num_features] | |
if batch_mol.x.shape[0] == 1: | |
pass | |
else: | |
target = [args.target]*len(batch[0].y) | |
save_dict['target'].extend(target) | |
save_dict['smiles'].extend(smiles) | |
for fold in range(1,5): | |
model = load_model(model, fold, args) | |
model.eval() | |
with torch.set_grad_enabled(False): | |
if args.use_prot: | |
pred = model(batch_mol,batch_prot) | |
else: | |
pred = model(batch_mol) | |
pred = F.softmax(pred,dim=1) | |
save_dict_temp[f'Folder {fold}'].extend(pred.cpu().tolist()) | |
for fold in range(1,5): | |
save_dict_temp[f'Folder {fold}'] = np.array(save_dict_temp[f'Folder {fold}']) | |
save_dict['interaction_probability'] = np.mean([save_dict_temp['Folder 1'], save_dict_temp['Folder 2'], save_dict_temp['Folder 3'], save_dict_temp['Folder 4']], axis = 0).tolist() | |
save_dict['interaction_class'] = [int(np.argmax(i)) for i in save_dict['interaction_probability']] | |
save_dict['interaction_probability'] = [x[1] for x in save_dict['interaction_probability']] | |
for fold in range(1,5): | |
save_dict_temp[f'Folder {fold}'] = save_dict_temp[f'Folder {fold}'].tolist() | |
save_df = pd.DataFrame(save_dict) | |
save_path = os.path.join(args.output_file) | |
print("Saving results to csv file: ", save_path) | |
save_df.to_csv(save_path, mode='a', header=True, index= False) | |
def get_dataset_inference( | |
dataset, use_prot=False, target=None, args=None, advs=False, saliency=False | |
): | |
DEFAULT_LABEL = 0 | |
total_dataset = [] | |
if use_prot: | |
prot_graph = transform_molecule_pg( | |
target["Fasta"].item(), label=None, is_prot=use_prot | |
) | |
for mol, label in tqdm( | |
zip(dataset["Smiles"], [DEFAULT_LABEL]*len(dataset["Smiles"])), total=len(dataset["Smiles"]) | |
): | |
if use_prot: | |
total_dataset.append( | |
[ | |
transform_molecule_pg(mol, label, args, advs, saliency=saliency), | |
prot_graph, | |
] | |
) | |
else: | |
total_dataset.append( | |
transform_molecule_pg(mol, label, args, advs, saliency=saliency) | |
) | |
return total_dataset | |