ribesstefano commited on
Commit
8aec0bb
·
1 Parent(s): 1ee75b1

Started working on cell line one-hot encoding experiments

Browse files
protac_degradation_predictor/protac_dataset.py CHANGED
@@ -1,5 +1,7 @@
1
  from typing import Literal, List, Tuple, Optional, Dict
2
  from collections import defaultdict
 
 
3
 
4
  from .data_utils import (
5
  get_fingerprint,
@@ -24,15 +26,16 @@ class PROTAC_Dataset(Dataset):
24
  def __init__(
25
  self,
26
  protac_df: pd.DataFrame,
27
- protein2embedding: Dict,
28
- cell2embedding: Dict,
29
- smiles2fp: Dict,
30
  use_smote: bool = False,
31
  oversampler: Optional[SMOTE | ADASYN] = None,
32
  active_label: str = 'Active',
33
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
34
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
35
  use_single_scaler: Optional[bool] = None,
 
36
  ):
37
  """ Initialize the PROTAC dataset
38
 
@@ -47,6 +50,7 @@ class PROTAC_Dataset(Dataset):
47
  disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
48
  scaler (StandardScaler | dict): The scaler to use for the embeddings
49
  use_single_scaler (bool): Whether to use a single scaler for all features
 
50
  """
51
  # Filter out examples with NaN in active_label column
52
  self.data = protac_df # [~protac_df[active_label].isna()]
@@ -84,6 +88,22 @@ class PROTAC_Dataset(Dataset):
84
  self.oversampler = oversampler
85
  if self.use_smote:
86
  self.apply_smote()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def apply_smote(self):
89
  # Prepare the dataset for SMOTE
@@ -269,6 +289,17 @@ class PROTAC_Dataset(Dataset):
269
  else:
270
  cell_emb = self.data['Cell Line Identifier'].iloc[idx]
271
 
 
 
 
 
 
 
 
 
 
 
 
272
  elem = {
273
  'smiles_emb': smiles_emb,
274
  'poi_emb': poi_emb,
@@ -293,6 +324,7 @@ def get_datasets(
293
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
294
  use_single_scaler: Optional[bool] = None,
295
  apply_scaling: bool = False,
 
296
  ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
297
  """ Get the datasets for training the PROTAC model.
298
 
@@ -323,6 +355,7 @@ def get_datasets(
323
  disabled_embeddings=disabled_embeddings,
324
  scaler=scaler,
325
  use_single_scaler=use_single_scaler,
 
326
  )
327
  val_ds = PROTAC_Dataset(
328
  val_df,
 
1
  from typing import Literal, List, Tuple, Optional, Dict
2
  from collections import defaultdict
3
+ import random
4
+ import logging
5
 
6
  from .data_utils import (
7
  get_fingerprint,
 
26
  def __init__(
27
  self,
28
  protac_df: pd.DataFrame,
29
+ protein2embedding: Dict[str, np.ndarray],
30
+ cell2embedding: Dict[str, np.ndarray],
31
+ smiles2fp: Dict[str, np.ndarray],
32
  use_smote: bool = False,
33
  oversampler: Optional[SMOTE | ADASYN] = None,
34
  active_label: str = 'Active',
35
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
36
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
37
  use_single_scaler: Optional[bool] = None,
38
+ shuffle_embedding_prob: float = 0.0,
39
  ):
40
  """ Initialize the PROTAC dataset
41
 
 
50
  disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
51
  scaler (StandardScaler | dict): The scaler to use for the embeddings
52
  use_single_scaler (bool): Whether to use a single scaler for all features
53
+ shuffle_embedding_prob (float): The probability of shuffling the embeddings. Used for testing whether embeddings act as "barcodes". Defaults to 0.0, i.e., no shuffling.
54
  """
55
  # Filter out examples with NaN in active_label column
56
  self.data = protac_df # [~protac_df[active_label].isna()]
 
88
  self.oversampler = oversampler
89
  if self.use_smote:
90
  self.apply_smote()
91
+
92
+ if shuffle_embedding_prob > 0.0:
93
+ self.shuffle_embedding_prob = shuffle_embedding_prob
94
+ # Set random seed
95
+ random.seed(42)
96
+ if self.protein_emb_dim != self.cell_emb_dim:
97
+ logging.warning('Protein and cell embeddings have different dimensions. Shuffling will be on POI and E3 embeddings only.')
98
+
99
+ def get_smiles_emb_dim(self):
100
+ return self.smiles_emb_dim
101
+
102
+ def get_protein_emb_dim(self):
103
+ return self.protein_emb_dim
104
+
105
+ def get_cell_emb_dim(self):
106
+ return self.cell_emb_dim
107
 
108
  def apply_smote(self):
109
  # Prepare the dataset for SMOTE
 
289
  else:
290
  cell_emb = self.data['Cell Line Identifier'].iloc[idx]
291
 
292
+ # Shuffle the embeddings if the probability is met
293
+ if random.random() < self.shuffle_embedding_prob:
294
+ if self.protein_emb_dim == self.cell_emb_dim:
295
+ # Randomly shuffle the embeddings for POI, cell, and E3
296
+ embeddings = np.vstack([poi_emb, e3_emb, cell_emb])
297
+ np.random.shuffle(embeddings)
298
+ poi_emb, e3_emb, cell_emb = embeddings
299
+ else:
300
+ # Swap POI and E3 embeddings only, because of different dimensions
301
+ poi_emb, e3_emb = e3_emb, poi_emb
302
+
303
  elem = {
304
  'smiles_emb': smiles_emb,
305
  'poi_emb': poi_emb,
 
324
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
325
  use_single_scaler: Optional[bool] = None,
326
  apply_scaling: bool = False,
327
+ shuffle_embedding_prob: float = 0.0,
328
  ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
329
  """ Get the datasets for training the PROTAC model.
330
 
 
355
  disabled_embeddings=disabled_embeddings,
356
  scaler=scaler,
357
  use_single_scaler=use_single_scaler,
358
+ shuffle_embedding_prob=shuffle_embedding_prob,
359
  )
360
  val_ds = PROTAC_Dataset(
361
  val_df,
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -23,7 +23,8 @@ from torchmetrics import (
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
26
- from sklearn.preprocessing import StandardScaler
 
27
 
28
 
29
  class PROTAC_Predictor(nn.Module):
@@ -402,9 +403,9 @@ class PROTAC_Model(pl.LightningModule):
402
 
403
  # TODO: Use some sort of **kwargs to pass all the parameters to the model...
404
  def train_model(
405
- protein2embedding: Dict,
406
- cell2embedding: Dict,
407
- smiles2fp: Dict,
408
  train_df: pd.DataFrame,
409
  val_df: pd.DataFrame,
410
  test_df: Optional[pd.DataFrame] = None,
@@ -414,10 +415,6 @@ def train_model(
414
  dropout: float = 0.2,
415
  max_epochs: int = 50,
416
  use_batch_norm: bool = False,
417
- smiles_emb_dim: int = config.fingerprint_size,
418
- poi_emb_dim: int = config.protein_embedding_size,
419
- e3_emb_dim: int = config.protein_embedding_size,
420
- cell_emb_dim: int = config.cell_embedding_size,
421
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
422
  smote_k_neighbors:int = 5,
423
  use_smote: bool = True,
@@ -431,29 +428,61 @@ def train_model(
431
  checkpoint_model_name: str = 'protac',
432
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
433
  return_predictions: bool = False,
 
 
 
434
  ) -> tuple:
435
  """ Train a PROTAC model using the given datasets and hyperparameters.
436
 
437
  Args:
438
- protein2embedding (dict): Dictionary of protein embeddings.
439
- cell2embedding (dict): Dictionary of cell line embeddings.
440
- smiles2fp (dict): Dictionary of SMILES to fingerprint.
441
- train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
442
- val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
443
- test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
444
- hidden_dim (int): The hidden dimension of the model.
445
- batch_size (int): The batch size.
446
- learning_rate (float): The learning rate.
447
- max_epochs (int): The maximum number of epochs.
448
- smiles_emb_dim (int): The dimension of the SMILES embeddings.
449
- smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
450
- fast_dev_run (bool): Whether to run a fast development run.
451
- disabled_embeddings (list): The list of disabled embeddings.
452
- return_predictions (bool): Whether to return the predictions after the model, trainer, and metrics.
 
 
 
 
 
 
 
 
 
 
453
 
454
  Returns:
455
  tuple: The trained model, the trainer, and the metrics over the validation and test sets.
456
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  train_ds, val_ds, test_ds = get_datasets(
458
  train_df,
459
  val_df,
@@ -465,7 +494,14 @@ def train_model(
465
  smote_k_neighbors=smote_k_neighbors,
466
  active_label=active_label,
467
  disabled_embeddings=disabled_embeddings,
 
468
  )
 
 
 
 
 
 
469
  loggers = [
470
  pl.loggers.TensorBoardLogger(
471
  save_dir=logger_save_dir,
 
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
26
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
27
+ from sklearn.feature_extraction.text import CountVectorizer
28
 
29
 
30
  class PROTAC_Predictor(nn.Module):
 
403
 
404
  # TODO: Use some sort of **kwargs to pass all the parameters to the model...
405
  def train_model(
406
+ protein2embedding: Dict[str, np.ndarray],
407
+ cell2embedding: Dict[str, np.ndarray],
408
+ smiles2fp: Dict[str, np.ndarray],
409
  train_df: pd.DataFrame,
410
  val_df: pd.DataFrame,
411
  test_df: Optional[pd.DataFrame] = None,
 
415
  dropout: float = 0.2,
416
  max_epochs: int = 50,
417
  use_batch_norm: bool = False,
 
 
 
 
418
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
419
  smote_k_neighbors:int = 5,
420
  use_smote: bool = True,
 
428
  checkpoint_model_name: str = 'protac',
429
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
430
  return_predictions: bool = False,
431
+ shuffle_embedding_prob: float = 0.0,
432
+ use_cells_one_hot: bool = False,
433
+ use_amino_acid_count: bool = False,
434
  ) -> tuple:
435
  """ Train a PROTAC model using the given datasets and hyperparameters.
436
 
437
  Args:
438
+ protein2embedding (dict): A dictionary mapping protein identifiers to embeddings.
439
+ cell2embedding (dict): A dictionary mapping cell line identifiers to embeddings.
440
+ smiles2fp (dict): A dictionary mapping SMILES strings to fingerprints.
441
+ train_df (pd.DataFrame): The training dataframe.
442
+ val_df (pd.DataFrame): The validation dataframe.
443
+ test_df (Optional[pd.DataFrame]): The test dataframe.
444
+ hidden_dim (int): The hidden dimension of the model
445
+ batch_size (int): The batch size
446
+ learning_rate (float): The learning rate
447
+ dropout (float): The dropout rate
448
+ max_epochs (int): The maximum number of epochs
449
+ use_batch_norm (bool): Whether to use batch normalization
450
+ join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
451
+ smote_k_neighbors (int): The number of neighbors to use in SMOTE
452
+ use_smote (bool): Whether to use SMOTE
453
+ apply_scaling (bool): Whether to apply scaling to the embeddings
454
+ active_label (str): The name of the active label. Default: 'Active'
455
+ fast_dev_run (bool): Whether to run a fast development run (see PyTorch Lightning documentation)
456
+ use_logger (bool): Whether to use a logger
457
+ logger_save_dir (str): The directory to save the logs
458
+ logger_name (str): The name of the logger
459
+ enable_checkpointing (bool): Whether to enable checkpointing
460
+ checkpoint_model_name (str): The name of the model for checkpointing
461
+ disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
462
+ return_predictions (bool): Whether to return predictions on the validation and test sets
463
 
464
  Returns:
465
  tuple: The trained model, the trainer, and the metrics over the validation and test sets.
466
  """
467
+ if use_cells_one_hot:
468
+ # Get one-hot encoded embeddings for cell lines
469
+ onehotenc = OneHotEncoder(sparse_output=False)
470
+ cell_embeddings = onehotenc.fit_transform(
471
+ np.array(list(cell2embedding.keys()))
472
+ )
473
+ cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
474
+
475
+ if use_amino_acid_count:
476
+ # Get count vectorized embeddings for proteins
477
+ # NOTE: Check that the protein2embedding is a dictionary of strings
478
+ if not all(isinstance(k, str) for k in protein2embedding.keys()):
479
+ raise ValueError("All keys in `protein2embedding` must be strings.")
480
+ countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
481
+ protein_embeddings = countvec.fit_transform(
482
+ list(protein2embedding.keys())
483
+ )
484
+ protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
485
+
486
  train_ds, val_ds, test_ds = get_datasets(
487
  train_df,
488
  val_df,
 
494
  smote_k_neighbors=smote_k_neighbors,
495
  active_label=active_label,
496
  disabled_embeddings=disabled_embeddings,
497
+ shuffle_embedding_prob=shuffle_embedding_prob,
498
  )
499
+ # NOTE: The embeddings dimensions should already match in all sets
500
+ smiles_emb_dim = train_ds.get_smiles_emb_dim()
501
+ poi_emb_dim = train_ds.get_protein_emb_dim()
502
+ e3_emb_dim = train_ds.get_protein_emb_dim()
503
+ cell_emb_dim = train_ds.get_cell_emb_dim()
504
+
505
  loggers = [
506
  pl.loggers.TensorBoardLogger(
507
  save_dir=logger_save_dir,
src/run_experiments.py CHANGED
@@ -238,10 +238,15 @@ def main(
238
  """ Train a PROTAC model using the given datasets and hyperparameters.
239
 
240
  Args:
241
- use_ored_activity (bool): Whether to use the 'Active - OR' column.
242
- n_trials (int): The number of hyperparameter optimization trials.
243
- n_splits (int): The number of cross-validation splits.
244
  fast_dev_run (bool): Whether to run a fast development run.
 
 
 
 
 
 
245
  """
246
  pl.seed_everything(42)
247
 
 
238
  """ Train a PROTAC model using the given datasets and hyperparameters.
239
 
240
  Args:
241
+ active_col (str): The column containing the active/inactive information. Must be in the format 'Active (Dmax N, pDC50 M)'.
242
+ n_trials (int): The number of hyperparameter tuning trials to run.
 
243
  fast_dev_run (bool): Whether to run a fast development run.
244
+ test_split (float): The percentage of the active PROTACs to use as the test set.
245
+ cv_n_splits (int): The number of cross-validation splits to use.
246
+ max_epochs (int): The maximum number of epochs to train the model.
247
+ run_sklearn (bool): Whether to run sklearn models.
248
+ force_study (bool): Whether to force the creation of a new Optuna study.
249
+ experiments (str): The type of experiments to run.
250
  """
251
  pl.seed_everything(42)
252
 
src/run_experiments_cells_onehot.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import defaultdict
4
+ import warnings
5
+ import logging
6
+ from typing import Literal
7
+
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ import protac_degradation_predictor as pdp
11
+
12
+ import pytorch_lightning as pl
13
+ from rdkit import Chem
14
+ from rdkit.Chem import AllChem
15
+ from rdkit import DataStructs
16
+ from jsonargparse import CLI
17
+ import pandas as pd
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ from sklearn.preprocessing import OrdinalEncoder
21
+ from sklearn.model_selection import (
22
+ StratifiedKFold,
23
+ StratifiedGroupKFold,
24
+ )
25
+
26
+ # Ignore UserWarning from Matplotlib
27
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
28
+ # Ignore UserWarning from PyTorch Lightning
29
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
30
+
31
+
32
+ root = logging.getLogger()
33
+ root.setLevel(logging.DEBUG)
34
+
35
+ handler = logging.StreamHandler(sys.stdout)
36
+ handler.setLevel(logging.DEBUG)
37
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38
+ handler.setFormatter(formatter)
39
+ root.addHandler(handler)
40
+
41
+ def main(
42
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
43
+ n_trials: int = 100,
44
+ fast_dev_run: bool = False,
45
+ test_split: float = 0.1,
46
+ cv_n_splits: int = 5,
47
+ max_epochs: int = 100,
48
+ force_study: bool = False,
49
+ experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
50
+ ):
51
+ """ Run experiments with the cells one-hot encoding model.
52
+
53
+ Args:
54
+ active_col (str): Name of the column containing the active values.
55
+ n_trials (int): Number of hyperparameter optimization trials.
56
+ fast_dev_run (bool): Whether to run a fast development run.
57
+ test_split (float): Percentage of data to use for testing.
58
+ cv_n_splits (int): Number of cross-validation splits.
59
+ max_epochs (int): Maximum number of epochs to train the model.
60
+ force_study (bool): Whether to force the creation of a new study.
61
+ experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
62
+ """
63
+
64
+ # Make directory ../reports if it does not exist
65
+ if not os.path.exists('../reports'):
66
+ os.makedirs('../reports')
67
+
68
+ # Load embedding dictionaries
69
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
70
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
71
+
72
+ studies_dir = '../data/studies'
73
+ train_val_perc = f'{int((1 - test_split) * 100)}'
74
+ test_perc = f'{int(test_split * 100)}'
75
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
76
+
77
+ if experiments == 'all':
78
+ experiments = ['standard', 'similarity', 'target']
79
+
80
+ # Cross-Validation Training
81
+ reports = defaultdict(list)
82
+ for split_type in experiments:
83
+
84
+ train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
85
+ test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
86
+
87
+ train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
88
+ test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
89
+
90
+ # Get SMILES and precompute fingerprints dictionary
91
+ unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
92
+ smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
93
+
94
+ # Get the CV object
95
+ if split_type == 'standard':
96
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
97
+ group = None
98
+ elif split_type == 'e3_ligase':
99
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
100
+ group = train_val_df['E3 Group'].to_numpy()
101
+ elif split_type == 'similarity':
102
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
103
+ group = train_val_df['Tanimoto Group'].to_numpy()
104
+ elif split_type == 'target':
105
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
106
+ group = train_val_df['Uniprot Group'].to_numpy()
107
+
108
+ # Start the experiment
109
+ experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
110
+ optuna_reports = pdp.hyperparameter_tuning_and_training(
111
+ protein2embedding=protein2embedding,
112
+ cell2embedding=cell2embedding,
113
+ smiles2fp=smiles2fp,
114
+ train_val_df=train_val_df,
115
+ test_df=test_df,
116
+ kf=kf,
117
+ groups=group,
118
+ split_type=split_type,
119
+ n_models_for_test=3,
120
+ fast_dev_run=fast_dev_run,
121
+ n_trials=n_trials,
122
+ max_epochs=max_epochs,
123
+ logger_save_dir='../logs',
124
+ logger_name=f'logs_{experiment_name}',
125
+ active_label=active_col,
126
+ study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
127
+ force_study=force_study,
128
+ )
129
+
130
+ # Save the reports to file
131
+ for report_name, report in optuna_reports.items():
132
+ report.to_csv(f'../reports/cellsonehot_{report_name}_{experiment_name}.csv', index=False)
133
+ reports[report_name].append(report.copy())
134
+
135
+
136
+ if __name__ == '__main__':
137
+ cli = CLI(main)
src/{run_xgboost_experiments.py → run_experiments_xgboost.py} RENAMED
File without changes