Commit
·
ed339ed
1
Parent(s):
394ed39
Started renaming splits. Added datasets in paper as separate CSV.
Browse files- README.md +14 -0
- data/studies/e3_ligase_test_10split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/e3_ligase_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/similarity_test_10split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/similarity_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/standard_test_10split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/standard_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/target_test_10split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- data/studies/target_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv +0 -0
- notebooks/predict_unknown_protacs.ipynb +0 -0
- protac_degradation_predictor/optuna_utils.py +15 -164
- src/README.md +70 -0
- src/get_studies_datasets.py +289 -0
README.md
CHANGED
@@ -1,3 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |

|
2 |
<a href="https://colab.research.google.com/github/ribesstefano/PROTAC-Degradation-Predictor/blob/main/notebooks/protac_degradation_predictor_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
3 |
[](https://huggingface.co/spaces/ailab-bio/PROTAC-Degradation-Predictor)
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
title: PROTAC-Degradation-Predictor
|
4 |
+
emoji: 🧬
|
5 |
+
colorFrom: pink
|
6 |
+
colorTo: green
|
7 |
+
sdk: gradio
|
8 |
+
sdk_version: 4.37.2
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
license: mit
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |

|
16 |
<a href="https://colab.research.google.com/github/ribesstefano/PROTAC-Degradation-Predictor/blob/main/notebooks/protac_degradation_predictor_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
17 |
[](https://huggingface.co/spaces/ailab-bio/PROTAC-Degradation-Predictor)
|
data/studies/e3_ligase_test_10split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/e3_ligase_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/similarity_test_10split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/similarity_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/standard_test_10split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/standard_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/target_test_10split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/studies/target_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/predict_unknown_protacs.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -9,25 +9,11 @@ from .pytorch_models import (
|
|
9 |
)
|
10 |
from .protac_dataset import get_datasets
|
11 |
|
12 |
-
from .sklearn_models import (
|
13 |
-
train_sklearn_model,
|
14 |
-
suggest_random_forest,
|
15 |
-
suggest_logistic_regression,
|
16 |
-
suggest_svc,
|
17 |
-
suggest_gradient_boosting,
|
18 |
-
)
|
19 |
-
|
20 |
import torch
|
21 |
import optuna
|
22 |
from optuna.samplers import TPESampler
|
23 |
import joblib
|
24 |
import pandas as pd
|
25 |
-
from sklearn.ensemble import (
|
26 |
-
RandomForestClassifier,
|
27 |
-
GradientBoostingClassifier,
|
28 |
-
)
|
29 |
-
from sklearn.linear_model import LogisticRegression
|
30 |
-
from sklearn.svm import SVC
|
31 |
from sklearn.model_selection import (
|
32 |
StratifiedKFold,
|
33 |
StratifiedGroupKFold,
|
@@ -270,14 +256,23 @@ def hyperparameter_tuning_and_training(
|
|
270 |
""" Hyperparameter tuning and training of a PROTAC model.
|
271 |
|
272 |
Args:
|
273 |
-
|
274 |
-
|
|
|
|
|
275 |
test_df (pd.DataFrame): The test set.
|
|
|
|
|
|
|
|
|
276 |
fast_dev_run (bool): Whether to run a fast development run.
|
277 |
-
n_trials (int): The number of hyperparameter
|
278 |
-
|
|
|
279 |
active_label (str): The active label column.
|
280 |
-
|
|
|
|
|
281 |
|
282 |
Returns:
|
283 |
tuple: The trained model, the trainer, and the best metrics.
|
@@ -507,148 +502,4 @@ def hyperparameter_tuning_and_training(
|
|
507 |
}
|
508 |
if not fast_dev_run:
|
509 |
ret['majority_vote_report'] = majority_vote_report
|
510 |
-
return ret
|
511 |
-
|
512 |
-
|
513 |
-
def sklearn_model_objective(
|
514 |
-
trial: optuna.Trial,
|
515 |
-
protein2embedding: Dict,
|
516 |
-
cell2embedding: Dict,
|
517 |
-
smiles2fp: Dict,
|
518 |
-
train_df: pd.DataFrame,
|
519 |
-
val_df: pd.DataFrame,
|
520 |
-
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
521 |
-
active_label: str = 'Active',
|
522 |
-
) -> float:
|
523 |
-
""" Objective function for hyperparameter optimization.
|
524 |
-
|
525 |
-
Args:
|
526 |
-
trial (optuna.Trial): The Optuna trial object.
|
527 |
-
train_df (pd.DataFrame): The training set.
|
528 |
-
val_df (pd.DataFrame): The validation set.
|
529 |
-
model_type (str): The model type.
|
530 |
-
hyperparameters (Dict): The hyperparameters for the model.
|
531 |
-
fast_dev_run (bool): Whether to run a fast development run.
|
532 |
-
active_label (str): The active label column.
|
533 |
-
"""
|
534 |
-
|
535 |
-
# Generate the hyperparameters
|
536 |
-
use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
|
537 |
-
if model_type == 'RandomForest':
|
538 |
-
clf = suggest_random_forest(trial)
|
539 |
-
elif model_type == 'SVC':
|
540 |
-
clf = suggest_svc(trial)
|
541 |
-
elif model_type == 'LogisticRegression':
|
542 |
-
clf = suggest_logistic_regression(trial)
|
543 |
-
elif model_type == 'GradientBoosting':
|
544 |
-
clf = suggest_gradient_boosting(trial)
|
545 |
-
else:
|
546 |
-
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
547 |
-
|
548 |
-
# Train the model with the current set of hyperparameters
|
549 |
-
_, metrics = train_sklearn_model(
|
550 |
-
clf=clf,
|
551 |
-
protein2embedding=protein2embedding,
|
552 |
-
cell2embedding=cell2embedding,
|
553 |
-
smiles2fp=smiles2fp,
|
554 |
-
train_df=train_df,
|
555 |
-
val_df=val_df,
|
556 |
-
active_label=active_label,
|
557 |
-
use_single_scaler=use_single_scaler,
|
558 |
-
)
|
559 |
-
|
560 |
-
# Metrics is a dictionary containing at least the validation loss
|
561 |
-
val_acc = metrics['val_acc']
|
562 |
-
val_roc_auc = metrics['val_roc_auc']
|
563 |
-
|
564 |
-
# Optuna aims to minimize the sklearn_model_objective
|
565 |
-
return - val_acc - val_roc_auc
|
566 |
-
|
567 |
-
|
568 |
-
def hyperparameter_tuning_and_training_sklearn(
|
569 |
-
protein2embedding: Dict,
|
570 |
-
cell2embedding: Dict,
|
571 |
-
smiles2fp: Dict,
|
572 |
-
train_df: pd.DataFrame,
|
573 |
-
val_df: pd.DataFrame,
|
574 |
-
test_df: Optional[pd.DataFrame] = None,
|
575 |
-
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
576 |
-
active_label: str = 'Active',
|
577 |
-
n_trials: int = 50,
|
578 |
-
logger_name: str = 'protac_hparam_search_sklearn',
|
579 |
-
study_filename: Optional[str] = None,
|
580 |
-
) -> Tuple:
|
581 |
-
""" Hyperparameter tuning and training of a PROTAC model.
|
582 |
-
|
583 |
-
Args:
|
584 |
-
train_df (pd.DataFrame): The training set.
|
585 |
-
val_df (pd.DataFrame): The validation set.
|
586 |
-
test_df (pd.DataFrame): The test set.
|
587 |
-
model_type (str): The model type.
|
588 |
-
n_trials (int): The number of hyperparameter optimization trials.
|
589 |
-
logger_name (str): The name of the logger. Unused, for compatibility with hyperparameter_tuning_and_training.
|
590 |
-
active_label (str): The active label column.
|
591 |
-
|
592 |
-
Returns:
|
593 |
-
tuple: The trained model and the best metrics.
|
594 |
-
"""
|
595 |
-
# Set the verbosity of Optuna
|
596 |
-
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
597 |
-
# Create an Optuna study object
|
598 |
-
sampler = TPESampler(seed=42, multivariate=True)
|
599 |
-
study = optuna.create_study(direction='minimize', sampler=sampler)
|
600 |
-
|
601 |
-
study_loaded = False
|
602 |
-
if study_filename:
|
603 |
-
if os.path.exists(study_filename):
|
604 |
-
study = joblib.load(study_filename)
|
605 |
-
study_loaded = True
|
606 |
-
logging.info(f'Loaded study from {study_filename}')
|
607 |
-
|
608 |
-
if not study_loaded:
|
609 |
-
study.optimize(
|
610 |
-
lambda trial: sklearn_model_objective(
|
611 |
-
trial=trial,
|
612 |
-
protein2embedding=protein2embedding,
|
613 |
-
cell2embedding=cell2embedding,
|
614 |
-
smiles2fp=smiles2fp,
|
615 |
-
train_df=train_df,
|
616 |
-
val_df=val_df,
|
617 |
-
model_type=model_type,
|
618 |
-
active_label=active_label,
|
619 |
-
),
|
620 |
-
n_trials=n_trials,
|
621 |
-
)
|
622 |
-
if study_filename:
|
623 |
-
joblib.dump(study, study_filename)
|
624 |
-
|
625 |
-
# Retrain the model with the best hyperparameters
|
626 |
-
best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
|
627 |
-
if model_type == 'RandomForest':
|
628 |
-
clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
|
629 |
-
elif model_type == 'SVC':
|
630 |
-
clf = SVC(random_state=42, probability=True, **best_hyperparameters)
|
631 |
-
elif model_type == 'LogisticRegression':
|
632 |
-
clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
|
633 |
-
elif model_type == 'GradientBoosting':
|
634 |
-
clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
|
635 |
-
else:
|
636 |
-
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
637 |
-
|
638 |
-
model, metrics = train_sklearn_model(
|
639 |
-
clf=clf,
|
640 |
-
protein2embedding=protein2embedding,
|
641 |
-
cell2embedding=cell2embedding,
|
642 |
-
smiles2fp=smiles2fp,
|
643 |
-
train_df=train_df,
|
644 |
-
val_df=val_df,
|
645 |
-
test_df=test_df,
|
646 |
-
active_label=active_label,
|
647 |
-
use_single_scaler=study.best_params['use_single_scaler'],
|
648 |
-
)
|
649 |
-
|
650 |
-
# Report the best hyperparameters found
|
651 |
-
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
652 |
-
|
653 |
-
# Return the best metrics
|
654 |
-
return model, metrics
|
|
|
9 |
)
|
10 |
from .protac_dataset import get_datasets
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import torch
|
13 |
import optuna
|
14 |
from optuna.samplers import TPESampler
|
15 |
import joblib
|
16 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from sklearn.model_selection import (
|
18 |
StratifiedKFold,
|
19 |
StratifiedGroupKFold,
|
|
|
256 |
""" Hyperparameter tuning and training of a PROTAC model.
|
257 |
|
258 |
Args:
|
259 |
+
protein2embedding (Dict): The protein to embedding dictionary.
|
260 |
+
cell2embedding (Dict): The cell to embedding dictionary.
|
261 |
+
smiles2fp (Dict): The SMILES to fingerprint dictionary.
|
262 |
+
train_val_df (pd.DataFrame): The training and validation set.
|
263 |
test_df (pd.DataFrame): The test set.
|
264 |
+
kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
|
265 |
+
groups (np.array): The groups for the StratifiedGroupKFold.
|
266 |
+
split_type (str): The split type.
|
267 |
+
n_models_for_test (int): The number of models to train for the test set.
|
268 |
fast_dev_run (bool): Whether to run a fast development run.
|
269 |
+
n_trials (int): The number of trials for the hyperparameter search.
|
270 |
+
logger_save_dir (str): The logger save directory.
|
271 |
+
logger_name (str): The logger name.
|
272 |
active_label (str): The active label column.
|
273 |
+
max_epochs (int): The maximum number of epochs.
|
274 |
+
study_filename (str): The study filename.
|
275 |
+
force_study (bool): Whether to force the study.
|
276 |
|
277 |
Returns:
|
278 |
tuple: The trained model, the trainer, and the best metrics.
|
|
|
502 |
}
|
503 |
if not fast_dev_run:
|
504 |
ret['majority_vote_report'] = majority_vote_report
|
505 |
+
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training Models
|
2 |
+
|
3 |
+
## Dataset Specification
|
4 |
+
|
5 |
+
From the repository top level directory, run the following command to get the datasets reported in the paper:
|
6 |
+
|
7 |
+
```bash
|
8 |
+
cd src
|
9 |
+
python get_studies_datasets.py
|
10 |
+
```
|
11 |
+
|
12 |
+
For training on custom datasets, please refer to the class `PROTAC_Dataset` in the file [`protac_dataset.py`](../protac_degradation_predictor/protac_dataset.py). The class expects a Pandas dataframe, so plase assemble a file to be parsed into a Pandas DataFrame with the following columns:
|
13 |
+
|
14 |
+
| Column Name | Type | Description |
|
15 |
+
| --- | --- | --- |
|
16 |
+
| Smiles | str | The SMILES representation of the PROTAC molecule. |
|
17 |
+
| Uniprot | str | The Uniprot ID of the target protein. |
|
18 |
+
| E3 Ligase Uniprot | str | The Uniprot ID of the E3 ligase. |
|
19 |
+
| Cell Line Identifier | str | The cell line identifier as one reported in Cellosaurus. |
|
20 |
+
| `<active_label>` | bool | The activity label of the PROTAC molecule to be predicted by the model. |
|
21 |
+
|
22 |
+
The column `<active_label>` is set _"Active"_ as default in the `PROTAC_Dataset` class and in the `hyperparameter_tuning_and_training` function (see below for how to use it).
|
23 |
+
|
24 |
+
## Training on Custom Data
|
25 |
+
|
26 |
+
For training on custom datasets, please refer to the function `hyperparameter_tuning_and_training` in [`optuna_utils.py`](../protac_degradation_predictor/optuna_utils.py) and the file [`run_experiments.py`](../src/run_experiments.py) for inspiration on how to use the function.
|
27 |
+
|
28 |
+
An example of skeleton implementation is as follows:
|
29 |
+
|
30 |
+
```python
|
31 |
+
import protac_degradation_predictor as pdp
|
32 |
+
import pandas as pd
|
33 |
+
import numpy as np
|
34 |
+
from sklearn.model_selection import StratifiedKFold
|
35 |
+
|
36 |
+
# Load train/val and test dataframes
|
37 |
+
train_val_df = pd.read_csv('path/to/custom_dataset.csv')
|
38 |
+
test_df = pd.read_csv('path/to/test_dataset.csv') # Load one of our test datasets
|
39 |
+
|
40 |
+
# NOTE: Make sure to avoid data leakage by removing leaking data in the train/val
|
41 |
+
# dataframe. Do NOT do remove/alter the test set, as it would impair comparison
|
42 |
+
# with our work. Data leakage can occur if the test set contains any combination
|
43 |
+
# of SMILES, Uniprot, E3 Ligase Uniprot, or Cell Line Identifier that is present
|
44 |
+
# in the train/val set too.
|
45 |
+
|
46 |
+
# Precompute Morgan fingerprints
|
47 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
48 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
49 |
+
|
50 |
+
# Load embedding dictionaries
|
51 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
52 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
53 |
+
|
54 |
+
# Setup Cross-Validation object
|
55 |
+
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
56 |
+
pdp.hyperparameter_tuning_and_training(
|
57 |
+
protein2embedding=protein2embedding,
|
58 |
+
cell2embedding=cell2embedding,
|
59 |
+
smiles2fp=smiles2fp,
|
60 |
+
train_val_df=train_val_df,
|
61 |
+
test_df=test_df,
|
62 |
+
kf=kf,
|
63 |
+
n_models_for_test=3,
|
64 |
+
n_trials=100,
|
65 |
+
max_epochs=20,
|
66 |
+
logger_save_dir='../logs',
|
67 |
+
logger_name=f'logs_{experiment_name}',
|
68 |
+
study_filename=f'../reports/study_{experiment_name}.pkl',
|
69 |
+
)
|
70 |
+
```
|
src/get_studies_datasets.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
import protac_degradation_predictor as pdp
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
import warnings
|
9 |
+
import logging
|
10 |
+
from typing import Literal
|
11 |
+
|
12 |
+
from sklearn.preprocessing import OrdinalEncoder
|
13 |
+
from tqdm import tqdm
|
14 |
+
import pandas as pd
|
15 |
+
import numpy as np
|
16 |
+
import pytorch_lightning as pl
|
17 |
+
from rdkit import DataStructs
|
18 |
+
|
19 |
+
|
20 |
+
root = logging.getLogger()
|
21 |
+
root.setLevel(logging.DEBUG)
|
22 |
+
|
23 |
+
handler = logging.StreamHandler(sys.stdout)
|
24 |
+
handler.setLevel(logging.DEBUG)
|
25 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
26 |
+
handler.setFormatter(formatter)
|
27 |
+
root.addHandler(handler)
|
28 |
+
|
29 |
+
|
30 |
+
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
|
31 |
+
""" Get the indices of the test set using a random split.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
35 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
pd.Index: The indices of the test set.
|
39 |
+
"""
|
40 |
+
test_df = active_df.sample(frac=test_split, random_state=42)
|
41 |
+
return test_df.index
|
42 |
+
|
43 |
+
|
44 |
+
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
|
45 |
+
""" Get the indices of the test set using the E3 ligase split.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
pd.Index: The indices of the test set.
|
52 |
+
"""
|
53 |
+
encoder = OrdinalEncoder()
|
54 |
+
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
|
55 |
+
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
56 |
+
return test_df.index
|
57 |
+
|
58 |
+
|
59 |
+
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
60 |
+
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
67 |
+
"""
|
68 |
+
unique_smiles = protac_df['Smiles'].unique().tolist()
|
69 |
+
|
70 |
+
smiles2fp = {}
|
71 |
+
for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
|
72 |
+
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
73 |
+
|
74 |
+
# # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
75 |
+
# tanimoto_matrix = defaultdict(list)
|
76 |
+
# for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
77 |
+
# fp1 = smiles2fp[smiles1]
|
78 |
+
# # TODO: Use BulkTanimotoSimilarity for better performance
|
79 |
+
# for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
|
80 |
+
# fp2 = smiles2fp[smiles2]
|
81 |
+
# tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
82 |
+
# tanimoto_matrix[smiles1].append(tanimoto_dist)
|
83 |
+
# avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
84 |
+
# protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
85 |
+
|
86 |
+
|
87 |
+
tanimoto_matrix = defaultdict(list)
|
88 |
+
fps = list(smiles2fp.values())
|
89 |
+
|
90 |
+
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
|
91 |
+
for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
|
92 |
+
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
|
93 |
+
for j, similarity in enumerate(similarities):
|
94 |
+
distance = 1 - similarity
|
95 |
+
tanimoto_matrix[smiles1].append(distance) # Store as distance
|
96 |
+
if i != i + j:
|
97 |
+
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
|
98 |
+
|
99 |
+
# Calculate average Tanimoto distance for each unique SMILES
|
100 |
+
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
101 |
+
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
102 |
+
|
103 |
+
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
104 |
+
|
105 |
+
return smiles2fp, protac_df
|
106 |
+
|
107 |
+
|
108 |
+
def get_tanimoto_split_indices(
|
109 |
+
active_df: pd.DataFrame,
|
110 |
+
active_col: str,
|
111 |
+
test_split: float,
|
112 |
+
n_bins_tanimoto: int = 200,
|
113 |
+
) -> pd.Index:
|
114 |
+
""" Get the indices of the test set using the Tanimoto-based split.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
118 |
+
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
pd.Index: The indices of the test set.
|
122 |
+
"""
|
123 |
+
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
124 |
+
encoder = OrdinalEncoder()
|
125 |
+
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
126 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
127 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
128 |
+
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
129 |
+
|
130 |
+
test_df = []
|
131 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
132 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
133 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
134 |
+
# in the active_col in test_df is roughly 50%.
|
135 |
+
for group in tanimoto_groups:
|
136 |
+
group_df = active_df[active_df['Tanimoto Group'] == group]
|
137 |
+
if test_df == []:
|
138 |
+
test_df.append(group_df)
|
139 |
+
continue
|
140 |
+
|
141 |
+
num_entries = len(group_df)
|
142 |
+
num_active_group = group_df[active_col].sum()
|
143 |
+
num_inactive_group = num_entries - num_active_group
|
144 |
+
|
145 |
+
tmp_test_df = pd.concat(test_df)
|
146 |
+
num_entries_test = len(tmp_test_df)
|
147 |
+
num_active_test = tmp_test_df[active_col].sum()
|
148 |
+
num_inactive_test = num_entries_test - num_active_test
|
149 |
+
|
150 |
+
# Check if the group entries can be added to the test_df
|
151 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
152 |
+
# Add anything at the beggining
|
153 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
154 |
+
test_df.append(group_df)
|
155 |
+
continue
|
156 |
+
# Be more selective and make sure that the percentage of active and
|
157 |
+
# inactive is balanced
|
158 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
159 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
160 |
+
test_df.append(group_df)
|
161 |
+
test_df = pd.concat(test_df)
|
162 |
+
return test_df.index
|
163 |
+
|
164 |
+
|
165 |
+
def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
|
166 |
+
""" Get the indices of the test set using the target-based split.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
170 |
+
active_col (str): The column containing the active/inactive information.
|
171 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
pd.Index: The indices of the test set.
|
175 |
+
"""
|
176 |
+
encoder = OrdinalEncoder()
|
177 |
+
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
|
178 |
+
|
179 |
+
test_df = []
|
180 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
181 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
182 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
183 |
+
# in the active_col in test_df is roughly 50%.
|
184 |
+
# Start the loop from the groups containing the smallest number of entries.
|
185 |
+
for group in reversed(active_df['Uniprot'].value_counts().index):
|
186 |
+
group_df = active_df[active_df['Uniprot'] == group]
|
187 |
+
if test_df == []:
|
188 |
+
test_df.append(group_df)
|
189 |
+
continue
|
190 |
+
|
191 |
+
num_entries = len(group_df)
|
192 |
+
num_active_group = group_df[active_col].sum()
|
193 |
+
num_inactive_group = num_entries - num_active_group
|
194 |
+
|
195 |
+
tmp_test_df = pd.concat(test_df)
|
196 |
+
num_entries_test = len(tmp_test_df)
|
197 |
+
num_active_test = tmp_test_df[active_col].sum()
|
198 |
+
num_inactive_test = num_entries_test - num_active_test
|
199 |
+
|
200 |
+
# Check if the group entries can be added to the test_df
|
201 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
202 |
+
# Add anything at the beggining
|
203 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
204 |
+
test_df.append(group_df)
|
205 |
+
continue
|
206 |
+
# Be more selective and make sure that the percentage of active and
|
207 |
+
# inactive is balanced
|
208 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
209 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
210 |
+
test_df.append(group_df)
|
211 |
+
test_df = pd.concat(test_df)
|
212 |
+
return test_df.index
|
213 |
+
|
214 |
+
|
215 |
+
def main(
|
216 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
217 |
+
test_split: float = 0.1,
|
218 |
+
studies: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
219 |
+
):
|
220 |
+
""" Get and save the datasets for the different studies.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
active_col (str): The column containing the active/inactive information. It should be in the format 'Active (Dmax N, pDC50 M)', where N and M are the thresholds float values for Dmax and pDC50, respectively.
|
224 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
225 |
+
studies (str): The type of studies to save dataset for. Options: 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
226 |
+
"""
|
227 |
+
pl.seed_everything(42)
|
228 |
+
|
229 |
+
# Set the Column to Predict
|
230 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
231 |
+
|
232 |
+
# Get Dmax_threshold from the active_col
|
233 |
+
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
|
234 |
+
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
|
235 |
+
|
236 |
+
# Load the PROTAC dataset
|
237 |
+
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
238 |
+
# Map E3 Ligase Iap to IAP
|
239 |
+
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
240 |
+
protac_df[active_col] = protac_df.apply(
|
241 |
+
lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
|
242 |
+
)
|
243 |
+
_, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
|
244 |
+
|
245 |
+
## Get the test sets
|
246 |
+
test_indeces = {}
|
247 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
248 |
+
|
249 |
+
# Remove legacy column 'Active - OR' if it exists
|
250 |
+
if 'Active - OR' in active_df.columns:
|
251 |
+
active_df.drop(columns='Active - OR', inplace=True)
|
252 |
+
|
253 |
+
if studies == 'standard' or studies == 'all':
|
254 |
+
test_indeces['standard'] = get_random_split_indices(active_df, test_split)
|
255 |
+
if studies == 'target' or studies == 'all':
|
256 |
+
test_indeces['target'] = get_target_split_indices(active_df, active_col, test_split)
|
257 |
+
if studies == 'e3_ligase' or studies == 'all':
|
258 |
+
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
259 |
+
if studies == 'similarity' or studies == 'all':
|
260 |
+
test_indeces['similarity'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
261 |
+
|
262 |
+
# Make directory for studies datasets if it does not exist
|
263 |
+
data_dir = '../data/studies'
|
264 |
+
if not os.path.exists(data_dir):
|
265 |
+
os.makedirs(data_dir)
|
266 |
+
|
267 |
+
# Cross-Validation Training
|
268 |
+
for split_type, indeces in test_indeces.items():
|
269 |
+
test_df = active_df.loc[indeces].copy()
|
270 |
+
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
271 |
+
|
272 |
+
# Save the datasets
|
273 |
+
|
274 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
275 |
+
test_perc = f'{int(test_split * 100)}'
|
276 |
+
|
277 |
+
train_val_filename = f'{data_dir}/{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
278 |
+
test_filename = f'{data_dir}/{split_type}_test_{test_perc}split_{active_name}.csv'
|
279 |
+
|
280 |
+
print('')
|
281 |
+
print(f'Saving train_val datasets as: {train_val_filename}')
|
282 |
+
print(f'Saving test datasets as: {test_filename}')
|
283 |
+
|
284 |
+
train_val_df.to_csv(train_val_filename, index=False)
|
285 |
+
test_df.to_csv(test_filename, index=False)
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
main()
|