ribesstefano commited on
Commit
ccc40da
·
1 Parent(s): e1370eb

Implemented cell lines one-hot and amino acid sequence count experiments

Browse files
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -117,6 +117,8 @@ def pytorch_model_objective(
117
  logger_save_dir: str = 'logs',
118
  logger_name: str = 'cv_model',
119
  enable_checkpointing: bool = False,
 
 
120
  ) -> float:
121
  """ Objective function for hyperparameter optimization.
122
 
@@ -139,6 +141,8 @@ def pytorch_model_objective(
139
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
140
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
141
  use_smote = trial.suggest_categorical('use_smote', [True, False])
 
 
142
  apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
143
  dropout = trial.suggest_float('dropout', *dropout_options)
144
  use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
@@ -252,6 +256,8 @@ def hyperparameter_tuning_and_training(
252
  max_epochs: int = 100,
253
  study_filename: Optional[str] = None,
254
  force_study: bool = False,
 
 
255
  ) -> tuple:
256
  """ Hyperparameter tuning and training of a PROTAC model.
257
 
@@ -263,7 +269,7 @@ def hyperparameter_tuning_and_training(
263
  test_df (pd.DataFrame): The test set.
264
  kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
265
  groups (np.array): The groups for the StratifiedGroupKFold.
266
- split_type (str): The split type.
267
  n_models_for_test (int): The number of models to train for the test set.
268
  fast_dev_run (bool): Whether to run a fast development run.
269
  n_trials (int): The number of trials for the hyperparameter search.
@@ -322,6 +328,8 @@ def hyperparameter_tuning_and_training(
322
  active_label=active_label,
323
  max_epochs=max_epochs,
324
  disabled_embeddings=[],
 
 
325
  ),
326
  n_trials=n_trials,
327
  )
@@ -354,6 +362,8 @@ def hyperparameter_tuning_and_training(
354
  logger_save_dir=logger_save_dir,
355
  logger_name=f'{logger_name}_{split_type}_cv_model',
356
  enable_checkpointing=True,
 
 
357
  )
358
 
359
  # Retrain N models with the best hyperparameters (measure model uncertainty)
 
117
  logger_save_dir: str = 'logs',
118
  logger_name: str = 'cv_model',
119
  enable_checkpointing: bool = False,
120
+ use_cells_one_hot: bool = False,
121
+ use_amino_acid_count: bool = False,
122
  ) -> float:
123
  """ Objective function for hyperparameter optimization.
124
 
 
141
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
142
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
143
  use_smote = trial.suggest_categorical('use_smote', [True, False])
144
+ if use_cells_one_hot or use_amino_acid_count:
145
+ use_smote = False
146
  apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
147
  dropout = trial.suggest_float('dropout', *dropout_options)
148
  use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
 
256
  max_epochs: int = 100,
257
  study_filename: Optional[str] = None,
258
  force_study: bool = False,
259
+ use_cells_one_hot: bool = False,
260
+ use_amino_acid_count: bool = False,
261
  ) -> tuple:
262
  """ Hyperparameter tuning and training of a PROTAC model.
263
 
 
269
  test_df (pd.DataFrame): The test set.
270
  kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
271
  groups (np.array): The groups for the StratifiedGroupKFold.
272
+ split_type (str): The split type of the current study. Used for reporting.
273
  n_models_for_test (int): The number of models to train for the test set.
274
  fast_dev_run (bool): Whether to run a fast development run.
275
  n_trials (int): The number of trials for the hyperparameter search.
 
328
  active_label=active_label,
329
  max_epochs=max_epochs,
330
  disabled_embeddings=[],
331
+ use_cells_one_hot=use_cells_one_hot,
332
+ use_amino_acid_count=use_amino_acid_count,
333
  ),
334
  n_trials=n_trials,
335
  )
 
362
  logger_save_dir=logger_save_dir,
363
  logger_name=f'{logger_name}_{split_type}_cv_model',
364
  enable_checkpointing=True,
365
+ use_cells_one_hot=use_cells_one_hot,
366
+ use_amino_acid_count=use_amino_acid_count,
367
  )
368
 
