|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
|
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv') |
|
protac_df.head() |
|
|
|
|
|
|
|
nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique() |
|
nan_active |
|
|
|
|
|
|
|
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP') |
|
|
|
|
|
protac_df.columns |
|
|
|
|
|
cells = sorted(protac_df['Cell Type'].dropna().unique().tolist()) |
|
print(f'Number of non-cleaned cell lines: {len(cells)}') |
|
|
|
|
|
cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist()) |
|
print(f'Number of cleaned cell lines: {len(cells)}') |
|
|
|
|
|
unlabeled_df = protac_df[protac_df['Active'].isna()] |
|
print(f'Number of compounds in test set: {len(unlabeled_df)}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import urllib.request |
|
|
|
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5" |
|
embeddings_path = "../data/uniprot2embedding.h5" |
|
if not os.path.exists(embeddings_path): |
|
|
|
print(f'Downloading embeddings from {download_link}') |
|
urllib.request.urlretrieve(download_link, embeddings_path) |
|
|
|
|
|
import h5py |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
|
|
protein_embeddings = {} |
|
with h5py.File("../data/uniprot2embedding.h5", "r") as file: |
|
print(f"number of entries: {len(file.items()):,}") |
|
uniprots = protac_df['Uniprot'].unique().tolist() |
|
uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist() |
|
for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'): |
|
try: |
|
embedding = file[sequence_id][:] |
|
protein_embeddings[sequence_id] = np.array(embedding) |
|
if i < 10: |
|
print( |
|
f"\tid: {sequence_id}, " |
|
f"\tembeddings shape: {embedding.shape}, " |
|
f"\tembeddings mean: {np.array(embedding).mean()}" |
|
) |
|
except KeyError: |
|
print(f'KeyError for {sequence_id}') |
|
protein_embeddings[sequence_id] = np.zeros((1024,)) |
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
|
|
cell2embedding_filepath = '../data/cell2embedding.pkl' |
|
with open(cell2embedding_filepath, 'rb') as f: |
|
cell2embedding = pickle.load(f) |
|
print(f'Loaded {len(cell2embedding)} cell lines') |
|
|
|
|
|
emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape |
|
|
|
for cell_line in protac_df['Cell Line Identifier'].unique(): |
|
if cell_line not in cell2embedding: |
|
cell2embedding[cell_line] = np.zeros(emb_shape) |
|
|
|
|
|
|
|
|
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit.Chem import Draw |
|
|
|
morgan_radius = 15 |
|
n_bits = 1024 |
|
|
|
|
|
rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512) |
|
morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits) |
|
|
|
smiles2fp = {} |
|
for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'): |
|
|
|
morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles)) |
|
|
|
|
|
smiles2fp[smiles] = morgan_fp |
|
|
|
|
|
print(f'Number of unique SMILES: {len(smiles2fp)}') |
|
print(f'Number of unique fingerprints: {len(set([tuple(fp) for fp in smiles2fp.values()]))}') |
|
|
|
overlapping_smiles = [] |
|
unique_fps = set() |
|
for smiles, fp in smiles2fp.items(): |
|
if tuple(fp) in unique_fps: |
|
overlapping_smiles.append(smiles) |
|
else: |
|
unique_fps.add(tuple(fp)) |
|
print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}') |
|
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}') |
|
|
|
|
|
|
|
from rdkit import DataStructs |
|
from collections import defaultdict |
|
|
|
tanimoto_matrix = defaultdict(list) |
|
for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')): |
|
fp1 = smiles2fp[smiles1] |
|
|
|
for j, smiles2 in enumerate(protac_df['Smiles'].unique()): |
|
if j < i: |
|
continue |
|
fp2 = smiles2fp[smiles2] |
|
tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2) |
|
tanimoto_matrix[smiles1].append(tanimoto_dist) |
|
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()} |
|
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
active_col = 'Active - OR' |
|
|
|
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
|
|
|
|
|
|
from imblearn.over_sampling import SMOTE, ADASYN |
|
from sklearn.preprocessing import LabelEncoder |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class PROTAC_Dataset(Dataset): |
|
def __init__( |
|
self, |
|
protac_df, |
|
protein_embeddings=protein_embeddings, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
use_smote=False, |
|
oversampler=None, |
|
use_ored_activity=False, |
|
): |
|
""" Initialize the PROTAC dataset |
|
|
|
Args: |
|
protac_df (pd.DataFrame): The PROTAC dataframe |
|
protein_embeddings (dict): Dictionary of protein embeddings |
|
cell2embedding (dict): Dictionary of cell line embeddings |
|
smiles2fp (dict): Dictionary of SMILES to fingerprint |
|
use_smote (bool): Whether to use SMOTE for oversampling |
|
use_ored_activity (bool): Whether to use the 'Active - OR' column |
|
""" |
|
|
|
self.data = protac_df |
|
self.protein_embeddings = protein_embeddings |
|
self.cell2embedding = cell2embedding |
|
self.smiles2fp = smiles2fp |
|
|
|
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0] |
|
self.protein_emb_dim = protein_embeddings[list( |
|
protein_embeddings.keys())[0]].shape[0] |
|
self.cell_emb_dim = cell2embedding[list( |
|
cell2embedding.keys())[0]].shape[0] |
|
|
|
self.active_label = 'Active - OR' if use_ored_activity else 'Active' |
|
|
|
self.use_smote = use_smote |
|
self.oversampler = oversampler |
|
|
|
if self.use_smote: |
|
self.apply_smote() |
|
|
|
def apply_smote(self): |
|
|
|
features = [] |
|
labels = [] |
|
for _, row in self.data.iterrows(): |
|
smiles_emb = smiles2fp[row['Smiles']] |
|
poi_emb = protein_embeddings[row['Uniprot']] |
|
e3_emb = protein_embeddings[row['E3 Ligase Uniprot']] |
|
cell_emb = cell2embedding[row['Cell Line Identifier']] |
|
features.append(np.hstack([ |
|
smiles_emb.astype(np.float32), |
|
poi_emb.astype(np.float32), |
|
e3_emb.astype(np.float32), |
|
cell_emb.astype(np.float32), |
|
])) |
|
labels.append(row[self.active_label]) |
|
|
|
|
|
features = np.array(features).astype(np.float32) |
|
labels = np.array(labels).astype(np.float32) |
|
|
|
|
|
if self.oversampler is None: |
|
oversampler = SMOTE(random_state=42) |
|
else: |
|
oversampler = self.oversampler |
|
features_smote, labels_smote = oversampler.fit_resample(features, labels) |
|
|
|
|
|
smiles_embs = features_smote[:, :self.smiles_emb_dim] |
|
poi_embs = features_smote[:, |
|
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim] |
|
e3_embs = features_smote[:, self.smiles_emb_dim + |
|
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim] |
|
cell_embs = features_smote[:, -self.cell_emb_dim:] |
|
|
|
|
|
df_smote = pd.DataFrame({ |
|
'Smiles': list(smiles_embs), |
|
'Uniprot': list(poi_embs), |
|
'E3 Ligase Uniprot': list(e3_embs), |
|
'Cell Line Identifier': list(cell_embs), |
|
self.active_label: labels_smote |
|
}) |
|
self.data = df_smote |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if self.use_smote: |
|
|
|
elem = { |
|
'smiles_emb': self.data['Smiles'].iloc[idx], |
|
'poi_emb': self.data['Uniprot'].iloc[idx], |
|
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx], |
|
'cell_emb': self.data['Cell Line Identifier'].iloc[idx], |
|
'active': self.data[self.active_label].iloc[idx], |
|
} |
|
else: |
|
elem = { |
|
'smiles_emb': self.smiles2fp[self.data['Smiles'].iloc[idx]].astype(np.float32), |
|
'poi_emb': self.protein_embeddings[self.data['Uniprot'].iloc[idx]].astype(np.float32), |
|
'e3_emb': self.protein_embeddings[self.data['E3 Ligase Uniprot'].iloc[idx]].astype(np.float32), |
|
'cell_emb': self.cell2embedding[self.data['Cell Line Identifier'].iloc[idx]].astype(np.float32), |
|
'active': 1. if self.data[self.active_label].iloc[idx] else 0., |
|
} |
|
return elem |
|
|
|
|
|
import warnings |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
from torchmetrics import ( |
|
Accuracy, |
|
AUROC, |
|
Precision, |
|
Recall, |
|
F1Score, |
|
) |
|
from torchmetrics import MetricCollection |
|
|
|
|
|
warnings.filterwarnings("ignore", ".*does not have many workers.*") |
|
|
|
class PROTAC_Model(pl.LightningModule): |
|
|
|
def __init__( |
|
self, |
|
hidden_dim, |
|
smiles_emb_dim=1024, |
|
poi_emb_dim=1024, |
|
e3_emb_dim=1024, |
|
cell_emb_dim=768, |
|
batch_size=32, |
|
learning_rate=1e-3, |
|
dropout=0.2, |
|
train_dataset=None, |
|
val_dataset=None, |
|
test_dataset=None, |
|
disabled_embeddings=[], |
|
): |
|
super().__init__() |
|
self.poi_emb_dim = poi_emb_dim |
|
self.e3_emb_dim = e3_emb_dim |
|
self.cell_emb_dim = cell_emb_dim |
|
self.smiles_emb_dim = smiles_emb_dim |
|
self.hidden_dim = hidden_dim |
|
self.batch_size = batch_size |
|
self.learning_rate = learning_rate |
|
self.train_dataset = train_dataset |
|
self.val_dataset = val_dataset |
|
self.test_dataset = test_dataset |
|
self.disabled_embeddings = disabled_embeddings |
|
|
|
self.__dict__.update(locals()) |
|
|
|
ignore_args_as_hyperparams = [ |
|
'train_dataset', |
|
'test_dataset', |
|
'val_dataset', |
|
] |
|
self.save_hyperparameters(ignore=ignore_args_as_hyperparams) |
|
|
|
if 'poi' not in self.disabled_embeddings: |
|
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'e3' not in self.disabled_embeddings: |
|
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'cell' not in self.disabled_embeddings: |
|
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'smiles' not in self.disabled_embeddings: |
|
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear( |
|
hidden_dim * (4 - len(self.disabled_embeddings)), hidden_dim) |
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
self.fc3 = nn.Linear(hidden_dim, 1) |
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
stages = ['train_metrics', 'val_metrics', 'test_metrics'] |
|
self.metrics = nn.ModuleDict({s: MetricCollection({ |
|
'acc': Accuracy(task='binary'), |
|
'roc_auc': AUROC(task='binary'), |
|
'precision': Precision(task='binary'), |
|
'recall': Recall(task='binary'), |
|
'f1_score': F1Score(task='binary'), |
|
'opt_score': Accuracy(task='binary') + F1Score(task='binary'), |
|
'hp_metric': Accuracy(task='binary'), |
|
}, prefix=s.replace('metrics', '')) for s in stages}) |
|
|
|
|
|
self.missing_dataset_error = \ |
|
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually: |
|
|
|
model = {1}.load_from_checkpoint('checkpoint.ckpt') |
|
model.{0} = my_{0} |
|
''' |
|
|
|
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb): |
|
embeddings = [] |
|
if 'poi' not in self.disabled_embeddings: |
|
embeddings.append(self.poi_emb(poi_emb)) |
|
if 'e3' not in self.disabled_embeddings: |
|
embeddings.append(self.e3_emb(e3_emb)) |
|
if 'cell' not in self.disabled_embeddings: |
|
embeddings.append(self.cell_emb(cell_emb)) |
|
if 'smiles' not in self.disabled_embeddings: |
|
embeddings.append(self.smiles_emb(smiles_emb)) |
|
x = torch.cat(embeddings, dim=1) |
|
x = self.dropout(F.gelu(self.fc1(x))) |
|
x = self.dropout(F.gelu(self.fc2(x))) |
|
x = self.fc3(x) |
|
return x |
|
|
|
def step(self, batch, batch_idx, stage): |
|
poi_emb = batch['poi_emb'] |
|
e3_emb = batch['e3_emb'] |
|
cell_emb = batch['cell_emb'] |
|
smiles_emb = batch['smiles_emb'] |
|
y = batch['active'].float().unsqueeze(1) |
|
|
|
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) |
|
loss = F.binary_cross_entropy_with_logits(y_hat, y) |
|
|
|
self.metrics[f'{stage}_metrics'].update(y_hat, y) |
|
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) |
|
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'train') |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'val') |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'test') |
|
|
|
def configure_optimizers(self): |
|
return optim.Adam(self.parameters(), lr=self.learning_rate) |
|
|
|
def predict_step(self, batch, batch_idx): |
|
poi_emb = batch['poi_emb'] |
|
e3_emb = batch['e3_emb'] |
|
cell_emb = batch['cell_emb'] |
|
smiles_emb = batch['smiles_emb'] |
|
|
|
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) |
|
return torch.sigmoid(y_hat) |
|
|
|
def train_dataloader(self): |
|
if self.train_dataset is None: |
|
format = 'train_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
|
|
) |
|
|
|
def val_dataloader(self): |
|
if self.val_dataset is None: |
|
format = 'val_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
) |
|
|
|
def test_dataloader(self): |
|
if self.test_dataset is None: |
|
format = 'test_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_indeces = {} |
|
|
|
|
|
|
|
|
|
|
|
active_df = protac_df[protac_df[active_col].notna()].copy() |
|
|
|
|
|
unique_smiles = active_df['Smiles'].value_counts() == 1 |
|
unique_uniprot = active_df['Uniprot'].value_counts() == 1 |
|
print(f'Number of unique SMILES: {unique_smiles.sum()}') |
|
print(f'Number of unique Uniprot: {unique_uniprot.sum()}') |
|
|
|
|
|
n = int(0.05 * len(active_df)) // 2 |
|
unique_smiles = unique_smiles[unique_smiles].sample(n=n, random_state=42) |
|
|
|
unique_indices = active_df[ |
|
active_df['Smiles'].isin(unique_smiles.index) & |
|
active_df['Uniprot'].isin(unique_uniprot.index) |
|
].index |
|
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})') |
|
|
|
test_indeces['random'] = unique_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
active_df = protac_df[protac_df[active_col].notna()].copy() |
|
|
|
unique_uniprot = active_df['Uniprot'].value_counts() == 1 |
|
print(f'Number of unique Uniprot: {unique_uniprot.sum()}') |
|
|
|
|
|
|
|
unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].index |
|
|
|
|
|
test_indeces['uniprot'] = unique_indices |
|
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
active_df = protac_df[protac_df[active_col].notna()] |
|
|
|
|
|
|
|
|
|
|
|
unique_smiles = active_df['Smiles'].value_counts() == 1 |
|
unique_uniprot = active_df['Uniprot'].value_counts() == 1 |
|
unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1 |
|
|
|
|
|
unique_smiles_idx = active_df['Smiles'].map(unique_smiles) |
|
unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot) |
|
unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot) |
|
|
|
|
|
|
|
unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx].index |
|
test_df = active_df.loc[unique_samples] |
|
|
|
warnings.filterwarnings("ignore", ".*FixedLocator*") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.model_selection import ( |
|
StratifiedKFold, |
|
StratifiedGroupKFold, |
|
) |
|
from sklearn.preprocessing import OrdinalEncoder |
|
|
|
|
|
|
|
|
|
n_bins_tanimoto = 60 if active_col == 'Active' else 400 |
|
n_splits = 5 |
|
|
|
|
|
active_df = protac_df[protac_df[active_col].notna()] |
|
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy() |
|
|
|
|
|
|
|
|
|
|
|
groups = [ |
|
'random', |
|
'uniprot', |
|
'tanimoto', |
|
] |
|
for group_type in groups: |
|
if group_type == 'random': |
|
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
groups = None |
|
elif group_type == 'uniprot': |
|
|
|
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
encoder = OrdinalEncoder() |
|
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1)) |
|
print(f'Number of unique groups: {len(encoder.categories_[0])}') |
|
elif group_type == 'tanimoto': |
|
|
|
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy() |
|
encoder = OrdinalEncoder() |
|
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)) |
|
print(f'Number of unique groups: {len(encoder.categories_[0])}') |
|
|
|
|
|
X = train_val_df.drop(columns=active_col) |
|
y = train_val_df[active_col].tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stats = [] |
|
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)): |
|
train_df = train_val_df.iloc[train_index] |
|
val_df = train_val_df.iloc[val_index] |
|
stat = { |
|
'fold': k, |
|
'train_len': len(train_df), |
|
'val_len': len(val_df), |
|
'train_perc': len(train_df) / len(train_val_df), |
|
'val_perc': len(val_df) / len(train_val_df), |
|
'train_active (%)': train_df[active_col].sum() / len(train_df) * 100, |
|
'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100, |
|
'val_active (%)': val_df[active_col].sum() / len(val_df) * 100, |
|
'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100, |
|
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))), |
|
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))), |
|
} |
|
if group_type != 'random': |
|
stat['train_unique_groups'] = len(np.unique(groups[train_index])) |
|
stat['val_unique_groups'] = len(np.unique(groups[val_index])) |
|
stats.append(stat) |
|
print('-' * 120) |
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
pl.seed_everything(42) |
|
|
|
|
|
def train_model( |
|
train_df, |
|
val_df, |
|
test_df=None, |
|
hidden_dim=768, |
|
batch_size=8, |
|
learning_rate=2e-5, |
|
max_epochs=50, |
|
smiles_emb_dim=1024, |
|
smote_n_neighbors=5, |
|
use_ored_activity=False if active_col == 'Active' else True, |
|
fast_dev_run=False, |
|
disabled_embeddings=[], |
|
) -> tuple: |
|
""" Train a PROTAC model using the given datasets and hyperparameters. |
|
|
|
Args: |
|
train_df (pd.DataFrame): The training set. |
|
val_df (pd.DataFrame): The validation set. |
|
test_df (pd.DataFrame): The test set. |
|
hidden_dim (int): The hidden dimension of the model. |
|
batch_size (int): The batch size. |
|
learning_rate (float): The learning rate. |
|
max_epochs (int): The maximum number of epochs. |
|
smiles_emb_dim (int): The dimension of the SMILES embeddings. |
|
smote_n_neighbors (int): The number of neighbors for the SMOTE oversampler. |
|
use_ored_activity (bool): Whether to use the ORED activity column. |
|
fast_dev_run (bool): Whether to run a fast development run. |
|
disabled_embeddings (list): The list of disabled embeddings. |
|
|
|
Returns: |
|
tuple: The trained model, the trainer, and the metrics. |
|
""" |
|
oversampler = SMOTE(k_neighbors=smote_n_neighbors, random_state=42) |
|
train_ds = PROTAC_Dataset( |
|
train_df, |
|
protein_embeddings, |
|
cell2embedding, |
|
smiles2fp, |
|
use_smote=True, |
|
oversampler=oversampler, |
|
use_ored_activity=use_ored_activity, |
|
) |
|
val_ds = PROTAC_Dataset( |
|
val_df, |
|
protein_embeddings, |
|
cell2embedding, |
|
smiles2fp, |
|
use_ored_activity=use_ored_activity, |
|
) |
|
if test_df is not None: |
|
test_ds = PROTAC_Dataset( |
|
test_df, |
|
protein_embeddings, |
|
cell2embedding, |
|
smiles2fp, |
|
use_ored_activity=use_ored_activity, |
|
) |
|
logger = pl.loggers.TensorBoardLogger( |
|
save_dir='../logs', |
|
name='protac', |
|
) |
|
callbacks = [ |
|
pl.callbacks.EarlyStopping( |
|
monitor='train_loss', |
|
patience=10, |
|
mode='max', |
|
verbose=True, |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
trainer = pl.Trainer( |
|
logger=logger, |
|
callbacks=callbacks, |
|
max_epochs=max_epochs, |
|
fast_dev_run=fast_dev_run, |
|
enable_model_summary=False, |
|
enable_checkpointing=False, |
|
) |
|
model = PROTAC_Model( |
|
hidden_dim=hidden_dim, |
|
smiles_emb_dim=smiles_emb_dim, |
|
poi_emb_dim=1024, |
|
e3_emb_dim=1024, |
|
cell_emb_dim=768, |
|
batch_size=batch_size, |
|
learning_rate=learning_rate, |
|
train_dataset=train_ds, |
|
val_dataset=val_ds, |
|
test_dataset=test_ds if test_df is not None else None, |
|
disabled_embeddings=disabled_embeddings, |
|
) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
trainer.fit(model) |
|
metrics = trainer.validate(model, verbose=False)[0] |
|
if test_df is not None: |
|
test_metrics = trainer.test(model, verbose=False)[0] |
|
metrics.update(test_metrics) |
|
return model, trainer, metrics |
|
|
|
|
|
|
|
|
|
|
|
import optuna |
|
import pandas as pd |
|
|
|
|
|
def objective( |
|
trial, |
|
train_df, |
|
val_df, |
|
hidden_dim_options, |
|
batch_size_options, |
|
learning_rate_options, |
|
max_epochs_options, |
|
fast_dev_run=False, |
|
) -> float: |
|
|
|
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options) |
|
batch_size = trial.suggest_categorical('batch_size', batch_size_options) |
|
learning_rate = trial.suggest_loguniform('learning_rate', *learning_rate_options) |
|
max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options) |
|
|
|
|
|
_, _, metrics = train_model( |
|
train_df, |
|
val_df, |
|
hidden_dim=hidden_dim, |
|
batch_size=batch_size, |
|
learning_rate=learning_rate, |
|
max_epochs=max_epochs, |
|
fast_dev_run=fast_dev_run, |
|
) |
|
|
|
|
|
val_loss = metrics['val_loss'] |
|
val_acc = metrics['val_acc'] |
|
val_roc_auc = metrics['val_roc_auc'] |
|
|
|
|
|
return val_loss - val_acc - val_roc_auc |
|
|
|
|
|
def hyperparameter_tuning_and_training( |
|
train_df, |
|
val_df, |
|
test_df, |
|
fast_dev_run=False, |
|
n_trials=20, |
|
) -> tuple: |
|
""" Hyperparameter tuning and training of a PROTAC model. |
|
|
|
Args: |
|
train_df (pd.DataFrame): The training set. |
|
val_df (pd.DataFrame): The validation set. |
|
test_df (pd.DataFrame): The test set. |
|
fast_dev_run (bool): Whether to run a fast development run. |
|
|
|
Returns: |
|
tuple: The trained model, the trainer, and the best metrics. |
|
""" |
|
|
|
hidden_dim_options = [256, 512, 768] |
|
batch_size_options = [8, 16, 32] |
|
learning_rate_options = (1e-5, 1e-3) |
|
max_epochs_options = [10, 20, 50] |
|
|
|
|
|
study = optuna.create_study(direction='minimize') |
|
study.optimize(lambda trial: objective( |
|
trial, |
|
train_df, |
|
val_df, |
|
hidden_dim_options, |
|
batch_size_options, |
|
learning_rate_options, |
|
max_epochs_options, |
|
fast_dev_run=fast_dev_run,), |
|
n_trials=n_trials, |
|
) |
|
|
|
|
|
best_params = study.best_params |
|
best_hidden_dim = best_params['hidden_dim'] |
|
best_batch_size = best_params['batch_size'] |
|
best_learning_rate = best_params['learning_rate'] |
|
best_max_epochs = best_params['max_epochs'] |
|
|
|
|
|
model, trainer, metrics = train_model( |
|
train_df, |
|
val_df, |
|
test_df, |
|
hidden_dim=best_hidden_dim, |
|
batch_size=best_batch_size, |
|
learning_rate=best_learning_rate, |
|
max_epochs=best_max_epochs, |
|
fast_dev_run=fast_dev_run, |
|
) |
|
|
|
|
|
return model, trainer, metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_splits = 5 |
|
report = [] |
|
active_df = protac_df[protac_df[active_col].notna()] |
|
train_val_df = active_df[~active_df.index.isin(unique_samples)] |
|
|
|
|
|
if not os.path.exists('../reports'): |
|
os.makedirs('../reports') |
|
|
|
for group_type in ['random', 'uniprot', 'tanimoto']: |
|
print(f'Starting CV for group type: {group_type}') |
|
|
|
if group_type == 'random': |
|
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
groups = None |
|
elif group_type == 'uniprot': |
|
|
|
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
encoder = OrdinalEncoder() |
|
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1)) |
|
elif group_type == 'tanimoto': |
|
|
|
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy() |
|
encoder = OrdinalEncoder() |
|
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)) |
|
|
|
X = train_val_df.drop(columns=active_col) |
|
y = train_val_df[active_col].tolist() |
|
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)): |
|
train_df = train_val_df.iloc[train_index] |
|
val_df = train_val_df.iloc[val_index] |
|
stats = { |
|
'fold': k, |
|
'group_type': group_type, |
|
'train_len': len(train_df), |
|
'val_len': len(val_df), |
|
'train_perc': len(train_df) / len(train_val_df), |
|
'val_perc': len(val_df) / len(train_val_df), |
|
'train_active_perc': train_df[active_col].sum() / len(train_df), |
|
'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df), |
|
'val_active_perc': val_df[active_col].sum() / len(val_df), |
|
'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df), |
|
'test_active_perc': test_df[active_col].sum() / len(test_df), |
|
'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df), |
|
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))), |
|
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))), |
|
} |
|
if group_type != 'random': |
|
stats['train_unique_groups'] = len(np.unique(groups[train_index])) |
|
stats['val_unique_groups'] = len(np.unique(groups[val_index])) |
|
|
|
|
|
model, trainer, metrics = hyperparameter_tuning_and_training( |
|
train_df, |
|
val_df, |
|
test_df, |
|
fast_dev_run=False, |
|
n_trials=50, |
|
) |
|
stats.update(metrics) |
|
del model |
|
del trainer |
|
report.append(stats) |
|
report = pd.DataFrame(report) |
|
report.to_csv( |
|
f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False, |
|
) |
|
|
|
|