|
|
|
import os |
|
import pickle |
|
import numpy as np |
|
from sklearn.preprocessing import StandardScaler |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
|
|
|
|
def collate_target(elem_dicts): |
|
""" Data loading for interactions based on protein target. """ |
|
batch = {'pids': [], 'targets': torch.Tensor(), 'mids': [], 'drugs': torch.Tensor()} |
|
labels = torch.Tensor() |
|
|
|
for i, elem_dict in enumerate(elem_dicts): |
|
|
|
labels = torch.cat((labels, torch.tensor(elem_dict['label'])), 0) |
|
|
|
batch['mids'].append(elem_dict['mid']) |
|
drug = torch.tensor(elem_dict['drug']).float().unsqueeze(0) |
|
batch['drugs'] = drug if len(batch['drugs']) == 0 else torch.cat((batch['drugs'], drug), 0) |
|
|
|
batch['pids'].append(elem_dict['pid']) |
|
if i == 0: |
|
batch['targets'] = torch.tensor(elem_dict['target']).float() |
|
|
|
batch['drugs'] = batch['drugs'].unsqueeze(0) |
|
|
|
return batch, labels |
|
|
|
|
|
class DrugRetrieval(Dataset): |
|
|
|
def __init__(self, data_path, query_target, query_embedding, drug_encoder='CDDD', target_encoder='SeqVec'): |
|
super(DrugRetrieval, self).__init__() |
|
self.data_path = data_path |
|
self.remove_batch = True |
|
|
|
assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.' |
|
assert os.path.exists(f'data/Lenselink/processed/{target_encoder}_encoding_train.pickle'), 'Context target embeddings not available.' |
|
|
|
|
|
emb_dict = self.get_drug_embeddings(encoder_name=drug_encoder) |
|
self.drug_ids = list(emb_dict.keys()) |
|
self.drug_embeddings = list(emb_dict.values()) |
|
|
|
|
|
self.target_scaler = StandardScaler() |
|
context = self.get_target_embeddings(encoder_name=target_encoder) |
|
self.context = self.standardize(embeddings=context) |
|
|
|
|
|
self.query_target = query_target |
|
self.query_embedding = self.target_scaler.transform([query_embedding.tolist()]) |
|
|
|
def __getitem__(self, item): |
|
return { |
|
'pid': self.query_target, |
|
'target': self.query_embedding, |
|
'mid': self.drug_ids[item], |
|
'drug': self.drug_embeddings[item], |
|
'label': [0], |
|
} |
|
|
|
def get_target_memory(self, exclude_pids=None): |
|
memory = list(self.context.values()) |
|
return torch.tensor(np.stack(memory), dtype=torch.float32) |
|
|
|
def __len__(self): |
|
return len(self.drug_ids) |
|
|
|
def get_drug_embeddings(self, encoder_name): |
|
with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding.pickle'), 'rb') as handle: |
|
embeddings = pickle.load(handle) |
|
return embeddings |
|
|
|
def get_target_embeddings(self, encoder_name): |
|
with open(f'data/Lenselink/processed/{encoder_name}_encoding_train.pickle', 'rb') as handle: |
|
embeddings = pickle.load(handle) |
|
return embeddings |
|
|
|
def standardize(self, embeddings): |
|
split_embeddings = [] |
|
unique_ids = embeddings.keys() |
|
for unique_id in unique_ids: |
|
split_embeddings.append(embeddings[unique_id].tolist()) |
|
|
|
self.target_scaler.fit(split_embeddings) |
|
scaled_embeddings = self.target_scaler.transform(split_embeddings) |
|
|
|
new_dict = {} |
|
for unique_id, emb in zip(unique_ids, scaled_embeddings): |
|
new_dict[unique_id] = emb |
|
return new_dict |
|
|
|
|