ribesstefano's picture
Started renaming splits. Added datasets in paper as separate CSV.
history blame
3.02 kB

Training Models

Dataset Specification

From the repository top level directory, run the following command to get the datasets reported in the paper:

cd src
python get_studies_datasets.py

For training on custom datasets, please refer to the class PROTAC_Dataset in the file protac_dataset.py. The class expects a Pandas dataframe, so plase assemble a file to be parsed into a Pandas DataFrame with the following columns:

Column Name Type Description
Smiles str The SMILES representation of the PROTAC molecule.
Uniprot str The Uniprot ID of the target protein.
E3 Ligase Uniprot str The Uniprot ID of the E3 ligase.
Cell Line Identifier str The cell line identifier as one reported in Cellosaurus.
<active_label> bool The activity label of the PROTAC molecule to be predicted by the model.

The column <active_label> is set "Active" as default in the PROTAC_Dataset class and in the hyperparameter_tuning_and_training function (see below for how to use it).

Training on Custom Data

For training on custom datasets, please refer to the function hyperparameter_tuning_and_training in optuna_utils.py and the file run_experiments.py for inspiration on how to use the function.

An example of skeleton implementation is as follows:

import protac_degradation_predictor as pdp
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

# Load train/val and test dataframes
train_val_df = pd.read_csv('path/to/custom_dataset.csv')
test_df = pd.read_csv('path/to/test_dataset.csv') # Load one of our test datasets

# NOTE: Make sure to avoid data leakage by removing leaking data in the train/val
# dataframe. Do NOT do remove/alter the test set, as it would impair comparison
# with our work. Data leakage can occur if the test set contains any combination
# of SMILES, Uniprot, E3 Ligase Uniprot, or Cell Line Identifier that is present
# in the train/val set too.

# Precompute Morgan fingerprints
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}

# Load embedding dictionaries
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')

# Setup Cross-Validation object
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)