
Added ablation study with randomly sampled vectors + Started working on LightningDataModule wrapper
4e1d3f6
import warnings | |
import pickle | |
import logging | |
from typing import Literal, List, Tuple, Optional, Dict | |
from .protac_dataset import PROTAC_Dataset, get_datasets | |
from .config import config | |
import pandas as pd | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import pytorch_lightning as pl | |
from torch.utils.data import Dataset, DataLoader | |
from torchmetrics import ( | |
Accuracy, | |
AUROC, | |
Precision, | |
Recall, | |
F1Score, | |
MetricCollection, | |
) | |
from imblearn.over_sampling import SMOTE | |
from sklearn.preprocessing import StandardScaler | |
class PROTAC_Predictor(nn.Module): | |
def __init__( | |
self, | |
hidden_dim: int, | |
smiles_emb_dim: int = config.fingerprint_size, | |
poi_emb_dim: int = config.protein_embedding_size, | |
e3_emb_dim: int = config.protein_embedding_size, | |
cell_emb_dim: int = config.cell_embedding_size, | |
dropout: float = 0.2, | |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', | |
use_batch_norm: bool = False, | |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], | |
): | |
""" Initialize the PROTAC model. | |
Args: | |
hidden_dim (int): The hidden dimension of the model | |
smiles_emb_dim (int): The dimension of the SMILES embeddings | |
poi_emb_dim (int): The dimension of the POI embeddings | |
e3_emb_dim (int): The dimension of the E3 Ligase embeddings | |
cell_emb_dim (int): The dimension of the cell line embeddings | |
dropout (float): The dropout rate | |
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings | |
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' | |
""" | |
super().__init__() | |
self.poi_emb_dim = poi_emb_dim | |
self.e3_emb_dim = e3_emb_dim | |
self.cell_emb_dim = cell_emb_dim | |
self.smiles_emb_dim = smiles_emb_dim | |
self.hidden_dim = hidden_dim | |
self.join_embeddings = join_embeddings | |
self.use_batch_norm = use_batch_norm | |
self.disabled_embeddings = disabled_embeddings | |
# Set our init args as class attributes | |
self.__dict__.update(locals()) | |
# Define "surrogate models" branches | |
# NOTE: The softmax is used to ensure that the embeddings are normalized | |
# and can be summed on a "similar scale". | |
if self.join_embeddings != 'beginning': | |
if 'poi' not in self.disabled_embeddings: | |
self.poi_fc = nn.Sequential( | |
nn.Linear(poi_emb_dim, hidden_dim), | |
nn.Softmax(dim=1), | |
) | |
if 'e3' not in self.disabled_embeddings: | |
self.e3_fc = nn.Sequential( | |
nn.Linear(e3_emb_dim, hidden_dim), | |
nn.Softmax(dim=1), | |
) | |
if 'cell' not in self.disabled_embeddings: | |
self.cell_fc = nn.Sequential( | |
nn.Linear(cell_emb_dim, hidden_dim), | |
nn.Softmax(dim=1), | |
) | |
if 'smiles' not in self.disabled_embeddings: | |
self.smiles_emb = nn.Sequential( | |
nn.Linear(smiles_emb_dim, hidden_dim), | |
nn.Softmax(dim=1), | |
) | |
# Define hidden dimension for joining layer | |
if self.join_embeddings == 'beginning': | |
joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0 | |
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0 | |
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0 | |
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0 | |
self.fc0 = nn.Linear(joint_dim, joint_dim) | |
elif self.join_embeddings == 'concat': | |
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings)) | |
elif self.join_embeddings == 'sum': | |
joint_dim = hidden_dim | |
self.fc1 = nn.Linear(joint_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
self.fc3 = nn.Linear(hidden_dim, 1) | |
self.bnorm = nn.BatchNorm1d(hidden_dim) | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb): | |
embeddings = [] | |
if self.join_embeddings == 'beginning': | |
# TODO: Remove this if-branch | |
if 'poi' not in self.disabled_embeddings: | |
embeddings.append(poi_emb) | |
if 'e3' not in self.disabled_embeddings: | |
embeddings.append(e3_emb) | |
if 'cell' not in self.disabled_embeddings: | |
embeddings.append(cell_emb) | |
if 'smiles' not in self.disabled_embeddings: | |
embeddings.append(smiles_emb) | |
x = torch.cat(embeddings, dim=1) | |
x = self.dropout(F.relu(self.fc0(x))) | |
else: | |
if 'poi' not in self.disabled_embeddings: | |
embeddings.append(self.poi_fc(poi_emb)) | |
if 'e3' not in self.disabled_embeddings: | |
embeddings.append(self.e3_fc(e3_emb)) | |
if 'cell' not in self.disabled_embeddings: | |
embeddings.append(self.cell_fc(cell_emb)) | |
if 'smiles' not in self.disabled_embeddings: | |
embeddings.append(self.smiles_emb(smiles_emb)) | |
if self.join_embeddings == 'concat': | |
x = torch.cat(embeddings, dim=1) | |
elif self.join_embeddings == 'sum': | |
if len(embeddings) > 1: | |
embeddings = torch.stack(embeddings, dim=1) | |
x = torch.sum(embeddings, dim=1) | |
else: | |
x = embeddings[0] | |
x = F.relu(self.fc1(x)) | |
x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x) | |
x = self.fc3(x) | |
return x | |
class PROTAC_Model(pl.LightningModule): | |
def __init__( | |
self, | |
hidden_dim: int, | |
smiles_emb_dim: int = config.fingerprint_size, | |
poi_emb_dim: int = config.protein_embedding_size, | |
e3_emb_dim: int = config.protein_embedding_size, | |
cell_emb_dim: int = config.cell_embedding_size, | |
batch_size: int = 128, | |
learning_rate: float = 1e-3, | |
dropout: float = 0.2, | |
use_batch_norm: bool = False, | |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', | |
train_dataset: PROTAC_Dataset = None, | |
val_dataset: PROTAC_Dataset = None, | |
test_dataset: PROTAC_Dataset = None, | |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], | |
apply_scaling: bool = True, | |
): | |
""" Initialize the PROTAC Pytorch Lightning model. | |
Args: | |
hidden_dim (int): The hidden dimension of the model | |
smiles_emb_dim (int): The dimension of the SMILES embeddings | |
poi_emb_dim (int): The dimension of the POI embeddings | |
e3_emb_dim (int): The dimension of the E3 Ligase embeddings | |
cell_emb_dim (int): The dimension of the cell line embeddings | |
batch_size (int): The batch size | |
learning_rate (float): The learning rate | |
dropout (float): The dropout rate | |
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings | |
train_dataset (PROTAC_Dataset): The training dataset | |
val_dataset (PROTAC_Dataset): The validation dataset | |
test_dataset (PROTAC_Dataset): The test dataset | |
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' | |
apply_scaling (bool): Whether to apply scaling to the embeddings | |
""" | |
super().__init__() | |
self.poi_emb_dim = poi_emb_dim | |
self.e3_emb_dim = e3_emb_dim | |
self.cell_emb_dim = cell_emb_dim | |
self.smiles_emb_dim = smiles_emb_dim | |
self.hidden_dim = hidden_dim | |
self.batch_size = batch_size | |
self.learning_rate = learning_rate | |
self.join_embeddings = join_embeddings | |
self.train_dataset = train_dataset | |
self.val_dataset = val_dataset | |
self.test_dataset = test_dataset | |
self.disabled_embeddings = disabled_embeddings | |
self.apply_scaling = apply_scaling | |
# Set our init args as class attributes | |
self.__dict__.update(locals()) # Add arguments as attributes | |
# Save the arguments passed to init | |
ignore_args_as_hyperparams = [ | |
'train_dataset', | |
'test_dataset', | |
'val_dataset', | |
] | |
self.save_hyperparameters(ignore=ignore_args_as_hyperparams) | |
self.model = PROTAC_Predictor( | |
hidden_dim=hidden_dim, | |
smiles_emb_dim=smiles_emb_dim, | |
poi_emb_dim=poi_emb_dim, | |
e3_emb_dim=e3_emb_dim, | |
cell_emb_dim=cell_emb_dim, | |
dropout=dropout, | |
join_embeddings=join_embeddings, | |
use_batch_norm=use_batch_norm, | |
disabled_embeddings=[], # NOTE: This is handled in the PROTAC_Dataset classes | |
) | |
stages = ['train_metrics', 'val_metrics', 'test_metrics'] | |
self.metrics = nn.ModuleDict({s: MetricCollection({ | |
'acc': Accuracy(task='binary'), | |
'roc_auc': AUROC(task='binary'), | |
'precision': Precision(task='binary'), | |
'recall': Recall(task='binary'), | |
'f1_score': F1Score(task='binary'), | |
# 'opt_score': Accuracy(task='binary') + F1Score(task='binary'), | |
# 'hp_metric': Accuracy(task='binary'), | |
}, prefix=s.replace('metrics', '')) for s in stages}) | |
# Misc settings | |
self.missing_dataset_error = \ | |
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually: | |
model = {1}.load_from_checkpoint('checkpoint.ckpt') | |
model.{0} = my_{0} | |
''' | |
# Apply scaling in datasets | |
self.scalers = None | |
if self.apply_scaling and self.train_dataset is not None: | |
self.initialize_scalers() | |
def initialize_scalers(self): | |
"""Initialize or reinitialize scalers based on dataset properties.""" | |
if self.scalers is None: | |
use_single_scaler = self.join_embeddings == 'beginning' | |
self.scalers = self.train_dataset.fit_scaling(use_single_scaler) | |
self.apply_scalers() | |
def apply_scalers(self): | |
"""Apply scalers to all datasets.""" | |
use_single_scaler = self.join_embeddings == 'beginning' | |
if self.train_dataset: | |
self.train_dataset.apply_scaling(self.scalers, use_single_scaler) | |
if self.val_dataset: | |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler) | |
if self.test_dataset: | |
self.test_dataset.apply_scaling(self.scalers, use_single_scaler) | |
def scale_tensor( | |
self, | |
tensor: torch.Tensor, | |
scaler: StandardScaler, | |
) -> torch.Tensor: | |
"""Scale a tensor using a scaler. This is done to avoid using numpy | |
arrays (and stay on the same device). | |
Args: | |
tensor (torch.Tensor): The tensor to scale. | |
scaler (StandardScaler): The scaler to use. | |
Returns: | |
torch.Tensor: The scaled tensor. | |
""" | |
tensor = tensor.float() | |
if scaler.with_mean: | |
tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device) | |
if scaler.with_std: | |
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) | |
return tensor | |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True): | |
if not prescaled_embeddings: | |
if self.apply_scaling: | |
if self.join_embeddings == 'beginning': | |
embeddings = self.scale_tensor( | |
torch.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]), | |
self.scalers, | |
) | |
smiles_emb = embeddings[:, :self.smiles_emb_dim] | |
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim] | |
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim] | |
cell_emb = embeddings[:, -self.cell_emb_dim:] | |
else: | |
poi_emb = self.scale_tensor(poi_emb, self.scalers['Uniprot']) | |
e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot']) | |
cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier']) | |
smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles']) | |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb) | |
def step(self, batch, batch_idx, stage): | |
poi_emb = batch['poi_emb'] | |
e3_emb = batch['e3_emb'] | |
cell_emb = batch['cell_emb'] | |
smiles_emb = batch['smiles_emb'] | |
y = batch['active'].float().unsqueeze(1) | |
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) | |
loss = F.binary_cross_entropy_with_logits(y_hat, y) | |
self.metrics[f'{stage}_metrics'].update(y_hat, y) | |
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) | |
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) | |
return loss | |
def training_step(self, batch, batch_idx): | |
return self.step(batch, batch_idx, 'train') | |
def validation_step(self, batch, batch_idx): | |
return self.step(batch, batch_idx, 'val') | |
def test_step(self, batch, batch_idx): | |
return self.step(batch, batch_idx, 'test') | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer=optimizer, | |
mode='min', | |
factor=0.1, | |
patience=0, | |
), | |
'interval': 'step', # or 'epoch' | |
'frequency': 1, | |
'monitor': 'val_loss', | |
} | |
def predict_step(self, batch, batch_idx): | |
poi_emb = batch['poi_emb'] | |
e3_emb = batch['e3_emb'] | |
cell_emb = batch['cell_emb'] | |
smiles_emb = batch['smiles_emb'] | |
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) | |
return torch.sigmoid(y_hat) | |
def train_dataloader(self): | |
if self.train_dataset is None: | |
format = 'train_dataset', self.__class__.__name__ | |
raise ValueError(self.missing_dataset_error.format(*format)) | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
# drop_last=True, | |
) | |
def val_dataloader(self): | |
if self.val_dataset is None: | |
format = 'val_dataset', self.__class__.__name__ | |
raise ValueError(self.missing_dataset_error.format(*format)) | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
) | |
def test_dataloader(self): | |
if self.test_dataset is None: | |
format = 'test_dataset', self.__class__.__name__ | |
raise ValueError(self.missing_dataset_error.format(*format)) | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
) | |
def on_save_checkpoint(self, checkpoint): | |
""" Serialize the scalers to the checkpoint. """ | |
checkpoint['scalers'] = pickle.dumps(self.scalers) | |
def on_load_checkpoint(self, checkpoint): | |
"""Deserialize the scalers from the checkpoint.""" | |
if 'scalers' in checkpoint: | |
self.scalers = pickle.loads(checkpoint['scalers']) | |
else: | |
self.scalers = None | |
if self.apply_scaling: | |
if self.scalers is not None: | |
# Re-apply scalers to ensure datasets are scaled | |
self.apply_scalers() | |
else: | |
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.") | |
# TODO: Use some sort of **kwargs to pass all the parameters to the model... | |
def train_model( | |
protein2embedding: Dict, | |
cell2embedding: Dict, | |
smiles2fp: Dict, | |
train_df: pd.DataFrame, | |
val_df: pd.DataFrame, | |
test_df: Optional[pd.DataFrame] = None, | |
hidden_dim: int = 768, | |
batch_size: int = 128, | |
learning_rate: float = 2e-5, | |
dropout: float = 0.2, | |
max_epochs: int = 50, | |
use_batch_norm: bool = False, | |
smiles_emb_dim: int = config.fingerprint_size, | |
poi_emb_dim: int = config.protein_embedding_size, | |
e3_emb_dim: int = config.protein_embedding_size, | |
cell_emb_dim: int = config.cell_embedding_size, | |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', | |
smote_k_neighbors:int = 5, | |
use_smote: bool = True, | |
apply_scaling: bool = True, | |
active_label: str = 'Active', | |
fast_dev_run: bool = False, | |
use_logger: bool = True, | |
logger_save_dir: str = '../logs', | |
logger_name: str = 'protac', | |
enable_checkpointing: bool = False, | |
checkpoint_model_name: str = 'protac', | |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], | |
return_predictions: bool = False, | |
) -> tuple: | |
""" Train a PROTAC model using the given datasets and hyperparameters. | |
Args: | |
protein2embedding (dict): Dictionary of protein embeddings. | |
cell2embedding (dict): Dictionary of cell line embeddings. | |
smiles2fp (dict): Dictionary of SMILES to fingerprint. | |
train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>. | |
val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>. | |
test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>. | |
hidden_dim (int): The hidden dimension of the model. | |
batch_size (int): The batch size. | |
learning_rate (float): The learning rate. | |
max_epochs (int): The maximum number of epochs. | |
smiles_emb_dim (int): The dimension of the SMILES embeddings. | |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler. | |
fast_dev_run (bool): Whether to run a fast development run. | |
disabled_embeddings (list): The list of disabled embeddings. | |
Returns: | |
tuple: The trained model, the trainer, and the metrics. | |
""" | |
train_ds, val_ds, test_ds = get_datasets( | |
train_df, | |
val_df, | |
test_df, | |
protein2embedding, | |
cell2embedding, | |
smiles2fp, | |
use_smote=use_smote, | |
smote_k_neighbors=smote_k_neighbors, | |
active_label=active_label, | |
disabled_embeddings=disabled_embeddings, | |
) | |
loggers = [ | |
pl.loggers.TensorBoardLogger( | |
save_dir=logger_save_dir, | |
version=logger_name, | |
name=logger_name, | |
), | |
pl.loggers.CSVLogger( | |
save_dir=logger_save_dir, | |
version=logger_name, | |
name=logger_name, | |
), | |
] | |
callbacks = [ | |
pl.callbacks.EarlyStopping( | |
monitor='train_loss', | |
patience=10, | |
mode='min', | |
verbose=False, | |
), | |
pl.callbacks.EarlyStopping( | |
monitor='train_acc', | |
patience=10, | |
mode='max', | |
verbose=False, | |
), | |
pl.callbacks.EarlyStopping( | |
monitor='val_loss', | |
patience=5, # Original: 5 | |
mode='min', | |
verbose=False, | |
), | |
pl.callbacks.EarlyStopping( | |
monitor='val_acc', | |
patience=10, | |
mode='max', | |
verbose=False, | |
), | |
] | |
if use_logger: | |
callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step')) | |
if enable_checkpointing: | |
callbacks.append(pl.callbacks.ModelCheckpoint( | |
monitor='val_acc', | |
mode='max', | |
verbose=False, | |
filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}', | |
)) | |
# Define Trainer | |
trainer = pl.Trainer( | |
logger=loggers if use_logger else False, | |
callbacks=callbacks, | |
max_epochs=max_epochs, | |
# val_check_interval=0.5, | |
fast_dev_run=fast_dev_run, | |
enable_model_summary=False, | |
enable_checkpointing=enable_checkpointing, | |
enable_progress_bar=False, | |
devices=1, | |
num_nodes=1, | |
) | |
model = PROTAC_Model( | |
hidden_dim=hidden_dim, | |
smiles_emb_dim=smiles_emb_dim, | |
poi_emb_dim=poi_emb_dim, | |
e3_emb_dim=e3_emb_dim, | |
cell_emb_dim=cell_emb_dim, | |
batch_size=batch_size, | |
join_embeddings=join_embeddings, | |
dropout=dropout, | |
use_batch_norm=use_batch_norm, | |
learning_rate=learning_rate, | |
apply_scaling=apply_scaling, | |
train_dataset=train_ds, | |
val_dataset=val_ds, | |
test_dataset=test_ds if test_df is not None else None, | |
disabled_embeddings=disabled_embeddings, | |
) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
trainer.fit(model) | |
metrics = {} | |
# Add train metrics | |
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m} | |
metrics.update(train_metrics) | |
# Add validation metrics | |
val_metrics = trainer.validate(model, verbose=False)[0] | |
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} | |
metrics.update(val_metrics) | |
# Add test metrics to metrics | |
if test_df is not None: | |
test_metrics = trainer.test(model, verbose=False)[0] | |
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} | |
metrics.update(test_metrics) | |
# Return predictions | |
if return_predictions: | |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) | |
val_pred = trainer.predict(model, val_dl) | |
val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze() | |
if test_df is not None: | |
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) | |
test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze() | |
return model, trainer, metrics, val_pred, test_pred | |
return model, trainer, metrics, val_pred | |
return model, trainer, metrics | |
def evaluate_model( | |
model: PROTAC_Model, | |
trainer: pl.Trainer, | |
val_ds: PROTAC_Dataset, | |
test_ds: Optional[PROTAC_Dataset] = None, | |
batch_size: int = 128, | |
) -> tuple: | |
""" Evaluate a PROTAC model using the given datasets. """ | |
ret = {} | |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) | |
val_metrics = trainer.validate(model, val_dl, verbose=False)[0] | |
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} | |
# Get predictions on validation set | |
val_pred = torch.cat(trainer.predict(model, val_dl)).squeeze() | |
ret['val_metrics'] = val_metrics | |
ret['val_pred'] = val_pred | |
if test_ds is not None: | |
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) | |
test_metrics = trainer.test(model, test_dl, verbose=False)[0] | |
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} | |
# Get predictions on test set | |
test_pred = torch.cat(trainer.predict(model, test_dl)).squeeze() | |
ret['test_metrics'] = test_metrics | |
ret['test_pred'] = test_pred | |
return ret | |
def load_model( | |
ckpt_path: str, | |
) -> PROTAC_Model: | |
""" Load a PROTAC model from a checkpoint. | |
Args: | |
ckpt_path (str): The path to the checkpoint. | |
Returns: | |
PROTAC_Model: The loaded model. | |
""" | |
# NOTE: The `map_locat` argument is automatically handled in newer versions | |
# of PyTorch Lightning, but we keep it here for compatibility with older ones. | |
model = PROTAC_Model.load_from_checkpoint( | |
ckpt_path, | |
map_location=torch.device('cpu') if not torch.cuda.is_available() else None, | |
) | |
# NOTE: The following is left as example for eventually re-applying scaling | |
# with other datasets... | |
# if model.apply_scaling: | |
# model.apply_scalers() | |
model.eval() | |
return model |