emmas96 commited on
Commit
e1fadab
·
1 Parent(s): 30763dd

fix target context to Lenselink due to trained checkpoint

Browse files
Files changed (1) hide show
  1. src/dataset.py +10 -5
src/dataset.py CHANGED
@@ -39,16 +39,16 @@ class DrugRetrieval(Dataset):
39
  self.remove_batch = True
40
 
41
  assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
42
- assert os.path.exists(os.path.join(self.data_path, f'processed/{target_encoder}_encoding_train.pickle')), 'Context target embeddings not available.'
43
 
44
  # Drugs
45
- emb_dict = self.get_embeddings(encoder_name=drug_encoder)
46
  self.drug_ids = list(emb_dict.keys())
47
  self.drug_embeddings = list(emb_dict.values())
48
 
49
  # Context
50
  self.target_scaler = StandardScaler()
51
- context = self.get_embeddings(encoder_name=target_encoder)
52
  self.context = self.standardize(embeddings=context)
53
 
54
  # Query target
@@ -71,8 +71,13 @@ class DrugRetrieval(Dataset):
71
  def __len__(self):
72
  return len(self.drug_ids)
73
 
74
- def get_embeddings(self, encoder_name):
75
- with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding{"_train" if encoder_name == "SeqVec" else ""}.pickle'), 'rb') as handle:
 
 
 
 
 
76
  embeddings = pickle.load(handle)
77
  return embeddings
78
 
 
39
  self.remove_batch = True
40
 
41
  assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
42
+ assert os.path.exists(f'data/Lenselink/processed/{target_encoder}_encoding_train.pickle')), 'Context target embeddings not available.'
43
 
44
  # Drugs
45
+ emb_dict = self.get_drug_embeddings(encoder_name=drug_encoder)
46
  self.drug_ids = list(emb_dict.keys())
47
  self.drug_embeddings = list(emb_dict.values())
48
 
49
  # Context
50
  self.target_scaler = StandardScaler()
51
+ context = self.get_target_embeddings(encoder_name=target_encoder)
52
  self.context = self.standardize(embeddings=context)
53
 
54
  # Query target
 
71
  def __len__(self):
72
  return len(self.drug_ids)
73
 
74
+ def get_drug_embeddings(self, encoder_name):
75
+ with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding.pickle'), 'rb') as handle:
76
+ embeddings = pickle.load(handle)
77
+ return embeddings
78
+
79
+ def get_target_embeddings(self, encoder_name):
80
+ with open(f'data/Lenselink/processed/{encoder_name}_encoding_train.pickle'), 'rb') as handle:
81
  embeddings = pickle.load(handle)
82
  return embeddings
83