hyper-dti / src /dataset.py
emmas96's picture
fix typo
04a4e43
raw
history blame
3.48 kB
import os
import pickle
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from torch.utils.data import Dataset
def collate_target(elem_dicts):
""" Data loading for interactions based on protein target. """
batch = {'pids': [], 'targets': torch.Tensor(), 'mids': [], 'drugs': torch.Tensor()}
labels = torch.Tensor()
for i, elem_dict in enumerate(elem_dicts):
labels = torch.cat((labels, torch.tensor(elem_dict['label'])), 0)
batch['mids'].append(elem_dict['mid'])
drug = torch.tensor(elem_dict['drug']).float().unsqueeze(0)
batch['drugs'] = drug if len(batch['drugs']) == 0 else torch.cat((batch['drugs'], drug), 0)
batch['pids'].append(elem_dict['pid'])
if i == 0:
batch['targets'] = torch.tensor(elem_dict['target']).float()
batch['drugs'] = batch['drugs'].unsqueeze(0)
return batch, labels
class DrugRetrieval(Dataset):
def __init__(self, data_path, query_target, query_embedding, drug_encoder='CDDD', target_encoder='SeqVec'):
super(DrugRetrieval, self).__init__()
self.data_path = data_path
self.remove_batch = True
assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
assert os.path.exists(f'data/Lenselink/processed/{target_encoder}_encoding_train.pickle'), 'Context target embeddings not available.'
# Drugs
emb_dict = self.get_drug_embeddings(encoder_name=drug_encoder)
self.drug_ids = list(emb_dict.keys())
self.drug_embeddings = list(emb_dict.values())
# Context
self.target_scaler = StandardScaler()
context = self.get_target_embeddings(encoder_name=target_encoder)
self.context = self.standardize(embeddings=context)
# Query target
self.query_target = query_target
self.query_embedding = self.target_scaler.transform([query_embedding.tolist()])
def __getitem__(self, item):
return {
'pid': self.query_target,
'target': self.query_embedding,
'mid': self.drug_ids[item],
'drug': self.drug_embeddings[item],
'label': [0],
}
def get_target_memory(self, exclude_pids=None):
memory = list(self.context.values())
return torch.tensor(np.stack(memory), dtype=torch.float32)
def __len__(self):
return len(self.drug_ids)
def get_drug_embeddings(self, encoder_name):
with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding.pickle'), 'rb') as handle:
embeddings = pickle.load(handle)
return embeddings
def get_target_embeddings(self, encoder_name):
with open(f'data/Lenselink/processed/{encoder_name}_encoding_train.pickle', 'rb') as handle:
embeddings = pickle.load(handle)
return embeddings
def standardize(self, embeddings):
split_embeddings = []
unique_ids = embeddings.keys()
for unique_id in unique_ids:
split_embeddings.append(embeddings[unique_id].tolist())
self.target_scaler.fit(split_embeddings)
scaled_embeddings = self.target_scaler.transform(split_embeddings)
new_dict = {}
for unique_id, emb in zip(unique_ids, scaled_embeddings):
new_dict[unique_id] = emb
return new_dict