import os import pickle import numpy as np from sklearn.preprocessing import StandardScaler import torch import torch.nn as nn from torch.utils.data import Dataset def collate_target(elem_dicts): """ Data loading for interactions based on protein target. """ batch = {'pids': [], 'targets': torch.Tensor(), 'mids': [], 'drugs': torch.Tensor()} labels = torch.Tensor() for i, elem_dict in enumerate(elem_dicts): labels = torch.cat((labels, torch.tensor(elem_dict['label'])), 0) batch['mids'].append(elem_dict['mid']) drug = torch.tensor(elem_dict['drug']).float().unsqueeze(0) batch['drugs'] = drug if len(batch['drugs']) == 0 else torch.cat((batch['drugs'], drug), 0) batch['pids'].append(elem_dict['pid']) if i == 0: batch['targets'] = torch.tensor(elem_dict['target']).float() batch['drugs'] = batch['drugs'].unsqueeze(0) return batch, labels class DrugRetrieval(Dataset): def __init__(self, data_path, query_target, query_embedding, drug_encoder='CDDD', target_encoder='SeqVec'): super(DrugRetrieval, self).__init__() self.data_path = data_path self.remove_batch = True assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.' assert os.path.exists(os.path.join(self.data_path, f'processed/{target_encoder}_encoding_train.pickle')), 'Context target embeddings not available.' # Drugs emb_dict = self.get_embeddings(encoder_name=drug_encoder) self.drug_ids = list(emb_dict.keys()) self.drug_embeddings = list(emb_dict.values()) # Context self.target_scaler = StandardScaler() context = self.get_embeddings(encoder_name=target_encoder) self.context = self.standardize(embeddings=context) # Query target self.query_target = query_target self.query_embedding = self.target_scaler.transform([query_embedding.tolist()]) def __getitem__(self, item): return { 'pid': self.query_target, 'target': self.query_embedding, 'mid': self.drug_ids[item], 'drug': self.drug_embeddings[item], 'label': [0], } def get_target_memory(self, exclude_pids=None): memory = list(self.context.values()) return torch.tensor(np.stack(memory), dtype=torch.float32) def __len__(self): return len(self.drug_ids) def get_embeddings(self, encoder_name): with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding{"_train" if encoder_name == "SeqVec" else ""}.pickle'), 'rb') as handle: embeddings = pickle.load(handle) return embeddings def standardize(self, embeddings): split_embeddings = [] unique_ids = embeddings.keys() for unique_id in unique_ids: split_embeddings.append(embeddings[unique_id].tolist()) self.target_scaler.fit(split_embeddings) scaled_embeddings = self.target_scaler.transform(split_embeddings) new_dict = {} for unique_id, emb in zip(unique_ids, scaled_embeddings): new_dict[unique_id] = emb return new_dict