import warnings import pickle import logging from typing import Literal, List, Tuple, Optional, Dict from .protac_dataset import PROTAC_Dataset, get_datasets from .config import config import pandas as pd import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import pytorch_lightning as pl from torch.utils.data import Dataset, DataLoader from torchmetrics import ( Accuracy, AUROC, Precision, Recall, F1Score, MetricCollection, ) from imblearn.over_sampling import SMOTE from sklearn.preprocessing import StandardScaler class PROTAC_Predictor(nn.Module): def __init__( self, hidden_dim: int, smiles_emb_dim: int = config.fingerprint_size, poi_emb_dim: int = config.protein_embedding_size, e3_emb_dim: int = config.protein_embedding_size, cell_emb_dim: int = config.cell_embedding_size, dropout: float = 0.2, join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', use_batch_norm: bool = False, disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], ): """ Initialize the PROTAC model. Args: hidden_dim (int): The hidden dimension of the model smiles_emb_dim (int): The dimension of the SMILES embeddings poi_emb_dim (int): The dimension of the POI embeddings e3_emb_dim (int): The dimension of the E3 Ligase embeddings cell_emb_dim (int): The dimension of the cell line embeddings dropout (float): The dropout rate join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' """ super().__init__() # Set our init args as class attributes self.__dict__.update(locals()) # Define "surrogate models" branches # NOTE: The softmax is used to ensure that the embeddings are normalized # and can be summed on a "similar scale". if self.join_embeddings != 'beginning': if 'poi' not in self.disabled_embeddings: self.poi_fc = nn.Sequential( nn.Linear(poi_emb_dim, hidden_dim), nn.Softmax(dim=1), ) if 'e3' not in self.disabled_embeddings: self.e3_fc = nn.Sequential( nn.Linear(e3_emb_dim, hidden_dim), nn.Softmax(dim=1), ) if 'cell' not in self.disabled_embeddings: self.cell_fc = nn.Sequential( nn.Linear(cell_emb_dim, hidden_dim), nn.Softmax(dim=1), ) if 'smiles' not in self.disabled_embeddings: self.smiles_emb = nn.Sequential( nn.Linear(smiles_emb_dim, hidden_dim), nn.Softmax(dim=1), ) # Define hidden dimension for joining layer if self.join_embeddings == 'beginning': joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0 joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0 joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0 joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0 self.fc0 = nn.Linear(joint_dim, joint_dim) elif self.join_embeddings == 'concat': joint_dim = hidden_dim * (4 - len(self.disabled_embeddings)) elif self.join_embeddings == 'sum': joint_dim = hidden_dim self.fc1 = nn.Linear(joint_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, 1) self.bnorm = nn.BatchNorm1d(hidden_dim) self.dropout = nn.Dropout(p=dropout) def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb): embeddings = [] if self.join_embeddings == 'beginning': # TODO: Remove this if-branch if 'poi' not in self.disabled_embeddings: embeddings.append(poi_emb) if 'e3' not in self.disabled_embeddings: embeddings.append(e3_emb) if 'cell' not in self.disabled_embeddings: embeddings.append(cell_emb) if 'smiles' not in self.disabled_embeddings: embeddings.append(smiles_emb) x = torch.cat(embeddings, dim=1) x = self.dropout(F.relu(self.fc0(x))) else: if 'poi' not in self.disabled_embeddings: embeddings.append(self.poi_fc(poi_emb)) if torch.isnan(embeddings[-1]).any(): raise ValueError("NaN values found in POI embeddings.") if 'e3' not in self.disabled_embeddings: embeddings.append(self.e3_fc(e3_emb)) if torch.isnan(embeddings[-1]).any(): raise ValueError("NaN values found in E3 embeddings.") if 'cell' not in self.disabled_embeddings: embeddings.append(self.cell_fc(cell_emb)) if torch.isnan(embeddings[-1]).any(): raise ValueError("NaN values found in cell embeddings.") if 'smiles' not in self.disabled_embeddings: embeddings.append(self.smiles_emb(smiles_emb)) if torch.isnan(embeddings[-1]).any(): raise ValueError("NaN values found in SMILES embeddings.") if self.join_embeddings == 'concat': x = torch.cat(embeddings, dim=1) elif self.join_embeddings == 'sum': if len(embeddings) > 1: embeddings = torch.stack(embeddings, dim=1) x = torch.sum(embeddings, dim=1) else: x = embeddings[0] if torch.isnan(x).any(): raise ValueError("NaN values found in sum of softmax-ed embeddings.") x = F.relu(self.fc1(x)) x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x) x = self.fc3(x) return x class PROTAC_Model(pl.LightningModule): def __init__( self, hidden_dim: int, smiles_emb_dim: int = config.fingerprint_size, poi_emb_dim: int = config.protein_embedding_size, e3_emb_dim: int = config.protein_embedding_size, cell_emb_dim: int = config.cell_embedding_size, batch_size: int = 128, learning_rate: float = 1e-3, dropout: float = 0.2, use_batch_norm: bool = False, join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', train_dataset: PROTAC_Dataset = None, val_dataset: PROTAC_Dataset = None, test_dataset: PROTAC_Dataset = None, disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], apply_scaling: bool = True, ): """ Initialize the PROTAC Pytorch Lightning model. Args: hidden_dim (int): The hidden dimension of the model smiles_emb_dim (int): The dimension of the SMILES embeddings poi_emb_dim (int): The dimension of the POI embeddings e3_emb_dim (int): The dimension of the E3 Ligase embeddings cell_emb_dim (int): The dimension of the cell line embeddings batch_size (int): The batch size learning_rate (float): The learning rate dropout (float): The dropout rate join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings train_dataset (PROTAC_Dataset): The training dataset val_dataset (PROTAC_Dataset): The validation dataset test_dataset (PROTAC_Dataset): The test dataset disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' apply_scaling (bool): Whether to apply scaling to the embeddings """ super().__init__() # Set our init args as class attributes self.__dict__.update(locals()) # Add arguments as attributes # Save the arguments passed to init ignore_args_as_hyperparams = [ 'train_dataset', 'test_dataset', 'val_dataset', ] self.save_hyperparameters(ignore=ignore_args_as_hyperparams) self.model = PROTAC_Predictor( hidden_dim=hidden_dim, smiles_emb_dim=smiles_emb_dim, poi_emb_dim=poi_emb_dim, e3_emb_dim=e3_emb_dim, cell_emb_dim=cell_emb_dim, dropout=dropout, join_embeddings=join_embeddings, use_batch_norm=use_batch_norm, disabled_embeddings=[], # NOTE: This is handled in the PROTAC_Dataset classes ) stages = ['train_metrics', 'val_metrics', 'test_metrics'] self.metrics = nn.ModuleDict({s: MetricCollection({ 'acc': Accuracy(task='binary'), 'roc_auc': AUROC(task='binary'), 'precision': Precision(task='binary'), 'recall': Recall(task='binary'), 'f1_score': F1Score(task='binary'), }, prefix=s.replace('metrics', '')) for s in stages}) # Misc settings self.missing_dataset_error = \ '''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually: model = {1}.load_from_checkpoint('checkpoint.ckpt') model.{0} = my_{0} ''' # Apply scaling in datasets self.scalers = None if self.apply_scaling and self.train_dataset is not None: self.initialize_scalers() def initialize_scalers(self): """Initialize or reinitialize scalers based on dataset properties.""" if self.scalers is None: use_single_scaler = self.join_embeddings == 'beginning' self.scalers = self.train_dataset.fit_scaling(use_single_scaler) self.apply_scalers() def apply_scalers(self): """Apply scalers to all datasets.""" use_single_scaler = self.join_embeddings == 'beginning' if self.train_dataset: self.train_dataset.apply_scaling(self.scalers, use_single_scaler) if self.val_dataset: self.val_dataset.apply_scaling(self.scalers, use_single_scaler) if self.test_dataset: self.test_dataset.apply_scaling(self.scalers, use_single_scaler) def scale_tensor( self, tensor: torch.Tensor, scaler: StandardScaler, alpha: float = 1e-10, ) -> torch.Tensor: """Scale a tensor using a scaler. This is done to avoid using numpy arrays (and stay on the same device). Args: tensor (torch.Tensor): The tensor to scale. scaler (StandardScaler): The scaler to use. Returns: torch.Tensor: The scaled tensor. """ tensor = tensor.float() if scaler.with_mean: tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device) if scaler.with_std: tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha return tensor def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True): if not prescaled_embeddings: if self.apply_scaling: if self.join_embeddings == 'beginning': embeddings = self.scale_tensor( torch.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]), self.scalers, ) smiles_emb = embeddings[:, :self.smiles_emb_dim] poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim] e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim] cell_emb = embeddings[:, -self.cell_emb_dim:] else: poi_emb = self.scale_tensor(poi_emb, self.scalers['Uniprot']) e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot']) cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier']) smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles']) if torch.isnan(poi_emb).any(): raise ValueError("NaN values found in POI embeddings.") if torch.isnan(e3_emb).any(): raise ValueError("NaN values found in E3 embeddings.") if torch.isnan(cell_emb).any(): raise ValueError("NaN values found in cell embeddings.") if torch.isnan(smiles_emb).any(): raise ValueError("NaN values found in SMILES embeddings.") return self.model(poi_emb, e3_emb, cell_emb, smiles_emb) def step(self, batch, batch_idx, stage): poi_emb = batch['poi_emb'] e3_emb = batch['e3_emb'] cell_emb = batch['cell_emb'] smiles_emb = batch['smiles_emb'] y = batch['active'].float().unsqueeze(1) y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) loss = F.binary_cross_entropy_with_logits(y_hat, y) self.metrics[f'{stage}_metrics'].update(y_hat, y) self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) return loss def training_step(self, batch, batch_idx): return self.step(batch, batch_idx, 'train') def validation_step(self, batch, batch_idx): return self.step(batch, batch_idx, 'val') def test_step(self, batch, batch_idx): return self.step(batch, batch_idx, 'test') def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) return { 'optimizer': optimizer, 'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode='min', factor=0.1, patience=0, ), 'interval': 'step', # or 'epoch' 'frequency': 1, 'monitor': 'val_loss', } def predict_step(self, batch, batch_idx): poi_emb = batch['poi_emb'] e3_emb = batch['e3_emb'] cell_emb = batch['cell_emb'] smiles_emb = batch['smiles_emb'] y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) return torch.sigmoid(y_hat) def train_dataloader(self): if self.train_dataset is None: format = 'train_dataset', self.__class__.__name__ raise ValueError(self.missing_dataset_error.format(*format)) return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, # drop_last=True, ) def val_dataloader(self): if self.val_dataset is None: format = 'val_dataset', self.__class__.__name__ raise ValueError(self.missing_dataset_error.format(*format)) return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, ) def test_dataloader(self): if self.test_dataset is None: format = 'test_dataset', self.__class__.__name__ raise ValueError(self.missing_dataset_error.format(*format)) return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, ) def on_save_checkpoint(self, checkpoint): """ Serialize the scalers to the checkpoint. """ checkpoint['scalers'] = pickle.dumps(self.scalers) def on_load_checkpoint(self, checkpoint): """Deserialize the scalers from the checkpoint.""" if 'scalers' in checkpoint: self.scalers = pickle.loads(checkpoint['scalers']) else: self.scalers = None if self.apply_scaling: if self.scalers is not None: # Re-apply scalers to ensure datasets are scaled self.apply_scalers() else: logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.") # TODO: Use some sort of **kwargs to pass all the parameters to the model... def train_model( protein2embedding: Dict, cell2embedding: Dict, smiles2fp: Dict, train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: Optional[pd.DataFrame] = None, hidden_dim: int = 768, batch_size: int = 128, learning_rate: float = 2e-5, dropout: float = 0.2, max_epochs: int = 50, use_batch_norm: bool = False, smiles_emb_dim: int = config.fingerprint_size, poi_emb_dim: int = config.protein_embedding_size, e3_emb_dim: int = config.protein_embedding_size, cell_emb_dim: int = config.cell_embedding_size, join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', smote_k_neighbors:int = 5, use_smote: bool = True, apply_scaling: bool = True, active_label: str = 'Active', fast_dev_run: bool = False, use_logger: bool = True, logger_save_dir: str = '../logs', logger_name: str = 'protac', enable_checkpointing: bool = False, checkpoint_model_name: str = 'protac', disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], return_predictions: bool = False, ) -> tuple: """ Train a PROTAC model using the given datasets and hyperparameters. Args: protein2embedding (dict): Dictionary of protein embeddings. cell2embedding (dict): Dictionary of cell line embeddings. smiles2fp (dict): Dictionary of SMILES to fingerprint. train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', . val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', . test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', . hidden_dim (int): The hidden dimension of the model. batch_size (int): The batch size. learning_rate (float): The learning rate. max_epochs (int): The maximum number of epochs. smiles_emb_dim (int): The dimension of the SMILES embeddings. smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler. fast_dev_run (bool): Whether to run a fast development run. disabled_embeddings (list): The list of disabled embeddings. return_predictions (bool): Whether to return the predictions after the model, trainer, and metrics. Returns: tuple: The trained model, the trainer, and the metrics over the validation and test sets. """ train_ds, val_ds, test_ds = get_datasets( train_df, val_df, test_df, protein2embedding, cell2embedding, smiles2fp, use_smote=use_smote, smote_k_neighbors=smote_k_neighbors, active_label=active_label, disabled_embeddings=disabled_embeddings, ) loggers = [ pl.loggers.TensorBoardLogger( save_dir=logger_save_dir, version=logger_name, name=logger_name, ), pl.loggers.CSVLogger( save_dir=logger_save_dir, version=logger_name, name=logger_name, ), ] callbacks = [ pl.callbacks.EarlyStopping( monitor='train_loss', patience=10, mode='min', verbose=False, ), pl.callbacks.EarlyStopping( monitor='train_acc', patience=10, mode='max', verbose=False, ), pl.callbacks.EarlyStopping( monitor='val_loss', patience=5, # Original: 5 mode='min', verbose=False, ), pl.callbacks.EarlyStopping( monitor='val_acc', patience=10, mode='max', verbose=False, ), ] if use_logger: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step')) if enable_checkpointing: callbacks.append(pl.callbacks.ModelCheckpoint( monitor='val_acc', mode='max', verbose=False, filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}', )) # Define Trainer trainer = pl.Trainer( logger=loggers if use_logger else False, callbacks=callbacks, max_epochs=max_epochs, # val_check_interval=0.5, fast_dev_run=fast_dev_run, enable_model_summary=False, enable_checkpointing=enable_checkpointing, enable_progress_bar=False, devices=1, num_nodes=1, ) model = PROTAC_Model( hidden_dim=hidden_dim, smiles_emb_dim=smiles_emb_dim, poi_emb_dim=poi_emb_dim, e3_emb_dim=e3_emb_dim, cell_emb_dim=cell_emb_dim, batch_size=batch_size, join_embeddings=join_embeddings, dropout=dropout, use_batch_norm=use_batch_norm, learning_rate=learning_rate, apply_scaling=apply_scaling, train_dataset=train_ds, val_dataset=val_ds, test_dataset=test_ds if test_df is not None else None, disabled_embeddings=disabled_embeddings, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") trainer.fit(model) metrics = {} # Add train metrics train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m} metrics.update(train_metrics) # Add validation metrics val_metrics = trainer.validate(model, verbose=False)[0] val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} metrics.update(val_metrics) # Add test metrics to metrics if test_df is not None: test_metrics = trainer.test(model, verbose=False)[0] test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} metrics.update(test_metrics) # Return predictions if return_predictions: val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) val_pred = trainer.predict(model, val_dl) val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze() if test_df is not None: test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze() return model, trainer, metrics, val_pred, test_pred return model, trainer, metrics, val_pred return model, trainer, metrics def evaluate_model( model: PROTAC_Model, trainer: pl.Trainer, val_ds: PROTAC_Dataset, test_ds: Optional[PROTAC_Dataset] = None, batch_size: int = 128, ) -> tuple: """ Evaluate a PROTAC model using the given datasets. """ ret = {} val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) val_metrics = trainer.validate(model, val_dl, verbose=False)[0] val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} # Get predictions on validation set val_pred = torch.cat(trainer.predict(model, val_dl)).squeeze() ret['val_metrics'] = val_metrics ret['val_pred'] = val_pred if test_ds is not None: test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) test_metrics = trainer.test(model, test_dl, verbose=False)[0] test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} # Get predictions on test set test_pred = torch.cat(trainer.predict(model, test_dl)).squeeze() ret['test_metrics'] = test_metrics ret['test_pred'] = test_pred return ret def load_model( ckpt_path: str, ) -> PROTAC_Model: """ Load a PROTAC model from a checkpoint. Args: ckpt_path (str): The path to the checkpoint. Returns: PROTAC_Model: The loaded model. """ # NOTE: The `map_locat` argument is automatically handled in newer versions # of PyTorch Lightning, but we keep it here for compatibility with older ones. model = PROTAC_Model.load_from_checkpoint( ckpt_path, map_location=torch.device('cpu') if not torch.cuda.is_available() else None, ) # NOTE: The following is left as example for eventually re-applying scaling # with other datasets... # if model.apply_scaling: # model.apply_scalers() return model.eval()