File size: 4,071 Bytes
799e642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

import os 
import torch
import numpy as np
from tqdm import tqdm
from data.dataset import transform_molecule_pg
import pandas as pd
import torch.nn.functional as F
import torch

def load_model(model, fold, args):
    model_name = os.path.join(args.target_checkpoint_path, f'Fold{fold}','Best_Model.pth')
        
    pre_model = torch.load(model_name,
        map_location=lambda storage, loc: storage)
    model.load_state_dict(pre_model['model_state_dict'])

    return model   

@torch.no_grad()
def test_gcn(model, device, loader,args):
    
    for batch in tqdm(loader, desc="Iteration"):
        save_dict = {'target': [],
                 'smiles': [],
                 'interaction_probability': [],
                 'interaction_class': []}
        save_dict_temp = {
                 'Folder 1': [],
                 'Folder 2': [],
                 'Folder 3': [],
                 'Folder 4': []}

        if args.use_prot:
            batch_mol = batch[0].to(device)
            batch_prot = batch[1].to(device)
            smiles = batch_mol['smiles']
            smiles = [smi for smi in smiles]
        else:
            batch_mol = batch[0].to(device)
            smiles = batch_mol['y']
            smiles = [smi for smi in smiles]
            
        if args.feature == 'full':
            pass
        elif args.feature == 'simple':
            # only retain the top two node/edge features
            num_features = args.num_features
            batch_mol.x = batch_mol.x[:, :num_features]
            batch_mol.edge_attr = batch_mol.edge_attr[:, :num_features]
        if batch_mol.x.shape[0] == 1:
            pass
        else:

            target = [args.target]*len(batch[0].y)
            save_dict['target'].extend(target)
            save_dict['smiles'].extend(smiles)
            for fold in range(1,5):
                model = load_model(model, fold, args)
                model.eval()

                with torch.set_grad_enabled(False):   
                    if args.use_prot:
                        pred = model(batch_mol,batch_prot)
                    else:
                        pred = model(batch_mol)
                    pred = F.softmax(pred,dim=1)
                    save_dict_temp[f'Folder {fold}'].extend(pred.cpu().tolist()) 
            for fold in range(1,5):
                save_dict_temp[f'Folder {fold}'] = np.array(save_dict_temp[f'Folder {fold}'])

            save_dict['interaction_probability'] = np.mean([save_dict_temp['Folder 1'], save_dict_temp['Folder 2'], save_dict_temp['Folder 3'], save_dict_temp['Folder 4']], axis = 0).tolist()
            save_dict['interaction_class'] = [int(np.argmax(i)) for i in save_dict['interaction_probability']]
            save_dict['interaction_probability'] = [x[1] for x in save_dict['interaction_probability']]
            for fold in range(1,5):
                save_dict_temp[f'Folder {fold}'] = save_dict_temp[f'Folder {fold}'].tolist()
            
            save_df = pd.DataFrame(save_dict)

            save_path = os.path.join(args.output_file)

            print("Saving results to csv file: ", save_path)
            save_df.to_csv(save_path, mode='a', header=True, index= False)
                
                
                
def get_dataset_inference(
    dataset, use_prot=False, target=None, args=None, advs=False, saliency=False
):
    DEFAULT_LABEL = 0
    total_dataset = []
    if use_prot:
        prot_graph = transform_molecule_pg(
            target["Fasta"].item(), label=None, is_prot=use_prot
        )

    for mol, label in tqdm(
        zip(dataset["Smiles"], [DEFAULT_LABEL]*len(dataset["Smiles"])), total=len(dataset["Smiles"])
    ):
        if use_prot:
            total_dataset.append(
                [
                    transform_molecule_pg(mol, label, args, advs, saliency=saliency),
                    prot_graph,
                ]
            )
        else:
            total_dataset.append(
                transform_molecule_pg(mol, label, args, advs, saliency=saliency)
            )
    return total_dataset