Commit
·
5e01175
1
Parent(s):
6101de8
Started working on packaging the repository
Browse files- notebooks/plotting_dragradation_activity_performance.ipynb +0 -0
- notebooks/protac_degradation_predictor.ipynb +10 -2
- notebooks/protac_degradation_predictor.py +3 -2
- protac_degradation_predictor/__init__.py +7 -0
- protac_degradation_predictor/config.py +37 -0
- protac_degradation_predictor/data/PROTAC-DB.csv +0 -0
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl +2 -2
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5 +2 -2
- protac_degradation_predictor/data_utils.py +46 -0
- protac_degradation_predictor/optuna_utils.py +318 -0
- protac_degradation_predictor/protac_dataset.py +193 -0
- protac_degradation_predictor/protac_degradation_predictor.py +88 -0
- protac_degradation_predictor/pytorch_models.py +471 -0
- protac_degradation_predictor/sklearn_models.py +243 -0
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl +0 -3
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl +0 -3
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl +0 -3
- setup.py +21 -0
notebooks/plotting_dragradation_activity_performance.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/protac_degradation_predictor.ipynb
CHANGED
|
@@ -1719,8 +1719,16 @@
|
|
| 1719 |
}
|
| 1720 |
],
|
| 1721 |
"source": [
|
| 1722 |
-
"
|
| 1723 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1724 |
"\n",
|
| 1725 |
"# Generic function to fit and evaluate a classifier model (given as argument),\n",
|
| 1726 |
"# on train and val sets (and optionally a test set) given as dataframes\n",
|
|
|
|
| 1719 |
}
|
| 1720 |
],
|
| 1721 |
"source": [
|
| 1722 |
+
"import torch\n",
|
| 1723 |
+
"import torch.nn as nn\n",
|
| 1724 |
+
"from torchmetrics import (\n",
|
| 1725 |
+
" Accuracy,\n",
|
| 1726 |
+
" AUROC,\n",
|
| 1727 |
+
" Precision,\n",
|
| 1728 |
+
" Recall,\n",
|
| 1729 |
+
" F1Score,\n",
|
| 1730 |
+
" MetricCollection,\n",
|
| 1731 |
+
")\n",
|
| 1732 |
"\n",
|
| 1733 |
"# Generic function to fit and evaluate a classifier model (given as argument),\n",
|
| 1734 |
"# on train and val sets (and optionally a test set) given as dataframes\n",
|
notebooks/protac_degradation_predictor.py
CHANGED
|
@@ -680,7 +680,7 @@ def train_model(
|
|
| 680 |
hidden_dim (int): The hidden dimension of the model.
|
| 681 |
batch_size (int): The batch size.
|
| 682 |
learning_rate (float): The learning rate.
|
| 683 |
-
max_epochs (int):
|
| 684 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
| 685 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
| 686 |
fast_dev_run (bool): Whether to run a fast development run.
|
|
@@ -985,6 +985,8 @@ def main(
|
|
| 985 |
encoder = OrdinalEncoder()
|
| 986 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
| 987 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
|
|
|
|
|
|
| 988 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
| 989 |
|
| 990 |
test_df = []
|
|
@@ -992,7 +994,6 @@ def main(
|
|
| 992 |
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 993 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 994 |
# in the active_col in test_df is roughly 50%.
|
| 995 |
-
# Start the loop from the groups containing the smallest number of entries.
|
| 996 |
for group in tanimoto_groups:
|
| 997 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
| 998 |
if test_df == []:
|
|
|
|
| 680 |
hidden_dim (int): The hidden dimension of the model.
|
| 681 |
batch_size (int): The batch size.
|
| 682 |
learning_rate (float): The learning rate.
|
| 683 |
+
max_epochs (int): Th e maximum number of epochs.
|
| 684 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
| 685 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
| 686 |
fast_dev_run (bool): Whether to run a fast development run.
|
|
|
|
| 985 |
encoder = OrdinalEncoder()
|
| 986 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
| 987 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
| 988 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
| 989 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
| 990 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
| 991 |
|
| 992 |
test_df = []
|
|
|
|
| 994 |
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 995 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 996 |
# in the active_col in test_df is roughly 50%.
|
|
|
|
| 997 |
for group in tanimoto_groups:
|
| 998 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
| 999 |
if test_df == []:
|
protac_degradation_predictor/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .protac_degradation_predictor import (
|
| 2 |
+
PROTAC_Model,
|
| 3 |
+
train_model,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
| 7 |
+
__author__ = "Stefano Ribes"
|
protac_degradation_predictor/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
@dataclass(frozen=True)
|
| 4 |
+
class Config:
|
| 5 |
+
# Embeddings information
|
| 6 |
+
morgan_radius: int = 15
|
| 7 |
+
fingerprint_size: int = 224
|
| 8 |
+
protein_embedding_size: int = 1024
|
| 9 |
+
cell_embedding_size: int = 768
|
| 10 |
+
|
| 11 |
+
# Data information
|
| 12 |
+
dmax_threshold: float = 0.6
|
| 13 |
+
pdc50_threshold: float = 6.0
|
| 14 |
+
e3_ligase2uniprot: dict = {
|
| 15 |
+
'VHL': 'P40337',
|
| 16 |
+
'CRBN': 'Q96SW2',
|
| 17 |
+
'DCAF11': 'Q8TEB1',
|
| 18 |
+
'DCAF15': 'Q66K64',
|
| 19 |
+
'DCAF16': 'Q9NXF7',
|
| 20 |
+
'MDM2': 'Q00987',
|
| 21 |
+
'Mdm2': 'Q00987',
|
| 22 |
+
'XIAP': 'P98170',
|
| 23 |
+
'cIAP1': 'Q7Z460',
|
| 24 |
+
'IAP': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
|
| 25 |
+
'Iap': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
|
| 26 |
+
'AhR': 'P35869',
|
| 27 |
+
'RNF4': 'P78317',
|
| 28 |
+
'RNF114': 'Q9Y508',
|
| 29 |
+
'FEM1B': 'Q9UK73',
|
| 30 |
+
'Ubr1': 'Q8IWV7',
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def __post_init__(self):
|
| 34 |
+
self.active_label: str = f'Active (Dmax {self.dmax_threshold}, pDC50 {self.pdc50_threshold})'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
config = Config()
|
protac_degradation_predictor/data/PROTAC-DB.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:627e8ce3842afeb6bb7d5caa5ec1ba034c36dc77fab70734e15dca340a7fd718
|
| 3 |
+
size 3550864
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19f4b8c73652392db7840962d1a7817c7e899716e2bb758e4947c8c2bb265336
|
| 3 |
+
size 51089512
|
protac_degradation_predictor/data_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pkg_resources
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
from config import config
|
| 7 |
+
|
| 8 |
+
import h5py
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from joblib import Memory
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
home_dir = os.path.expanduser('~')
|
| 17 |
+
cachedir = os.path.join(home_dir, '.cache', 'protac_degradation_predictor')
|
| 18 |
+
memory = Memory(cachedir, verbose=0)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@memory.cache
|
| 22 |
+
def load_protein2embedding() -> Dict[str, np.ndarray]:
|
| 23 |
+
embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
|
| 24 |
+
protein2embedding = {}
|
| 25 |
+
with h5py.File(embeddings_path, "r") as file:
|
| 26 |
+
for sequence_id in file.keys():
|
| 27 |
+
embedding = file[sequence_id][:]
|
| 28 |
+
protein2embedding[sequence_id] = np.array(embedding)
|
| 29 |
+
return protein2embedding
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@memory.cache
|
| 33 |
+
def load_cell2embedding() -> Dict[str, np.ndarray]:
|
| 34 |
+
embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
|
| 35 |
+
with open(embeddings_path, 'rb') as f:
|
| 36 |
+
cell2embedding = pickle.load(f)
|
| 37 |
+
return cell2embedding
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_fingerprint(smiles: str) -> np.ndarray:
|
| 41 |
+
morgan_fpgen = AllChem.GetMorganGenerator(
|
| 42 |
+
radius=config.morgan_radius,
|
| 43 |
+
fpSize=config.fingerprint_size,
|
| 44 |
+
includeChirality=True,
|
| 45 |
+
)
|
| 46 |
+
return morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
|
protac_degradation_predictor/optuna_utils.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
| 3 |
+
|
| 4 |
+
from pytorch_models import train_model
|
| 5 |
+
from sklearn_models import (
|
| 6 |
+
train_sklearn_model,
|
| 7 |
+
suggest_random_forest,
|
| 8 |
+
suggest_logistic_regression,
|
| 9 |
+
suggest_svc,
|
| 10 |
+
suggest_gradient_boosting,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
import optuna
|
| 14 |
+
from optuna.samplers import TPESampler
|
| 15 |
+
import joblib
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from sklearn.ensemble import (
|
| 18 |
+
RandomForestClassifier,
|
| 19 |
+
GradientBoostingClassifier,
|
| 20 |
+
)
|
| 21 |
+
from sklearn.linear_model import LogisticRegression
|
| 22 |
+
from sklearn.svm import SVC
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pytorch_model_objective(
|
| 26 |
+
trial: optuna.Trial,
|
| 27 |
+
protein2embedding: Dict,
|
| 28 |
+
cell2embedding: Dict,
|
| 29 |
+
smiles2fp: Dict,
|
| 30 |
+
train_df: pd.DataFrame,
|
| 31 |
+
val_df: pd.DataFrame,
|
| 32 |
+
hidden_dim_options: List[int] = [256, 512, 768],
|
| 33 |
+
batch_size_options: List[int] = [8, 16, 32],
|
| 34 |
+
learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
|
| 35 |
+
smote_k_neighbors_options: List[int] = list(range(3, 16)),
|
| 36 |
+
dropout_options: Tuple[float, float] = (0.1, 0.5),
|
| 37 |
+
fast_dev_run: bool = False,
|
| 38 |
+
active_label: str = 'Active',
|
| 39 |
+
disabled_embeddings: List[str] = [],
|
| 40 |
+
max_epochs: int = 100,
|
| 41 |
+
) -> float:
|
| 42 |
+
""" Objective function for hyperparameter optimization.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 46 |
+
train_df (pd.DataFrame): The training set.
|
| 47 |
+
val_df (pd.DataFrame): The validation set.
|
| 48 |
+
hidden_dim_options (List[int]): The hidden dimension options.
|
| 49 |
+
batch_size_options (List[int]): The batch size options.
|
| 50 |
+
learning_rate_options (Tuple[float, float]): The learning rate options.
|
| 51 |
+
smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
|
| 52 |
+
dropout_options (Tuple[float, float]): The dropout options.
|
| 53 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 54 |
+
active_label (str): The active label column.
|
| 55 |
+
disabled_embeddings (List[str]): The list of disabled embeddings.
|
| 56 |
+
"""
|
| 57 |
+
# Generate the hyperparameters
|
| 58 |
+
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
| 59 |
+
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
| 60 |
+
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
| 61 |
+
join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
|
| 62 |
+
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
| 63 |
+
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
| 64 |
+
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
|
| 65 |
+
dropout = trial.suggest_float('dropout', *dropout_options)
|
| 66 |
+
|
| 67 |
+
# Train the model with the current set of hyperparameters
|
| 68 |
+
_, _, metrics = train_model(
|
| 69 |
+
protein2embedding,
|
| 70 |
+
cell2embedding,
|
| 71 |
+
smiles2fp,
|
| 72 |
+
train_df,
|
| 73 |
+
val_df,
|
| 74 |
+
hidden_dim=hidden_dim,
|
| 75 |
+
batch_size=batch_size,
|
| 76 |
+
join_embeddings=join_embeddings,
|
| 77 |
+
learning_rate=learning_rate,
|
| 78 |
+
dropout=dropout,
|
| 79 |
+
max_epochs=max_epochs,
|
| 80 |
+
smote_k_neighbors=smote_k_neighbors,
|
| 81 |
+
apply_scaling=apply_scaling,
|
| 82 |
+
use_smote=use_smote,
|
| 83 |
+
use_logger=False,
|
| 84 |
+
fast_dev_run=fast_dev_run,
|
| 85 |
+
active_label=active_label,
|
| 86 |
+
disabled_embeddings=disabled_embeddings,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Metrics is a dictionary containing at least the validation loss
|
| 90 |
+
val_loss = metrics['val_loss']
|
| 91 |
+
val_acc = metrics['val_acc']
|
| 92 |
+
val_roc_auc = metrics['val_roc_auc']
|
| 93 |
+
|
| 94 |
+
# Optuna aims to minimize the pytorch_model_objective
|
| 95 |
+
return val_loss - val_acc - val_roc_auc
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def hyperparameter_tuning_and_training(
|
| 99 |
+
protein2embedding: Dict,
|
| 100 |
+
cell2embedding: Dict,
|
| 101 |
+
smiles2fp: Dict,
|
| 102 |
+
train_df: pd.DataFrame,
|
| 103 |
+
val_df: pd.DataFrame,
|
| 104 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 105 |
+
fast_dev_run: bool = False,
|
| 106 |
+
n_trials: int = 50,
|
| 107 |
+
logger_name: str = 'protac_hparam_search',
|
| 108 |
+
active_label: str = 'Active',
|
| 109 |
+
disabled_embeddings: List[str] = [],
|
| 110 |
+
study_filename: Optional[str] = None,
|
| 111 |
+
) -> tuple:
|
| 112 |
+
""" Hyperparameter tuning and training of a PROTAC model.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
train_df (pd.DataFrame): The training set.
|
| 116 |
+
val_df (pd.DataFrame): The validation set.
|
| 117 |
+
test_df (pd.DataFrame): The test set.
|
| 118 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 119 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
| 120 |
+
logger_name (str): The name of the logger.
|
| 121 |
+
active_label (str): The active label column.
|
| 122 |
+
disabled_embeddings (List[str]): The list of disabled embeddings.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
tuple: The trained model, the trainer, and the best metrics.
|
| 126 |
+
"""
|
| 127 |
+
# Define the search space
|
| 128 |
+
hidden_dim_options = [256, 512, 768]
|
| 129 |
+
batch_size_options = [8, 16, 32]
|
| 130 |
+
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
|
| 131 |
+
smote_k_neighbors_options = list(range(3, 16))
|
| 132 |
+
|
| 133 |
+
# Set the verbosity of Optuna
|
| 134 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 135 |
+
# Create an Optuna study object
|
| 136 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
| 137 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
| 138 |
+
|
| 139 |
+
study_loaded = False
|
| 140 |
+
if study_filename:
|
| 141 |
+
if os.path.exists(study_filename):
|
| 142 |
+
study = joblib.load(study_filename)
|
| 143 |
+
study_loaded = True
|
| 144 |
+
print(f'Loaded study from {study_filename}')
|
| 145 |
+
|
| 146 |
+
if not study_loaded:
|
| 147 |
+
study.optimize(
|
| 148 |
+
lambda trial: pytorch_model_objective(
|
| 149 |
+
trial=trial,
|
| 150 |
+
protein2embedding=protein2embedding,
|
| 151 |
+
cell2embedding=cell2embedding,
|
| 152 |
+
smiles2fp=smiles2fp,
|
| 153 |
+
train_df=train_df,
|
| 154 |
+
val_df=val_df,
|
| 155 |
+
hidden_dim_options=hidden_dim_options,
|
| 156 |
+
batch_size_options=batch_size_options,
|
| 157 |
+
learning_rate_options=learning_rate_options,
|
| 158 |
+
smote_k_neighbors_options=smote_k_neighbors_options,
|
| 159 |
+
fast_dev_run=fast_dev_run,
|
| 160 |
+
active_label=active_label,
|
| 161 |
+
disabled_embeddings=disabled_embeddings,
|
| 162 |
+
),
|
| 163 |
+
n_trials=n_trials,
|
| 164 |
+
)
|
| 165 |
+
if study_filename:
|
| 166 |
+
joblib.dump(study, study_filename)
|
| 167 |
+
|
| 168 |
+
# Retrain the model with the best hyperparameters
|
| 169 |
+
model, trainer, metrics = train_model(
|
| 170 |
+
protein2embedding=protein2embedding,
|
| 171 |
+
cell2embedding=cell2embedding,
|
| 172 |
+
smiles2fp=smiles2fp,
|
| 173 |
+
train_df=train_df,
|
| 174 |
+
val_df=val_df,
|
| 175 |
+
test_df=test_df,
|
| 176 |
+
use_logger=True,
|
| 177 |
+
logger_name=logger_name,
|
| 178 |
+
fast_dev_run=fast_dev_run,
|
| 179 |
+
active_label=active_label,
|
| 180 |
+
disabled_embeddings=disabled_embeddings,
|
| 181 |
+
**study.best_params,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Report the best hyperparameters found
|
| 185 |
+
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
| 186 |
+
|
| 187 |
+
# Return the best metrics
|
| 188 |
+
return model, trainer, metrics
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def sklearn_model_objective(
|
| 192 |
+
trial: optuna.Trial,
|
| 193 |
+
protein2embedding: Dict,
|
| 194 |
+
cell2embedding: Dict,
|
| 195 |
+
smiles2fp: Dict,
|
| 196 |
+
train_df: pd.DataFrame,
|
| 197 |
+
val_df: pd.DataFrame,
|
| 198 |
+
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
| 199 |
+
active_label: str = 'Active',
|
| 200 |
+
) -> float:
|
| 201 |
+
""" Objective function for hyperparameter optimization.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 205 |
+
train_df (pd.DataFrame): The training set.
|
| 206 |
+
val_df (pd.DataFrame): The validation set.
|
| 207 |
+
model_type (str): The model type.
|
| 208 |
+
hyperparameters (Dict): The hyperparameters for the model.
|
| 209 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 210 |
+
active_label (str): The active label column.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
# Generate the hyperparameters
|
| 214 |
+
use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
|
| 215 |
+
if model_type == 'RandomForest':
|
| 216 |
+
clf = suggest_random_forest(trial)
|
| 217 |
+
elif model_type == 'SVC':
|
| 218 |
+
clf = suggest_svc(trial)
|
| 219 |
+
elif model_type == 'LogisticRegression':
|
| 220 |
+
clf = suggest_logistic_regression(trial)
|
| 221 |
+
elif model_type == 'GradientBoosting':
|
| 222 |
+
clf = suggest_gradient_boosting(trial)
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
| 225 |
+
|
| 226 |
+
# Train the model with the current set of hyperparameters
|
| 227 |
+
_, metrics = train_sklearn_model(
|
| 228 |
+
clf=clf,
|
| 229 |
+
protein2embedding=protein2embedding,
|
| 230 |
+
cell2embedding=cell2embedding,
|
| 231 |
+
smiles2fp=smiles2fp,
|
| 232 |
+
train_df=train_df,
|
| 233 |
+
val_df=val_df,
|
| 234 |
+
active_label=active_label,
|
| 235 |
+
use_single_scaler=use_single_scaler,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Metrics is a dictionary containing at least the validation loss
|
| 239 |
+
val_acc = metrics['val_acc']
|
| 240 |
+
val_roc_auc = metrics['val_roc_auc']
|
| 241 |
+
|
| 242 |
+
# Optuna aims to minimize the sklearn_model_objective
|
| 243 |
+
return - val_acc - val_roc_auc
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def hyperparameter_tuning_and_training_sklearn(
|
| 247 |
+
protein2embedding: Dict,
|
| 248 |
+
cell2embedding: Dict,
|
| 249 |
+
smiles2fp: Dict,
|
| 250 |
+
train_df: pd.DataFrame,
|
| 251 |
+
val_df: pd.DataFrame,
|
| 252 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 253 |
+
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
| 254 |
+
active_label: str = 'Active',
|
| 255 |
+
n_trials: int = 50,
|
| 256 |
+
logger_name: str = 'protac_hparam_search',
|
| 257 |
+
study_filename: Optional[str] = None,
|
| 258 |
+
) -> Tuple:
|
| 259 |
+
# Set the verbosity of Optuna
|
| 260 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 261 |
+
# Create an Optuna study object
|
| 262 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
| 263 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
| 264 |
+
|
| 265 |
+
study_loaded = False
|
| 266 |
+
if study_filename:
|
| 267 |
+
if os.path.exists(study_filename):
|
| 268 |
+
study = joblib.load(study_filename)
|
| 269 |
+
study_loaded = True
|
| 270 |
+
print(f'Loaded study from {study_filename}')
|
| 271 |
+
|
| 272 |
+
if not study_loaded:
|
| 273 |
+
study.optimize(
|
| 274 |
+
lambda trial: sklearn_model_objective(
|
| 275 |
+
trial=trial,
|
| 276 |
+
protein2embedding=protein2embedding,
|
| 277 |
+
cell2embedding=cell2embedding,
|
| 278 |
+
smiles2fp=smiles2fp,
|
| 279 |
+
train_df=train_df,
|
| 280 |
+
val_df=val_df,
|
| 281 |
+
model_type=model_type,
|
| 282 |
+
active_label=active_label,
|
| 283 |
+
),
|
| 284 |
+
n_trials=n_trials,
|
| 285 |
+
)
|
| 286 |
+
if study_filename:
|
| 287 |
+
joblib.dump(study, study_filename)
|
| 288 |
+
|
| 289 |
+
# Retrain the model with the best hyperparameters
|
| 290 |
+
best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
|
| 291 |
+
if model_type == 'RandomForest':
|
| 292 |
+
clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
|
| 293 |
+
elif model_type == 'SVC':
|
| 294 |
+
clf = SVC(random_state=42, probability=True, **best_hyperparameters)
|
| 295 |
+
elif model_type == 'LogisticRegression':
|
| 296 |
+
clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
|
| 297 |
+
elif model_type == 'GradientBoosting':
|
| 298 |
+
clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
|
| 299 |
+
else:
|
| 300 |
+
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
| 301 |
+
|
| 302 |
+
model, metrics = train_sklearn_model(
|
| 303 |
+
clf=clf,
|
| 304 |
+
protein2embedding=protein2embedding,
|
| 305 |
+
cell2embedding=cell2embedding,
|
| 306 |
+
smiles2fp=smiles2fp,
|
| 307 |
+
train_df=train_df,
|
| 308 |
+
val_df=val_df,
|
| 309 |
+
test_df=test_df,
|
| 310 |
+
active_label=active_label,
|
| 311 |
+
use_single_scaler=study.best_params['use_single_scaler'],
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Report the best hyperparameters found
|
| 315 |
+
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
| 316 |
+
|
| 317 |
+
# Return the best metrics
|
| 318 |
+
return model, metrics
|
protac_degradation_predictor/protac_dataset.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
| 2 |
+
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import numpy as np
|
| 5 |
+
from imblearn.over_sampling import SMOTE, ADASYN
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PROTAC_Dataset(Dataset):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
protac_df: pd.DataFrame,
|
| 14 |
+
protein2embedding: Dict,
|
| 15 |
+
cell2embedding: Dict,
|
| 16 |
+
smiles2fp: Dict,
|
| 17 |
+
use_smote: bool = False,
|
| 18 |
+
oversampler: Optional[SMOTE | ADASYN] = None,
|
| 19 |
+
active_label: str = 'Active',
|
| 20 |
+
):
|
| 21 |
+
""" Initialize the PROTAC dataset
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
protac_df (pd.DataFrame): The PROTAC dataframe
|
| 25 |
+
protein2embedding (dict): Dictionary of protein embeddings
|
| 26 |
+
cell2embedding (dict): Dictionary of cell line embeddings
|
| 27 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
| 28 |
+
use_smote (bool): Whether to use SMOTE for oversampling
|
| 29 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column
|
| 30 |
+
"""
|
| 31 |
+
# Filter out examples with NaN in active_col column
|
| 32 |
+
self.data = protac_df # [~protac_df[active_col].isna()]
|
| 33 |
+
self.protein2embedding = protein2embedding
|
| 34 |
+
self.cell2embedding = cell2embedding
|
| 35 |
+
self.smiles2fp = smiles2fp
|
| 36 |
+
self.active_label = active_label
|
| 37 |
+
self.use_single_scaler = None
|
| 38 |
+
|
| 39 |
+
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
|
| 40 |
+
self.protein_emb_dim = protein2embedding[list(
|
| 41 |
+
protein2embedding.keys())[0]].shape[0]
|
| 42 |
+
self.cell_emb_dim = cell2embedding[list(
|
| 43 |
+
cell2embedding.keys())[0]].shape[0]
|
| 44 |
+
|
| 45 |
+
# Look up the embeddings
|
| 46 |
+
self.data = pd.DataFrame({
|
| 47 |
+
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
|
| 48 |
+
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
|
| 49 |
+
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
|
| 50 |
+
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
|
| 51 |
+
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
# Apply SMOTE
|
| 55 |
+
self.use_smote = use_smote
|
| 56 |
+
self.oversampler = oversampler
|
| 57 |
+
if self.use_smote:
|
| 58 |
+
self.apply_smote()
|
| 59 |
+
|
| 60 |
+
def apply_smote(self):
|
| 61 |
+
# Prepare the dataset for SMOTE
|
| 62 |
+
features = []
|
| 63 |
+
labels = []
|
| 64 |
+
for _, row in self.data.iterrows():
|
| 65 |
+
features.append(np.hstack([
|
| 66 |
+
row['Smiles'],
|
| 67 |
+
row['Uniprot'],
|
| 68 |
+
row['E3 Ligase Uniprot'],
|
| 69 |
+
row['Cell Line Identifier'],
|
| 70 |
+
]))
|
| 71 |
+
labels.append(row[self.active_label])
|
| 72 |
+
|
| 73 |
+
# Convert to numpy array
|
| 74 |
+
features = np.array(features).astype(np.float32)
|
| 75 |
+
labels = np.array(labels).astype(np.float32)
|
| 76 |
+
|
| 77 |
+
# Initialize SMOTE and fit
|
| 78 |
+
if self.oversampler is None:
|
| 79 |
+
oversampler = SMOTE(random_state=42)
|
| 80 |
+
else:
|
| 81 |
+
oversampler = self.oversampler
|
| 82 |
+
features_smote, labels_smote = oversampler.fit_resample(features, labels)
|
| 83 |
+
|
| 84 |
+
# Separate the features back into their respective embeddings
|
| 85 |
+
smiles_embs = features_smote[:, :self.smiles_emb_dim]
|
| 86 |
+
poi_embs = features_smote[:,
|
| 87 |
+
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
|
| 88 |
+
e3_embs = features_smote[:, self.smiles_emb_dim +
|
| 89 |
+
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
|
| 90 |
+
cell_embs = features_smote[:, -self.cell_emb_dim:]
|
| 91 |
+
|
| 92 |
+
# Reconstruct the dataframe with oversampled data
|
| 93 |
+
df_smote = pd.DataFrame({
|
| 94 |
+
'Smiles': list(smiles_embs),
|
| 95 |
+
'Uniprot': list(poi_embs),
|
| 96 |
+
'E3 Ligase Uniprot': list(e3_embs),
|
| 97 |
+
'Cell Line Identifier': list(cell_embs),
|
| 98 |
+
self.active_label: labels_smote
|
| 99 |
+
})
|
| 100 |
+
self.data = df_smote
|
| 101 |
+
|
| 102 |
+
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
| 103 |
+
""" Fit the scalers for the data.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
| 107 |
+
scaler_kwargs: Keyword arguments for the StandardScaler.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
dict: The fitted scalers.
|
| 111 |
+
"""
|
| 112 |
+
if use_single_scaler:
|
| 113 |
+
self.use_single_scaler = True
|
| 114 |
+
scaler = StandardScaler(**scaler_kwargs)
|
| 115 |
+
embeddings = np.hstack([
|
| 116 |
+
np.array(self.data['Smiles'].tolist()),
|
| 117 |
+
np.array(self.data['Uniprot'].tolist()),
|
| 118 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
| 119 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
| 120 |
+
])
|
| 121 |
+
scaler.fit(embeddings)
|
| 122 |
+
return scaler
|
| 123 |
+
else:
|
| 124 |
+
self.use_single_scaler = False
|
| 125 |
+
scalers = {}
|
| 126 |
+
scalers['Smiles'] = StandardScaler(**scaler_kwargs)
|
| 127 |
+
scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
|
| 128 |
+
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
|
| 129 |
+
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
|
| 130 |
+
|
| 131 |
+
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
|
| 132 |
+
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
|
| 133 |
+
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
|
| 134 |
+
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
|
| 135 |
+
|
| 136 |
+
return scalers
|
| 137 |
+
|
| 138 |
+
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
|
| 139 |
+
""" Apply scaling to the data.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
scalers (dict): The scalers for each feature.
|
| 143 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
| 144 |
+
"""
|
| 145 |
+
if self.use_single_scaler is None:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"The fit_scaling method must be called before apply_scaling.")
|
| 148 |
+
if use_single_scaler != self.use_single_scaler:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
|
| 151 |
+
if use_single_scaler:
|
| 152 |
+
embeddings = np.hstack([
|
| 153 |
+
np.array(self.data['Smiles'].tolist()),
|
| 154 |
+
np.array(self.data['Uniprot'].tolist()),
|
| 155 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
| 156 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
| 157 |
+
])
|
| 158 |
+
scaled_embeddings = scalers.transform(embeddings)
|
| 159 |
+
self.data = pd.DataFrame({
|
| 160 |
+
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
|
| 161 |
+
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
|
| 162 |
+
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
|
| 163 |
+
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
|
| 164 |
+
self.active_label: self.data[self.active_label]
|
| 165 |
+
})
|
| 166 |
+
else:
|
| 167 |
+
self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
|
| 168 |
+
self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
|
| 169 |
+
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
|
| 170 |
+
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
|
| 171 |
+
|
| 172 |
+
def get_numpy_arrays(self):
|
| 173 |
+
X = np.hstack([
|
| 174 |
+
np.array(self.data['Smiles'].tolist()),
|
| 175 |
+
np.array(self.data['Uniprot'].tolist()),
|
| 176 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
| 177 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
| 178 |
+
]).copy()
|
| 179 |
+
y = self.data[self.active_label].values.copy()
|
| 180 |
+
return X, y
|
| 181 |
+
|
| 182 |
+
def __len__(self):
|
| 183 |
+
return len(self.data)
|
| 184 |
+
|
| 185 |
+
def __getitem__(self, idx):
|
| 186 |
+
elem = {
|
| 187 |
+
'smiles_emb': self.data['Smiles'].iloc[idx],
|
| 188 |
+
'poi_emb': self.data['Uniprot'].iloc[idx],
|
| 189 |
+
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
|
| 190 |
+
'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
|
| 191 |
+
'active': self.data[self.active_label].iloc[idx],
|
| 192 |
+
}
|
| 193 |
+
return elem
|
protac_degradation_predictor/protac_degradation_predictor.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pkg_resources
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from pytorch_models import PROTAC_Model, load_model
|
| 5 |
+
from data_utils import (
|
| 6 |
+
load_protein2embedding,
|
| 7 |
+
load_cell2embedding,
|
| 8 |
+
get_fingerprint,
|
| 9 |
+
)
|
| 10 |
+
from config import config
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import sigmoid
|
| 15 |
+
|
| 16 |
+
package_name = 'protac_degradation_predictor'
|
| 17 |
+
|
| 18 |
+
def get_protac_active_proba(
|
| 19 |
+
protac_smiles: str,
|
| 20 |
+
e3_ligase: str,
|
| 21 |
+
target_uniprot: str,
|
| 22 |
+
cell_line: str,
|
| 23 |
+
device: str = 'cpu',
|
| 24 |
+
) -> bool:
|
| 25 |
+
ckpt_path = pkg_resources.resource_stream(__name__, 'data/model.ckpt')
|
| 26 |
+
model = load_model(ckpt_path).to(device)
|
| 27 |
+
protein2embedding = load_protein2embedding()
|
| 28 |
+
cell2embedding = load_cell2embedding()
|
| 29 |
+
|
| 30 |
+
# Setup default embeddings
|
| 31 |
+
if e3_ligase not in config.e3_ligase2uniprot:
|
| 32 |
+
available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
|
| 33 |
+
logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
|
| 34 |
+
if target_uniprot not in protein2embedding:
|
| 35 |
+
logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
|
| 36 |
+
if cell_line not in load_cell2embedding():
|
| 37 |
+
logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
|
| 38 |
+
|
| 39 |
+
default_protein_emb = np.zeros(config.protein_embedding_size)
|
| 40 |
+
default_cell_emb = np.zeros(config.cell_embedding_size)
|
| 41 |
+
|
| 42 |
+
# Convert the E3 ligase to Uniprot ID
|
| 43 |
+
e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
|
| 44 |
+
|
| 45 |
+
# Get the embeddings
|
| 46 |
+
poi_emb = protein2embedding.get(target_uniprot, default_protein_emb)
|
| 47 |
+
e3_emb = protein2embedding.get(e3_ligase_uniprot, default_protein_emb)
|
| 48 |
+
cell_emb = cell2embedding.get(cell_line, default_cell_emb)
|
| 49 |
+
smiles_emb = get_fingerprint(protac_smiles)
|
| 50 |
+
|
| 51 |
+
# Convert to torch tensors
|
| 52 |
+
poi_emb = torch.tensor(poi_emb).to(device)
|
| 53 |
+
e3_emb = torch.tensor(e3_emb).to(device)
|
| 54 |
+
cell_emb = torch.tensor(cell_emb).to(device)
|
| 55 |
+
smiles_emb = torch.tensor(smiles_emb).to(device)
|
| 56 |
+
|
| 57 |
+
return model(poi_emb, e3_emb, cell_emb, smiles_emb).item()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_protac_active(
|
| 61 |
+
protac_smiles: str,
|
| 62 |
+
e3_ligase: str,
|
| 63 |
+
target_uniprot: str,
|
| 64 |
+
cell_line: str,
|
| 65 |
+
device: str = 'cpu',
|
| 66 |
+
proba_threshold: float = 0.5,
|
| 67 |
+
) -> bool:
|
| 68 |
+
""" Predict whether a PROTAC is active or not.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
protac_smiles (str): The SMILES of the PROTAC.
|
| 72 |
+
e3_ligase (str): The Uniprot ID of the E3 ligase.
|
| 73 |
+
target_uniprot (str): The Uniprot ID of the target protein.
|
| 74 |
+
cell_line (str): The cell line identifier.
|
| 75 |
+
device (str): The device to run the model on.
|
| 76 |
+
proba_threshold (float): The probability threshold.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
bool: Whether the PROTAC is active or not.
|
| 80 |
+
"""
|
| 81 |
+
pred = get_protac_active_proba(
|
| 82 |
+
protac_smiles,
|
| 83 |
+
e3_ligase,
|
| 84 |
+
target_uniprot,
|
| 85 |
+
cell_line,
|
| 86 |
+
device,
|
| 87 |
+
)
|
| 88 |
+
return sigmoid(pred) > proba_threshold
|
protac_degradation_predictor/pytorch_models.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
| 3 |
+
|
| 4 |
+
from protac_dataset import PROTAC_Dataset
|
| 5 |
+
from config import Config
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from torchmetrics import (
|
| 16 |
+
Accuracy,
|
| 17 |
+
AUROC,
|
| 18 |
+
Precision,
|
| 19 |
+
Recall,
|
| 20 |
+
F1Score,
|
| 21 |
+
MetricCollection,
|
| 22 |
+
)
|
| 23 |
+
from imblearn.over_sampling import SMOTE
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PROTAC_Predictor(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
hidden_dim: int,
|
| 31 |
+
smiles_emb_dim: int = Config.fingerprint_size,
|
| 32 |
+
poi_emb_dim: int = Config.protein_embedding_size,
|
| 33 |
+
e3_emb_dim: int = Config.protein_embedding_size,
|
| 34 |
+
cell_emb_dim: int = Config.cell_embedding_size,
|
| 35 |
+
dropout: float = 0.2,
|
| 36 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
| 37 |
+
disabled_embeddings: list = [],
|
| 38 |
+
):
|
| 39 |
+
""" Initialize the PROTAC model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
hidden_dim (int): The hidden dimension of the model
|
| 43 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings
|
| 44 |
+
poi_emb_dim (int): The dimension of the POI embeddings
|
| 45 |
+
e3_emb_dim (int): The dimension of the E3 Ligase embeddings
|
| 46 |
+
cell_emb_dim (int): The dimension of the cell line embeddings
|
| 47 |
+
dropout (float): The dropout rate
|
| 48 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
| 49 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
| 50 |
+
"""
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.poi_emb_dim = poi_emb_dim
|
| 53 |
+
self.e3_emb_dim = e3_emb_dim
|
| 54 |
+
self.cell_emb_dim = cell_emb_dim
|
| 55 |
+
self.smiles_emb_dim = smiles_emb_dim
|
| 56 |
+
self.hidden_dim = hidden_dim
|
| 57 |
+
self.join_embeddings = join_embeddings
|
| 58 |
+
self.disabled_embeddings = disabled_embeddings
|
| 59 |
+
# Set our init args as class attributes
|
| 60 |
+
self.__dict__.update(locals())
|
| 61 |
+
|
| 62 |
+
# Define "surrogate models" branches
|
| 63 |
+
if self.join_embeddings != 'beginning':
|
| 64 |
+
if 'poi' not in self.disabled_embeddings:
|
| 65 |
+
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
|
| 66 |
+
if 'e3' not in self.disabled_embeddings:
|
| 67 |
+
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
|
| 68 |
+
if 'cell' not in self.disabled_embeddings:
|
| 69 |
+
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
|
| 70 |
+
if 'smiles' not in self.disabled_embeddings:
|
| 71 |
+
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
|
| 72 |
+
|
| 73 |
+
# Define hidden dimension for joining layer
|
| 74 |
+
if self.join_embeddings == 'beginning':
|
| 75 |
+
joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0
|
| 76 |
+
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
|
| 77 |
+
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
|
| 78 |
+
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
|
| 79 |
+
elif self.join_embeddings == 'concat':
|
| 80 |
+
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
|
| 81 |
+
elif self.join_embeddings == 'sum':
|
| 82 |
+
joint_dim = hidden_dim
|
| 83 |
+
|
| 84 |
+
self.fc0 = nn.Linear(joint_dim, joint_dim)
|
| 85 |
+
self.fc1 = nn.Linear(joint_dim, hidden_dim)
|
| 86 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 87 |
+
self.fc3 = nn.Linear(hidden_dim, 1)
|
| 88 |
+
|
| 89 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
| 93 |
+
embeddings = []
|
| 94 |
+
if self.join_embeddings == 'beginning':
|
| 95 |
+
if 'poi' not in self.disabled_embeddings:
|
| 96 |
+
embeddings.append(poi_emb)
|
| 97 |
+
if 'e3' not in self.disabled_embeddings:
|
| 98 |
+
embeddings.append(e3_emb)
|
| 99 |
+
if 'cell' not in self.disabled_embeddings:
|
| 100 |
+
embeddings.append(cell_emb)
|
| 101 |
+
if 'smiles' not in self.disabled_embeddings:
|
| 102 |
+
embeddings.append(smiles_emb)
|
| 103 |
+
x = torch.cat(embeddings, dim=1)
|
| 104 |
+
x = self.dropout(F.relu(self.fc0(x)))
|
| 105 |
+
else:
|
| 106 |
+
if 'poi' not in self.disabled_embeddings:
|
| 107 |
+
embeddings.append(self.poi_emb(poi_emb))
|
| 108 |
+
if 'e3' not in self.disabled_embeddings:
|
| 109 |
+
embeddings.append(self.e3_emb(e3_emb))
|
| 110 |
+
if 'cell' not in self.disabled_embeddings:
|
| 111 |
+
embeddings.append(self.cell_emb(cell_emb))
|
| 112 |
+
if 'smiles' not in self.disabled_embeddings:
|
| 113 |
+
embeddings.append(self.smiles_emb(smiles_emb))
|
| 114 |
+
if self.join_embeddings == 'concat':
|
| 115 |
+
x = torch.cat(embeddings, dim=1)
|
| 116 |
+
elif self.join_embeddings == 'sum':
|
| 117 |
+
if len(embeddings) > 1:
|
| 118 |
+
embeddings = torch.stack(embeddings, dim=1)
|
| 119 |
+
x = torch.sum(embeddings, dim=1)
|
| 120 |
+
else:
|
| 121 |
+
x = embeddings[0]
|
| 122 |
+
x = self.dropout(F.relu(self.fc1(x)))
|
| 123 |
+
x = self.dropout(F.relu(self.fc2(x)))
|
| 124 |
+
x = self.fc3(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PROTAC_Model(pl.LightningModule):
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
hidden_dim: int,
|
| 134 |
+
smiles_emb_dim: int = 224,
|
| 135 |
+
poi_emb_dim: int = 1024,
|
| 136 |
+
e3_emb_dim: int = 1024,
|
| 137 |
+
cell_emb_dim: int = 768,
|
| 138 |
+
batch_size: int = 32,
|
| 139 |
+
learning_rate: float = 1e-3,
|
| 140 |
+
dropout: float = 0.2,
|
| 141 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
| 142 |
+
train_dataset: PROTAC_Dataset = None,
|
| 143 |
+
val_dataset: PROTAC_Dataset = None,
|
| 144 |
+
test_dataset: PROTAC_Dataset = None,
|
| 145 |
+
disabled_embeddings: list = [],
|
| 146 |
+
apply_scaling: bool = False,
|
| 147 |
+
):
|
| 148 |
+
""" Initialize the PROTAC Pytorch Lightning model.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
hidden_dim (int): The hidden dimension of the model
|
| 152 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings
|
| 153 |
+
poi_emb_dim (int): The dimension of the POI embeddings
|
| 154 |
+
e3_emb_dim (int): The dimension of the E3 Ligase embeddings
|
| 155 |
+
cell_emb_dim (int): The dimension of the cell line embeddings
|
| 156 |
+
batch_size (int): The batch size
|
| 157 |
+
learning_rate (float): The learning rate
|
| 158 |
+
dropout (float): The dropout rate
|
| 159 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
| 160 |
+
train_dataset (PROTAC_Dataset): The training dataset
|
| 161 |
+
val_dataset (PROTAC_Dataset): The validation dataset
|
| 162 |
+
test_dataset (PROTAC_Dataset): The test dataset
|
| 163 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
| 164 |
+
apply_scaling (bool): Whether to apply scaling to the embeddings
|
| 165 |
+
"""
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.poi_emb_dim = poi_emb_dim
|
| 168 |
+
self.e3_emb_dim = e3_emb_dim
|
| 169 |
+
self.cell_emb_dim = cell_emb_dim
|
| 170 |
+
self.smiles_emb_dim = smiles_emb_dim
|
| 171 |
+
self.hidden_dim = hidden_dim
|
| 172 |
+
self.batch_size = batch_size
|
| 173 |
+
self.learning_rate = learning_rate
|
| 174 |
+
self.join_embeddings = join_embeddings
|
| 175 |
+
self.train_dataset = train_dataset
|
| 176 |
+
self.val_dataset = val_dataset
|
| 177 |
+
self.test_dataset = test_dataset
|
| 178 |
+
self.disabled_embeddings = disabled_embeddings
|
| 179 |
+
self.apply_scaling = apply_scaling
|
| 180 |
+
# Set our init args as class attributes
|
| 181 |
+
self.__dict__.update(locals()) # Add arguments as attributes
|
| 182 |
+
# Save the arguments passed to init
|
| 183 |
+
ignore_args_as_hyperparams = [
|
| 184 |
+
'train_dataset',
|
| 185 |
+
'test_dataset',
|
| 186 |
+
'val_dataset',
|
| 187 |
+
]
|
| 188 |
+
self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
|
| 189 |
+
|
| 190 |
+
self.model = PROTAC_Predictor(
|
| 191 |
+
hidden_dim=hidden_dim,
|
| 192 |
+
smiles_emb_dim=smiles_emb_dim,
|
| 193 |
+
poi_emb_dim=poi_emb_dim,
|
| 194 |
+
e3_emb_dim=e3_emb_dim,
|
| 195 |
+
cell_emb_dim=cell_emb_dim,
|
| 196 |
+
dropout=dropout,
|
| 197 |
+
join_embeddings=join_embeddings,
|
| 198 |
+
disabled_embeddings=disabled_embeddings,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
| 202 |
+
self.metrics = nn.ModuleDict({s: MetricCollection({
|
| 203 |
+
'acc': Accuracy(task='binary'),
|
| 204 |
+
'roc_auc': AUROC(task='binary'),
|
| 205 |
+
'precision': Precision(task='binary'),
|
| 206 |
+
'recall': Recall(task='binary'),
|
| 207 |
+
'f1_score': F1Score(task='binary'),
|
| 208 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
| 209 |
+
'hp_metric': Accuracy(task='binary'),
|
| 210 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
| 211 |
+
|
| 212 |
+
# Misc settings
|
| 213 |
+
self.missing_dataset_error = \
|
| 214 |
+
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
|
| 215 |
+
|
| 216 |
+
model = {1}.load_from_checkpoint('checkpoint.ckpt')
|
| 217 |
+
model.{0} = my_{0}
|
| 218 |
+
'''
|
| 219 |
+
|
| 220 |
+
# Apply scaling in datasets
|
| 221 |
+
if self.apply_scaling:
|
| 222 |
+
use_single_scaler = True if self.join_embeddings == 'beginning' else False
|
| 223 |
+
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
| 224 |
+
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 225 |
+
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 226 |
+
if self.test_dataset:
|
| 227 |
+
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 228 |
+
|
| 229 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
| 230 |
+
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
| 231 |
+
|
| 232 |
+
def step(self, batch, batch_idx, stage):
|
| 233 |
+
poi_emb = batch['poi_emb']
|
| 234 |
+
e3_emb = batch['e3_emb']
|
| 235 |
+
cell_emb = batch['cell_emb']
|
| 236 |
+
smiles_emb = batch['smiles_emb']
|
| 237 |
+
y = batch['active'].float().unsqueeze(1)
|
| 238 |
+
|
| 239 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
| 240 |
+
loss = F.binary_cross_entropy_with_logits(y_hat, y)
|
| 241 |
+
|
| 242 |
+
self.metrics[f'{stage}_metrics'].update(y_hat, y)
|
| 243 |
+
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
|
| 244 |
+
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
|
| 245 |
+
|
| 246 |
+
return loss
|
| 247 |
+
|
| 248 |
+
def training_step(self, batch, batch_idx):
|
| 249 |
+
return self.step(batch, batch_idx, 'train')
|
| 250 |
+
|
| 251 |
+
def validation_step(self, batch, batch_idx):
|
| 252 |
+
return self.step(batch, batch_idx, 'val')
|
| 253 |
+
|
| 254 |
+
def test_step(self, batch, batch_idx):
|
| 255 |
+
return self.step(batch, batch_idx, 'test')
|
| 256 |
+
|
| 257 |
+
def configure_optimizers(self):
|
| 258 |
+
return optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 259 |
+
|
| 260 |
+
def predict_step(self, batch, batch_idx):
|
| 261 |
+
poi_emb = batch['poi_emb']
|
| 262 |
+
e3_emb = batch['e3_emb']
|
| 263 |
+
cell_emb = batch['cell_emb']
|
| 264 |
+
smiles_emb = batch['smiles_emb']
|
| 265 |
+
|
| 266 |
+
if self.apply_scaling:
|
| 267 |
+
if self.join_embeddings == 'beginning':
|
| 268 |
+
embeddings = np.hstack([
|
| 269 |
+
np.array(smiles_emb.tolist()),
|
| 270 |
+
np.array(poi_emb.tolist()),
|
| 271 |
+
np.array(e3_emb.tolist()),
|
| 272 |
+
np.array(cell_emb.tolist()),
|
| 273 |
+
])
|
| 274 |
+
embeddings = self.scalers.transform(embeddings)
|
| 275 |
+
smiles_emb = embeddings[:, :self.smiles_emb_dim]
|
| 276 |
+
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
|
| 277 |
+
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
|
| 278 |
+
cell_emb = embeddings[:, -self.cell_emb_dim:]
|
| 279 |
+
else:
|
| 280 |
+
poi_emb = self.scalers['Uniprot'].transform(poi_emb)
|
| 281 |
+
e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
|
| 282 |
+
cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
|
| 283 |
+
smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
|
| 284 |
+
|
| 285 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
| 286 |
+
return torch.sigmoid(y_hat)
|
| 287 |
+
|
| 288 |
+
def train_dataloader(self):
|
| 289 |
+
if self.train_dataset is None:
|
| 290 |
+
format = 'train_dataset', self.__class__.__name__
|
| 291 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
| 292 |
+
|
| 293 |
+
return DataLoader(
|
| 294 |
+
self.train_dataset,
|
| 295 |
+
batch_size=self.batch_size,
|
| 296 |
+
shuffle=True,
|
| 297 |
+
# drop_last=True,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def val_dataloader(self):
|
| 301 |
+
if self.val_dataset is None:
|
| 302 |
+
format = 'val_dataset', self.__class__.__name__
|
| 303 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
| 304 |
+
return DataLoader(
|
| 305 |
+
self.val_dataset,
|
| 306 |
+
batch_size=self.batch_size,
|
| 307 |
+
shuffle=False,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def test_dataloader(self):
|
| 311 |
+
if self.test_dataset is None:
|
| 312 |
+
format = 'test_dataset', self.__class__.__name__
|
| 313 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
| 314 |
+
return DataLoader(
|
| 315 |
+
self.test_dataset,
|
| 316 |
+
batch_size=self.batch_size,
|
| 317 |
+
shuffle=False,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def train_model(
|
| 322 |
+
protein2embedding: Dict,
|
| 323 |
+
cell2embedding: Dict,
|
| 324 |
+
smiles2fp: Dict,
|
| 325 |
+
train_df: pd.DataFrame,
|
| 326 |
+
val_df: pd.DataFrame,
|
| 327 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 328 |
+
hidden_dim: int = 768,
|
| 329 |
+
batch_size: int = 8,
|
| 330 |
+
learning_rate: float = 2e-5,
|
| 331 |
+
dropout: float = 0.2,
|
| 332 |
+
max_epochs: int = 50,
|
| 333 |
+
smiles_emb_dim: int = 224,
|
| 334 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
| 335 |
+
smote_k_neighbors:int = 5,
|
| 336 |
+
use_smote: bool = True,
|
| 337 |
+
apply_scaling: bool = False,
|
| 338 |
+
active_label: str = 'Active',
|
| 339 |
+
fast_dev_run: bool = False,
|
| 340 |
+
use_logger: bool = True,
|
| 341 |
+
logger_name: str = 'protac',
|
| 342 |
+
disabled_embeddings: List[str] = [],
|
| 343 |
+
) -> tuple:
|
| 344 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
protein2embedding (dict): Dictionary of protein embeddings.
|
| 348 |
+
cell2embedding (dict): Dictionary of cell line embeddings.
|
| 349 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint.
|
| 350 |
+
train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
| 351 |
+
val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
| 352 |
+
test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
| 353 |
+
hidden_dim (int): The hidden dimension of the model.
|
| 354 |
+
batch_size (int): The batch size.
|
| 355 |
+
learning_rate (float): The learning rate.
|
| 356 |
+
max_epochs (int): The maximum number of epochs.
|
| 357 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
| 358 |
+
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
| 359 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 360 |
+
disabled_embeddings (list): The list of disabled embeddings.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
tuple: The trained model, the trainer, and the metrics.
|
| 364 |
+
"""
|
| 365 |
+
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
| 366 |
+
train_ds = PROTAC_Dataset(
|
| 367 |
+
train_df,
|
| 368 |
+
protein2embedding,
|
| 369 |
+
cell2embedding,
|
| 370 |
+
smiles2fp,
|
| 371 |
+
use_smote=use_smote,
|
| 372 |
+
oversampler=oversampler if use_smote else None,
|
| 373 |
+
active_label=active_label,
|
| 374 |
+
)
|
| 375 |
+
val_ds = PROTAC_Dataset(
|
| 376 |
+
val_df,
|
| 377 |
+
protein2embedding,
|
| 378 |
+
cell2embedding,
|
| 379 |
+
smiles2fp,
|
| 380 |
+
active_label=active_label,
|
| 381 |
+
)
|
| 382 |
+
if test_df is not None:
|
| 383 |
+
test_ds = PROTAC_Dataset(
|
| 384 |
+
test_df,
|
| 385 |
+
protein2embedding,
|
| 386 |
+
cell2embedding,
|
| 387 |
+
smiles2fp,
|
| 388 |
+
active_label=active_label,
|
| 389 |
+
)
|
| 390 |
+
logger = pl.loggers.TensorBoardLogger(
|
| 391 |
+
save_dir='../logs',
|
| 392 |
+
name=logger_name,
|
| 393 |
+
)
|
| 394 |
+
callbacks = [
|
| 395 |
+
pl.callbacks.EarlyStopping(
|
| 396 |
+
monitor='train_loss',
|
| 397 |
+
patience=10,
|
| 398 |
+
mode='min',
|
| 399 |
+
verbose=False,
|
| 400 |
+
),
|
| 401 |
+
pl.callbacks.EarlyStopping(
|
| 402 |
+
monitor='val_loss',
|
| 403 |
+
patience=5,
|
| 404 |
+
mode='min',
|
| 405 |
+
verbose=False,
|
| 406 |
+
),
|
| 407 |
+
pl.callbacks.EarlyStopping(
|
| 408 |
+
monitor='val_acc',
|
| 409 |
+
patience=10,
|
| 410 |
+
mode='max',
|
| 411 |
+
verbose=False,
|
| 412 |
+
),
|
| 413 |
+
# pl.callbacks.ModelCheckpoint(
|
| 414 |
+
# monitor='val_acc',
|
| 415 |
+
# mode='max',
|
| 416 |
+
# verbose=True,
|
| 417 |
+
# filename='{epoch}-{val_metrics_opt_score:.4f}',
|
| 418 |
+
# ),
|
| 419 |
+
]
|
| 420 |
+
# Define Trainer
|
| 421 |
+
trainer = pl.Trainer(
|
| 422 |
+
logger=logger if use_logger else False,
|
| 423 |
+
callbacks=callbacks,
|
| 424 |
+
max_epochs=max_epochs,
|
| 425 |
+
fast_dev_run=fast_dev_run,
|
| 426 |
+
enable_model_summary=False,
|
| 427 |
+
enable_checkpointing=False,
|
| 428 |
+
enable_progress_bar=False,
|
| 429 |
+
devices=1,
|
| 430 |
+
num_nodes=1,
|
| 431 |
+
)
|
| 432 |
+
model = PROTAC_Model(
|
| 433 |
+
hidden_dim=hidden_dim,
|
| 434 |
+
smiles_emb_dim=smiles_emb_dim,
|
| 435 |
+
poi_emb_dim=1024,
|
| 436 |
+
e3_emb_dim=1024,
|
| 437 |
+
cell_emb_dim=768,
|
| 438 |
+
batch_size=batch_size,
|
| 439 |
+
join_embeddings=join_embeddings,
|
| 440 |
+
dropout=dropout,
|
| 441 |
+
learning_rate=learning_rate,
|
| 442 |
+
apply_scaling=apply_scaling,
|
| 443 |
+
train_dataset=train_ds,
|
| 444 |
+
val_dataset=val_ds,
|
| 445 |
+
test_dataset=test_ds if test_df is not None else None,
|
| 446 |
+
disabled_embeddings=disabled_embeddings,
|
| 447 |
+
)
|
| 448 |
+
with warnings.catch_warnings():
|
| 449 |
+
warnings.simplefilter("ignore")
|
| 450 |
+
trainer.fit(model)
|
| 451 |
+
metrics = trainer.validate(model, verbose=False)[0]
|
| 452 |
+
if test_df is not None:
|
| 453 |
+
test_metrics = trainer.test(model, verbose=False)[0]
|
| 454 |
+
metrics.update(test_metrics)
|
| 455 |
+
return model, trainer, metrics
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def load_model(
|
| 459 |
+
ckpt_path: str,
|
| 460 |
+
) -> PROTAC_Model:
|
| 461 |
+
""" Load a PROTAC model from a checkpoint.
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
ckpt_path (str): The path to the checkpoint.
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
PROTAC_Model: The loaded model.
|
| 468 |
+
"""
|
| 469 |
+
model = PROTAC_Model.load_from_checkpoint(ckpt_path)
|
| 470 |
+
model.eval()
|
| 471 |
+
return model
|
protac_degradation_predictor/sklearn_models.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
| 2 |
+
|
| 3 |
+
from protac_dataset import PROTAC_Dataset
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from sklearn.base import ClassifierMixin
|
| 7 |
+
from sklearn.ensemble import (
|
| 8 |
+
RandomForestClassifier,
|
| 9 |
+
GradientBoostingClassifier,
|
| 10 |
+
)
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.svm import SVC
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torchmetrics import (
|
| 17 |
+
Accuracy,
|
| 18 |
+
AUROC,
|
| 19 |
+
Precision,
|
| 20 |
+
Recall,
|
| 21 |
+
F1Score,
|
| 22 |
+
MetricCollection,
|
| 23 |
+
)
|
| 24 |
+
import optuna
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def train_sklearn_model(
|
| 28 |
+
clf: ClassifierMixin,
|
| 29 |
+
protein2embedding: Dict,
|
| 30 |
+
cell2embedding: Dict,
|
| 31 |
+
smiles2fp: Dict,
|
| 32 |
+
train_df: pd.DataFrame,
|
| 33 |
+
val_df: pd.DataFrame,
|
| 34 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 35 |
+
active_label: str = 'Active',
|
| 36 |
+
use_single_scaler: bool = True,
|
| 37 |
+
) -> Tuple[ClassifierMixin, Dict]:
|
| 38 |
+
""" Train a classifier model on train and val sets and evaluate it on a test set.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
clf: The classifier model to train and evaluate.
|
| 42 |
+
train_df (pd.DataFrame): The training set.
|
| 43 |
+
val_df (pd.DataFrame): The validation set.
|
| 44 |
+
test_df (Optional[pd.DataFrame]): The test set.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
|
| 48 |
+
"""
|
| 49 |
+
# Initialize the datasets
|
| 50 |
+
train_ds = PROTAC_Dataset(
|
| 51 |
+
train_df,
|
| 52 |
+
protein2embedding,
|
| 53 |
+
cell2embedding,
|
| 54 |
+
smiles2fp,
|
| 55 |
+
active_label=active_label,
|
| 56 |
+
use_smote=False,
|
| 57 |
+
)
|
| 58 |
+
scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
|
| 59 |
+
train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
| 60 |
+
val_ds = PROTAC_Dataset(
|
| 61 |
+
val_df,
|
| 62 |
+
protein2embedding,
|
| 63 |
+
cell2embedding,
|
| 64 |
+
smiles2fp,
|
| 65 |
+
active_label=active_label,
|
| 66 |
+
use_smote=False,
|
| 67 |
+
)
|
| 68 |
+
val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
| 69 |
+
if test_df is not None:
|
| 70 |
+
test_ds = PROTAC_Dataset(
|
| 71 |
+
test_df,
|
| 72 |
+
protein2embedding,
|
| 73 |
+
cell2embedding,
|
| 74 |
+
smiles2fp,
|
| 75 |
+
active_label=active_label,
|
| 76 |
+
use_smote=False,
|
| 77 |
+
)
|
| 78 |
+
test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
| 79 |
+
|
| 80 |
+
# Get the numpy arrays
|
| 81 |
+
X_train, y_train = train_ds.get_numpy_arrays()
|
| 82 |
+
X_val, y_val = val_ds.get_numpy_arrays()
|
| 83 |
+
if test_df is not None:
|
| 84 |
+
X_test, y_test = test_ds.get_numpy_arrays()
|
| 85 |
+
|
| 86 |
+
# Train the model
|
| 87 |
+
clf.fit(X_train, y_train)
|
| 88 |
+
# Define the metrics as a module dict
|
| 89 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
| 90 |
+
metrics = nn.ModuleDict({s: MetricCollection({
|
| 91 |
+
'acc': Accuracy(task='binary'),
|
| 92 |
+
'roc_auc': AUROC(task='binary'),
|
| 93 |
+
'precision': Precision(task='binary'),
|
| 94 |
+
'recall': Recall(task='binary'),
|
| 95 |
+
'f1_score': F1Score(task='binary'),
|
| 96 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
| 97 |
+
'hp_metric': Accuracy(task='binary'),
|
| 98 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
| 99 |
+
|
| 100 |
+
# Get the predictions
|
| 101 |
+
metrics_out = {}
|
| 102 |
+
|
| 103 |
+
y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
|
| 104 |
+
y_true = torch.tensor(y_train)
|
| 105 |
+
metrics['train_metrics'].update(y_pred, y_true)
|
| 106 |
+
metrics_out.update(metrics['train_metrics'].compute())
|
| 107 |
+
|
| 108 |
+
y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
|
| 109 |
+
y_true = torch.tensor(y_val)
|
| 110 |
+
metrics['val_metrics'].update(y_pred, y_true)
|
| 111 |
+
metrics_out.update(metrics['val_metrics'].compute())
|
| 112 |
+
|
| 113 |
+
if test_df is not None:
|
| 114 |
+
y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
|
| 115 |
+
y_true = torch.tensor(y_test)
|
| 116 |
+
metrics['test_metrics'].update(y_pred, y_true)
|
| 117 |
+
metrics_out.update(metrics['test_metrics'].compute())
|
| 118 |
+
|
| 119 |
+
return clf, metrics_out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def suggest_random_forest(
|
| 123 |
+
trial: optuna.Trial,
|
| 124 |
+
) -> ClassifierMixin:
|
| 125 |
+
""" Suggest hyperparameters for a Random Forest classifier.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
ClassifierMixin: The Random Forest classifier with the suggested hyperparameters.
|
| 132 |
+
"""
|
| 133 |
+
n_estimators = trial.suggest_int('model_n_estimators', 10, 1000)
|
| 134 |
+
max_depth = trial.suggest_int('model_max_depth', 2, 100)
|
| 135 |
+
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
|
| 136 |
+
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
|
| 137 |
+
max_features = trial.suggest_categorical('model_max_features', [None, 'sqrt', 'log2'])
|
| 138 |
+
criterion = trial.suggest_categorical('model_criterion', ['gini', 'entropy'])
|
| 139 |
+
|
| 140 |
+
clf = RandomForestClassifier(
|
| 141 |
+
n_estimators=n_estimators,
|
| 142 |
+
max_depth=max_depth,
|
| 143 |
+
min_samples_split=min_samples_split,
|
| 144 |
+
min_samples_leaf=min_samples_leaf,
|
| 145 |
+
max_features=max_features,
|
| 146 |
+
criterion=criterion,
|
| 147 |
+
random_state=42,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return clf
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def suggest_logistic_regression(
|
| 154 |
+
trial: optuna.Trial,
|
| 155 |
+
) -> ClassifierMixin:
|
| 156 |
+
""" Suggest hyperparameters for a Logistic Regression classifier.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
ClassifierMixin: The Logistic Regression classifier with the suggested hyperparameters.
|
| 163 |
+
"""
|
| 164 |
+
# Suggest values for the logistic regression hyperparameters
|
| 165 |
+
C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
|
| 166 |
+
penalty = trial.suggest_categorical('model_penalty', ['l1', 'l2', 'elasticnet', None])
|
| 167 |
+
solver = trial.suggest_categorical('model_solver', ['newton-cholesky', 'lbfgs', 'liblinear', 'sag', 'saga'])
|
| 168 |
+
|
| 169 |
+
# Check solver compatibility
|
| 170 |
+
if penalty == 'l1' and solver not in ['liblinear', 'saga']:
|
| 171 |
+
raise optuna.exceptions.TrialPruned()
|
| 172 |
+
if penalty == None and solver not in ['newton-cholesky', 'lbfgs', 'sag']:
|
| 173 |
+
raise optuna.exceptions.TrialPruned()
|
| 174 |
+
|
| 175 |
+
# Configure the classifier with the trial's suggested parameters
|
| 176 |
+
clf = LogisticRegression(
|
| 177 |
+
C=C,
|
| 178 |
+
penalty=penalty,
|
| 179 |
+
solver=solver,
|
| 180 |
+
max_iter=1000,
|
| 181 |
+
random_state=42,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return clf
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def suggest_svc(
|
| 188 |
+
trial: optuna.Trial,
|
| 189 |
+
) -> ClassifierMixin:
|
| 190 |
+
""" Suggest hyperparameters for an SVC classifier.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
ClassifierMixin: The SVC classifier with the suggested hyperparameters.
|
| 197 |
+
"""
|
| 198 |
+
C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
|
| 199 |
+
kernel = trial.suggest_categorical('model_kernel', ['linear', 'poly', 'rbf', 'sigmoid'])
|
| 200 |
+
gamma = trial.suggest_categorical('model_gamma', ['scale', 'auto'])
|
| 201 |
+
degree = trial.suggest_int('model_degree', 2, 5) if kernel == 'poly' else 3
|
| 202 |
+
|
| 203 |
+
clf = SVC(
|
| 204 |
+
C=C,
|
| 205 |
+
kernel=kernel,
|
| 206 |
+
gamma=gamma,
|
| 207 |
+
degree=degree,
|
| 208 |
+
probability=True,
|
| 209 |
+
random_state=42,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return clf
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def suggest_gradient_boosting(
|
| 216 |
+
trial: optuna.Trial,
|
| 217 |
+
) -> ClassifierMixin:
|
| 218 |
+
""" Suggest hyperparameters for a Gradient Boosting classifier.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
ClassifierMixin: The Gradient Boosting classifier with the suggested hyperparameters.
|
| 225 |
+
"""
|
| 226 |
+
n_estimators = trial.suggest_int('model_n_estimators', 50, 500)
|
| 227 |
+
learning_rate = trial.suggest_loguniform('model_learning_rate', 0.01, 1)
|
| 228 |
+
max_depth = trial.suggest_int('model_max_depth', 3, 10)
|
| 229 |
+
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
|
| 230 |
+
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
|
| 231 |
+
max_features = trial.suggest_categorical('model_max_features', ['sqrt', 'log2', None])
|
| 232 |
+
|
| 233 |
+
clf = GradientBoostingClassifier(
|
| 234 |
+
n_estimators=n_estimators,
|
| 235 |
+
learning_rate=learning_rate,
|
| 236 |
+
max_depth=max_depth,
|
| 237 |
+
min_samples_split=min_samples_split,
|
| 238 |
+
min_samples_leaf=min_samples_leaf,
|
| 239 |
+
max_features=max_features,
|
| 240 |
+
random_state=42,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return clf
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45164
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d2f036ce141fbeb81930cc9ce49dbd6effc76221b26b92ae0498af1c34289f3
|
| 3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45164
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8ce36b5f52f8f88105c3ec0c5b60f865e1b054aff8f9e96c21f1e037eaa65af
|
| 3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45164
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d22f4f54d46ca72b8585645fdfac43683a23dcc00d80fb8bd1f785d4eb4a9594
|
| 3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f7298a04d0888f4de87041efd6b78c42e13c3f1630c43567d582bc7710a40847
|
| 3 |
-
size 45164
|
|
|
|
|
|
|
|
|
|
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45164
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54f869b328af6667567bd5cc805ce63fc5434ed1b77afc1e66d95b8f02e40642
|
| 3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d2a7a9f6ed11e1b5b6f876dd927612d03f4780f9db3e65b9f1ebb8fbd853677f
|
| 3 |
-
size 45164
|
|
|
|
|
|
|
|
|
|
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45164
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27f0e76e7f89950199c843699c000eaad8628441c84aec394e20c23f701b1609
|
| 3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:fd4c40033da16a1cee16fd998c82e0403a31db4d14ba0604160ea143bae03668
|
| 3 |
-
size 45164
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import setuptools
|
| 2 |
+
|
| 3 |
+
setuptools.setup(
|
| 4 |
+
name="protac_degradation_predictor",
|
| 5 |
+
version="0.0.1",
|
| 6 |
+
author="Stefano Ribes",
|
| 7 |
+
url="https://github.com/ribesstefano/PROTAC-Degradation-Predictor",
|
| 8 |
+
author_email="[email protected]",
|
| 9 |
+
description="A package to predict PROTAC-induced protein degradation.",
|
| 10 |
+
long_description=open("README.md").read(),
|
| 11 |
+
packages=setuptools.find_packages(),
|
| 12 |
+
install_requires=["torch", "pytorch_lightning", "sklearn", "imblearn", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
|
| 13 |
+
classifiers=[
|
| 14 |
+
"Programming Language :: Python :: 3",
|
| 15 |
+
"Programming Language :: Python :: 3.6",
|
| 16 |
+
"License :: OSI Approved :: MIT License",
|
| 17 |
+
"Operating System :: OS Independent",
|
| 18 |
+
],
|
| 19 |
+
include_package_data=True,
|
| 20 |
+
package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv"]},
|
| 21 |
+
)
|