File size: 3,483 Bytes
e3bf276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a4e43
e3bf276
 
e1fadab
e3bf276
 
 
 
 
e1fadab
e3bf276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1fadab
 
 
 
 
 
04a4e43
e3bf276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import os
import pickle
import numpy as np
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
from torch.utils.data import Dataset


def collate_target(elem_dicts):
    """ Data loading for interactions based on protein target. """
    batch = {'pids': [], 'targets': torch.Tensor(), 'mids': [], 'drugs': torch.Tensor()}
    labels = torch.Tensor()

    for i, elem_dict in enumerate(elem_dicts):

        labels = torch.cat((labels, torch.tensor(elem_dict['label'])), 0)

        batch['mids'].append(elem_dict['mid'])
        drug = torch.tensor(elem_dict['drug']).float().unsqueeze(0)
        batch['drugs'] = drug if len(batch['drugs']) == 0 else torch.cat((batch['drugs'], drug), 0)
                            
        batch['pids'].append(elem_dict['pid'])
        if i == 0:
            batch['targets'] = torch.tensor(elem_dict['target']).float()

    batch['drugs'] = batch['drugs'].unsqueeze(0)

    return batch, labels


class DrugRetrieval(Dataset):

    def __init__(self, data_path, query_target, query_embedding, drug_encoder='CDDD', target_encoder='SeqVec'):
        super(DrugRetrieval, self).__init__()
        self.data_path = data_path
        self.remove_batch = True

        assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
        assert os.path.exists(f'data/Lenselink/processed/{target_encoder}_encoding_train.pickle'), 'Context target embeddings not available.'

        # Drugs
        emb_dict = self.get_drug_embeddings(encoder_name=drug_encoder)
        self.drug_ids = list(emb_dict.keys())
        self.drug_embeddings = list(emb_dict.values())
        
        # Context
        self.target_scaler = StandardScaler()
        context = self.get_target_embeddings(encoder_name=target_encoder)
        self.context = self.standardize(embeddings=context)

        # Query target
        self.query_target = query_target
        self.query_embedding = self.target_scaler.transform([query_embedding.tolist()])

    def __getitem__(self, item):
        return {
            'pid': self.query_target,
            'target': self.query_embedding,
            'mid': self.drug_ids[item],
            'drug': self.drug_embeddings[item],
            'label': [0],
        }

    def get_target_memory(self, exclude_pids=None):
        memory = list(self.context.values())
        return torch.tensor(np.stack(memory), dtype=torch.float32)

    def __len__(self):
        return len(self.drug_ids)

    def get_drug_embeddings(self, encoder_name):
        with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding.pickle'), 'rb') as handle:
            embeddings = pickle.load(handle)
        return embeddings

    def get_target_embeddings(self, encoder_name):
        with open(f'data/Lenselink/processed/{encoder_name}_encoding_train.pickle', 'rb') as handle:
            embeddings = pickle.load(handle)
        return embeddings

    def standardize(self, embeddings):
        split_embeddings = []
        unique_ids = embeddings.keys()
        for unique_id in unique_ids:
            split_embeddings.append(embeddings[unique_id].tolist())

        self.target_scaler.fit(split_embeddings)
        scaled_embeddings = self.target_scaler.transform(split_embeddings)

        new_dict = {}
        for unique_id, emb in zip(unique_ids, scaled_embeddings):
            new_dict[unique_id] = emb
        return new_dict