ribesstefano commited on
Commit
6a5a99e
·
1 Parent(s): 4d17fea

Refactored experiments + fixed bug in dataset when applying scaling to val and test sets

Browse files
README.md CHANGED
@@ -1,5 +1,42 @@
1
  # PROTAC-Degradation-Predictor
 
2
  Predicting PROTAC protein degradation activity via machine learning.
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  > If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
5
  > Stefano.
 
1
  # PROTAC-Degradation-Predictor
2
+
3
  Predicting PROTAC protein degradation activity via machine learning.
4
 
5
+ ## Data Curation
6
+
7
+ For data curation code, please refer to the code in the Jupyter notebooks [`data_curation.ipynb`](notebooks/data_curation.ipynb).
8
+
9
+ ## Installing the Package
10
+
11
+ To install the package, run the following command:
12
+
13
+ ```bash
14
+ pip install .
15
+ ```
16
+
17
+ ## Running the Package
18
+
19
+ To run the package after installation, here is an example snippet:
20
+
21
+ ```python
22
+ import protac_degradation_predictor as pdp
23
+
24
+ protac_smiles = 'CC(C)(C)OC(=O)N1CCN(CC1)C2=CC(=C(C=C2)C(=O)NC3=CC(=C(C=C3)F)Cl)C(=O)NC4=CC=C(C=C4)F'
25
+ e3_ligase = 'VHL'
26
+ target_uniprot = 'P04637'
27
+ cell_line = 'HeLa'
28
+
29
+ active_protac = pdp.is_protac_active(
30
+ protac_smiles,
31
+ e3_ligase,
32
+ target_uniprot,
33
+ cell_line,
34
+ device='gpu', # Default to 'cpu'
35
+ proba_threshold=0.5, # Default value
36
+ )
37
+
38
+ print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
39
+ ```
40
+
41
  > If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
42
  > Stefano.
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -21,6 +21,12 @@ from sklearn.ensemble import (
21
  )
22
  from sklearn.linear_model import LogisticRegression
23
  from sklearn.svm import SVC
 
 
 
 
 
 
24
 
25
 
