Commit
·
33f1644
1
Parent(s):
4bf0ec2
Added XGBoost Optuna training + Added Ablation studies with zeroed input vectors
Browse files- protac_degradation_predictor/__init__.py +3 -1
- protac_degradation_predictor/optuna_utils.py +12 -0
- protac_degradation_predictor/optuna_utils_xgboost.py +323 -0
- protac_degradation_predictor/protac_dataset.py +31 -4
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random.csv +29 -0
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_tanimoto.csv +29 -0
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_uniprot.csv +29 -0
- src/run_xgboost_experiments.py +329 -0
protac_degradation_predictor/__init__.py
CHANGED
|
@@ -17,7 +17,9 @@ from .sklearn_models import (
|
|
| 17 |
)
|
| 18 |
from .optuna_utils import (
|
| 19 |
hyperparameter_tuning_and_training,
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
)
|
| 22 |
from .protac_degradation_predictor import (
|
| 23 |
get_protac_active_proba,
|
|
|
|
| 17 |
)
|
| 18 |
from .optuna_utils import (
|
| 19 |
hyperparameter_tuning_and_training,
|
| 20 |
+
)
|
| 21 |
+
from .optuna_utils_xgboost import (
|
| 22 |
+
xgboost_hyperparameter_tuning_and_training,
|
| 23 |
)
|
| 24 |
from .protac_degradation_predictor import (
|
| 25 |
get_protac_active_proba,
|
protac_degradation_predictor/optuna_utils.py
CHANGED
|
@@ -234,6 +234,18 @@ def pytorch_model_objective(
|
|
| 234 |
|
| 235 |
# Optuna aims to minimize the pytorch_model_objective
|
| 236 |
return - val_roc_auc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
|
| 239 |
def hyperparameter_tuning_and_training(
|
|
|
|
| 234 |
|
| 235 |
# Optuna aims to minimize the pytorch_model_objective
|
| 236 |
return - val_roc_auc
|
| 237 |
+
# # Get the majority vote for the test predictions
|
| 238 |
+
# if test_df is not None and not fast_dev_run:
|
| 239 |
+
# majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
| 240 |
+
# majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
|
| 241 |
+
# trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
|
| 242 |
+
# logging.info(f'Majority vote metrics: {majority_vote_metrics}')
|
| 243 |
+
|
| 244 |
+
# # Get the average validation accuracy and ROC AUC accross the folds
|
| 245 |
+
# val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
| 246 |
+
|
| 247 |
+
# # Optuna aims to minimize the pytorch_model_objective
|
| 248 |
+
# return - val_roc_auc
|
| 249 |
|
| 250 |
|
| 251 |
def hyperparameter_tuning_and_training(
|
protac_degradation_predictor/optuna_utils_xgboost.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from .optuna_utils import get_majority_vote_metrics, get_dataframe_stats
|
| 6 |
+
from .protac_dataset import get_datasets
|
| 7 |
+
|
| 8 |
+
import optuna
|
| 9 |
+
import xgboost as xgb
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import numpy as np
|
| 12 |
+
from sklearn.model_selection import StratifiedKFold
|
| 13 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
|
| 14 |
+
import xgboost as xgb
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import numpy as np
|
| 17 |
+
from sklearn.model_selection import StratifiedKFold
|
| 18 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
|
| 19 |
+
import joblib
|
| 20 |
+
from optuna.samplers import TPESampler
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
xgb.set_config(verbosity=0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def train_and_evaluate_xgboost(
|
| 28 |
+
protein2embedding: Dict,
|
| 29 |
+
cell2embedding: Dict,
|
| 30 |
+
smiles2fp: Dict,
|
| 31 |
+
train_df: pd.DataFrame,
|
| 32 |
+
val_df: pd.DataFrame,
|
| 33 |
+
params: dict,
|
| 34 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 35 |
+
active_label: str = 'Active',
|
| 36 |
+
num_boost_round: int = 100,
|
| 37 |
+
shuffle_train_data: bool = False,
|
| 38 |
+
) -> tuple:
|
| 39 |
+
"""
|
| 40 |
+
Train and evaluate an XGBoost model with the given parameters.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
train_df (pd.DataFrame): The training and validation data.
|
| 44 |
+
test_df (pd.DataFrame): The test data.
|
| 45 |
+
params (dict): Hyperparameters for the XGBoost model.
|
| 46 |
+
active_label (str): The active label column.
|
| 47 |
+
num_boost_round (int): Maximum number of epochs.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
tuple: The trained model, test predictions, and metrics.
|
| 51 |
+
"""
|
| 52 |
+
# Get datasets and their numpy arrays
|
| 53 |
+
train_ds, val_ds, test_ds = get_datasets(
|
| 54 |
+
protein2embedding=protein2embedding,
|
| 55 |
+
cell2embedding=cell2embedding,
|
| 56 |
+
smiles2fp=smiles2fp,
|
| 57 |
+
train_df=train_df,
|
| 58 |
+
val_df=val_df,
|
| 59 |
+
test_df=test_df,
|
| 60 |
+
disabled_embeddings=[],
|
| 61 |
+
active_label=active_label,
|
| 62 |
+
apply_scaling=False,
|
| 63 |
+
)
|
| 64 |
+
X_train, y_train = train_ds.get_numpy_arrays()
|
| 65 |
+
X_val, y_val = val_ds.get_numpy_arrays()
|
| 66 |
+
|
| 67 |
+
# Shuffle the training data
|
| 68 |
+
if shuffle_train_data:
|
| 69 |
+
idx = np.random.permutation(len(X_train))
|
| 70 |
+
X_train, y_train = X_train[idx], y_train[idx]
|
| 71 |
+
|
| 72 |
+
# Setup training and validation data in XGBoost data format
|
| 73 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 74 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
| 75 |
+
evallist = [(dval, 'eval'), (dtrain, 'train')]
|
| 76 |
+
|
| 77 |
+
# Setup test data
|
| 78 |
+
if test_df is not None:
|
| 79 |
+
X_test, y_test = test_ds.get_numpy_arrays()
|
| 80 |
+
dtest = xgb.DMatrix(X_test, label=y_test)
|
| 81 |
+
evallist.append((dtest, 'test'))
|
| 82 |
+
|
| 83 |
+
model = xgb.train(
|
| 84 |
+
params,
|
| 85 |
+
dtrain,
|
| 86 |
+
num_boost_round=num_boost_round,
|
| 87 |
+
evals=evallist,
|
| 88 |
+
early_stopping_rounds=10,
|
| 89 |
+
verbose_eval=False,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Evaluate model
|
| 93 |
+
val_pred = model.predict(dval)
|
| 94 |
+
val_pred_binary = (val_pred > 0.5).astype(int)
|
| 95 |
+
metrics = {
|
| 96 |
+
'val_accuracy': accuracy_score(y_val, val_pred_binary),
|
| 97 |
+
'val_roc_auc': roc_auc_score(y_val, val_pred),
|
| 98 |
+
'val_precision': precision_score(y_val, val_pred_binary),
|
| 99 |
+
'val_recall': recall_score(y_val, val_pred_binary),
|
| 100 |
+
'val_f1_score': f1_score(y_val, val_pred_binary),
|
| 101 |
+
}
|
| 102 |
+
preds = {'val_pred': val_pred}
|
| 103 |
+
|
| 104 |
+
if test_df is not None:
|
| 105 |
+
test_pred = model.predict(dtest)
|
| 106 |
+
test_pred_binary = (test_pred > 0.5).astype(int)
|
| 107 |
+
metrics.update({
|
| 108 |
+
'test_accuracy': accuracy_score(y_test, test_pred_binary),
|
| 109 |
+
'test_roc_auc': roc_auc_score(y_test, test_pred),
|
| 110 |
+
'test_precision': precision_score(y_test, test_pred_binary),
|
| 111 |
+
'test_recall': recall_score(y_test, test_pred_binary),
|
| 112 |
+
'test_f1_score': f1_score(y_test, test_pred_binary),
|
| 113 |
+
})
|
| 114 |
+
preds.update({'test_pred': test_pred})
|
| 115 |
+
|
| 116 |
+
return model, preds, metrics
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def xgboost_model_objective(
|
| 120 |
+
trial: optuna.Trial,
|
| 121 |
+
protein2embedding: Dict,
|
| 122 |
+
cell2embedding: Dict,
|
| 123 |
+
smiles2fp: Dict,
|
| 124 |
+
train_val_df: pd.DataFrame,
|
| 125 |
+
kf: StratifiedKFold,
|
| 126 |
+
groups: Optional[np.array] = None,
|
| 127 |
+
active_label: str = 'Active',
|
| 128 |
+
num_boost_round: int = 100,
|
| 129 |
+
) -> float:
|
| 130 |
+
""" Objective function for hyperparameter optimization with XGBoost.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
trial (optuna.Trial): The Optuna trial object.
|
| 134 |
+
train_val_df (pd.DataFrame): The training and validation data.
|
| 135 |
+
kf (StratifiedKFold): Stratified K-Folds cross-validator.
|
| 136 |
+
test_df (Optional[pd.DataFrame]): The test data.
|
| 137 |
+
active_label (str): The active label column.
|
| 138 |
+
num_boost_round (int): Maximum number of epochs.
|
| 139 |
+
use_logger (bool): Whether to use logging.
|
| 140 |
+
"""
|
| 141 |
+
# Suggest hyperparameters to be used across the CV folds
|
| 142 |
+
params = {
|
| 143 |
+
'booster': 'gbtree',
|
| 144 |
+
'tree_method': 'hist', # if torch.cuda.is_available() else 'hist',
|
| 145 |
+
'objective': 'binary:logistic',
|
| 146 |
+
'eval_metric': 'auc',
|
| 147 |
+
'eta': trial.suggest_float('eta', 1e-4, 1e-1, log=True),
|
| 148 |
+
'max_depth': trial.suggest_int('max_depth', 3, 10),
|
| 149 |
+
'min_child_weight': trial.suggest_float('min_child_weight', 1e-3, 10.0, log=True),
|
| 150 |
+
'gamma': trial.suggest_float('gamma', 1e-4, 1e-1, log=True),
|
| 151 |
+
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
|
| 152 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
X = train_val_df.copy().drop(columns=active_label)
|
| 156 |
+
y = train_val_df[active_label].tolist()
|
| 157 |
+
report = []
|
| 158 |
+
val_preds = []
|
| 159 |
+
|
| 160 |
+
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
| 161 |
+
logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
|
| 162 |
+
train_df = train_val_df.iloc[train_index]
|
| 163 |
+
val_df = train_val_df.iloc[val_index]
|
| 164 |
+
|
| 165 |
+
# Get some statistics from the dataframes
|
| 166 |
+
stats = {
|
| 167 |
+
'model_type': 'XGBoost',
|
| 168 |
+
'fold': k,
|
| 169 |
+
'train_len': len(train_df),
|
| 170 |
+
'val_len': len(val_df),
|
| 171 |
+
'train_perc': len(train_df) / len(train_val_df),
|
| 172 |
+
'val_perc': len(val_df) / len(train_val_df),
|
| 173 |
+
}
|
| 174 |
+
stats.update(get_dataframe_stats(train_df, val_df, active_label=active_label))
|
| 175 |
+
if groups is not None:
|
| 176 |
+
stats['train_unique_groups'] = len(np.unique(groups[train_index]))
|
| 177 |
+
stats['val_unique_groups'] = len(np.unique(groups[val_index]))
|
| 178 |
+
|
| 179 |
+
_, preds, metrics = train_and_evaluate_xgboost(
|
| 180 |
+
protein2embedding=protein2embedding,
|
| 181 |
+
cell2embedding=cell2embedding,
|
| 182 |
+
smiles2fp=smiles2fp,
|
| 183 |
+
train_df=train_df,
|
| 184 |
+
val_df=val_df,
|
| 185 |
+
params=params,
|
| 186 |
+
active_label=active_label,
|
| 187 |
+
num_boost_round=num_boost_round,
|
| 188 |
+
)
|
| 189 |
+
stats.update(metrics)
|
| 190 |
+
report.append(stats.copy())
|
| 191 |
+
val_preds.append(preds['val_pred'])
|
| 192 |
+
|
| 193 |
+
# Save the report in the trial
|
| 194 |
+
trial.set_user_attr('report', report)
|
| 195 |
+
trial.set_user_attr('val_preds', val_preds)
|
| 196 |
+
trial.set_user_attr('params', params)
|
| 197 |
+
|
| 198 |
+
# Get the average validation metrics across the folds
|
| 199 |
+
mean_val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
| 200 |
+
logging.info(f'\tMean val ROC AUC: {mean_val_roc_auc:.4f}')
|
| 201 |
+
|
| 202 |
+
# Optuna aims to minimize the objective, so return the negative ROC AUC
|
| 203 |
+
return -mean_val_roc_auc
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def xgboost_hyperparameter_tuning_and_training(
|
| 207 |
+
protein2embedding: Dict,
|
| 208 |
+
cell2embedding: Dict,
|
| 209 |
+
smiles2fp: Dict,
|
| 210 |
+
train_val_df: pd.DataFrame,
|
| 211 |
+
test_df: pd.DataFrame,
|
| 212 |
+
kf: StratifiedKFold,
|
| 213 |
+
groups: Optional[np.array] = None,
|
| 214 |
+
split_type: str = 'random',
|
| 215 |
+
n_models_for_test: int = 3,
|
| 216 |
+
n_trials: int = 50,
|
| 217 |
+
active_label: str = 'Active',
|
| 218 |
+
num_boost_round: int = 100,
|
| 219 |
+
study_filename: Optional[str] = None,
|
| 220 |
+
force_study: bool = False,
|
| 221 |
+
) -> dict:
|
| 222 |
+
""" Hyperparameter tuning and training of an XGBoost model.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
train_val_df (pd.DataFrame): The training and validation data.
|
| 226 |
+
test_df (pd.DataFrame): The test data.
|
| 227 |
+
kf (StratifiedKFold): Stratified K-Folds cross-validator.
|
| 228 |
+
groups (Optional[np.array]): Group labels for the samples used while splitting the dataset into train/test set.
|
| 229 |
+
split_type (str): Type of the data split.
|
| 230 |
+
n_models_for_test (int): Number of models to train for testing.
|
| 231 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 232 |
+
n_trials (int): Number of trials for hyperparameter optimization.
|
| 233 |
+
logger_save_dir (str): Directory to save logs.
|
| 234 |
+
logger_name (str): Name of the logger.
|
| 235 |
+
active_label (str): The active label column.
|
| 236 |
+
num_boost_round (int): Maximum number of epochs.
|
| 237 |
+
study_filename (Optional[str]): File name to save/load the Optuna study.
|
| 238 |
+
force_study (bool): Whether to force the study optimization even if the study file exists.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
dict: A dictionary containing reports from the CV and test.
|
| 242 |
+
"""
|
| 243 |
+
# Set the verbosity of Optuna
|
| 244 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 245 |
+
|
| 246 |
+
# Create an Optuna study object
|
| 247 |
+
sampler = TPESampler(seed=42)
|
| 248 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
| 249 |
+
|
| 250 |
+
study_loaded = False
|
| 251 |
+
if study_filename and not force_study:
|
| 252 |
+
if os.path.exists(study_filename):
|
| 253 |
+
study = joblib.load(study_filename)
|
| 254 |
+
study_loaded = True
|
| 255 |
+
logging.info(f'Loaded study from {study_filename}')
|
| 256 |
+
|
| 257 |
+
if not study_loaded or force_study:
|
| 258 |
+
study.optimize(
|
| 259 |
+
lambda trial: xgboost_model_objective(
|
| 260 |
+
trial=trial,
|
| 261 |
+
protein2embedding=protein2embedding,
|
| 262 |
+
cell2embedding=cell2embedding,
|
| 263 |
+
smiles2fp=smiles2fp,
|
| 264 |
+
train_val_df=train_val_df,
|
| 265 |
+
kf=kf,
|
| 266 |
+
groups=groups,
|
| 267 |
+
active_label=active_label,
|
| 268 |
+
num_boost_round=num_boost_round,
|
| 269 |
+
),
|
| 270 |
+
n_trials=n_trials,
|
| 271 |
+
)
|
| 272 |
+
if study_filename:
|
| 273 |
+
joblib.dump(study, study_filename)
|
| 274 |
+
|
| 275 |
+
cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
|
| 276 |
+
hparam_report = pd.DataFrame([study.best_params])
|
| 277 |
+
|
| 278 |
+
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
| 279 |
+
best_models = []
|
| 280 |
+
test_report = []
|
| 281 |
+
test_preds = []
|
| 282 |
+
for i in range(n_models_for_test):
|
| 283 |
+
logging.info(f'Training best model {i + 1}/{n_models_for_test}')
|
| 284 |
+
model, preds, metrics = train_and_evaluate_xgboost(
|
| 285 |
+
protein2embedding=protein2embedding,
|
| 286 |
+
cell2embedding=cell2embedding,
|
| 287 |
+
smiles2fp=smiles2fp,
|
| 288 |
+
train_df=train_val_df,
|
| 289 |
+
val_df=test_df,
|
| 290 |
+
params=study.best_trial.user_attrs['params'],
|
| 291 |
+
active_label=active_label,
|
| 292 |
+
num_boost_round=num_boost_round,
|
| 293 |
+
shuffle_train_data=True,
|
| 294 |
+
)
|
| 295 |
+
metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
|
| 296 |
+
metrics['model_type'] = 'XGBoost'
|
| 297 |
+
metrics['test_model_id'] = i
|
| 298 |
+
metrics.update(get_dataframe_stats(
|
| 299 |
+
train_val_df,
|
| 300 |
+
test_df=test_df,
|
| 301 |
+
active_label=active_label,
|
| 302 |
+
))
|
| 303 |
+
test_report.append(metrics.copy())
|
| 304 |
+
test_preds.append(torch.tensor(preds['val_pred']))
|
| 305 |
+
best_models.append(model)
|
| 306 |
+
test_report = pd.DataFrame(test_report)
|
| 307 |
+
|
| 308 |
+
# Get the majority vote for the test predictions
|
| 309 |
+
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
| 310 |
+
majority_vote_report = pd.DataFrame([majority_vote_metrics])
|
| 311 |
+
majority_vote_report['model_type'] = 'XGBoost'
|
| 312 |
+
|
| 313 |
+
# Add a column with the split_type to all reports
|
| 314 |
+
for report in [cv_report, hparam_report, test_report, majority_vote_report]:
|
| 315 |
+
report['split_type'] = split_type
|
| 316 |
+
|
| 317 |
+
# Return the reports
|
| 318 |
+
return {
|
| 319 |
+
'cv_report': cv_report,
|
| 320 |
+
'hparam_report': hparam_report,
|
| 321 |
+
'test_report': test_report,
|
| 322 |
+
'majority_vote_report' :majority_vote_report,
|
| 323 |
+
}
|
protac_degradation_predictor/protac_dataset.py
CHANGED
|
@@ -42,7 +42,11 @@ class PROTAC_Dataset(Dataset):
|
|
| 42 |
cell2embedding (dict): Dictionary of cell line embeddings
|
| 43 |
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
| 44 |
use_smote (bool): Whether to use SMOTE for oversampling
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
"""
|
| 47 |
# Filter out examples with NaN in active_label column
|
| 48 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
@@ -124,7 +128,7 @@ class PROTAC_Dataset(Dataset):
|
|
| 124 |
self.data = df_smote
|
| 125 |
|
| 126 |
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
| 127 |
-
""" Fit the scalers for the data.
|
| 128 |
|
| 129 |
Args:
|
| 130 |
use_single_scaler (bool): Whether to use a single scaler for all features.
|
|
@@ -288,8 +292,25 @@ def get_datasets(
|
|
| 288 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 289 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 290 |
use_single_scaler: Optional[bool] = None,
|
|
|
|
| 291 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
| 292 |
-
""" Get the datasets for training the PROTAC model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
| 294 |
train_ds = PROTAC_Dataset(
|
| 295 |
train_df,
|
|
@@ -313,6 +334,10 @@ def get_datasets(
|
|
| 313 |
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
| 314 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
if test_df is not None:
|
| 317 |
test_ds = PROTAC_Dataset(
|
| 318 |
test_df,
|
|
@@ -321,9 +346,11 @@ def get_datasets(
|
|
| 321 |
smiles2fp,
|
| 322 |
active_label=active_label,
|
| 323 |
disabled_embeddings=disabled_embeddings,
|
| 324 |
-
scaler=
|
| 325 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 326 |
)
|
|
|
|
|
|
|
| 327 |
else:
|
| 328 |
test_ds = None
|
| 329 |
return train_ds, val_ds, test_ds
|
|
|
|
| 42 |
cell2embedding (dict): Dictionary of cell line embeddings
|
| 43 |
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
| 44 |
use_smote (bool): Whether to use SMOTE for oversampling
|
| 45 |
+
oversampler (SMOTE | ADASYN): The oversampler to use
|
| 46 |
+
active_label (str): The column containing the active/inactive information
|
| 47 |
+
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
| 48 |
+
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
| 49 |
+
use_single_scaler (bool): Whether to use a single scaler for all features
|
| 50 |
"""
|
| 51 |
# Filter out examples with NaN in active_label column
|
| 52 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
|
|
| 128 |
self.data = df_smote
|
| 129 |
|
| 130 |
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
| 131 |
+
""" Fit the scalers for the data and save them in the dataset class.
|
| 132 |
|
| 133 |
Args:
|
| 134 |
use_single_scaler (bool): Whether to use a single scaler for all features.
|
|
|
|
| 292 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 293 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 294 |
use_single_scaler: Optional[bool] = None,
|
| 295 |
+
apply_scaling: bool = False,
|
| 296 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
| 297 |
+
""" Get the datasets for training the PROTAC model.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
train_df (pd.DataFrame): The training data.
|
| 301 |
+
val_df (pd.DataFrame): The validation data.
|
| 302 |
+
test_df (pd.DataFrame): The test data.
|
| 303 |
+
protein2embedding (dict): Dictionary of protein embeddings.
|
| 304 |
+
cell2embedding (dict): Dictionary of cell line embeddings.
|
| 305 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint.
|
| 306 |
+
use_smote (bool): Whether to use SMOTE for oversampling.
|
| 307 |
+
smote_k_neighbors (int): The number of neighbors to use for SMOTE.
|
| 308 |
+
active_label (str): The active label column.
|
| 309 |
+
disabled_embeddings (list): The list of embeddings to disable.
|
| 310 |
+
scaler (StandardScaler | dict): The scaler to use for the embeddings.
|
| 311 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
| 312 |
+
apply_scaling (bool): Whether to apply scaling to the data now. Defaults to False (the Pytorch Lightning model does that).
|
| 313 |
+
"""
|
| 314 |
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
| 315 |
train_ds = PROTAC_Dataset(
|
| 316 |
train_df,
|
|
|
|
| 334 |
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
| 335 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 336 |
)
|
| 337 |
+
train_scalers = None
|
| 338 |
+
if apply_scaling:
|
| 339 |
+
train_scalers = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
|
| 340 |
+
val_ds.apply_scaling(train_scalers, use_single_scaler=use_single_scaler)
|
| 341 |
if test_df is not None:
|
| 342 |
test_ds = PROTAC_Dataset(
|
| 343 |
test_df,
|
|
|
|
| 346 |
smiles2fp,
|
| 347 |
active_label=active_label,
|
| 348 |
disabled_embeddings=disabled_embeddings,
|
| 349 |
+
scaler=train_scalers if apply_scaling else scaler,
|
| 350 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 351 |
)
|
| 352 |
+
if apply_scaling:
|
| 353 |
+
test_ds.apply_scaling(train_ds.scaler, use_single_scaler=use_single_scaler)
|
| 354 |
else:
|
| 355 |
test_ds = None
|
| 356 |
return train_ds, val_ds, test_ds
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
| 2 |
+
0.7269228100776672,0.604651153087616,0.6730769276618958,0.546875,0.875,0.7173913717269897,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
| 3 |
+
0.6971672177314758,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6717391014099121,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
| 4 |
+
0.6542536020278931,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7141305208206177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
| 5 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.689673900604248,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3,0.5074626803398132,random
|
| 6 |
+
0.7447491884231567,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.70923912525177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
| 7 |
+
0.7114118933677673,0.604651153087616,0.5405405163764954,0.5882353186607361,0.5,0.6630434989929199,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
| 8 |
+
0.6734361052513123,0.6162790656089783,0.6373626589775085,0.5686274766921997,0.7250000238418579,0.6940217614173889,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
| 9 |
+
,0.5930232405662537,,0.5806451439857483,0.44999998807907104,0.6809782981872559,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi,0.5070422291755676,random
|
| 10 |
+
0.7288045883178711,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.717663049697876,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
| 11 |
+
0.6981603503227234,0.6395348906517029,0.5866666436195374,0.6285714507102966,0.550000011920929,0.6709238886833191,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
| 12 |
+
0.6586534380912781,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
| 13 |
+
,0.6279069781303406,,0.6333333253860474,0.4749999940395355,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled cell,0.5428571701049805,random
|
| 14 |
+
0.7676423788070679,0.4651162922382355,0.6349206566810608,0.4651162922382355,1.0,0.7361413240432739,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
| 15 |
+
0.7521520256996155,0.5348837375640869,0.0,0.0,0.0,0.7638586759567261,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
| 16 |
+
0.7137073278427124,0.5930232405662537,0.2857142984867096,0.7777777910232544,0.17499999701976776,0.727989137172699,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
| 17 |
+
,0.5348837375640869,,0.0,0.0,0.7638587951660156,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled smiles,0.0,random
|
| 18 |
+
0.7207046151161194,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.7160326242446899,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
| 19 |
+
0.6998258829116821,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6720108985900879,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
| 20 |
+
0.6533703207969666,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
| 21 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3 cell,0.5074626803398132,random
|
| 22 |
+
0.7362547516822815,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.710326075553894,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
| 23 |
+
0.7125736474990845,0.6162790656089783,0.5479452013969421,0.6060606241226196,0.5,0.6619565486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
| 24 |
+
0.6676729321479797,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.6945651769638062,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
| 25 |
+
,0.6162790656089783,,0.6206896305084229,0.44999998807907104,0.6836956143379211,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3,0.52173912525177,random
|
| 26 |
+
0.7300900816917419,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.706793487071991,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
| 27 |
+
0.7153109908103943,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6611412763595581,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
| 28 |
+
0.6669936180114746,0.6279069781303406,0.6279069781303406,0.5869565010070801,0.675000011920929,0.6932065486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
| 29 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.6834239363670349,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3 cell,0.5074626803398132,random
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_tanimoto.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
| 2 |
+
0.8296061754226685,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7832207083702087,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
| 3 |
+
0.6474169492721558,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7668918371200562,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
| 4 |
+
0.6295721530914307,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.8141891956329346,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
| 5 |
+
,0.7529411911964417,,0.75,0.6486486196517944,0.8023648858070374,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3,0.695652186870575,tanimoto
|
| 6 |
+
0.8408050537109375,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7691441774368286,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
| 7 |
+
0.6602048277854919,0.5764706134796143,0.6470588445663452,0.5076923370361328,0.8918918967247009,0.7494369745254517,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
| 8 |
+
0.634836733341217,0.7411764860153198,0.6944444179534912,0.7142857313156128,0.6756756901741028,0.7849099636077881,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
| 9 |
+
,0.7411764860153198,,0.7272727489471436,0.6486486196517944,0.7770270109176636,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi,0.6857143044471741,tanimoto
|
| 10 |
+
0.835131824016571,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7736486196517944,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
| 11 |
+
0.6562066674232483,0.5882353186607361,0.6534653306007385,0.515625,0.8918918967247009,0.75,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
| 12 |
+
0.6323299407958984,0.729411780834198,0.6760563254356384,0.7058823704719543,0.6486486196517944,0.8001126050949097,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
| 13 |
+
,0.729411780834198,,0.71875,0.6216216087341309,0.7905405163764954,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled cell,0.6666666865348816,tanimoto
|
| 14 |
+
0.8332716226577759,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.798704981803894,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
| 15 |
+
0.765400767326355,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7919481992721558,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
| 16 |
+
0.6887043118476868,0.4941176474094391,0.632478654384613,0.4625000059604645,1.0,0.8110923171043396,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
| 17 |
+
,0.4941176474094391,,0.4625000059604645,1.0,0.8110923767089844,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled smiles,0.632478654384613,tanimoto
|
| 18 |
+
0.825886070728302,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7787162065505981,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
| 19 |
+
0.6474983096122742,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7567567825317383,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
| 20 |
+
0.6309086680412292,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.8119369149208069,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
| 21 |
+
,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.8006756901741028,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3 cell,0.6764705777168274,tanimoto
|
| 22 |
+
0.8314616680145264,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7697072625160217,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
| 23 |
+
0.651317298412323,0.6117647290229797,0.6666666865348816,0.5322580933570862,0.8918918967247009,0.7488738894462585,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
| 24 |
+
0.633421003818512,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.795045018196106,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
| 25 |
+
,0.7529411911964417,,0.75,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3,0.695652186870575,tanimoto
|
| 26 |
+
0.8277769088745117,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7629504203796387,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
| 27 |
+
0.6514514088630676,0.6000000238418579,0.6530612111091614,0.5245901346206665,0.8648648858070374,0.7438063025474548,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
| 28 |
+
0.6348393559455872,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
| 29 |
+
,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.7742117047309875,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3 cell,0.6764705777168274,tanimoto
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_uniprot.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
| 2 |
+
0.7041562795639038,0.4588235318660736,0.0,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
| 3 |
+
0.6916469931602478,0.4588235318660736,0.0,0.0,0.0,0.4420289397239685,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
| 4 |
+
0.6960257887840271,0.4588235318660736,0.0,0.0,0.0,0.4303233027458191,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
| 5 |
+
,0.4588235318660736,,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3,0.0,uniprot
|
| 6 |
+
0.7039564251899719,0.4588235318660736,0.0,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
| 7 |
+
0.6913965940475464,0.4588235318660736,0.0,0.0,0.0,0.46739131212234497,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
| 8 |
+
0.6957095265388489,0.4588235318660736,0.0,0.0,0.0,0.45234110951423645,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
| 9 |
+
,0.4588235318660736,,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi,0.0,uniprot
|
| 10 |
+
0.7036164402961731,0.4588235318660736,0.0,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
| 11 |
+
0.6914005875587463,0.4588235318660736,0.0,0.0,0.0,0.48188406229019165,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
| 12 |
+
0.695412814617157,0.4588235318660736,0.0,0.0,0.0,0.4763098955154419,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
| 13 |
+
,0.4588235318660736,,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled cell,0.0,uniprot
|
| 14 |
+
0.697465717792511,0.4588235318660736,0.0,0.0,0.0,0.6223523020744324,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
| 15 |
+
0.6916133761405945,0.4588235318660736,0.0,0.0,0.0,0.6636008620262146,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
| 16 |
+
0.6932395696640015,0.4588235318660736,0.0,0.0,0.0,0.651337742805481,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
| 17 |
+
,0.4588235318660736,,0.0,0.0,0.6223522424697876,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled smiles,0.0,uniprot
|
| 18 |
+
0.704821765422821,0.4588235318660736,0.0,0.0,0.0,0.518673300743103,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
| 19 |
+
0.6916972398757935,0.4588235318660736,0.0,0.0,0.0,0.45234113931655884,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
| 20 |
+
0.6962708830833435,0.4588235318660736,0.0,0.0,0.0,0.42892974615097046,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
| 21 |
+
,0.4588235318660736,,0.0,0.0,0.5186733603477478,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3 cell,0.0,uniprot
|
| 22 |
+
0.7051585912704468,0.4588235318660736,0.0,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
| 23 |
+
0.6916910409927368,0.4588235318660736,0.0,0.0,0.0,0.44732439517974854,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
| 24 |
+
0.6965663433074951,0.4588235318660736,0.0,0.0,0.0,0.40328872203826904,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
| 25 |
+
,0.4588235318660736,,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3,0.0,uniprot
|
| 26 |
+
0.7058382034301758,0.4588235318660736,0.0,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
| 27 |
+
0.6917427778244019,0.4588235318660736,0.0,0.0,0.0,0.450111448764801,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
| 28 |
+
0.6968205571174622,0.4588235318660736,0.0,0.0,0.0,0.4155518114566803,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
| 29 |
+
,0.4588235318660736,,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3 cell,0.0,uniprot
|
src/run_xgboost_experiments.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import warnings
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 9 |
+
|
| 10 |
+
import protac_degradation_predictor as pdp
|
| 11 |
+
from protac_degradation_predictor.optuna_utils import get_dataframe_stats
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
from rdkit.Chem import AllChem
|
| 16 |
+
from rdkit import DataStructs
|
| 17 |
+
from jsonargparse import CLI
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import numpy as np
|
| 21 |
+
from sklearn.preprocessing import OrdinalEncoder
|
| 22 |
+
from sklearn.model_selection import (
|
| 23 |
+
StratifiedKFold,
|
| 24 |
+
StratifiedGroupKFold,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Ignore UserWarning from Matplotlib
|
| 28 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
| 29 |
+
# Ignore UserWarning from PyTorch Lightning
|
| 30 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
root = logging.getLogger()
|
| 34 |
+
root.setLevel(logging.DEBUG)
|
| 35 |
+
|
| 36 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 37 |
+
handler.setLevel(logging.DEBUG)
|
| 38 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 39 |
+
handler.setFormatter(formatter)
|
| 40 |
+
root.addHandler(handler)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
|
| 44 |
+
""" Get the indices of the test set using a random split.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 48 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
pd.Index: The indices of the test set.
|
| 52 |
+
"""
|
| 53 |
+
test_df = active_df.sample(frac=test_split, random_state=42)
|
| 54 |
+
return test_df.index
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
|
| 58 |
+
""" Get the indices of the test set using the E3 ligase split.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
pd.Index: The indices of the test set.
|
| 65 |
+
"""
|
| 66 |
+
encoder = OrdinalEncoder()
|
| 67 |
+
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
|
| 68 |
+
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
| 69 |
+
return test_df.index
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
| 73 |
+
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
| 80 |
+
"""
|
| 81 |
+
unique_smiles = protac_df['Smiles'].unique().tolist()
|
| 82 |
+
|
| 83 |
+
smiles2fp = {}
|
| 84 |
+
for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
|
| 85 |
+
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
| 86 |
+
|
| 87 |
+
# # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
| 88 |
+
# tanimoto_matrix = defaultdict(list)
|
| 89 |
+
# for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
| 90 |
+
# fp1 = smiles2fp[smiles1]
|
| 91 |
+
# # TODO: Use BulkTanimotoSimilarity for better performance
|
| 92 |
+
# for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
|
| 93 |
+
# fp2 = smiles2fp[smiles2]
|
| 94 |
+
# tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 95 |
+
# tanimoto_matrix[smiles1].append(tanimoto_dist)
|
| 96 |
+
# avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
| 97 |
+
# protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
tanimoto_matrix = defaultdict(list)
|
| 101 |
+
fps = list(smiles2fp.values())
|
| 102 |
+
|
| 103 |
+
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
|
| 104 |
+
for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
|
| 105 |
+
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
|
| 106 |
+
for j, similarity in enumerate(similarities):
|
| 107 |
+
distance = 1 - similarity
|
| 108 |
+
tanimoto_matrix[smiles1].append(distance) # Store as distance
|
| 109 |
+
if i != i + j:
|
| 110 |
+
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
|
| 111 |
+
|
| 112 |
+
# Calculate average Tanimoto distance for each unique SMILES
|
| 113 |
+
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
| 114 |
+
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
| 115 |
+
|
| 116 |
+
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
| 117 |
+
|
| 118 |
+
return smiles2fp, protac_df
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_tanimoto_split_indices(
|
| 122 |
+
active_df: pd.DataFrame,
|
| 123 |
+
active_col: str,
|
| 124 |
+
test_split: float,
|
| 125 |
+
n_bins_tanimoto: int = 200,
|
| 126 |
+
) -> pd.Index:
|
| 127 |
+
""" Get the indices of the test set using the Tanimoto-based split.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 131 |
+
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
pd.Index: The indices of the test set.
|
| 135 |
+
"""
|
| 136 |
+
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
| 137 |
+
encoder = OrdinalEncoder()
|
| 138 |
+
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
| 139 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
| 140 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
| 141 |
+
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
| 142 |
+
|
| 143 |
+
test_df = []
|
| 144 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
| 145 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 146 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 147 |
+
# in the active_col in test_df is roughly 50%.
|
| 148 |
+
for group in tanimoto_groups:
|
| 149 |
+
group_df = active_df[active_df['Tanimoto Group'] == group]
|
| 150 |
+
if test_df == []:
|
| 151 |
+
test_df.append(group_df)
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
num_entries = len(group_df)
|
| 155 |
+
num_active_group = group_df[active_col].sum()
|
| 156 |
+
num_inactive_group = num_entries - num_active_group
|
| 157 |
+
|
| 158 |
+
tmp_test_df = pd.concat(test_df)
|
| 159 |
+
num_entries_test = len(tmp_test_df)
|
| 160 |
+
num_active_test = tmp_test_df[active_col].sum()
|
| 161 |
+
num_inactive_test = num_entries_test - num_active_test
|
| 162 |
+
|
| 163 |
+
# Check if the group entries can be added to the test_df
|
| 164 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
| 165 |
+
# Add anything at the beggining
|
| 166 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
| 167 |
+
test_df.append(group_df)
|
| 168 |
+
continue
|
| 169 |
+
# Be more selective and make sure that the percentage of active and
|
| 170 |
+
# inactive is balanced
|
| 171 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
| 172 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
| 173 |
+
test_df.append(group_df)
|
| 174 |
+
test_df = pd.concat(test_df)
|
| 175 |
+
return test_df.index
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
|
| 179 |
+
""" Get the indices of the test set using the target-based split.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 183 |
+
active_col (str): The column containing the active/inactive information.
|
| 184 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
pd.Index: The indices of the test set.
|
| 188 |
+
"""
|
| 189 |
+
encoder = OrdinalEncoder()
|
| 190 |
+
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
|
| 191 |
+
|
| 192 |
+
test_df = []
|
| 193 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
| 194 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 195 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 196 |
+
# in the active_col in test_df is roughly 50%.
|
| 197 |
+
# Start the loop from the groups containing the smallest number of entries.
|
| 198 |
+
for group in reversed(active_df['Uniprot'].value_counts().index):
|
| 199 |
+
group_df = active_df[active_df['Uniprot'] == group]
|
| 200 |
+
if test_df == []:
|
| 201 |
+
test_df.append(group_df)
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
num_entries = len(group_df)
|
| 205 |
+
num_active_group = group_df[active_col].sum()
|
| 206 |
+
num_inactive_group = num_entries - num_active_group
|
| 207 |
+
|
| 208 |
+
tmp_test_df = pd.concat(test_df)
|
| 209 |
+
num_entries_test = len(tmp_test_df)
|
| 210 |
+
num_active_test = tmp_test_df[active_col].sum()
|
| 211 |
+
num_inactive_test = num_entries_test - num_active_test
|
| 212 |
+
|
| 213 |
+
# Check if the group entries can be added to the test_df
|
| 214 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
| 215 |
+
# Add anything at the beggining
|
| 216 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
| 217 |
+
test_df.append(group_df)
|
| 218 |
+
continue
|
| 219 |
+
# Be more selective and make sure that the percentage of active and
|
| 220 |
+
# inactive is balanced
|
| 221 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
| 222 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
| 223 |
+
test_df.append(group_df)
|
| 224 |
+
test_df = pd.concat(test_df)
|
| 225 |
+
return test_df.index
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def main(
|
| 229 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
| 230 |
+
n_trials: int = 100,
|
| 231 |
+
test_split: float = 0.1,
|
| 232 |
+
cv_n_splits: int = 5,
|
| 233 |
+
num_boost_round: int = 100,
|
| 234 |
+
force_study: bool = False,
|
| 235 |
+
experiments: str | Literal['all', 'random', 'e3_ligase', 'tanimoto', 'uniprot'] = 'all',
|
| 236 |
+
):
|
| 237 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column.
|
| 241 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
| 242 |
+
n_splits (int): The number of cross-validation splits.
|
| 243 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 244 |
+
"""
|
| 245 |
+
pl.seed_everything(42)
|
| 246 |
+
|
| 247 |
+
# Set the Column to Predict
|
| 248 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
| 249 |
+
|
| 250 |
+
# Get Dmax_threshold from the active_col
|
| 251 |
+
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
|
| 252 |
+
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
|
| 253 |
+
|
| 254 |
+
# Load the PROTAC dataset
|
| 255 |
+
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
| 256 |
+
# Map E3 Ligase Iap to IAP
|
| 257 |
+
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
| 258 |
+
protac_df[active_col] = protac_df.apply(
|
| 259 |
+
lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
|
| 260 |
+
)
|
| 261 |
+
smiles2fp, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
|
| 262 |
+
|
| 263 |
+
## Get the test sets
|
| 264 |
+
test_indeces = {}
|
| 265 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
| 266 |
+
|
| 267 |
+
if experiments == 'random' or experiments == 'all':
|
| 268 |
+
test_indeces['random'] = get_random_split_indices(active_df, test_split)
|
| 269 |
+
if experiments == 'uniprot' or experiments == 'all':
|
| 270 |
+
test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
|
| 271 |
+
if experiments == 'e3_ligase' or experiments == 'all':
|
| 272 |
+
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
| 273 |
+
if experiments == 'tanimoto' or experiments == 'all':
|
| 274 |
+
test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
| 275 |
+
|
| 276 |
+
# Make directory ../reports if it does not exist
|
| 277 |
+
if not os.path.exists('../reports'):
|
| 278 |
+
os.makedirs('../reports')
|
| 279 |
+
|
| 280 |
+
# Load embedding dictionaries
|
| 281 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
| 282 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
| 283 |
+
|
| 284 |
+
# Cross-Validation Training
|
| 285 |
+
reports = defaultdict(list)
|
| 286 |
+
for split_type, indeces in test_indeces.items():
|
| 287 |
+
test_df = active_df.loc[indeces].copy()
|
| 288 |
+
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
| 289 |
+
|
| 290 |
+
# Get the CV object
|
| 291 |
+
if split_type == 'random':
|
| 292 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 293 |
+
group = None
|
| 294 |
+
elif split_type == 'e3_ligase':
|
| 295 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 296 |
+
group = train_val_df['E3 Group'].to_numpy()
|
| 297 |
+
elif split_type == 'tanimoto':
|
| 298 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 299 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
| 300 |
+
elif split_type == 'uniprot':
|
| 301 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 302 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
| 303 |
+
|
| 304 |
+
# Start the experiment
|
| 305 |
+
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
| 306 |
+
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
| 307 |
+
protein2embedding=protein2embedding,
|
| 308 |
+
cell2embedding=cell2embedding,
|
| 309 |
+
smiles2fp=smiles2fp,
|
| 310 |
+
train_val_df=train_val_df,
|
| 311 |
+
test_df=test_df,
|
| 312 |
+
kf=kf,
|
| 313 |
+
groups=group,
|
| 314 |
+
split_type=split_type,
|
| 315 |
+
n_models_for_test=3,
|
| 316 |
+
n_trials=n_trials,
|
| 317 |
+
active_label=active_col,
|
| 318 |
+
num_boost_round=num_boost_round,
|
| 319 |
+
study_filename=f'../reports/study_xgboost_{experiment_name}.pkl',
|
| 320 |
+
force_study=force_study,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Save the reports to file
|
| 324 |
+
for report_name, report in optuna_reports.items():
|
| 325 |
+
report.to_csv(f'../reports/xgboost_{report_name}_{experiment_name}.csv', index=False)
|
| 326 |
+
reports[report_name].append(report.copy())
|
| 327 |
+
|
| 328 |
+
if __name__ == '__main__':
|
| 329 |
+
cli = CLI(main)
|