Commit
·
de956c8
1
Parent(s):
6a5a99e
Added model tests + checkpointing of the scaler object
Browse files
protac_degradation_predictor/__init__.py
CHANGED
@@ -5,6 +5,8 @@ from .data_utils import (
|
|
5 |
is_active,
|
6 |
)
|
7 |
from .pytorch_models import (
|
|
|
|
|
8 |
train_model,
|
9 |
)
|
10 |
from .sklearn_models import (
|
|
|
5 |
is_active,
|
6 |
)
|
7 |
from .pytorch_models import (
|
8 |
+
PROTAC_Predictor,
|
9 |
+
PROTAC_Model,
|
10 |
train_model,
|
11 |
)
|
12 |
from .sklearn_models import (
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -73,7 +73,7 @@ def pytorch_model_objective(
|
|
73 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
74 |
|
75 |
# Start the CV over the folds
|
76 |
-
X = train_val_df.drop(columns=active_label)
|
77 |
y = train_val_df[active_label].tolist()
|
78 |
report = []
|
79 |
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
@@ -108,11 +108,11 @@ def pytorch_model_objective(
|
|
108 |
# At each fold, train and evaluate the Pytorch model
|
109 |
# Train the model with the current set of hyperparameters
|
110 |
_, _, metrics = train_model(
|
111 |
-
protein2embedding,
|
112 |
-
cell2embedding,
|
113 |
-
smiles2fp,
|
114 |
-
train_df,
|
115 |
-
val_df,
|
116 |
hidden_dim=hidden_dim,
|
117 |
batch_size=batch_size,
|
118 |
join_embeddings=join_embeddings,
|
@@ -223,7 +223,7 @@ def hyperparameter_tuning_and_training(
|
|
223 |
test_report = []
|
224 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
225 |
for i in range(n_models_for_test):
|
226 |
-
pl.seed_everything(42 + i)
|
227 |
_, _, metrics = train_model(
|
228 |
protein2embedding=protein2embedding,
|
229 |
cell2embedding=cell2embedding,
|
@@ -235,9 +235,9 @@ def hyperparameter_tuning_and_training(
|
|
235 |
active_label=active_label,
|
236 |
max_epochs=max_epochs,
|
237 |
disabled_embeddings=[],
|
238 |
-
logger_name=f'{logger_name}
|
239 |
enable_checkpointing=True,
|
240 |
-
checkpoint_model_name=f'
|
241 |
**study.best_params,
|
242 |
)
|
243 |
# Rename the keys in the metrics dictionary
|
@@ -245,6 +245,9 @@ def hyperparameter_tuning_and_training(
|
|
245 |
metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
|
246 |
metrics['model_type'] = 'Pytorch'
|
247 |
metrics['test_model_id'] = i
|
|
|
|
|
|
|
248 |
test_report.append(metrics.copy())
|
249 |
test_report = pd.DataFrame(test_report)
|
250 |
|
|
|
73 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
74 |
|
75 |
# Start the CV over the folds
|
76 |
+
X = train_val_df.copy().drop(columns=active_label)
|
77 |
y = train_val_df[active_label].tolist()
|
78 |
report = []
|
79 |
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
|
|
108 |
# At each fold, train and evaluate the Pytorch model
|
109 |
# Train the model with the current set of hyperparameters
|
110 |
_, _, metrics = train_model(
|
111 |
+
protein2embedding=protein2embedding,
|
112 |
+
cell2embedding=cell2embedding,
|
113 |
+
smiles2fp=smiles2fp,
|
114 |
+
train_df=train_df,
|
115 |
+
val_df=val_df,
|
116 |
hidden_dim=hidden_dim,
|
117 |
batch_size=batch_size,
|
118 |
join_embeddings=join_embeddings,
|
|
|
223 |
test_report = []
|
224 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
225 |
for i in range(n_models_for_test):
|
226 |
+
pl.seed_everything(42 + i + 1)
|
227 |
_, _, metrics = train_model(
|
228 |
protein2embedding=protein2embedding,
|
229 |
cell2embedding=cell2embedding,
|
|
|
235 |
active_label=active_label,
|
236 |
max_epochs=max_epochs,
|
237 |
disabled_embeddings=[],
|
238 |
+
logger_name=f'{logger_name}_best_model_n{i}',
|
239 |
enable_checkpointing=True,
|
240 |
+
checkpoint_model_name=f'best_model_n{i}_{split_type}',
|
241 |
**study.best_params,
|
242 |
)
|
243 |
# Rename the keys in the metrics dictionary
|
|
|
245 |
metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
|
246 |
metrics['model_type'] = 'Pytorch'
|
247 |
metrics['test_model_id'] = i
|
248 |
+
metrics['test_len'] = len(test_df)
|
249 |
+
metrics['test_active_perc'] = test_df[active_label].sum() / len(test_df)
|
250 |
+
metrics['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
|
251 |
test_report.append(metrics.copy())
|
252 |
test_report = pd.DataFrame(test_report)
|
253 |
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import warnings
|
|
|
|
|
2 |
from typing import Literal, List, Tuple, Optional, Dict
|
3 |
|
4 |
from .protac_dataset import PROTAC_Dataset
|
@@ -125,7 +127,6 @@ class PROTAC_Predictor(nn.Module):
|
|
125 |
return x
|
126 |
|
127 |
|
128 |
-
|
129 |
class PROTAC_Model(pl.LightningModule):
|
130 |
|
131 |
def __init__(
|
@@ -218,13 +219,26 @@ class PROTAC_Model(pl.LightningModule):
|
|
218 |
'''
|
219 |
|
220 |
# Apply scaling in datasets
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
|
|
225 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
226 |
-
|
227 |
-
|
228 |
|
229 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
230 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
@@ -316,6 +330,23 @@ class PROTAC_Model(pl.LightningModule):
|
|
316 |
batch_size=self.batch_size,
|
317 |
shuffle=False,
|
318 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
|
321 |
def train_model(
|
@@ -421,7 +452,7 @@ def train_model(
|
|
421 |
monitor='val_acc',
|
422 |
mode='max',
|
423 |
verbose=False,
|
424 |
-
filename=checkpoint_model_name + '-{epoch}-{
|
425 |
))
|
426 |
# Define Trainer
|
427 |
trainer = pl.Trainer(
|
@@ -455,6 +486,9 @@ def train_model(
|
|
455 |
warnings.simplefilter("ignore")
|
456 |
trainer.fit(model)
|
457 |
metrics = trainer.validate(model, verbose=False)[0]
|
|
|
|
|
|
|
458 |
if test_df is not None:
|
459 |
test_metrics = trainer.test(model, verbose=False)[0]
|
460 |
metrics.update(test_metrics)
|
@@ -472,6 +506,15 @@ def load_model(
|
|
472 |
Returns:
|
473 |
PROTAC_Model: The loaded model.
|
474 |
"""
|
475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
model.eval()
|
477 |
return model
|
|
|
1 |
import warnings
|
2 |
+
import pickle
|
3 |
+
import logging
|
4 |
from typing import Literal, List, Tuple, Optional, Dict
|
5 |
|
6 |
from .protac_dataset import PROTAC_Dataset
|
|
|
127 |
return x
|
128 |
|
129 |
|
|
|
130 |
class PROTAC_Model(pl.LightningModule):
|
131 |
|
132 |
def __init__(
|
|
|
219 |
'''
|
220 |
|
221 |
# Apply scaling in datasets
|
222 |
+
self.scalers = None
|
223 |
+
if self.apply_scaling and self.train_dataset is not None:
|
224 |
+
self.initialize_scalers()
|
225 |
+
|
226 |
+
def initialize_scalers(self):
|
227 |
+
"""Initialize or reinitialize scalers based on dataset properties."""
|
228 |
+
if self.scalers is None:
|
229 |
+
use_single_scaler = self.join_embeddings == 'beginning'
|
230 |
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
231 |
+
self.apply_scalers()
|
232 |
+
|
233 |
+
def apply_scalers(self):
|
234 |
+
"""Apply scalers to all datasets."""
|
235 |
+
use_single_scaler = self.join_embeddings == 'beginning'
|
236 |
+
if self.train_dataset:
|
237 |
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
238 |
+
if self.val_dataset:
|
239 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
240 |
+
if self.test_dataset:
|
241 |
+
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
242 |
|
243 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
244 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
|
|
330 |
batch_size=self.batch_size,
|
331 |
shuffle=False,
|
332 |
)
|
333 |
+
|
334 |
+
def on_save_checkpoint(self, checkpoint):
|
335 |
+
""" Serialize the scalers to the checkpoint. """
|
336 |
+
checkpoint['scalers'] = pickle.dumps(self.scalers)
|
337 |
+
|
338 |
+
def on_load_checkpoint(self, checkpoint):
|
339 |
+
"""Deserialize the scalers from the checkpoint."""
|
340 |
+
if 'scalers' in checkpoint:
|
341 |
+
self.scalers = pickle.loads(checkpoint['scalers'])
|
342 |
+
else:
|
343 |
+
self.scalers = None
|
344 |
+
if self.apply_scaling:
|
345 |
+
if self.scalers is not None:
|
346 |
+
# Re-apply scalers to ensure datasets are scaled
|
347 |
+
self.apply_scalers()
|
348 |
+
else:
|
349 |
+
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
|
350 |
|
351 |
|
352 |
def train_model(
|
|
|
452 |
monitor='val_acc',
|
453 |
mode='max',
|
454 |
verbose=False,
|
455 |
+
filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}',
|
456 |
))
|
457 |
# Define Trainer
|
458 |
trainer = pl.Trainer(
|
|
|
486 |
warnings.simplefilter("ignore")
|
487 |
trainer.fit(model)
|
488 |
metrics = trainer.validate(model, verbose=False)[0]
|
489 |
+
|
490 |
+
# Add train metrics to metrics
|
491 |
+
|
492 |
if test_df is not None:
|
493 |
test_metrics = trainer.test(model, verbose=False)[0]
|
494 |
metrics.update(test_metrics)
|
|
|
506 |
Returns:
|
507 |
PROTAC_Model: The loaded model.
|
508 |
"""
|
509 |
+
# NOTE: The `map_locat` argument is automatically handled in newer versions
|
510 |
+
# of PyTorch Lightning, but we keep it here for compatibility with older ones.
|
511 |
+
model = PROTAC_Model.load_from_checkpoint(
|
512 |
+
ckpt_path,
|
513 |
+
map_location=torch.device('cpu') if not torch.cuda.is_available() else None,
|
514 |
+
)
|
515 |
+
# NOTE: The following is left as example for eventually re-applying scaling
|
516 |
+
# with other datasets...
|
517 |
+
# if model.apply_scaling:
|
518 |
+
# model.apply_scalers()
|
519 |
model.eval()
|
520 |
return model
|
src/run_experiments.py
CHANGED
@@ -207,10 +207,11 @@ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_spli
|
|
207 |
|
208 |
def main(
|
209 |
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
210 |
-
n_trials: int =
|
211 |
fast_dev_run: bool = False,
|
212 |
-
test_split: float = 0.
|
213 |
cv_n_splits: int = 5,
|
|
|
214 |
run_sklearn: bool = False,
|
215 |
):
|
216 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
@@ -287,7 +288,7 @@ def main(
|
|
287 |
n_models_for_test=3,
|
288 |
fast_dev_run=fast_dev_run,
|
289 |
n_trials=n_trials,
|
290 |
-
max_epochs=
|
291 |
logger_name=f'logs_{experiment_name}',
|
292 |
active_label=active_col,
|
293 |
study_filename=f'../reports/study_{experiment_name}.pkl',
|
|
|
207 |
|
208 |
def main(
|
209 |
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
210 |
+
n_trials: int = 100,
|
211 |
fast_dev_run: bool = False,
|
212 |
+
test_split: float = 0.1,
|
213 |
cv_n_splits: int = 5,
|
214 |
+
max_epochs: int = 100,
|
215 |
run_sklearn: bool = False,
|
216 |
):
|
217 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
|
|
288 |
n_models_for_test=3,
|
289 |
fast_dev_run=fast_dev_run,
|
290 |
n_trials=n_trials,
|
291 |
+
max_epochs=max_epochs,
|
292 |
logger_name=f'logs_{experiment_name}',
|
293 |
active_label=active_col,
|
294 |
study_filename=f'../reports/study_{experiment_name}.pkl',
|
tests/test_pytorch_model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import logging
|
5 |
+
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
7 |
+
|
8 |
+
from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def test_protac_model():
|
14 |
+
model = PROTAC_Model(hidden_dim=128)
|
15 |
+
assert model.hidden_dim == 128
|
16 |
+
assert model.smiles_emb_dim == 224
|
17 |
+
assert model.poi_emb_dim == 1024
|
18 |
+
assert model.e3_emb_dim == 1024
|
19 |
+
assert model.cell_emb_dim == 768
|
20 |
+
assert model.batch_size == 32
|
21 |
+
assert model.learning_rate == 0.001
|
22 |
+
assert model.dropout == 0.2
|
23 |
+
assert model.join_embeddings == 'concat'
|
24 |
+
assert model.train_dataset is None
|
25 |
+
assert model.val_dataset is None
|
26 |
+
assert model.test_dataset is None
|
27 |
+
assert model.disabled_embeddings == []
|
28 |
+
assert model.apply_scaling == False
|
29 |
+
|
30 |
+
def test_protac_predictor():
|
31 |
+
predictor = PROTAC_Predictor(hidden_dim=128)
|
32 |
+
assert predictor.hidden_dim == 128
|
33 |
+
assert predictor.smiles_emb_dim == 224
|
34 |
+
assert predictor.poi_emb_dim == 1024
|
35 |
+
assert predictor.e3_emb_dim == 1024
|
36 |
+
assert predictor.cell_emb_dim == 768
|
37 |
+
assert predictor.join_embeddings == 'concat'
|
38 |
+
assert predictor.disabled_embeddings == []
|
39 |
+
|
40 |
+
def test_load_model(caplog):
|
41 |
+
caplog.set_level(logging.WARNING)
|
42 |
+
|
43 |
+
model = PROTAC_Model.load_from_checkpoint(
|
44 |
+
'data/test_model.ckpt',
|
45 |
+
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
46 |
+
)
|
47 |
+
# apply_scaling: true
|
48 |
+
# batch_size: 8
|
49 |
+
# cell_emb_dim: 768
|
50 |
+
# disabled_embeddings: []
|
51 |
+
# dropout: 0.1498104322091649
|
52 |
+
# e3_emb_dim: 1024
|
53 |
+
# hidden_dim: 768
|
54 |
+
# join_embeddings: concat
|
55 |
+
# learning_rate: 4.881387978425994e-05
|
56 |
+
# poi_emb_dim: 1024
|
57 |
+
# smiles_emb_dim: 224
|
58 |
+
assert model.hidden_dim == 768
|
59 |
+
assert model.smiles_emb_dim == 224
|
60 |
+
assert model.poi_emb_dim == 1024
|
61 |
+
assert model.e3_emb_dim == 1024
|
62 |
+
assert model.cell_emb_dim == 768
|
63 |
+
assert model.batch_size == 8
|
64 |
+
assert model.learning_rate == 4.881387978425994e-05
|
65 |
+
assert model.dropout == 0.1498104322091649
|
66 |
+
assert model.join_embeddings == 'concat'
|
67 |
+
assert model.disabled_embeddings == []
|
68 |
+
assert model.apply_scaling == True
|
69 |
+
|
70 |
+
|
71 |
+
def test_checkpoint_file():
|
72 |
+
checkpoint = torch.load(
|
73 |
+
'data/test_model.ckpt',
|
74 |
+
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
75 |
+
)
|
76 |
+
print(checkpoint.keys())
|
77 |
+
print(checkpoint["hyper_parameters"])
|
78 |
+
print([k for k, v in checkpoint["state_dict"].items()])
|
79 |
+
|
80 |
+
pytest.main()
|