26
  def pytorch_model_objective(
@@ -28,8 +34,9 @@ def pytorch_model_objective(
28
  protein2embedding: Dict,
29
  cell2embedding: Dict,
30
  smiles2fp: Dict,
31
- train_df: pd.DataFrame,
32
- val_df: pd.DataFrame,
 
33
  hidden_dim_options: List[int] = [256, 512, 768],
34
  batch_size_options: List[int] = [8, 16, 32],
35
  learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
@@ -55,7 +62,7 @@ def pytorch_model_objective(
55
  active_label (str): The active label column.
56
  disabled_embeddings (List[str]): The list of disabled embeddings.
57
  """
58
- # Generate the hyperparameters
59
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
60
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
61
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
@@ -65,49 +72,90 @@ def pytorch_model_objective(
65
  apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
66
  dropout = trial.suggest_float('dropout', *dropout_options)
67
 
68
- # Train the model with the current set of hyperparameters
69
- _, _, metrics = train_model(
70
- protein2embedding,
71
- cell2embedding,
72
- smiles2fp,
73
- train_df,
74
- val_df,
75
- hidden_dim=hidden_dim,
76
- batch_size=batch_size,
77
- join_embeddings=join_embeddings,
78
- learning_rate=learning_rate,
79
- dropout=dropout,
80
- max_epochs=max_epochs,
81
- smote_k_neighbors=smote_k_neighbors,
82
- apply_scaling=apply_scaling,
83
- use_smote=use_smote,
84
- use_logger=False,
85
- fast_dev_run=fast_dev_run,
86
- active_label=active_label,
87
- disabled_embeddings=disabled_embeddings,
88
- )
89
 
90
- # Metrics is a dictionary containing at least the validation loss
91
- val_loss = metrics['val_loss']
92
- val_acc = metrics['val_acc']
93
- val_roc_auc = metrics['val_roc_auc']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Optuna aims to minimize the pytorch_model_objective
96
- return val_loss - val_acc - val_roc_auc
97
 
98
 
99
  def hyperparameter_tuning_and_training(
100
  protein2embedding: Dict,
101
  cell2embedding: Dict,
102
  smiles2fp: Dict,
103
- train_df: pd.DataFrame,
104
- val_df: pd.DataFrame,
105
- test_df: Optional[pd.DataFrame] = None,
 
 
 
106
  fast_dev_run: bool = False,
107
  n_trials: int = 50,
108
  logger_name: str = 'protac_hparam_search',
109
  active_label: str = 'Active',
110
- disabled_embeddings: List[str] = [],
111
  study_filename: Optional[str] = None,
112
  ) -> tuple:
113
  """ Hyperparameter tuning and training of a PROTAC model.
@@ -125,6 +173,8 @@ def hyperparameter_tuning_and_training(
125
  Returns:
126
  tuple: The trained model, the trainer, and the best metrics.
127
  """
 
 
128
  # Define the search space
129
  hidden_dim_options = [256, 512, 768]
130
  batch_size_options = [8, 16, 32]
@@ -151,42 +201,87 @@ def hyperparameter_tuning_and_training(
151
  protein2embedding=protein2embedding,
152
  cell2embedding=cell2embedding,
153
  smiles2fp=smiles2fp,
154
- train_df=train_df,
155
- val_df=val_df,
 
156
  hidden_dim_options=hidden_dim_options,
157
  batch_size_options=batch_size_options,
158
  learning_rate_options=learning_rate_options,
159
  smote_k_neighbors_options=smote_k_neighbors_options,
160
  fast_dev_run=fast_dev_run,
161
  active_label=active_label,
162
- disabled_embeddings=disabled_embeddings,
 
163
  ),
164
  n_trials=n_trials,
165
  )
166
  if study_filename:
167
  joblib.dump(study, study_filename)
 
 
168
 
169
- # Retrain the model with the best hyperparameters
170
- model, trainer, metrics = train_model(
171
- protein2embedding=protein2embedding,
172
- cell2embedding=cell2embedding,
173
- smiles2fp=smiles2fp,
174
- train_df=train_df,
175
- val_df=val_df,
176
- test_df=test_df,
177
- use_logger=True,
178
- logger_name=logger_name,
179
- fast_dev_run=fast_dev_run,
180
- active_label=active_label,
181
- disabled_embeddings=disabled_embeddings,
182
- **study.best_params,
183
- )
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- # Report the best hyperparameters found
186
- metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # Return the best metrics
189
- return model, trainer, metrics
 
 
 
 
190
 
191
 
192
  def sklearn_model_objective(
 
21
  )
22
  from sklearn.linear_model import LogisticRegression
23
  from sklearn.svm import SVC
24
+ from sklearn.model_selection import (
25
+ StratifiedKFold,
26
+ StratifiedGroupKFold,
27
+ )
28
+ import numpy as np
29
+ import pytorch_lightning as pl
30
 
31
 
32
  def pytorch_model_objective(
 
34
  protein2embedding: Dict,
35
  cell2embedding: Dict,
36
  smiles2fp: Dict,
37
+ train_val_df: pd.DataFrame,
38
+ kf: StratifiedKFold | StratifiedGroupKFold,
39
+ groups: Optional[np.array] = None,
40
  hidden_dim_options: List[int] = [256, 512, 768],
41
  batch_size_options: List[int] = [8, 16, 32],
42
  learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
 
62
  active_label (str): The active label column.
63
  disabled_embeddings (List[str]): The list of disabled embeddings.
64
  """
65
+ # Suggest hyperparameters to be used accross the CV folds
66
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
67
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
68
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
 
72
  apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
73
  dropout = trial.suggest_float('dropout', *dropout_options)
74
 
75
+ # Start the CV over the folds
76
+ X = train_val_df.drop(columns=active_label)
77
+ y = train_val_df[active_label].tolist()
78
+ report = []
79
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
80
+ logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
81
+ # Get the train and val sets
82
+ train_df = train_val_df.iloc[train_index]
83
+ val_df = train_val_df.iloc[val_index]
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Check for data leakage and get some statistics
86
+ leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
87
+ leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
88
+ stats = {
89
+ 'model_type': 'Pytorch',
90
+ 'fold': k,
91
+ 'train_len': len(train_df),
92
+ 'val_len': len(val_df),
93
+ 'train_perc': len(train_df) / len(train_val_df),
94
+ 'val_perc': len(val_df) / len(train_val_df),
95
+ 'train_active_perc': train_df[active_label].sum() / len(train_df),
96
+ 'train_inactive_perc': (len(train_df) - train_df[active_label].sum()) / len(train_df),
97
+ 'val_active_perc': val_df[active_label].sum() / len(val_df),
98
+ 'val_inactive_perc': (len(val_df) - val_df[active_label].sum()) / len(val_df),
99
+ 'num_leaking_uniprot': len(leaking_uniprot),
100
+ 'num_leaking_smiles': len(leaking_smiles),
101
+ 'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
102
+ 'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
103
+ }
104
+ if groups is not None:
105
+ stats['train_unique_groups'] = len(np.unique(groups[train_index]))
106
+ stats['val_unique_groups'] = len(np.unique(groups[val_index]))
107
+
108
+ # At each fold, train and evaluate the Pytorch model
109
+ # Train the model with the current set of hyperparameters
110
+ _, _, metrics = train_model(
111
+ protein2embedding,
112
+ cell2embedding,
113
+ smiles2fp,
114
+ train_df,
115
+ val_df,
116
+ hidden_dim=hidden_dim,
117
+ batch_size=batch_size,
118
+ join_embeddings=join_embeddings,
119
+ learning_rate=learning_rate,
120
+ dropout=dropout,
121
+ max_epochs=max_epochs,
122
+ smote_k_neighbors=smote_k_neighbors,
123
+ apply_scaling=apply_scaling,
124
+ use_smote=use_smote,
125
+ use_logger=False,
126
+ fast_dev_run=fast_dev_run,
127
+ active_label=active_label,
128
+ disabled_embeddings=disabled_embeddings,
129
+ )
130
+ stats.update(metrics)
131
+ report.append(stats.copy())
132
+
133
+ # Get the average validation accuracy and ROC AUC accross the folds
134
+ val_acc = np.mean([r['val_acc'] for r in report])
135
+ val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
136
+
137
+ # Save the report in the trial
138
+ trial.set_user_attr('report', report)
139
 
140
  # Optuna aims to minimize the pytorch_model_objective
141
+ return - val_acc - val_roc_auc
142
 
143
 
144
  def hyperparameter_tuning_and_training(
145
  protein2embedding: Dict,
146
  cell2embedding: Dict,
147
  smiles2fp: Dict,
148
+ train_val_df: pd.DataFrame,
149
+ test_df: pd.DataFrame,
150
+ kf: StratifiedKFold | StratifiedGroupKFold,
151
+ groups: Optional[np.array] = None,
152
+ split_type: str = 'random',
153
+ n_models_for_test: int = 3,
154
  fast_dev_run: bool = False,
155
  n_trials: int = 50,
156
  logger_name: str = 'protac_hparam_search',
157
  active_label: str = 'Active',
158
+ max_epochs: int = 100,
159
  study_filename: Optional[str] = None,
160
  ) -> tuple:
161
  """ Hyperparameter tuning and training of a PROTAC model.
 
173
  Returns:
174
  tuple: The trained model, the trainer, and the best metrics.
175
  """
176
+ pl.seed_everything(42)
177
+
178
  # Define the search space
179
  hidden_dim_options = [256, 512, 768]
180
  batch_size_options = [8, 16, 32]
 
201
  protein2embedding=protein2embedding,
202
  cell2embedding=cell2embedding,
203
  smiles2fp=smiles2fp,
204
+ train_val_df=train_val_df,
205
+ kf=kf,
206
+ groups=groups,
207
  hidden_dim_options=hidden_dim_options,
208
  batch_size_options=batch_size_options,
209
  learning_rate_options=learning_rate_options,
210
  smote_k_neighbors_options=smote_k_neighbors_options,
211
  fast_dev_run=fast_dev_run,
212
  active_label=active_label,
213
+ max_epochs=max_epochs,
214
+ disabled_embeddings=[],
215
  ),
216
  n_trials=n_trials,
217
  )
218
  if study_filename:
219
  joblib.dump(study, study_filename)
220
+ cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
221
+ hparam_report = pd.DataFrame([study.best_params])
222
 
223
+ test_report = []
224
+ # Retrain N models with the best hyperparameters (measure model uncertainty)
225
+ for i in range(n_models_for_test):
226
+ pl.seed_everything(42 + i)
227
+ _, _, metrics = train_model(
228
+ protein2embedding=protein2embedding,
229
+ cell2embedding=cell2embedding,
230
+ smiles2fp=smiles2fp,
231
+ train_df=train_val_df,
232
+ val_df=test_df,
233
+ use_logger=True,
234
+ fast_dev_run=fast_dev_run,
235
+ active_label=active_label,
236
+ max_epochs=max_epochs,
237
+ disabled_embeddings=[],
238
+ logger_name=f'{logger_name}_best_model_{i}',
239
+ enable_checkpointing=True,
240
+ checkpoint_model_name=f'best_model_{split_type}_{i}',
241
+ **study.best_params,
242
+ )
243
+ # Rename the keys in the metrics dictionary
244
+ metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
245
+ metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
246
+ metrics['model_type'] = 'Pytorch'
247
+ metrics['test_model_id'] = i
248
+ test_report.append(metrics.copy())
249
+ test_report = pd.DataFrame(test_report)
250
 
251
+ # Ablation study: disable embeddings at a time
252
+ ablation_report = []
253
+ for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
254
+ logging.info('-' * 100)
255
+ logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
256
+ logging.info('-' * 100)
257
+ _, _, metrics = train_model(
258
+ protein2embedding=protein2embedding,
259
+ cell2embedding=cell2embedding,
260
+ smiles2fp=smiles2fp,
261
+ train_df=train_val_df,
262
+ val_df=test_df,
263
+ fast_dev_run=fast_dev_run,
264
+ active_label=active_label,
265
+ max_epochs=max_epochs,
266
+ use_logger=True,
267
+ logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
268
+ disabled_embeddings=disabled_embeddings,
269
+ **study.best_params,
270
+ )
271
+ # Rename the keys in the metrics dictionary
272
+ metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
273
+ metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
274
+ metrics['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
275
+ metrics['model_type'] = 'Pytorch'
276
+ ablation_report.append(metrics.copy())
277
+ ablation_report = pd.DataFrame(ablation_report)
278
 
279
+ # Add a column with the split_type to all reports
280
+ for report in [cv_report, hparam_report, test_report, ablation_report]:
281
+ report['split_type'] = split_type
282
+
283
+ # Return the reports
284
+ return cv_report, hparam_report, test_report, ablation_report
285
 
286
 
287
  def sklearn_model_objective(
protac_degradation_predictor/protac_dataset.py CHANGED
@@ -146,12 +146,15 @@ class PROTAC_Dataset(Dataset):
146
  scalers (dict): The scalers for each feature.
147
  use_single_scaler (bool): Whether to use a single scaler for all features.
148
  """
149
- if self.use_single_scaler is None:
150
- raise ValueError(
151
- "The fit_scaling method must be called before apply_scaling.")
152
- if use_single_scaler != self.use_single_scaler:
153
- raise ValueError(
154
- f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
 
 
 
155
  if use_single_scaler:
156
  embeddings = np.hstack([
157
  np.array(self.data['Smiles'].tolist()),
 
146
  scalers (dict): The scalers for each feature.
147
  use_single_scaler (bool): Whether to use a single scaler for all features.
148
  """
149
+ # TODO: The following check is WRONG: for val and test sets I must NOT
150
+ # use run the fit_scaling method, but I must use the scalers from the
151
+ # training set.
152
+ # if self.use_single_scaler is None:
153
+ # raise ValueError(
154
+ # "The fit_scaling method must be called before apply_scaling.")
155
+ # if use_single_scaler != self.use_single_scaler:
156
+ # raise ValueError(
157
+ # f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
158
  if use_single_scaler:
159
  embeddings = np.hstack([
160
  np.array(self.data['Smiles'].tolist()),
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -2,7 +2,7 @@ import warnings
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
  from .protac_dataset import PROTAC_Dataset
5
- from .config import Config
6
 
7
  import pandas as pd
8
  import numpy as np
@@ -28,10 +28,10 @@ class PROTAC_Predictor(nn.Module):
28
  def __init__(
29
  self,
30
  hidden_dim: int,
31
- smiles_emb_dim: int = Config.fingerprint_size,
32
- poi_emb_dim: int = Config.protein_embedding_size,
33
- e3_emb_dim: int = Config.protein_embedding_size,
34
- cell_emb_dim: int = Config.cell_embedding_size,
35
  dropout: float = 0.2,
36
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
37
  disabled_embeddings: list = [],
@@ -131,10 +131,10 @@ class PROTAC_Model(pl.LightningModule):
131
  def __init__(
132
  self,
133
  hidden_dim: int,
134
- smiles_emb_dim: int = 224,
135
- poi_emb_dim: int = 1024,
136
- e3_emb_dim: int = 1024,
137
- cell_emb_dim: int = 768,
138
  batch_size: int = 32,
139
  learning_rate: float = 1e-3,
140
  dropout: float = 0.2,
@@ -330,7 +330,10 @@ def train_model(
330
  learning_rate: float = 2e-5,
331
  dropout: float = 0.2,
332
  max_epochs: int = 50,
333
- smiles_emb_dim: int = 224,
 
 
 
334
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
335
  smote_k_neighbors:int = 5,
336
  use_smote: bool = True,
@@ -339,6 +342,8 @@ def train_model(
339
  fast_dev_run: bool = False,
340
  use_logger: bool = True,
341
  logger_name: str = 'protac',
 
 
342
  disabled_embeddings: List[str] = [],
343
  ) -> tuple:
344
  """ Train a PROTAC model using the given datasets and hyperparameters.
@@ -410,13 +415,14 @@ def train_model(
410
  mode='max',
411
  verbose=False,
412
  ),
413
- # pl.callbacks.ModelCheckpoint(
414
- # monitor='val_acc',
415
- # mode='max',
416
- # verbose=True,
417
- # filename='{epoch}-{val_metrics_opt_score:.4f}',
418
- # ),
419
  ]
 
 
 
 
 
 
 
420
  # Define Trainer
421
  trainer = pl.Trainer(
422
  logger=logger if use_logger else False,
@@ -424,7 +430,7 @@ def train_model(
424
  max_epochs=max_epochs,
425
  fast_dev_run=fast_dev_run,
426
  enable_model_summary=False,
427
- enable_checkpointing=False,
428
  enable_progress_bar=False,
429
  devices=1,
430
  num_nodes=1,
@@ -432,9 +438,9 @@ def train_model(
432
  model = PROTAC_Model(
433
  hidden_dim=hidden_dim,
434
  smiles_emb_dim=smiles_emb_dim,
435
- poi_emb_dim=1024,
436
- e3_emb_dim=1024,
437
- cell_emb_dim=768,
438
  batch_size=batch_size,
439
  join_embeddings=join_embeddings,
440
  dropout=dropout,
 
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
  from .protac_dataset import PROTAC_Dataset
5
+ from .config import config
6
 
7
  import pandas as pd
8
  import numpy as np
 
28
  def __init__(
29
  self,
30
  hidden_dim: int,
31
+ smiles_emb_dim: int = config.fingerprint_size,
32
+ poi_emb_dim: int = config.protein_embedding_size,
33
+ e3_emb_dim: int = config.protein_embedding_size,
34
+ cell_emb_dim: int = config.cell_embedding_size,
35
  dropout: float = 0.2,
36
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
37
  disabled_embeddings: list = [],
 
131
  def __init__(
132
  self,
133
  hidden_dim: int,
134
+ smiles_emb_dim: int = config.fingerprint_size,
135
+ poi_emb_dim: int = config.protein_embedding_size,
136
+ e3_emb_dim: int = config.protein_embedding_size,
137
+ cell_emb_dim: int = config.cell_embedding_size,
138
  batch_size: int = 32,
139
  learning_rate: float = 1e-3,
140
  dropout: float = 0.2,
 
330
  learning_rate: float = 2e-5,
331
  dropout: float = 0.2,
332
  max_epochs: int = 50,
333
+ smiles_emb_dim: int = config.fingerprint_size,
334
+ poi_emb_dim: int = config.protein_embedding_size,
335
+ e3_emb_dim: int = config.protein_embedding_size,
336
+ cell_emb_dim: int = config.cell_embedding_size,
337
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
338
  smote_k_neighbors:int = 5,
339
  use_smote: bool = True,
 
342
  fast_dev_run: bool = False,
343
  use_logger: bool = True,
344
  logger_name: str = 'protac',
345
+ enable_checkpointing: bool = False,
346
+ checkpoint_model_name: str = 'protac',
347
  disabled_embeddings: List[str] = [],
348
  ) -> tuple:
349
  """ Train a PROTAC model using the given datasets and hyperparameters.
 
415
  mode='max',
416
  verbose=False,
417
  ),
 
 
 
 
 
 
418
  ]
419
+ if enable_checkpointing:
420
+ callbacks.append(pl.callbacks.ModelCheckpoint(
421
+ monitor='val_acc',
422
+ mode='max',
423
+ verbose=False,
424
+ filename=checkpoint_model_name + '-{epoch}-{val_metrics_opt_score:.4f}',
425
+ ))
426
  # Define Trainer
427
  trainer = pl.Trainer(
428
  logger=logger if use_logger else False,
 
430
  max_epochs=max_epochs,
431
  fast_dev_run=fast_dev_run,
432
  enable_model_summary=False,
433
+ enable_checkpointing=enable_checkpointing,
434
  enable_progress_bar=False,
435
  devices=1,
436
  num_nodes=1,
 
438
  model = PROTAC_Model(
439
  hidden_dim=hidden_dim,
440
  smiles_emb_dim=smiles_emb_dim,
441
+ poi_emb_dim=poi_emb_dim,
442
+ e3_emb_dim=e3_emb_dim,
443
+ cell_emb_dim=cell_emb_dim,
444
  batch_size=batch_size,
445
  join_embeddings=join_embeddings,
446
  dropout=dropout,
src/run_experiments.py CHANGED
@@ -27,6 +27,16 @@ warnings.filterwarnings("ignore", ".*FixedLocator*")
27
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
28
 
29
 
 
 
 
 
 
 
 
 
 
 
30
  def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
31
  """ Get the indices of the test set using a random split.
32
 
@@ -263,120 +273,148 @@ def main(
263
  kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
264
  group = train_val_df['Uniprot Group'].to_numpy()
265
 
266
- # Start the CV over the folds
267
- X = train_val_df.drop(columns=active_col)
268
- y = train_val_df[active_col].tolist()
269
- for k, (train_index, val_index) in enumerate(kf.split(X, y, group)):
270
- print('-' * 100)
271
- print(f'Starting CV for group type: {split_type}, fold: {k}')
272
- print('-' * 100)
273
- train_df = train_val_df.iloc[train_index]
274
- val_df = train_val_df.iloc[val_index]
275
-
276
- leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
277
- leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
278
-
279
- stats = {
280
- 'fold': k,
281
- 'split_type': split_type,
282
- 'train_len': len(train_df),
283
- 'val_len': len(val_df),
284
- 'train_perc': len(train_df) / len(train_val_df),
285
- 'val_perc': len(val_df) / len(train_val_df),
286
- 'train_active_perc': train_df[active_col].sum() / len(train_df),
287
- 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
288
- 'val_active_perc': val_df[active_col].sum() / len(val_df),
289
- 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
290
- 'test_active_perc': test_df[active_col].sum() / len(test_df),
291
- 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
292
- 'num_leaking_uniprot': len(leaking_uniprot),
293
- 'num_leaking_smiles': len(leaking_smiles),
294
- 'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
295
- 'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
296
- }
297
- if split_type != 'random':
298
- stats['train_unique_groups'] = len(np.unique(group[train_index]))
299
- stats['val_unique_groups'] = len(np.unique(group[val_index]))
300
-
301
- # At each fold, train and evaluate the Pytorch model
302
- if split_type != 'tanimoto' or run_sklearn:
303
- logging.info(f'Skipping Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
304
- continue
305
- else:
306
- logging.info(f'Starting Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
307
- # Train and evaluate the model
308
- model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
309
- protein2embedding,
310
- cell2embedding,
311
- smiles2fp,
312
- train_df,
313
- val_df,
314
- test_df,
315
- fast_dev_run=fast_dev_run,
316
- n_trials=n_trials,
317
- logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
318
- active_label=active_col,
319
- study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
320
- )
321
- hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
322
- stats.update(metrics)
323
- stats['model_type'] = 'Pytorch'
324
- report.append(stats.copy())
325
- del model
326
- del trainer
327
-
328
- # Ablation study: disable embeddings at a time
329
- for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
330
- print('-' * 100)
331
- print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
332
- print('-' * 100)
333
- stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
334
- model, trainer, metrics = pdp.train_model(
335
- protein2embedding,
336
- cell2embedding,
337
- smiles2fp,
338
- train_df,
339
- val_df,
340
- test_df,
341
- fast_dev_run=fast_dev_run,
342
- logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
343
- active_label=active_col,
344
- disabled_embeddings=disabled_embeddings,
345
- **hparams,
346
- )
347
- stats.update(metrics)
348
- report.append(stats.copy())
349
- del model
350
- del trainer
351
-
352
- # At each fold, train and evaluate sklearn models
353
- if run_sklearn:
354
- for model_type in ['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting']:
355
- logging.info(f'Starting sklearn model {model_type} training on fold {k} with split type {split_type} and test split {test_split}.')
356
- # Train and evaluate sklearn models
357
- model, metrics = pdp.hyperparameter_tuning_and_training_sklearn(
358
- protein2embedding=protein2embedding,
359
- cell2embedding=cell2embedding,
360
- smiles2fp=smiles2fp,
361
- train_df=train_df,
362
- val_df=val_df,
363
- test_df=test_df,
364
- model_type=model_type,
365
- active_label=active_col,
366
- n_trials=n_trials,
367
- study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}_{model_type.lower()}.pkl',
368
- )
369
- hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
370
- stats['model_type'] = model_type
371
- stats.update(metrics)
372
- report.append(stats.copy())
373
-
374
- # Save the report at the end of each split type
375
- report_df = pd.DataFrame(report)
376
- report_df.to_csv(
377
- f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}{"_sklearn" if run_sklearn else ""}.csv',
378
- index=False,
379
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
 
382
  if __name__ == '__main__':
 
27
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
28
 
29
 
30
+ root = logging.getLogger()
31
+ root.setLevel(logging.DEBUG)
32
+
33
+ handler = logging.StreamHandler(sys.stdout)
34
+ handler.setLevel(logging.DEBUG)
35
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
36
+ handler.setFormatter(formatter)
37
+ root.addHandler(handler)
38
+
39
+
40
  def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
41
  """ Get the indices of the test set using a random split.
42
 
 
273
  kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
274
  group = train_val_df['Uniprot Group'].to_numpy()
275
 
276
+ # Start the experiment
277
+ experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
278
+ reports = pdp.hyperparameter_tuning_and_training(
279
+ protein2embedding=protein2embedding,
280
+ cell2embedding=cell2embedding,
281
+ smiles2fp=smiles2fp,
282
+ train_val_df=train_val_df,
283
+ test_df=test_df,
284
+ kf=kf,
285
+ groups=group,
286
+ split_type=split_type,
287
+ n_models_for_test=3,
288
+ fast_dev_run=fast_dev_run,
289
+ n_trials=n_trials,
290
+ max_epochs=10,
291
+ logger_name=f'logs_{experiment_name}',
292
+ active_label=active_col,
293
+ study_filename=f'../reports/study_{experiment_name}.pkl',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
+ cv_report, hparam_report, test_report, ablation_report = reports
296
+
297
+ # Save the reports to file
298
+ for report, filename in zip([cv_report, hparam_report, test_report, ablation_report], ['cv_train', 'hparams', 'test', 'ablation']):
299
+ report.to_csv(f'../reports/report_{filename}_{experiment_name}.csv', index=False)
300
+
301
+
302
+
303
+
304
+ # # Start the CV over the folds
305
+ # X = train_val_df.drop(columns=active_col)
306
+ # y = train_val_df[active_col].tolist()
307
+ # for k, (train_index, val_index) in enumerate(kf.split(X, y, group)):
308
+ # print('-' * 100)
309
+ # print(f'Starting CV for group type: {split_type}, fold: {k}')
310
+ # print('-' * 100)
311
+ # train_df = train_val_df.iloc[train_index]
312
+ # val_df = train_val_df.iloc[val_index]
313
+
314
+ # leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
315
+ # leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
316
+
317
+ # stats = {
318
+ # 'fold': k,
319
+ # 'split_type': split_type,
320
+ # 'train_len': len(train_df),
321
+ # 'val_len': len(val_df),
322
+ # 'train_perc': len(train_df) / len(train_val_df),
323
+ # 'val_perc': len(val_df) / len(train_val_df),
324
+ # 'train_active_perc': train_df[active_col].sum() / len(train_df),
325
+ # 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
326
+ # 'val_active_perc': val_df[active_col].sum() / len(val_df),
327
+ # 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
328
+ # 'test_active_perc': test_df[active_col].sum() / len(test_df),
329
+ # 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
330
+ # 'num_leaking_uniprot': len(leaking_uniprot),
331
+ # 'num_leaking_smiles': len(leaking_smiles),
332
+ # 'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
333
+ # 'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
334
+ # }
335
+ # if split_type != 'random':
336
+ # stats['train_unique_groups'] = len(np.unique(group[train_index]))
337
+ # stats['val_unique_groups'] = len(np.unique(group[val_index]))
338
+
339
+ # # At each fold, train and evaluate the Pytorch model
340
+ # if split_type != 'tanimoto' or run_sklearn:
341
+ # logging.info(f'Skipping Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
342
+ # continue
343
+ # else:
344
+ # logging.info(f'Starting Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
345
+ # # Train and evaluate the model
346
+ # model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
347
+ # protein2embedding,
348
+ # cell2embedding,
349
+ # smiles2fp,
350
+ # train_df,
351
+ # val_df,
352
+ # test_df,
353
+ # fast_dev_run=fast_dev_run,
354
+ # n_trials=n_trials,
355
+ # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
356
+ # active_label=active_col,
357
+ # study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
358
+ # )
359
+ # hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
360
+ # stats.update(metrics)
361
+ # stats['model_type'] = 'Pytorch'
362
+ # report.append(stats.copy())
363
+ # del model
364
+ # del trainer
365
+
366
+ # # Ablation study: disable embeddings at a time
367
+ # for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
368
+ # print('-' * 100)
369
+ # print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
370
+ # print('-' * 100)
371
+ # stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
372
+ # model, trainer, metrics = pdp.train_model(
373
+ # protein2embedding,
374
+ # cell2embedding,
375
+ # smiles2fp,
376
+ # train_df,
377
+ # val_df,
378
+ # test_df,
379
+ # fast_dev_run=fast_dev_run,
380
+ # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
381
+ # active_label=active_col,
382
+ # disabled_embeddings=disabled_embeddings,
383
+ # **hparams,
384
+ # )
385
+ # stats.update(metrics)
386
+ # report.append(stats.copy())
387
+ # del model
388
+ # del trainer
389
+
390
+ # # At each fold, train and evaluate sklearn models
391
+ # if run_sklearn:
392
+ # for model_type in ['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting']:
393
+ # logging.info(f'Starting sklearn model {model_type} training on fold {k} with split type {split_type} and test split {test_split}.')
394
+ # # Train and evaluate sklearn models
395
+ # model, metrics = pdp.hyperparameter_tuning_and_training_sklearn(
396
+ # protein2embedding=protein2embedding,
397
+ # cell2embedding=cell2embedding,
398
+ # smiles2fp=smiles2fp,
399
+ # train_df=train_df,
400
+ # val_df=val_df,
401
+ # test_df=test_df,
402
+ # model_type=model_type,
403
+ # active_label=active_col,
404
+ # n_trials=n_trials,
405
+ # study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}_{model_type.lower()}.pkl',
406
+ # )
407
+ # hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
408
+ # stats['model_type'] = model_type
409
+ # stats.update(metrics)
410
+ # report.append(stats.copy())
411
+
412
+ # # Save the report at the end of each split type
413
+ # report_df = pd.DataFrame(report)
414
+ # report_df.to_csv(
415
+ # f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}{"_sklearn" if run_sklearn else ""}.csv',
416
+ # index=False,
417
+ # )
418
 
419
 
420
  if __name__ == '__main__':