ribesstefano commited on
Commit
b86d3ec
·
1 Parent(s): 74a86c6

Added majority voting evaluation

Browse files
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from typing import Literal, List, Tuple, Optional, Dict
3
  import logging
4
 
5
- from .pytorch_models import train_model
6
  from .sklearn_models import (
7
  train_sklearn_model,
8
  suggest_random_forest,
@@ -11,6 +11,7 @@ from .sklearn_models import (
11
  suggest_gradient_boosting,
12
  )
13
 
 
14
  import optuna
15
  from optuna.samplers import TPESampler
16
  import joblib
@@ -27,6 +28,56 @@ from sklearn.model_selection import (
27
  )
28
  import numpy as np
29
  import pytorch_lightning as pl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def pytorch_model_objective(
@@ -77,15 +128,15 @@ def pytorch_model_objective(
77
  X = train_val_df.copy().drop(columns=active_label)
78
  y = train_val_df[active_label].tolist()
79
  report = []
 
 
80
  for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
81
  logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
82
  # Get the train and val sets
83
  train_df = train_val_df.iloc[train_index]
84
  val_df = train_val_df.iloc[val_index]
85
 
86
- # Check for data leakage and get some statistics
87
- leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
88
- leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
89
  stats = {
90
  'model_type': 'Pytorch',
91
  'fold': k,
@@ -93,22 +144,15 @@ def pytorch_model_objective(
93
  'val_len': len(val_df),
94
  'train_perc': len(train_df) / len(train_val_df),
95
  'val_perc': len(val_df) / len(train_val_df),
96
- 'train_active_perc': train_df[active_label].sum() / len(train_df),
97
- 'train_inactive_perc': (len(train_df) - train_df[active_label].sum()) / len(train_df),
98
- 'val_active_perc': val_df[active_label].sum() / len(val_df),
99
- 'val_inactive_perc': (len(val_df) - val_df[active_label].sum()) / len(val_df),
100
- 'num_leaking_uniprot': len(leaking_uniprot),
101
- 'num_leaking_smiles': len(leaking_smiles),
102
- 'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
103
- 'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
104
  }
 
105
  if groups is not None:
106
  stats['train_unique_groups'] = len(np.unique(groups[train_index]))
107
  stats['val_unique_groups'] = len(np.unique(groups[val_index]))
108
 
109
  # At each fold, train and evaluate the Pytorch model
110
  # Train the model with the current set of hyperparameters
111
- _, trainer, metrics = train_model(
112
  protein2embedding=protein2embedding,
113
  cell2embedding=cell2embedding,
114
  smiles2fp=smiles2fp,
@@ -127,22 +171,47 @@ def pytorch_model_objective(
127
  use_logger=False,
128
  fast_dev_run=fast_dev_run,
129
  active_label=active_label,
 
130
  disabled_embeddings=disabled_embeddings,
131
  )
 
 
 
 
 
 
132
  train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
133
  stats.update(metrics)
134
  stats.update(train_metrics)
135
  report.append(stats.copy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # Get the average validation accuracy and ROC AUC accross the folds
138
- val_acc = np.mean([r['val_acc'] for r in report])
139
  val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
140
 
141
- # Save the report in the trial
142
- trial.set_user_attr('report', report)
143
-
144
  # Optuna aims to minimize the pytorch_model_objective
145
- return - val_acc - val_roc_auc
146
 
147
 
148
  def hyperparameter_tuning_and_training(
@@ -162,6 +231,7 @@ def hyperparameter_tuning_and_training(
162
  active_label: str = 'Active',
163
  max_epochs: int = 100,
164
  study_filename: Optional[str] = None,
 
165
  ) -> tuple:
166
  """ Hyperparameter tuning and training of a PROTAC model.
167
 
@@ -181,10 +251,11 @@ def hyperparameter_tuning_and_training(
181
  pl.seed_everything(42)
182
 
183
  # Define the search space
184
- hidden_dim_options = [256, 512, 768]
185
- batch_size_options = [8, 16, 32]
186
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
187
  smote_k_neighbors_options = list(range(3, 16))
 
188
 
189
  # Set the verbosity of Optuna
190
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -193,13 +264,13 @@ def hyperparameter_tuning_and_training(
193
  study = optuna.create_study(direction='minimize', sampler=sampler)
194
 
195
  study_loaded = False
196
- if study_filename:
197
  if os.path.exists(study_filename):
198
  study = joblib.load(study_filename)
199
  study_loaded = True
200
  logging.info(f'Loaded study from {study_filename}')
201
 
202
- if not study_loaded:
203
  study.optimize(
204
  lambda trial: pytorch_model_objective(
205
  trial=trial,
@@ -214,6 +285,7 @@ def hyperparameter_tuning_and_training(
214
  batch_size_options=batch_size_options,
215
  learning_rate_options=learning_rate_options,
216
  smote_k_neighbors_options=smote_k_neighbors_options,
 
217
  fast_dev_run=fast_dev_run,
218
  active_label=active_label,
219
  max_epochs=max_epochs,
@@ -228,9 +300,11 @@ def hyperparameter_tuning_and_training(
228
 
229
  # Retrain N models with the best hyperparameters (measure model uncertainty)
230
  test_report = []
 
 
231
  for i in range(n_models_for_test):
232
  pl.seed_everything(42 + i + 1)
233
- _, trainer, metrics = train_model(
234
  protein2embedding=protein2embedding,
235
  cell2embedding=cell2embedding,
236
  smiles2fp=smiles2fp,
@@ -245,29 +319,52 @@ def hyperparameter_tuning_and_training(
245
  logger_name=f'{logger_name}_best_model_n{i}',
246
  enable_checkpointing=True,
247
  checkpoint_model_name=f'best_model_n{i}_{split_type}',
 
248
  **study.best_params,
249
  )
250
  # Rename the keys in the metrics dictionary
251
  metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
252
  metrics['model_type'] = 'Pytorch'
253
  metrics['test_model_id'] = i
254
- metrics['test_len'] = len(test_df)
255
- metrics['test_active_perc'] = test_df[active_label].sum() / len(test_df)
256
- metrics['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
257
 
258
  # Add the training metrics
259
- train_metrics = {m.replace('train_', 'train_val_'): v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
260
  logging.info(f'Training metrics: {train_metrics}')
261
  logging.info(f'Training trainer.logged_metrics: {trainer.logged_metrics}')
262
  logging.info(f'Training trainer.callback_metrics: {trainer.callback_metrics}')
263
 
264
  metrics.update(train_metrics)
265
-
266
  test_report.append(metrics.copy())
 
267
  test_report = pd.DataFrame(test_report)
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  # Ablation study: disable embeddings at a time
270
  ablation_report = []
 
271
  for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
272
  logging.info('-' * 100)
273
  logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
@@ -291,9 +388,10 @@ def hyperparameter_tuning_and_training(
291
  metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
292
  metrics['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
293
  metrics['model_type'] = 'Pytorch'
 
294
 
295
  # Add the training metrics
296
- train_metrics = {m.replace('train_', 'train_val_'): v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
297
  metrics.update(train_metrics)
298
 
299
  ablation_report.append(metrics.copy())
@@ -304,7 +402,14 @@ def hyperparameter_tuning_and_training(
304
  report['split_type'] = split_type
305
 
306
  # Return the reports
307
- return cv_report, hparam_report, test_report, ablation_report
 
 
 
 
 
 
 
308
 
309
 
310
  def sklearn_model_objective(
 
2
  from typing import Literal, List, Tuple, Optional, Dict
3
  import logging
4
 
5
+ from .pytorch_models import train_model, PROTAC_Model
6
  from .sklearn_models import (
7
  train_sklearn_model,
8
  suggest_random_forest,
 
11
  suggest_gradient_boosting,
12
  )
13
 
14
+ import torch
15
  import optuna
16
  from optuna.samplers import TPESampler
17
  import joblib
 
28
  )
29
  import numpy as np
30
  import pytorch_lightning as pl
31
+ from torchmetrics import (
32
+ Accuracy,
33
+ AUROC,
34
+ Precision,
35
+ Recall,
36
+ F1Score,
37
+ )
38
+
39
+
40
+ def get_dataframe_stats(
41
+ train_df = None,
42
+ val_df = None,
43
+ test_df = None,
44
+ active_label = 'Active',
45
+ ) -> Dict:
46
+ """ Get some statistics from the dataframes.
47
+
48
+ Args:
49
+ train_df (pd.DataFrame): The training set.
50
+ val_df (pd.DataFrame): The validation set.
51
+ test_df (pd.DataFrame): The test set.
52
+ """
53
+ stats = {}
54
+ if train_df is not None:
55
+ stats['train_len'] = len(train_df)
56
+ stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
57
+ stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
58
+ if val_df is not None:
59
+ stats['val_len'] = len(val_df)
60
+ stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
61
+ stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
62
+ if test_df is not None:
63
+ stats['test_len'] = len(test_df)
64
+ stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
65
+ stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
66
+ if train_df is not None and val_df is not None:
67
+ leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
68
+ leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
69
+ stats['num_leaking_uniprot_train_val'] = len(leaking_uniprot)
70
+ stats['num_leaking_smiles_train_val'] = len(leaking_smiles)
71
+ stats['perc_leaking_uniprot_train_val'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df)
72
+ stats['perc_leaking_smiles_train_val'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df)
73
+ if train_df is not None and test_df is not None:
74
+ leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(test_df['Uniprot'])))
75
+ leaking_smiles = list(set(train_df['Smiles']).intersection(set(test_df['Smiles'])))
76
+ stats['num_leaking_uniprot_train_test'] = len(leaking_uniprot)
77
+ stats['num_leaking_smiles_train_test'] = len(leaking_smiles)
78
+ stats['perc_leaking_uniprot_train_test'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df)
79
+ stats['perc_leaking_smiles_train_test'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df)
80
+ return stats
81
 
82
 
83
  def pytorch_model_objective(
 
128
  X = train_val_df.copy().drop(columns=active_label)
129
  y = train_val_df[active_label].tolist()
130
  report = []
131
+ val_preds = []
132
+ test_preds = []
133
  for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
134
  logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
135
  # Get the train and val sets
136
  train_df = train_val_df.iloc[train_index]
137
  val_df = train_val_df.iloc[val_index]
138
 
139
+ # Get some statistics from the dataframes
 
 
140
  stats = {
141
  'model_type': 'Pytorch',
142
  'fold': k,
 
144
  'val_len': len(val_df),
145
  'train_perc': len(train_df) / len(train_val_df),
146
  'val_perc': len(val_df) / len(train_val_df),
 
 
 
 
 
 
 
 
147
  }
148
+ stats.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
149
  if groups is not None:
150
  stats['train_unique_groups'] = len(np.unique(groups[train_index]))
151
  stats['val_unique_groups'] = len(np.unique(groups[val_index]))
152
 
153
  # At each fold, train and evaluate the Pytorch model
154
  # Train the model with the current set of hyperparameters
155
+ ret = train_model(
156
  protein2embedding=protein2embedding,
157
  cell2embedding=cell2embedding,
158
  smiles2fp=smiles2fp,
 
171
  use_logger=False,
172
  fast_dev_run=fast_dev_run,
173
  active_label=active_label,
174
+ return_predictions=True,
175
  disabled_embeddings=disabled_embeddings,
176
  )
177
+ if test_df is not None:
178
+ _, trainer, metrics, val_pred, test_pred = ret
179
+ test_preds.append(test_pred)
180
+ logging.info(f'Test predictions: {test_pred}')
181
+ else:
182
+ _, trainer, metrics, val_pred = ret
183
  train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
184
  stats.update(metrics)
185
  stats.update(train_metrics)
186
  report.append(stats.copy())
187
+ val_preds.append(val_pred)
188
+
189
+ # Save the report in the trial
190
+ trial.set_user_attr('report', report)
191
+
192
+ # Get the majority vote for the test predictions
193
+ if test_df is not None:
194
+ # Get the majority vote for the test predictions
195
+ test_preds = torch.stack(test_preds)
196
+ test_preds, _ = torch.mode(test_preds, dim=0)
197
+ y = torch.tensor(test_df[active_label].tolist())
198
+ # Measure the test accuracy and ROC AUC
199
+ majority_vote_metrics = {
200
+ 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
201
+ 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
202
+ 'test_precision': Precision(task='binary')(test_preds, y).item(),
203
+ 'test_recall': Recall(task='binary')(test_preds, y).item(),
204
+ 'test_f1': F1Score(task='binary')(test_preds, y).item(),
205
+ }
206
+ majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
207
+ trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
208
+ logging.info(f'Majority vote metrics: {majority_vote_metrics}')
209
 
210
  # Get the average validation accuracy and ROC AUC accross the folds
 
211
  val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
212
 
 
 
 
213
  # Optuna aims to minimize the pytorch_model_objective
214
+ return - val_roc_auc
215
 
216
 
217
  def hyperparameter_tuning_and_training(
 
231
  active_label: str = 'Active',
232
  max_epochs: int = 100,
233
  study_filename: Optional[str] = None,
234
+ force_study: bool = False,
235
  ) -> tuple:
236
  """ Hyperparameter tuning and training of a PROTAC model.
237
 
 
251
  pl.seed_everything(42)
252
 
253
  # Define the search space
254
+ hidden_dim_options = [32, 64, 128, 256, 512, 768]
255
+ batch_size_options = [4, 8, 16, 32, 64, 128]
256
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
257
  smote_k_neighbors_options = list(range(3, 16))
258
+ dropout_options = (0.1, 0.9)
259
 
260
  # Set the verbosity of Optuna
261
  optuna.logging.set_verbosity(optuna.logging.WARNING)
 
264
  study = optuna.create_study(direction='minimize', sampler=sampler)
265
 
266
  study_loaded = False
267
+ if study_filename and not force_study:
268
  if os.path.exists(study_filename):
269
  study = joblib.load(study_filename)
270
  study_loaded = True
271
  logging.info(f'Loaded study from {study_filename}')
272
 
273
+ if not study_loaded or force_study:
274
  study.optimize(
275
  lambda trial: pytorch_model_objective(
276
  trial=trial,
 
285
  batch_size_options=batch_size_options,
286
  learning_rate_options=learning_rate_options,
287
  smote_k_neighbors_options=smote_k_neighbors_options,
288
+ dropout_options=dropout_options,
289
  fast_dev_run=fast_dev_run,
290
  active_label=active_label,
291
  max_epochs=max_epochs,
 
300
 
301
  # Retrain N models with the best hyperparameters (measure model uncertainty)
302
  test_report = []
303
+ test_preds = []
304
+ dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
305
  for i in range(n_models_for_test):
306
  pl.seed_everything(42 + i + 1)
307
+ _, trainer, metrics, test_pred = train_model(
308
  protein2embedding=protein2embedding,
309
  cell2embedding=cell2embedding,
310
  smiles2fp=smiles2fp,
 
319
  logger_name=f'{logger_name}_best_model_n{i}',
320
  enable_checkpointing=True,
321
  checkpoint_model_name=f'best_model_n{i}_{split_type}',
322
+ return_predictions=True,
323
  **study.best_params,
324
  )
325
  # Rename the keys in the metrics dictionary
326
  metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
327
  metrics['model_type'] = 'Pytorch'
328
  metrics['test_model_id'] = i
329
+ metrics.update(dfs_stats)
 
 
330
 
331
  # Add the training metrics
332
+ train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
333
  logging.info(f'Training metrics: {train_metrics}')
334
  logging.info(f'Training trainer.logged_metrics: {trainer.logged_metrics}')
335
  logging.info(f'Training trainer.callback_metrics: {trainer.callback_metrics}')
336
 
337
  metrics.update(train_metrics)
 
338
  test_report.append(metrics.copy())
339
+ test_preds.append(test_pred)
340
  test_report = pd.DataFrame(test_report)
341
 
342
+ # Get the majority vote for the test predictions
343
+ test_preds = torch.stack(test_preds)
344
+ test_preds, _ = torch.mode(test_preds, dim=0)
345
+ y = torch.tensor(test_df[active_label].tolist())
346
+ # Measure the test accuracy and ROC AUC
347
+ majority_vote_metrics = {
348
+ 'cv_models': False,
349
+ 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
350
+ 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
351
+ 'test_precision': Precision(task='binary')(test_preds, y).item(),
352
+ 'test_recall': Recall(task='binary')(test_preds, y).item(),
353
+ 'test_f1': F1Score(task='binary')(test_preds, y).item(),
354
+ }
355
+ majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
356
+ majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
357
+ majority_vote_metrics_cv['cv_models'] = True
358
+ majority_vote_report = pd.DataFrame([
359
+ majority_vote_metrics,
360
+ majority_vote_metrics_cv,
361
+ ])
362
+ majority_vote_report['model_type'] = 'Pytorch'
363
+ majority_vote_report['split_type'] = split_type
364
+
365
  # Ablation study: disable embeddings at a time
366
  ablation_report = []
367
+ dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
368
  for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
369
  logging.info('-' * 100)
370
  logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
 
388
  metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
389
  metrics['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
390
  metrics['model_type'] = 'Pytorch'
391
+ metrics.update(dfs_stats)
392
 
393
  # Add the training metrics
394
+ train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
395
  metrics.update(train_metrics)
396
 
397
  ablation_report.append(metrics.copy())
 
402
  report['split_type'] = split_type
403
 
404
  # Return the reports
405
+ ret = {
406
+ 'cv_report': cv_report,
407
+ 'hparam_report': hparam_report,
408
+ 'test_report': test_report,
409
+ 'ablation_report': ablation_report,
410
+ 'majority_vote_report': majority_vote_report,
411
+ }
412
+ return ret
413
 
414
 
415
  def sklearn_model_objective(
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -315,26 +315,6 @@ class PROTAC_Model(pl.LightningModule):
315
  e3_emb = batch['e3_emb']
316
  cell_emb = batch['cell_emb']
317
  smiles_emb = batch['smiles_emb']
318
-
319
- if self.apply_scaling:
320
- if self.join_embeddings == 'beginning':
321
- embeddings = np.hstack([
322
- np.array(smiles_emb.tolist()),
323
- np.array(poi_emb.tolist()),
324
- np.array(e3_emb.tolist()),
325
- np.array(cell_emb.tolist()),
326
- ])
327
- embeddings = self.scalers.transform(embeddings)
328
- smiles_emb = embeddings[:, :self.smiles_emb_dim]
329
- poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
330
- e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
331
- cell_emb = embeddings[:, -self.cell_emb_dim:]
332
- else:
333
- poi_emb = self.scalers['Uniprot'].transform(poi_emb)
334
- e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
335
- cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
336
- smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
337
-
338
  y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
339
  return torch.sigmoid(y_hat)
340
 
@@ -416,6 +396,7 @@ def train_model(
416
  enable_checkpointing: bool = False,
417
  checkpoint_model_name: str = 'protac',
418
  disabled_embeddings: List[str] = [],
 
419
  ) -> tuple:
420
  """ Train a PROTAC model using the given datasets and hyperparameters.
421
 
@@ -540,12 +521,19 @@ def train_model(
540
  warnings.simplefilter("ignore")
541
  trainer.fit(model)
542
  metrics = trainer.validate(model, verbose=False)[0]
543
-
544
- # Add train metrics to metrics
545
-
546
  if test_df is not None:
547
  test_metrics = trainer.test(model, verbose=False)[0]
548
  metrics.update(test_metrics)
 
 
 
 
 
 
 
 
 
549
  return model, trainer, metrics
550
 
551
 
 
315
  e3_emb = batch['e3_emb']
316
  cell_emb = batch['cell_emb']
317
  smiles_emb = batch['smiles_emb']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
319
  return torch.sigmoid(y_hat)
320
 
 
396
  enable_checkpointing: bool = False,
397
  checkpoint_model_name: str = 'protac',
398
  disabled_embeddings: List[str] = [],
399
+ return_predictions: bool = False,
400
  ) -> tuple:
401
  """ Train a PROTAC model using the given datasets and hyperparameters.
402
 
 
521
  warnings.simplefilter("ignore")
522
  trainer.fit(model)
523
  metrics = trainer.validate(model, verbose=False)[0]
524
+ # Add test metrics to metrics
 
 
525
  if test_df is not None:
526
  test_metrics = trainer.test(model, verbose=False)[0]
527
  metrics.update(test_metrics)
528
+ if return_predictions:
529
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
530
+ val_pred = trainer.predict(model, val_dl)
531
+ val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze()
532
+ if test_df is not None:
533
+ test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
534
+ test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze()
535
+ return model, trainer, metrics, val_pred, test_pred
536
+ return model, trainer, metrics, val_pred
537
  return model, trainer, metrics
538
 
539
 
src/run_experiments.py CHANGED
@@ -3,6 +3,7 @@ import sys
3
  from collections import defaultdict
4
  import warnings
5
  import logging
 
6
 
7
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
 
@@ -214,6 +215,8 @@ def main(
214
  cv_n_splits: int = 5,
215
  max_epochs: int = 100,
216
  run_sklearn: bool = False,
 
 
217
  ):
218
  """ Train a PROTAC model using the given datasets and hyperparameters.
219
 
@@ -244,10 +247,15 @@ def main(
244
  ## Get the test sets
245
  test_indeces = {}
246
  active_df = protac_df[protac_df[active_col].notna()].copy()
247
- test_indeces['random'] = get_random_split_indices(active_df, test_split)
248
- test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
249
- test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
250
- test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
 
 
 
 
 
251
 
252
  # Make directory ../reports if it does not exist
253
  if not os.path.exists('../reports'):
@@ -296,22 +304,18 @@ def main(
296
  logger_name=f'logs_{experiment_name}',
297
  active_label=active_col,
298
  study_filename=f'../reports/study_{experiment_name}.pkl',
 
299
  )
300
- cv_report, hparam_report, test_report, ablation_report = optuna_reports
301
 
302
  # Save the reports to file
303
- for report, filename in zip([cv_report, hparam_report, test_report, ablation_report], ['cv_train', 'hparams', 'test', 'ablation']):
304
- report.to_csv(f'../reports/report_{filename}_{experiment_name}.csv', index=False)
305
-
306
- reports['cv'].append(cv_report.copy())
307
- reports['hparam'].append(hparam_report.copy())
308
- reports['test'].append(test_report.copy())
309
- reports['ablation'].append(ablation_report.copy())
310
 
311
  # Save the reports to file after concatenating them
312
- for key, report in reports.items():
313
  report = pd.concat(report)
314
- report.to_csv(f'../reports/report_{key}_{active_name}_test_split_{test_split}.csv', index=False)
315
 
316
 
317
 
 
3
  from collections import defaultdict
4
  import warnings
5
  import logging
6
+ from typing import Literal
7
 
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
 
 
215
  cv_n_splits: int = 5,
216
  max_epochs: int = 100,
217
  run_sklearn: bool = False,
218
+ force_study: bool = False,
219
+ experiments: str | Literal['all', 'random', 'e3_ligase', 'tanimoto', 'uniprot'] = 'all',
220
  ):
221
  """ Train a PROTAC model using the given datasets and hyperparameters.
222
 
 
247
  ## Get the test sets
248
  test_indeces = {}
249
  active_df = protac_df[protac_df[active_col].notna()].copy()
250
+
251
+ if experiments == 'random' or experiments == 'all':
252
+ test_indeces['random'] = get_random_split_indices(active_df, test_split)
253
+ if experiments == 'uniprot' or experiments == 'all':
254
+ test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
255
+ if experiments == 'e3_ligase' or experiments == 'all':
256
+ test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
257
+ if experiments == 'tanimoto' or experiments == 'all':
258
+ test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
259
 
260
  # Make directory ../reports if it does not exist
261
  if not os.path.exists('../reports'):
 
304
  logger_name=f'logs_{experiment_name}',
305
  active_label=active_col,
306
  study_filename=f'../reports/study_{experiment_name}.pkl',
307
+ force_study=force_study,
308
  )
 
309
 
310
  # Save the reports to file
311
+ for report_name, report in optuna_reports.items():
312
+ report.to_csv(f'../reports/report_{report_name}_{experiment_name}.csv', index=False)
313
+ reports[report_name].append(report.copy())
 
 
 
 
314
 
315
  # Save the reports to file after concatenating them
316
+ for report_name, report in reports.items():
317
  report = pd.concat(report)
318
+ report.to_csv(f'../reports/report_{report_name}_{active_name}_test_split_{test_split}.csv', index=False)
319
 
320
 
321