Commit
·
ccc40da
1
Parent(s):
e1370eb
Implemented cell lines one-hot and amino acid sequence count experiments
Browse files
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -117,6 +117,8 @@ def pytorch_model_objective(
|
|
117 |
logger_save_dir: str = 'logs',
|
118 |
logger_name: str = 'cv_model',
|
119 |
enable_checkpointing: bool = False,
|
|
|
|
|
120 |
) -> float:
|
121 |
""" Objective function for hyperparameter optimization.
|
122 |
|
@@ -139,6 +141,8 @@ def pytorch_model_objective(
|
|
139 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
140 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
141 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
|
|
|
|
142 |
apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
|
143 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
144 |
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
@@ -252,6 +256,8 @@ def hyperparameter_tuning_and_training(
|
|
252 |
max_epochs: int = 100,
|
253 |
study_filename: Optional[str] = None,
|
254 |
force_study: bool = False,
|
|
|
|
|
255 |
) -> tuple:
|
256 |
""" Hyperparameter tuning and training of a PROTAC model.
|
257 |
|
@@ -263,7 +269,7 @@ def hyperparameter_tuning_and_training(
|
|
263 |
test_df (pd.DataFrame): The test set.
|
264 |
kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
|
265 |
groups (np.array): The groups for the StratifiedGroupKFold.
|
266 |
-
split_type (str): The split type.
|
267 |
n_models_for_test (int): The number of models to train for the test set.
|
268 |
fast_dev_run (bool): Whether to run a fast development run.
|
269 |
n_trials (int): The number of trials for the hyperparameter search.
|
@@ -322,6 +328,8 @@ def hyperparameter_tuning_and_training(
|
|
322 |
active_label=active_label,
|
323 |
max_epochs=max_epochs,
|
324 |
disabled_embeddings=[],
|
|
|
|
|
325 |
),
|
326 |
n_trials=n_trials,
|
327 |
)
|
@@ -354,6 +362,8 @@ def hyperparameter_tuning_and_training(
|
|
354 |
logger_save_dir=logger_save_dir,
|
355 |
logger_name=f'{logger_name}_{split_type}_cv_model',
|
356 |
enable_checkpointing=True,
|
|
|
|
|
357 |
)
|
358 |
|
359 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
|
|
117 |
logger_save_dir: str = 'logs',
|
118 |
logger_name: str = 'cv_model',
|
119 |
enable_checkpointing: bool = False,
|
120 |
+
use_cells_one_hot: bool = False,
|
121 |
+
use_amino_acid_count: bool = False,
|
122 |
) -> float:
|
123 |
""" Objective function for hyperparameter optimization.
|
124 |
|
|
|
141 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
142 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
143 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
144 |
+
if use_cells_one_hot or use_amino_acid_count:
|
145 |
+
use_smote = False
|
146 |
apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
|
147 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
148 |
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
|
|
256 |
max_epochs: int = 100,
|
257 |
study_filename: Optional[str] = None,
|
258 |
force_study: bool = False,
|
259 |
+
use_cells_one_hot: bool = False,
|
260 |
+
use_amino_acid_count: bool = False,
|
261 |
) -> tuple:
|
262 |
""" Hyperparameter tuning and training of a PROTAC model.
|
263 |
|
|
|
269 |
test_df (pd.DataFrame): The test set.
|
270 |
kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
|
271 |
groups (np.array): The groups for the StratifiedGroupKFold.
|
272 |
+
split_type (str): The split type of the current study. Used for reporting.
|
273 |
n_models_for_test (int): The number of models to train for the test set.
|
274 |
fast_dev_run (bool): Whether to run a fast development run.
|
275 |
n_trials (int): The number of trials for the hyperparameter search.
|
|
|
328 |
active_label=active_label,
|
329 |
max_epochs=max_epochs,
|
330 |
disabled_embeddings=[],
|
331 |
+
use_cells_one_hot=use_cells_one_hot,
|
332 |
+
use_amino_acid_count=use_amino_acid_count,
|
333 |
),
|
334 |
n_trials=n_trials,
|
335 |
)
|
|
|
362 |
logger_save_dir=logger_save_dir,
|
363 |
logger_name=f'{logger_name}_{split_type}_cv_model',
|
364 |
enable_checkpointing=True,
|
365 |
+
use_cells_one_hot=use_cells_one_hot,
|
366 |
+
use_amino_acid_count=use_amino_acid_count,
|
367 |
)
|
368 |
|
369 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -23,8 +23,7 @@ from torchmetrics import (
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
26 |
-
from sklearn.preprocessing import StandardScaler
|
27 |
-
from sklearn.feature_extraction.text import CountVectorizer
|
28 |
|
29 |
|
30 |
class PROTAC_Predictor(nn.Module):
|
@@ -429,8 +428,6 @@ def train_model(
|
|
429 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
430 |
return_predictions: bool = False,
|
431 |
shuffle_embedding_prob: float = 0.0,
|
432 |
-
use_cells_one_hot: bool = False,
|
433 |
-
use_amino_acid_count: bool = False,
|
434 |
) -> tuple:
|
435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
436 |
|
@@ -464,25 +461,6 @@ def train_model(
|
|
464 |
Returns:
|
465 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
466 |
"""
|
467 |
-
if use_cells_one_hot:
|
468 |
-
# Get one-hot encoded embeddings for cell lines
|
469 |
-
onehotenc = OneHotEncoder(sparse_output=False)
|
470 |
-
cell_embeddings = onehotenc.fit_transform(
|
471 |
-
np.array(list(cell2embedding.keys()))
|
472 |
-
)
|
473 |
-
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
474 |
-
|
475 |
-
if use_amino_acid_count:
|
476 |
-
# Get count vectorized embeddings for proteins
|
477 |
-
# NOTE: Check that the protein2embedding is a dictionary of strings
|
478 |
-
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
479 |
-
raise ValueError("All keys in `protein2embedding` must be strings.")
|
480 |
-
countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
|
481 |
-
protein_embeddings = countvec.fit_transform(
|
482 |
-
list(protein2embedding.keys())
|
483 |
-
)
|
484 |
-
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
485 |
-
|
486 |
train_ds, val_ds, test_ds = get_datasets(
|
487 |
train_df,
|
488 |
val_df,
|
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
26 |
+
from sklearn.preprocessing import StandardScaler
|
|
|
27 |
|
28 |
|
29 |
class PROTAC_Predictor(nn.Module):
|
|
|
428 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
429 |
return_predictions: bool = False,
|
430 |
shuffle_embedding_prob: float = 0.0,
|
|
|
|
|
431 |
) -> tuple:
|
432 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
433 |
|
|
|
461 |
Returns:
|
462 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
463 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
train_ds, val_ds, test_ds = get_datasets(
|
465 |
train_df,
|
466 |
val_df,
|
src/run_experiments_aminoacid_counts.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from collections import defaultdict
|
4 |
+
import warnings
|
5 |
+
import logging
|
6 |
+
from typing import Literal
|
7 |
+
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
+
|
10 |
+
import protac_degradation_predictor as pdp
|
11 |
+
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
from rdkit import Chem
|
14 |
+
from rdkit.Chem import AllChem
|
15 |
+
from rdkit import DataStructs
|
16 |
+
from jsonargparse import CLI
|
17 |
+
import pandas as pd
|
18 |
+
from tqdm import tqdm
|
19 |
+
import numpy as np
|
20 |
+
from sklearn.model_selection import (
|
21 |
+
StratifiedKFold,
|
22 |
+
StratifiedGroupKFold,
|
23 |
+
)
|
24 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
25 |
+
|
26 |
+
# Ignore UserWarning from Matplotlib
|
27 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
28 |
+
# Ignore UserWarning from PyTorch Lightning
|
29 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
30 |
+
|
31 |
+
root = logging.getLogger()
|
32 |
+
root.setLevel(logging.DEBUG)
|
33 |
+
|
34 |
+
handler = logging.StreamHandler(sys.stdout)
|
35 |
+
handler.setLevel(logging.DEBUG)
|
36 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
37 |
+
handler.setFormatter(formatter)
|
38 |
+
root.addHandler(handler)
|
39 |
+
|
40 |
+
def main(
|
41 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
42 |
+
n_trials: int = 100,
|
43 |
+
fast_dev_run: bool = False,
|
44 |
+
test_split: float = 0.1,
|
45 |
+
cv_n_splits: int = 5,
|
46 |
+
max_epochs: int = 100,
|
47 |
+
force_study: bool = False,
|
48 |
+
experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
49 |
+
):
|
50 |
+
""" Run experiments with the cells one-hot encoding model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
active_col (str): Name of the column containing the active values.
|
54 |
+
n_trials (int): Number of hyperparameter optimization trials.
|
55 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
56 |
+
test_split (float): Percentage of data to use for testing.
|
57 |
+
cv_n_splits (int): Number of cross-validation splits.
|
58 |
+
max_epochs (int): Maximum number of epochs to train the model.
|
59 |
+
force_study (bool): Whether to force the creation of a new study.
|
60 |
+
experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
61 |
+
"""
|
62 |
+
|
63 |
+
# Make directory ../reports if it does not exist
|
64 |
+
if not os.path.exists('../reports'):
|
65 |
+
os.makedirs('../reports')
|
66 |
+
|
67 |
+
# Load embedding dictionaries
|
68 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
69 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
70 |
+
|
71 |
+
# Create a new protein2embedding dictionary with amino acid sequence
|
72 |
+
protac_df = pdp.load_curated_dataset()
|
73 |
+
# Create the dictionary mapping 'Uniprot' to 'POI Sequence'
|
74 |
+
protein2embedding = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
|
75 |
+
# Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
|
76 |
+
e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
|
77 |
+
# Merge the two dictionaries into a new protein2embedding dictionary
|
78 |
+
protein2embedding.update(e32seq)
|
79 |
+
|
80 |
+
# Get count vectorized embeddings for proteins
|
81 |
+
# NOTE: Check that the protein2embedding is a dictionary of strings
|
82 |
+
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
83 |
+
raise ValueError("All keys in `protein2embedding` must be strings.")
|
84 |
+
countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
|
85 |
+
protein_embeddings = countvec.fit_transform(
|
86 |
+
list(protein2embedding.keys())
|
87 |
+
).toarray()
|
88 |
+
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
89 |
+
|
90 |
+
studies_dir = '../data/studies'
|
91 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
92 |
+
test_perc = f'{int(test_split * 100)}'
|
93 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
94 |
+
|
95 |
+
if experiments == 'all':
|
96 |
+
experiments = ['standard', 'similarity', 'target']
|
97 |
+
|
98 |
+
# Cross-Validation Training
|
99 |
+
reports = defaultdict(list)
|
100 |
+
for split_type in experiments:
|
101 |
+
|
102 |
+
train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
103 |
+
test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
|
104 |
+
|
105 |
+
train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
|
106 |
+
test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
|
107 |
+
|
108 |
+
# Get SMILES and precompute fingerprints dictionary
|
109 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
110 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
111 |
+
|
112 |
+
# Get the CV object
|
113 |
+
if split_type == 'standard':
|
114 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
115 |
+
group = None
|
116 |
+
elif split_type == 'e3_ligase':
|
117 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
118 |
+
group = train_val_df['E3 Group'].to_numpy()
|
119 |
+
elif split_type == 'similarity':
|
120 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
121 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
122 |
+
elif split_type == 'target':
|
123 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
124 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
125 |
+
|
126 |
+
# Start the experiment
|
127 |
+
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
128 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
129 |
+
protein2embedding=protein2embedding,
|
130 |
+
cell2embedding=cell2embedding,
|
131 |
+
smiles2fp=smiles2fp,
|
132 |
+
train_val_df=train_val_df,
|
133 |
+
test_df=test_df,
|
134 |
+
kf=kf,
|
135 |
+
groups=group,
|
136 |
+
split_type=split_type,
|
137 |
+
n_models_for_test=3,
|
138 |
+
fast_dev_run=fast_dev_run,
|
139 |
+
n_trials=n_trials,
|
140 |
+
max_epochs=max_epochs,
|
141 |
+
logger_save_dir='../logs',
|
142 |
+
logger_name=f'logs_{experiment_name}',
|
143 |
+
active_label=active_col,
|
144 |
+
study_filename=f'../reports/study_aminoacidcnt_{experiment_name}.pkl',
|
145 |
+
force_study=force_study,
|
146 |
+
use_amino_acid_count=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Save the reports to file
|
150 |
+
for report_name, report in optuna_reports.items():
|
151 |
+
report.to_csv(f'../reports/aminoacidcnt_{report_name}_{experiment_name}.csv', index=False)
|
152 |
+
reports[report_name].append(report.copy())
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
cli = CLI(main)
|
src/run_experiments_cells_onehot.py
CHANGED
@@ -22,6 +22,7 @@ from sklearn.model_selection import (
|
|
22 |
StratifiedKFold,
|
23 |
StratifiedGroupKFold,
|
24 |
)
|
|
|
25 |
|
26 |
# Ignore UserWarning from Matplotlib
|
27 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
@@ -69,6 +70,13 @@ def main(
|
|
69 |
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
70 |
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
studies_dir = '../data/studies'
|
73 |
train_val_perc = f'{int((1 - test_split) * 100)}'
|
74 |
test_perc = f'{int(test_split * 100)}'
|
@@ -125,6 +133,7 @@ def main(
|
|
125 |
active_label=active_col,
|
126 |
study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
|
127 |
force_study=force_study,
|
|
|
128 |
)
|
129 |
|
130 |
# Save the reports to file
|
|
|
22 |
StratifiedKFold,
|
23 |
StratifiedGroupKFold,
|
24 |
)
|
25 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
|
26 |
|
27 |
# Ignore UserWarning from Matplotlib
|
28 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
|
|
70 |
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
71 |
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
72 |
|
73 |
+
# Get one-hot encoded embeddings for cell lines
|
74 |
+
onehotenc = OneHotEncoder(sparse_output=False)
|
75 |
+
cell_embeddings = onehotenc.fit_transform(
|
76 |
+
np.array(list(cell2embedding.keys())).reshape(-1, 1)
|
77 |
+
)
|
78 |
+
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
79 |
+
|
80 |
studies_dir = '../data/studies'
|
81 |
train_val_perc = f'{int((1 - test_split) * 100)}'
|
82 |
test_perc = f'{int(test_split * 100)}'
|
|
|
133 |
active_label=active_col,
|
134 |
study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
|
135 |
force_study=force_study,
|
136 |
+
use_cells_one_hot=True,
|
137 |
)
|
138 |
|
139 |
# Save the reports to file
|