369
  # Retrain N models with the best hyperparameters (measure model uncertainty)
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -23,8 +23,7 @@ from torchmetrics import (
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
26
- from sklearn.preprocessing import StandardScaler, OneHotEncoder
27
- from sklearn.feature_extraction.text import CountVectorizer
28
 
29
 
30
  class PROTAC_Predictor(nn.Module):
@@ -429,8 +428,6 @@ def train_model(
429
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
430
  return_predictions: bool = False,
431
  shuffle_embedding_prob: float = 0.0,
432
- use_cells_one_hot: bool = False,
433
- use_amino_acid_count: bool = False,
434
  ) -> tuple:
435
  """ Train a PROTAC model using the given datasets and hyperparameters.
436
 
@@ -464,25 +461,6 @@ def train_model(
464
  Returns:
465
  tuple: The trained model, the trainer, and the metrics over the validation and test sets.
466
  """
467
- if use_cells_one_hot:
468
- # Get one-hot encoded embeddings for cell lines
469
- onehotenc = OneHotEncoder(sparse_output=False)
470
- cell_embeddings = onehotenc.fit_transform(
471
- np.array(list(cell2embedding.keys()))
472
- )
473
- cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
474
-
475
- if use_amino_acid_count:
476
- # Get count vectorized embeddings for proteins
477
- # NOTE: Check that the protein2embedding is a dictionary of strings
478
- if not all(isinstance(k, str) for k in protein2embedding.keys()):
479
- raise ValueError("All keys in `protein2embedding` must be strings.")
480
- countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
481
- protein_embeddings = countvec.fit_transform(
482
- list(protein2embedding.keys())
483
- )
484
- protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
485
-
486
  train_ds, val_ds, test_ds = get_datasets(
487
  train_df,
488
  val_df,
 
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
26
+ from sklearn.preprocessing import StandardScaler
 
27
 
28
 
29
  class PROTAC_Predictor(nn.Module):
 
428
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
429
  return_predictions: bool = False,
430
  shuffle_embedding_prob: float = 0.0,
 
 
431
  ) -> tuple:
432
  """ Train a PROTAC model using the given datasets and hyperparameters.
433
 
 
461
  Returns:
462
  tuple: The trained model, the trainer, and the metrics over the validation and test sets.
463
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  train_ds, val_ds, test_ds = get_datasets(
465
  train_df,
466
  val_df,
src/run_experiments_aminoacid_counts.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import defaultdict
4
+ import warnings
5
+ import logging
6
+ from typing import Literal
7
+
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ import protac_degradation_predictor as pdp
11
+
12
+ import pytorch_lightning as pl
13
+ from rdkit import Chem
14
+ from rdkit.Chem import AllChem
15
+ from rdkit import DataStructs
16
+ from jsonargparse import CLI
17
+ import pandas as pd
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ from sklearn.model_selection import (
21
+ StratifiedKFold,
22
+ StratifiedGroupKFold,
23
+ )
24
+ from sklearn.feature_extraction.text import CountVectorizer
25
+
26
+ # Ignore UserWarning from Matplotlib
27
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
28
+ # Ignore UserWarning from PyTorch Lightning
29
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
30
+
31
+ root = logging.getLogger()
32
+ root.setLevel(logging.DEBUG)
33
+
34
+ handler = logging.StreamHandler(sys.stdout)
35
+ handler.setLevel(logging.DEBUG)
36
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
37
+ handler.setFormatter(formatter)
38
+ root.addHandler(handler)
39
+
40
+ def main(
41
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
42
+ n_trials: int = 100,
43
+ fast_dev_run: bool = False,
44
+ test_split: float = 0.1,
45
+ cv_n_splits: int = 5,
46
+ max_epochs: int = 100,
47
+ force_study: bool = False,
48
+ experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
49
+ ):
50
+ """ Run experiments with the cells one-hot encoding model.
51
+
52
+ Args:
53
+ active_col (str): Name of the column containing the active values.
54
+ n_trials (int): Number of hyperparameter optimization trials.
55
+ fast_dev_run (bool): Whether to run a fast development run.
56
+ test_split (float): Percentage of data to use for testing.
57
+ cv_n_splits (int): Number of cross-validation splits.
58
+ max_epochs (int): Maximum number of epochs to train the model.
59
+ force_study (bool): Whether to force the creation of a new study.
60
+ experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
61
+ """
62
+
63
+ # Make directory ../reports if it does not exist
64
+ if not os.path.exists('../reports'):
65
+ os.makedirs('../reports')
66
+
67
+ # Load embedding dictionaries
68
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
69
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
70
+
71
+ # Create a new protein2embedding dictionary with amino acid sequence
72
+ protac_df = pdp.load_curated_dataset()
73
+ # Create the dictionary mapping 'Uniprot' to 'POI Sequence'
74
+ protein2embedding = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
75
+ # Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
76
+ e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
77
+ # Merge the two dictionaries into a new protein2embedding dictionary
78
+ protein2embedding.update(e32seq)
79
+
80
+ # Get count vectorized embeddings for proteins
81
+ # NOTE: Check that the protein2embedding is a dictionary of strings
82
+ if not all(isinstance(k, str) for k in protein2embedding.keys()):
83
+ raise ValueError("All keys in `protein2embedding` must be strings.")
84
+ countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
85
+ protein_embeddings = countvec.fit_transform(
86
+ list(protein2embedding.keys())
87
+ ).toarray()
88
+ protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
89
+
90
+ studies_dir = '../data/studies'
91
+ train_val_perc = f'{int((1 - test_split) * 100)}'
92
+ test_perc = f'{int(test_split * 100)}'
93
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
94
+
95
+ if experiments == 'all':
96
+ experiments = ['standard', 'similarity', 'target']
97
+
98
+ # Cross-Validation Training
99
+ reports = defaultdict(list)
100
+ for split_type in experiments:
101
+
102
+ train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
103
+ test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
104
+
105
+ train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
106
+ test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
107
+
108
+ # Get SMILES and precompute fingerprints dictionary
109
+ unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
110
+ smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
111
+
112
+ # Get the CV object
113
+ if split_type == 'standard':
114
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
115
+ group = None
116
+ elif split_type == 'e3_ligase':
117
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
118
+ group = train_val_df['E3 Group'].to_numpy()
119
+ elif split_type == 'similarity':
120
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
121
+ group = train_val_df['Tanimoto Group'].to_numpy()
122
+ elif split_type == 'target':
123
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
124
+ group = train_val_df['Uniprot Group'].to_numpy()
125
+
126
+ # Start the experiment
127
+ experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
128
+ optuna_reports = pdp.hyperparameter_tuning_and_training(
129
+ protein2embedding=protein2embedding,
130
+ cell2embedding=cell2embedding,
131
+ smiles2fp=smiles2fp,
132
+ train_val_df=train_val_df,
133
+ test_df=test_df,
134
+ kf=kf,
135
+ groups=group,
136
+ split_type=split_type,
137
+ n_models_for_test=3,
138
+ fast_dev_run=fast_dev_run,
139
+ n_trials=n_trials,
140
+ max_epochs=max_epochs,
141
+ logger_save_dir='../logs',
142
+ logger_name=f'logs_{experiment_name}',
143
+ active_label=active_col,
144
+ study_filename=f'../reports/study_aminoacidcnt_{experiment_name}.pkl',
145
+ force_study=force_study,
146
+ use_amino_acid_count=True,
147
+ )
148
+
149
+ # Save the reports to file
150
+ for report_name, report in optuna_reports.items():
151
+ report.to_csv(f'../reports/aminoacidcnt_{report_name}_{experiment_name}.csv', index=False)
152
+ reports[report_name].append(report.copy())
153
+
154
+
155
+ if __name__ == '__main__':
156
+ cli = CLI(main)
src/run_experiments_cells_onehot.py CHANGED
@@ -22,6 +22,7 @@ from sklearn.model_selection import (
22
  StratifiedKFold,
23
  StratifiedGroupKFold,
24
  )
 
25
 
26
  # Ignore UserWarning from Matplotlib
27
  warnings.filterwarnings("ignore", ".*FixedLocator*")
@@ -69,6 +70,13 @@ def main(
69
  protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
70
  cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
71
 
 
 
 
 
 
 
 
72
  studies_dir = '../data/studies'
73
  train_val_perc = f'{int((1 - test_split) * 100)}'
74
  test_perc = f'{int(test_split * 100)}'
@@ -125,6 +133,7 @@ def main(
125
  active_label=active_col,
126
  study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
127
  force_study=force_study,
 
128
  )
129
 
130
  # Save the reports to file
 
22
  StratifiedKFold,
23
  StratifiedGroupKFold,
24
  )
25
+ from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
26
 
27
  # Ignore UserWarning from Matplotlib
28
  warnings.filterwarnings("ignore", ".*FixedLocator*")
 
70
  protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
71
  cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
72
 
73
+ # Get one-hot encoded embeddings for cell lines
74
+ onehotenc = OneHotEncoder(sparse_output=False)
75
+ cell_embeddings = onehotenc.fit_transform(
76
+ np.array(list(cell2embedding.keys())).reshape(-1, 1)
77
+ )
78
+ cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
79
+
80
  studies_dir = '../data/studies'
81
  train_val_perc = f'{int((1 - test_split) * 100)}'
82
  test_perc = f'{int(test_split * 100)}'
 
133
  active_label=active_col,
134
  study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
135
  force_study=force_study,
136
+ use_cells_one_hot=True,
137
  )
138
 
139
  # Save the reports to file