ribesstefano commited on
Commit
de956c8
·
1 Parent(s): 6a5a99e

Added model tests + checkpointing of the scaler object

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -5,6 +5,8 @@ from .data_utils import (
5
  is_active,
6
  )
7
  from .pytorch_models import (
 
 
8
  train_model,
9
  )
10
  from .sklearn_models import (
 
5
  is_active,
6
  )
7
  from .pytorch_models import (
8
+ PROTAC_Predictor,
9
+ PROTAC_Model,
10
  train_model,
11
  )
12
  from .sklearn_models import (
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -73,7 +73,7 @@ def pytorch_model_objective(
73
  dropout = trial.suggest_float('dropout', *dropout_options)
74
 
75
  # Start the CV over the folds
76
- X = train_val_df.drop(columns=active_label)
77
  y = train_val_df[active_label].tolist()
78
  report = []
79
  for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
@@ -108,11 +108,11 @@ def pytorch_model_objective(
108
  # At each fold, train and evaluate the Pytorch model
109
  # Train the model with the current set of hyperparameters
110
  _, _, metrics = train_model(
111
- protein2embedding,
112
- cell2embedding,
113
- smiles2fp,
114
- train_df,
115
- val_df,
116
  hidden_dim=hidden_dim,
117
  batch_size=batch_size,
118
  join_embeddings=join_embeddings,
@@ -223,7 +223,7 @@ def hyperparameter_tuning_and_training(
223
  test_report = []
224
  # Retrain N models with the best hyperparameters (measure model uncertainty)
225
  for i in range(n_models_for_test):
226
- pl.seed_everything(42 + i)
227
  _, _, metrics = train_model(
228
  protein2embedding=protein2embedding,
229
  cell2embedding=cell2embedding,
@@ -235,9 +235,9 @@ def hyperparameter_tuning_and_training(
235
  active_label=active_label,
236
  max_epochs=max_epochs,
237
  disabled_embeddings=[],
238
- logger_name=f'{logger_name}_best_model_{i}',
239
  enable_checkpointing=True,
240
- checkpoint_model_name=f'best_model_{split_type}_{i}',
241
  **study.best_params,
242
  )
243
  # Rename the keys in the metrics dictionary
@@ -245,6 +245,9 @@ def hyperparameter_tuning_and_training(
245
  metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
246
  metrics['model_type'] = 'Pytorch'
247
  metrics['test_model_id'] = i
 
 
 
248
  test_report.append(metrics.copy())
249
  test_report = pd.DataFrame(test_report)
250
 
 
73
  dropout = trial.suggest_float('dropout', *dropout_options)
74
 
75
  # Start the CV over the folds
76
+ X = train_val_df.copy().drop(columns=active_label)
77
  y = train_val_df[active_label].tolist()
78
  report = []
79
  for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
 
108
  # At each fold, train and evaluate the Pytorch model
109
  # Train the model with the current set of hyperparameters
110
  _, _, metrics = train_model(
111
+ protein2embedding=protein2embedding,
112
+ cell2embedding=cell2embedding,
113
+ smiles2fp=smiles2fp,
114
+ train_df=train_df,
115
+ val_df=val_df,
116
  hidden_dim=hidden_dim,
117
  batch_size=batch_size,
118
  join_embeddings=join_embeddings,
 
223
  test_report = []
224
  # Retrain N models with the best hyperparameters (measure model uncertainty)
225
  for i in range(n_models_for_test):
226
+ pl.seed_everything(42 + i + 1)
227
  _, _, metrics = train_model(
228
  protein2embedding=protein2embedding,
229
  cell2embedding=cell2embedding,
 
235
  active_label=active_label,
236
  max_epochs=max_epochs,
237
  disabled_embeddings=[],
238
+ logger_name=f'{logger_name}_best_model_n{i}',
239
  enable_checkpointing=True,
240
+ checkpoint_model_name=f'best_model_n{i}_{split_type}',
241
  **study.best_params,
242
  )
243
  # Rename the keys in the metrics dictionary
 
245
  metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
246
  metrics['model_type'] = 'Pytorch'
247
  metrics['test_model_id'] = i
248
+ metrics['test_len'] = len(test_df)
249
+ metrics['test_active_perc'] = test_df[active_label].sum() / len(test_df)
250
+ metrics['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
251
  test_report.append(metrics.copy())
252
  test_report = pd.DataFrame(test_report)
253
 
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -1,4 +1,6 @@
1
  import warnings
 
 
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
  from .protac_dataset import PROTAC_Dataset
@@ -125,7 +127,6 @@ class PROTAC_Predictor(nn.Module):
125
  return x
126
 
127
 
128
-
129
  class PROTAC_Model(pl.LightningModule):
130
 
131
  def __init__(
@@ -218,13 +219,26 @@ class PROTAC_Model(pl.LightningModule):
218
  '''
219
 
220
  # Apply scaling in datasets
221
- if self.apply_scaling:
222
- use_single_scaler = True if self.join_embeddings == 'beginning' else False
 
 
 
 
 
 
223
  self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
 
 
 
 
 
 
224
  self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
 
225
  self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
226
- if self.test_dataset:
227
- self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
228
 
229
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
230
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
@@ -316,6 +330,23 @@ class PROTAC_Model(pl.LightningModule):
316
  batch_size=self.batch_size,
317
  shuffle=False,
318
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  def train_model(
@@ -421,7 +452,7 @@ def train_model(
421
  monitor='val_acc',
422
  mode='max',
423
  verbose=False,
424
- filename=checkpoint_model_name + '-{epoch}-{val_metrics_opt_score:.4f}',
425
  ))
426
  # Define Trainer
427
  trainer = pl.Trainer(
@@ -455,6 +486,9 @@ def train_model(
455
  warnings.simplefilter("ignore")
456
  trainer.fit(model)
457
  metrics = trainer.validate(model, verbose=False)[0]
 
 
 
458
  if test_df is not None:
459
  test_metrics = trainer.test(model, verbose=False)[0]
460
  metrics.update(test_metrics)
@@ -472,6 +506,15 @@ def load_model(
472
  Returns:
473
  PROTAC_Model: The loaded model.
474
  """
475
- model = PROTAC_Model.load_from_checkpoint(ckpt_path)
 
 
 
 
 
 
 
 
 
476
  model.eval()
477
  return model
 
1
  import warnings
2
+ import pickle
3
+ import logging
4
  from typing import Literal, List, Tuple, Optional, Dict
5
 
6
  from .protac_dataset import PROTAC_Dataset
 
127
  return x
128
 
129
 
 
130
  class PROTAC_Model(pl.LightningModule):
131
 
132
  def __init__(
 
219
  '''
220
 
221
  # Apply scaling in datasets
222
+ self.scalers = None
223
+ if self.apply_scaling and self.train_dataset is not None:
224
+ self.initialize_scalers()
225
+
226
+ def initialize_scalers(self):
227
+ """Initialize or reinitialize scalers based on dataset properties."""
228
+ if self.scalers is None:
229
+ use_single_scaler = self.join_embeddings == 'beginning'
230
  self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
231
+ self.apply_scalers()
232
+
233
+ def apply_scalers(self):
234
+ """Apply scalers to all datasets."""
235
+ use_single_scaler = self.join_embeddings == 'beginning'
236
+ if self.train_dataset:
237
  self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
238
+ if self.val_dataset:
239
  self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
240
+ if self.test_dataset:
241
+ self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
242
 
243
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
244
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
 
330
  batch_size=self.batch_size,
331
  shuffle=False,
332
  )
333
+
334
+ def on_save_checkpoint(self, checkpoint):
335
+ """ Serialize the scalers to the checkpoint. """
336
+ checkpoint['scalers'] = pickle.dumps(self.scalers)
337
+
338
+ def on_load_checkpoint(self, checkpoint):
339
+ """Deserialize the scalers from the checkpoint."""
340
+ if 'scalers' in checkpoint:
341
+ self.scalers = pickle.loads(checkpoint['scalers'])
342
+ else:
343
+ self.scalers = None
344
+ if self.apply_scaling:
345
+ if self.scalers is not None:
346
+ # Re-apply scalers to ensure datasets are scaled
347
+ self.apply_scalers()
348
+ else:
349
+ logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
350
 
351
 
352
  def train_model(
 
452
  monitor='val_acc',
453
  mode='max',
454
  verbose=False,
455
+ filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}',
456
  ))
457
  # Define Trainer
458
  trainer = pl.Trainer(
 
486
  warnings.simplefilter("ignore")
487
  trainer.fit(model)
488
  metrics = trainer.validate(model, verbose=False)[0]
489
+
490
+ # Add train metrics to metrics
491
+
492
  if test_df is not None:
493
  test_metrics = trainer.test(model, verbose=False)[0]
494
  metrics.update(test_metrics)
 
506
  Returns:
507
  PROTAC_Model: The loaded model.
508
  """
509
+ # NOTE: The `map_locat` argument is automatically handled in newer versions
510
+ # of PyTorch Lightning, but we keep it here for compatibility with older ones.
511
+ model = PROTAC_Model.load_from_checkpoint(
512
+ ckpt_path,
513
+ map_location=torch.device('cpu') if not torch.cuda.is_available() else None,
514
+ )
515
+ # NOTE: The following is left as example for eventually re-applying scaling
516
+ # with other datasets...
517
+ # if model.apply_scaling:
518
+ # model.apply_scalers()
519
  model.eval()
520
  return model
src/run_experiments.py CHANGED
@@ -207,10 +207,11 @@ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_spli
207
 
208
  def main(
209
  active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
210
- n_trials: int = 50,
211
  fast_dev_run: bool = False,
212
- test_split: float = 0.2,
213
  cv_n_splits: int = 5,
 
214
  run_sklearn: bool = False,
215
  ):
216
  """ Train a PROTAC model using the given datasets and hyperparameters.
@@ -287,7 +288,7 @@ def main(
287
  n_models_for_test=3,
288
  fast_dev_run=fast_dev_run,
289
  n_trials=n_trials,
290
- max_epochs=10,
291
  logger_name=f'logs_{experiment_name}',
292
  active_label=active_col,
293
  study_filename=f'../reports/study_{experiment_name}.pkl',
 
207
 
208
  def main(
209
  active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
210
+ n_trials: int = 100,
211
  fast_dev_run: bool = False,
212
+ test_split: float = 0.1,
213
  cv_n_splits: int = 5,
214
+ max_epochs: int = 100,
215
  run_sklearn: bool = False,
216
  ):
217
  """ Train a PROTAC model using the given datasets and hyperparameters.
 
288
  n_models_for_test=3,
289
  fast_dev_run=fast_dev_run,
290
  n_trials=n_trials,
291
+ max_epochs=max_epochs,
292
  logger_name=f'logs_{experiment_name}',
293
  active_label=active_col,
294
  study_filename=f'../reports/study_{experiment_name}.pkl',
tests/test_pytorch_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ import sys
4
+ import logging
5
+
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor
9
+
10
+ import torch
11
+
12
+
13
+ def test_protac_model():
14
+ model = PROTAC_Model(hidden_dim=128)
15
+ assert model.hidden_dim == 128
16
+ assert model.smiles_emb_dim == 224
17
+ assert model.poi_emb_dim == 1024
18
+ assert model.e3_emb_dim == 1024
19
+ assert model.cell_emb_dim == 768
20
+ assert model.batch_size == 32
21
+ assert model.learning_rate == 0.001
22
+ assert model.dropout == 0.2
23
+ assert model.join_embeddings == 'concat'
24
+ assert model.train_dataset is None
25
+ assert model.val_dataset is None
26
+ assert model.test_dataset is None
27
+ assert model.disabled_embeddings == []
28
+ assert model.apply_scaling == False
29
+
30
+ def test_protac_predictor():
31
+ predictor = PROTAC_Predictor(hidden_dim=128)
32
+ assert predictor.hidden_dim == 128
33
+ assert predictor.smiles_emb_dim == 224
34
+ assert predictor.poi_emb_dim == 1024
35
+ assert predictor.e3_emb_dim == 1024
36
+ assert predictor.cell_emb_dim == 768
37
+ assert predictor.join_embeddings == 'concat'
38
+ assert predictor.disabled_embeddings == []
39
+
40
+ def test_load_model(caplog):
41
+ caplog.set_level(logging.WARNING)
42
+
43
+ model = PROTAC_Model.load_from_checkpoint(
44
+ 'data/test_model.ckpt',
45
+ map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
46
+ )
47
+ # apply_scaling: true
48
+ # batch_size: 8
49
+ # cell_emb_dim: 768
50
+ # disabled_embeddings: []
51
+ # dropout: 0.1498104322091649
52
+ # e3_emb_dim: 1024
53
+ # hidden_dim: 768
54
+ # join_embeddings: concat
55
+ # learning_rate: 4.881387978425994e-05
56
+ # poi_emb_dim: 1024
57
+ # smiles_emb_dim: 224
58
+ assert model.hidden_dim == 768
59
+ assert model.smiles_emb_dim == 224
60
+ assert model.poi_emb_dim == 1024
61
+ assert model.e3_emb_dim == 1024
62
+ assert model.cell_emb_dim == 768
63
+ assert model.batch_size == 8
64
+ assert model.learning_rate == 4.881387978425994e-05
65
+ assert model.dropout == 0.1498104322091649
66
+ assert model.join_embeddings == 'concat'
67
+ assert model.disabled_embeddings == []
68
+ assert model.apply_scaling == True
69
+
70
+
71
+ def test_checkpoint_file():
72
+ checkpoint = torch.load(
73
+ 'data/test_model.ckpt',
74
+ map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
75
+ )
76
+ print(checkpoint.keys())
77
+ print(checkpoint["hyper_parameters"])
78
+ print([k for k, v in checkpoint["state_dict"].items()])
79
+
80
+ pytest.main()