Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# CellViT Inference Method for Patch-Wise Inference on a test set | |
# Without merging WSI | |
# | |
# Aim is to calculate metrics as defined for the PanNuke dataset | |
# | |
# @ Fabian Hörst, [email protected] | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import argparse | |
import inspect | |
import os | |
import sys | |
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
parentdir = os.path.dirname(currentdir) | |
sys.path.insert(0, parentdir) | |
parentdir = os.path.dirname(parentdir) | |
sys.path.insert(0, parentdir) | |
from base_ml.base_experiment import BaseExperiment | |
BaseExperiment.seed_run(1232) | |
import json | |
from pathlib import Path | |
from typing import List, Tuple, Union | |
import albumentations as A | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import tqdm | |
import yaml | |
from matplotlib import pyplot as plt | |
from PIL import Image, ImageDraw | |
from skimage.color import rgba2rgb | |
from sklearn.metrics import accuracy_score | |
from tabulate import tabulate | |
from torch.utils.data import DataLoader | |
from torchmetrics.functional import dice | |
from torchmetrics.functional.classification import binary_jaccard_index | |
from torchvision import transforms | |
from cell_segmentation.datasets.dataset_coordinator import select_dataset | |
from models.segmentation.cell_segmentation.cellvit import DataclassHVStorage | |
from cell_segmentation.utils.metrics import ( | |
cell_detection_scores, | |
cell_type_detection_scores, | |
get_fast_pq, | |
remap_label, | |
binarize, | |
) | |
from cell_segmentation.utils.post_proc_cellvit import calculate_instances | |
from cell_segmentation.utils.tools import cropping_center, pair_coordinates | |
from models.segmentation.cell_segmentation.cellvit import CellViT | |
from utils.logger import Logger | |
class InferenceCellViT: | |
def __init__( | |
self, | |
run_dir: Union[Path, str], | |
gpu: int, | |
magnification: int = 40, | |
checkpoint_name: str = "model_best.pth", | |
) -> None: | |
"""Inference for HoverNet | |
Args: | |
run_dir (Union[Path, str]): logging directory with checkpoints and configs | |
gpu (int): CUDA GPU device to use for inference | |
magnification (int, optional): Dataset magnification. Defaults to 40. | |
checkpoint_name (str, optional): Select name of the model to load. Defaults to model_best.pth | |
""" | |
self.run_dir = Path(run_dir) | |
self.device = "cpu" | |
self.run_conf: dict = None | |
self.logger: Logger = None | |
self.magnification = magnification | |
self.checkpoint_name = checkpoint_name | |
self.__load_run_conf() | |
# self.__instantiate_logger() | |
self.__setup_amp() | |
self.num_classes = self.run_conf["data"]["num_nuclei_classes"] | |
def __load_run_conf(self) -> None: | |
"""Load the config.yaml file with the run setup | |
Be careful with loading and usage, since original None values in the run configuration are not stored when dumped to yaml file. | |
If you want to check if a key is not defined, first check if the key does exists in the dict. | |
""" | |
with open((self.run_dir / "config.yaml").resolve(), "r") as run_config_file: | |
yaml_config = yaml.safe_load(run_config_file) | |
self.run_conf = dict(yaml_config) | |
def __load_dataset_setup(self, dataset_path: Union[Path, str]) -> None: | |
"""Load the configuration of the cell segmentation dataset. | |
The dataset must have a dataset_config.yaml file in their dataset path with the following entries: | |
* tissue_types: describing the present tissue types with corresponding integer | |
* nuclei_types: describing the present nuclei types with corresponding integer | |
Args: | |
dataset_path (Union[Path, str]): Path to dataset folder | |
""" | |
dataset_config_path = Path(dataset_path) / "dataset_config.yaml" | |
with open(dataset_config_path, "r") as dataset_config_file: | |
yaml_config = yaml.safe_load(dataset_config_file) | |
self.dataset_config = dict(yaml_config) | |
def __instantiate_logger(self) -> None: | |
"""Instantiate logger | |
Logger is using no formatters. Logs are stored in the run directory under the filename: inference.log | |
""" | |
logger = Logger( | |
level=self.run_conf["logging"]["level"].upper(), | |
log_dir=Path(self.run_dir).resolve(), | |
comment="inference", | |
use_timestamp=False, | |
formatter="%(message)s", | |
) | |
self.logger = logger.create_logger() | |
def __check_eval_model(self) -> None: | |
"""Check if there is a best model pytorch file""" | |
assert (self.run_dir / "checkpoints" / self.checkpoint_name).is_file() | |
def __setup_amp(self) -> None: | |
"""Setup automated mixed precision (amp) for inference.""" | |
self.mixed_precision = self.run_conf["training"].get("mixed_precision", False) | |
def get_model( | |
self, model_type: str | |
) -> CellViT: | |
"""Return the trained model for inference | |
Args: | |
model_type (str): Name of the model. Must either be one of: | |
CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared | |
Returns: | |
Union[CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared]: Model | |
""" | |
implemented_models = [ | |
"CellViT", | |
] | |
if model_type not in implemented_models: | |
raise NotImplementedError( | |
f"Unknown model type. Please select one of {implemented_models}" | |
) | |
if model_type in ["CellViT", "CellViTShared"]: | |
if model_type == "CellViT": | |
model_class = CellViT | |
model = model_class( | |
model256_path=self.run_conf["model"].get("pretrained_encoder"), | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
#embed_dim=self.run_conf["model"]["embed_dim"], | |
in_channels=self.run_conf["model"].get("input_chanels", 3), | |
#embed_dim=self.run_conf["model"]["embed_dim"], | |
#input_channels=self.run_conf["model"].get("input_channels", 3), | |
#depth=self.run_conf["model"]["depth"], | |
#num_heads=self.run_conf["model"]["num_heads"], | |
#extract_layers=self.run_conf["model"]["extract_layers"], | |
#regression_loss=self.run_conf["model"].get("regression_loss", False), | |
) | |
return model | |
def setup_patch_inference( | |
self, test_folds: List[int] = None | |
) -> Tuple[ | |
CellViT, | |
DataLoader, | |
dict, | |
]: | |
"""Setup patch inference by defining a patch-wise datalaoder and loading the model checkpoint | |
Args: | |
test_folds (List[int], optional): Test fold to use. Otherwise defined folds from config.yaml (in run_dir) are loaded. Defaults to None. | |
Returns: | |
tuple[Union[CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared], DataLoader, dict]: | |
Union[CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared]: Best model loaded form checkpoint | |
DataLoader: Inference DataLoader | |
dict: Dataset configuration. Keys are: | |
* "tissue_types": describing the present tissue types with corresponding integer | |
* "nuclei_types": describing the present nuclei types with corresponding integer | |
""" | |
# get dataset | |
if test_folds is None: | |
if "test_folds" in self.run_conf["data"]: | |
if self.run_conf["data"]["test_folds"] is None: | |
self.logger.info( | |
"There was no test set provided. We now use the validation dataset for testing" | |
) | |
self.run_conf["data"]["test_folds"] = self.run_conf["data"][ | |
"val_folds" | |
] | |
else: | |
self.logger.info( | |
"There was no test set provided. We now use the validation dataset for testing" | |
) | |
self.run_conf["data"]["test_folds"] = self.run_conf["data"]["val_folds"] | |
else: | |
self.run_conf["data"]["test_folds"] = self.run_conf["data"]["val_folds"] | |
self.logger.info( | |
f"Performing Inference on test set: {self.run_conf['data']['test_folds']}" | |
) | |
inference_dataset = select_dataset( | |
dataset_name=self.run_conf["data"]["dataset"], | |
split="test", | |
dataset_config=self.run_conf["data"], | |
transforms=transforms, | |
) | |
inference_dataloader = DataLoader( | |
inference_dataset, | |
batch_size=1, | |
num_workers=12, | |
pin_memory=False, | |
shuffle=False, | |
) | |
return inference_dataloader, self.dataset_config | |
def run_patch_inference( | |
self, | |
model: CellViT, | |
inference_dataloader: DataLoader, | |
dataset_config: dict, | |
generate_plots: bool = False, | |
) -> None: | |
"""Run Patch inference with given setup | |
Args: | |
model (Union[CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared]): Model to use for inference | |
inference_dataloader (DataLoader): Inference Dataloader. Must return a batch with the following structure: | |
* Images (torch.Tensor) | |
* Masks (dict) | |
* Tissue types as str | |
* Image name as str | |
dataset_config (dict): Dataset configuration. Required keys are: | |
* "tissue_types": describing the present tissue types with corresponding integer | |
* "nuclei_types": describing the present nuclei types with corresponding integer | |
generate_plots (bool, optional): If inference plots should be generated. Defaults to False. | |
""" | |
# put model in eval mode | |
model.to(device=self.device) | |
model.eval() | |
# setup score tracker | |
image_names = [] # image names as str | |
binary_dice_scores = [] # binary dice scores per image | |
binary_jaccard_scores = [] # binary jaccard scores per image | |
pq_scores = [] # pq-scores per image | |
dq_scores = [] # dq-scores per image | |
sq_scores = [] # sq-scores per image | |
cell_type_pq_scores = [] # pq-scores per cell type and image | |
cell_type_dq_scores = [] # dq-scores per cell type and image | |
cell_type_sq_scores = [] # sq-scores per cell type and image | |
tissue_pred = [] # tissue predictions for each image | |
tissue_gt = [] # ground truth tissue image class | |
tissue_types_inf = [] # string repr of ground truth tissue image class | |
paired_all_global = [] # unique matched index pair | |
unpaired_true_all_global = ( | |
[] | |
) # the index must exist in `true_inst_type_all` and unique | |
unpaired_pred_all_global = ( | |
[] | |
) # the index must exist in `pred_inst_type_all` and unique | |
true_inst_type_all_global = [] # each index is 1 independent data point | |
pred_inst_type_all_global = [] # each index is 1 independent data point | |
# for detections scores | |
true_idx_offset = 0 | |
pred_idx_offset = 0 | |
inference_loop = tqdm.tqdm( | |
enumerate(inference_dataloader), total=len(inference_dataloader) | |
) | |
with torch.no_grad(): | |
for batch_idx, batch in inference_loop: | |
batch_metrics = self.inference_step( | |
model, batch, generate_plots=generate_plots | |
) | |
# unpack batch_metrics | |
image_names = image_names + batch_metrics["image_names"] | |
# dice scores | |
binary_dice_scores = ( | |
binary_dice_scores + batch_metrics["binary_dice_scores"] | |
) | |
binary_jaccard_scores = ( | |
binary_jaccard_scores + batch_metrics["binary_jaccard_scores"] | |
) | |
# pq scores | |
pq_scores = pq_scores + batch_metrics["pq_scores"] | |
dq_scores = dq_scores + batch_metrics["dq_scores"] | |
sq_scores = sq_scores + batch_metrics["sq_scores"] | |
tissue_types_inf = tissue_types_inf + batch_metrics["tissue_types"] | |
cell_type_pq_scores = ( | |
cell_type_pq_scores + batch_metrics["cell_type_pq_scores"] | |
) | |
cell_type_dq_scores = ( | |
cell_type_dq_scores + batch_metrics["cell_type_dq_scores"] | |
) | |
cell_type_sq_scores = ( | |
cell_type_sq_scores + batch_metrics["cell_type_sq_scores"] | |
) | |
tissue_pred.append(batch_metrics["tissue_pred"]) | |
tissue_gt.append(batch_metrics["tissue_gt"]) | |
# detection scores | |
true_idx_offset = ( | |
true_idx_offset + true_inst_type_all_global[-1].shape[0] | |
if batch_idx != 0 | |
else 0 | |
) | |
pred_idx_offset = ( | |
pred_idx_offset + pred_inst_type_all_global[-1].shape[0] | |
if batch_idx != 0 | |
else 0 | |
) | |
true_inst_type_all_global.append(batch_metrics["true_inst_type_all"]) | |
pred_inst_type_all_global.append(batch_metrics["pred_inst_type_all"]) | |
# increment the pairing index statistic | |
batch_metrics["paired_all"][:, 0] += true_idx_offset | |
batch_metrics["paired_all"][:, 1] += pred_idx_offset | |
paired_all_global.append(batch_metrics["paired_all"]) | |
batch_metrics["unpaired_true_all"] += true_idx_offset | |
batch_metrics["unpaired_pred_all"] += pred_idx_offset | |
unpaired_true_all_global.append(batch_metrics["unpaired_true_all"]) | |
unpaired_pred_all_global.append(batch_metrics["unpaired_pred_all"]) | |
# assemble batches to datasets (global) | |
tissue_types_inf = [t.lower() for t in tissue_types_inf] | |
paired_all = np.concatenate(paired_all_global, axis=0) | |
unpaired_true_all = np.concatenate(unpaired_true_all_global, axis=0) | |
unpaired_pred_all = np.concatenate(unpaired_pred_all_global, axis=0) | |
true_inst_type_all = np.concatenate(true_inst_type_all_global, axis=0) | |
pred_inst_type_all = np.concatenate(pred_inst_type_all_global, axis=0) | |
paired_true_type = true_inst_type_all[paired_all[:, 0]] | |
paired_pred_type = pred_inst_type_all[paired_all[:, 1]] | |
unpaired_true_type = true_inst_type_all[unpaired_true_all] | |
unpaired_pred_type = pred_inst_type_all[unpaired_pred_all] | |
binary_dice_scores = np.array(binary_dice_scores) | |
binary_jaccard_scores = np.array(binary_jaccard_scores) | |
pq_scores = np.array(pq_scores) | |
dq_scores = np.array(dq_scores) | |
sq_scores = np.array(sq_scores) | |
tissue_detection_accuracy = accuracy_score( | |
y_true=np.concatenate(tissue_gt), y_pred=np.concatenate(tissue_pred) | |
) | |
f1_d, prec_d, rec_d = cell_detection_scores( | |
paired_true=paired_true_type, | |
paired_pred=paired_pred_type, | |
unpaired_true=unpaired_true_type, | |
unpaired_pred=unpaired_pred_type, | |
) | |
dataset_metrics = { | |
"Binary-Cell-Dice-Mean": float(np.nanmean(binary_dice_scores)), | |
"Binary-Cell-Jacard-Mean": float(np.nanmean(binary_jaccard_scores)), | |
"Tissue-Multiclass-Accuracy": tissue_detection_accuracy, | |
"bPQ": float(np.nanmean(pq_scores)), | |
"bDQ": float(np.nanmean(dq_scores)), | |
"bSQ": float(np.nanmean(sq_scores)), | |
"mPQ": float(np.nanmean([np.nanmean(pq) for pq in cell_type_pq_scores])), | |
"mDQ": float(np.nanmean([np.nanmean(dq) for dq in cell_type_dq_scores])), | |
"mSQ": float(np.nanmean([np.nanmean(sq) for sq in cell_type_sq_scores])), | |
"f1_detection": float(f1_d), | |
"precision_detection": float(prec_d), | |
"recall_detection": float(rec_d), | |
} | |
# calculate tissue metrics | |
tissue_types = dataset_config["tissue_types"] | |
tissue_metrics = {} | |
for tissue in tissue_types.keys(): | |
tissue = tissue.lower() | |
tissue_ids = np.where(np.asarray(tissue_types_inf) == tissue) | |
tissue_metrics[f"{tissue}"] = {} | |
tissue_metrics[f"{tissue}"]["Dice"] = float( | |
np.nanmean(binary_dice_scores[tissue_ids]) | |
) | |
tissue_metrics[f"{tissue}"]["Jaccard"] = float( | |
np.nanmean(binary_jaccard_scores[tissue_ids]) | |
) | |
tissue_metrics[f"{tissue}"]["mPQ"] = float( | |
np.nanmean( | |
[np.nanmean(pq) for pq in np.array(cell_type_pq_scores)[tissue_ids]] | |
) | |
) | |
tissue_metrics[f"{tissue}"]["bPQ"] = float( | |
np.nanmean(pq_scores[tissue_ids]) | |
) | |
# calculate nuclei metrics | |
nuclei_types = dataset_config["nuclei_types"] | |
nuclei_metrics_d = {} | |
nuclei_metrics_pq = {} | |
nuclei_metrics_dq = {} | |
nuclei_metrics_sq = {} | |
for nuc_name, nuc_type in nuclei_types.items(): | |
if nuc_name.lower() == "background": | |
continue | |
nuclei_metrics_pq[nuc_name] = np.nanmean( | |
[pq[nuc_type] for pq in cell_type_pq_scores] | |
) | |
nuclei_metrics_dq[nuc_name] = np.nanmean( | |
[dq[nuc_type] for dq in cell_type_dq_scores] | |
) | |
nuclei_metrics_sq[nuc_name] = np.nanmean( | |
[sq[nuc_type] for sq in cell_type_sq_scores] | |
) | |
f1_cell, prec_cell, rec_cell = cell_type_detection_scores( | |
paired_true_type, | |
paired_pred_type, | |
unpaired_true_type, | |
unpaired_pred_type, | |
nuc_type, | |
) | |
nuclei_metrics_d[nuc_name] = { | |
"f1_cell": f1_cell, | |
"prec_cell": prec_cell, | |
"rec_cell": rec_cell, | |
} | |
# print final results | |
# binary | |
self.logger.info(f"{20*'*'} Binary Dataset metrics {20*'*'}") | |
[self.logger.info(f"{f'{k}:': <25} {v}") for k, v in dataset_metrics.items()] | |
# tissue -> the PQ values are bPQ values -> what about mBQ? | |
self.logger.info(f"{20*'*'} Tissue metrics {20*'*'}") | |
flattened_tissue = [] | |
for key in tissue_metrics: | |
flattened_tissue.append( | |
[ | |
key, | |
tissue_metrics[key]["Dice"], | |
tissue_metrics[key]["Jaccard"], | |
tissue_metrics[key]["mPQ"], | |
tissue_metrics[key]["bPQ"], | |
] | |
) | |
self.logger.info( | |
tabulate( | |
flattened_tissue, headers=["Tissue", "Dice", "Jaccard", "mPQ", "bPQ"] | |
) | |
) | |
# nuclei types | |
self.logger.info(f"{20*'*'} Nuclei Type Metrics {20*'*'}") | |
flattened_nuclei_type = [] | |
for key in nuclei_metrics_pq: | |
flattened_nuclei_type.append( | |
[ | |
key, | |
nuclei_metrics_dq[key], | |
nuclei_metrics_sq[key], | |
nuclei_metrics_pq[key], | |
] | |
) | |
self.logger.info( | |
tabulate(flattened_nuclei_type, headers=["Nuclei Type", "DQ", "SQ", "PQ"]) | |
) | |
# nuclei detection metrics | |
self.logger.info(f"{20*'*'} Nuclei Detection Metrics {20*'*'}") | |
flattened_detection = [] | |
for key in nuclei_metrics_d: | |
flattened_detection.append( | |
[ | |
key, | |
nuclei_metrics_d[key]["prec_cell"], | |
nuclei_metrics_d[key]["rec_cell"], | |
nuclei_metrics_d[key]["f1_cell"], | |
] | |
) | |
self.logger.info( | |
tabulate( | |
flattened_detection, | |
headers=["Nuclei Type", "Precision", "Recall", "F1"], | |
) | |
) | |
# save all folds | |
image_metrics = {} | |
for idx, image_name in enumerate(image_names): | |
image_metrics[image_name] = { | |
"Dice": float(binary_dice_scores[idx]), | |
"Jaccard": float(binary_jaccard_scores[idx]), | |
"bPQ": float(pq_scores[idx]), | |
} | |
all_metrics = { | |
"dataset": dataset_metrics, | |
"tissue_metrics": tissue_metrics, | |
"image_metrics": image_metrics, | |
"nuclei_metrics_pq": nuclei_metrics_pq, | |
"nuclei_metrics_d": nuclei_metrics_d, | |
} | |
# saving | |
with open(str(self.run_dir / "inference_results.json"), "w") as outfile: | |
json.dump(all_metrics, outfile, indent=2) | |
def inference_step( | |
self, | |
model: CellViT, | |
batch: tuple, | |
generate_plots: bool = False, | |
) -> None: | |
"""Inference step for a patch-wise batch | |
Args: | |
model (CellViT): Model to use for inference | |
batch (tuple): Batch with the following structure: | |
* Images (torch.Tensor) | |
* Masks (dict) | |
* Tissue types as str | |
* Image name as str | |
generate_plots (bool, optional): If inference plots should be generated. Defaults to False. | |
""" | |
# unpack batch, for shape compare train_step method | |
imgs = batch[0].to(self.device) | |
masks = batch[1] | |
tissue_types = list(batch[2]) | |
image_names = list(batch[3]) | |
model.zero_grad() | |
if self.mixed_precision: | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
predictions = model.forward(imgs) | |
else: | |
predictions = model.forward(imgs) | |
predictions = self.unpack_predictions(predictions=predictions, model=model) | |
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types, model=model) | |
# scores | |
batch_metrics, scores = self.calculate_step_metric(predictions, gt, image_names) | |
batch_metrics["tissue_types"] = tissue_types | |
if generate_plots: | |
self.plot_results( | |
imgs=imgs, | |
predictions=predictions, | |
ground_truth=gt, | |
img_names=image_names, | |
num_nuclei_classes=self.num_classes, | |
outdir=Path(self.run_dir / "inference_predictions"), | |
scores=scores, | |
) | |
return batch_metrics | |
def run_single_image_inference( self, model: CellViT, image: np.ndarray, generate_plots: bool = True, ) -> None: | |
# set image transforms | |
transform_settings = self.run_conf["transformations"] | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
transforms = A.Compose([A.Normalize(mean=mean, std=std)]) | |
transformed_img = transforms(image=image)["image"] | |
image = torch.from_numpy(transformed_img).permute(2, 0, 1).unsqueeze(0).float() | |
imgs = image.to(self.device) | |
model.zero_grad() | |
predictions = model.forward(imgs) | |
predictions = self.unpack_predictions(predictions=predictions, model=model) | |
image_output = self.plot_results( | |
imgs=imgs, | |
predictions=predictions, | |
num_nuclei_classes=self.num_classes, | |
outdir=Path(self.run_dir), | |
) | |
return image_output | |
def unpack_predictions( | |
self, predictions: dict, model: CellViT | |
) -> DataclassHVStorage: | |
"""Unpack the given predictions. Main focus lays on reshaping and postprocessing predictions, e.g. separating instances | |
Args: | |
predictions (dict): Dictionary with the following keys: | |
* tissue_types: Logit tissue prediction output. Shape: (batch_size, num_tissue_classes) | |
* nuclei_binary_map: Logit output for binary nuclei prediction branch. Shape: (batch_size, H, W, 2) | |
* hv_map: Logit output for hv-prediction. Shape: (batch_size, H, W, 2) | |
* nuclei_type_map: Logit output for nuclei instance-prediction. Shape: (batch_size, num_nuclei_classes, H, W) | |
model (CellViT): Current model | |
Returns: | |
DataclassHVStorage: Processed network output | |
""" | |
predictions["tissue_types"] = predictions["tissue_types"].to(self.device) | |
predictions["nuclei_binary_map"] = F.softmax( | |
predictions["nuclei_binary_map"], dim=1 | |
) # shape: (batch_size, 2, H, W) | |
predictions["nuclei_type_map"] = F.softmax( | |
predictions["nuclei_type_map"], dim=1 | |
) # shape: (batch_size, num_nuclei_classes, H, W) | |
( | |
predictions["instance_map"], | |
predictions["instance_types"], | |
) = model.calculate_instance_map( | |
predictions, magnification=self.magnification | |
) # shape: (batch_size, H', W') | |
predictions["instance_types_nuclei"] = model.generate_instance_nuclei_map( | |
predictions["instance_map"], predictions["instance_types"] | |
).permute(0, 3, 1, 2).to( | |
self.device | |
) # shape: (batch_size, num_nuclei_classes, H, W) change | |
predictions = DataclassHVStorage( | |
nuclei_binary_map=predictions["nuclei_binary_map"], #[64, 2, 256, 256] | |
hv_map=predictions["hv_map"], #[64, 2, 256, 256] | |
nuclei_type_map=predictions["nuclei_type_map"], #[64, 6, 256, 256] | |
tissue_types=predictions["tissue_types"], #[64,19] | |
instance_map=predictions["instance_map"], #[64, 256, 256] | |
instance_types=predictions["instance_types"], #list of 64 tensors, each tensor is [256,256] | |
instance_types_nuclei=predictions["instance_types_nuclei"], #[64,256,256,6] | |
batch_size=predictions["tissue_types"].shape[0],#64 | |
) | |
return predictions | |
def unpack_masks( | |
self, masks: dict, tissue_types: list, model: CellViT | |
) -> DataclassHVStorage: | |
# get ground truth values, perform one hot encoding for segmentation maps | |
gt_nuclei_binary_map_onehot = ( | |
F.one_hot(masks["nuclei_binary_map"], num_classes=2) | |
).type( | |
torch.float32 | |
) # background, nuclei #[64, 256,256,2] | |
nuclei_type_maps = torch.squeeze(masks["nuclei_type_map"]).type(torch.int64) #[64,256,256] | |
gt_nuclei_type_maps_onehot = F.one_hot( | |
nuclei_type_maps, num_classes=self.num_classes | |
).type( | |
torch.float32 | |
) # background + nuclei types [64, 256, 256, 6] | |
# assemble ground truth dictionary | |
gt = { | |
"nuclei_type_map": gt_nuclei_type_maps_onehot.permute(0, 3, 1, 2).to( | |
self.device | |
), # shape: (batch_size, H, W, num_nuclei_classes) #[64,256,256,6] ->[64,6,256,256] | |
"nuclei_binary_map": gt_nuclei_binary_map_onehot.permute(0, 3, 1, 2).to( | |
self.device | |
), # shape: (batch_size, H, W, 2) #[64,256,256,2] ->[64,2,256,256] | |
"hv_map": masks["hv_map"].to(self.device), # shape: (batch_size, H, W, 2)原来的是错的 [64, 2, 256, 256] | |
"instance_map": masks["instance_map"].to( | |
self.device | |
), # shape: (batch_size, H, W) -> each instance has one integer (64,256,256) | |
"instance_types_nuclei": ( | |
gt_nuclei_type_maps_onehot * masks["instance_map"][..., None] | |
) | |
.permute(0, 3, 1, 2) | |
.to( | |
self.device | |
), # shape: (batch_size, num_nuclei_classes, H, W) -> instance has one integer, for each nuclei class (64,256,256,6) | |
"tissue_types": torch.Tensor( | |
[self.dataset_config["tissue_types"][t] for t in tissue_types] | |
) | |
.type(torch.LongTensor) | |
.to(self.device), # shape: batch_size 64 | |
} | |
gt["instance_types"] = calculate_instances( | |
gt["nuclei_type_map"], gt["instance_map"] | |
) | |
gt = DataclassHVStorage(**gt, batch_size=gt["tissue_types"].shape[0]) | |
return gt | |
def calculate_step_metric( | |
self, | |
predictions: DataclassHVStorage, | |
gt: DataclassHVStorage, | |
image_names: List[str], | |
) -> Tuple[dict, list]: | |
"""Calculate the metrics for the validation step | |
Args: | |
predictions (DataclassHVStorage): Processed network output | |
gt (DataclassHVStorage): Ground truth values | |
image_names (list(str)): List with image names | |
Returns: | |
Tuple[dict, list]: | |
* dict: Dictionary with metrics. Structure not fixed yet | |
* list with cell_dice, cell_jaccard and pq for each image | |
""" | |
predictions = predictions.get_dict() | |
gt = gt.get_dict() | |
# preparation and device movement | |
predictions["tissue_types_classes"] = F.softmax( | |
predictions["tissue_types"], dim=-1 | |
) | |
pred_tissue = ( | |
torch.argmax(predictions["tissue_types_classes"], dim=-1) | |
.detach() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
predictions["instance_map"] = predictions["instance_map"].detach().cpu() | |
predictions["instance_types_nuclei"] = ( | |
predictions["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) # shape: (batch_size, num_nuclei_classes, H, W) [64,256,256,6] | |
instance_maps_gt = gt["instance_map"].detach().cpu() #[64,256,256] | |
gt["tissue_types"] = gt["tissue_types"].detach().cpu().numpy().astype(np.uint8) | |
gt["nuclei_binary_map"] = torch.argmax(gt["nuclei_binary_map"], dim=1).type( | |
torch.uint8 | |
) | |
gt["instance_types_nuclei"] = ( | |
gt["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) # shape: (batch_size, num_nuclei_classes, H, W) [64,6,256,256] ################################与前面的predictions的形状不同 | |
# segmentation scores | |
binary_dice_scores = [] # binary dice scores per image | |
binary_jaccard_scores = [] # binary jaccard scores per image | |
pq_scores = [] # pq-scores per image | |
dq_scores = [] # dq-scores per image | |
sq_scores = [] # sq_scores per image | |
cell_type_pq_scores = [] # pq-scores per cell type and image | |
cell_type_dq_scores = [] # dq-scores per cell type and image | |
cell_type_sq_scores = [] # sq-scores per cell type and image | |
scores = [] # all scores in one list | |
# detection scores | |
paired_all = [] # unique matched index pair | |
unpaired_true_all = ( | |
[] | |
) # the index must exist in `true_inst_type_all` and unique | |
unpaired_pred_all = ( | |
[] | |
) # the index must exist in `pred_inst_type_all` and unique | |
true_inst_type_all = [] # each index is 1 independent data point | |
pred_inst_type_all = [] # each index is 1 independent data point | |
# for detections scores | |
true_idx_offset = 0 | |
pred_idx_offset = 0 | |
for i in range(len(pred_tissue)): | |
# binary dice score: Score for cell detection per image, without background | |
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"][i], dim=0) | |
target_binary_map = gt["nuclei_binary_map"][i] | |
cell_dice = ( | |
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0) | |
.detach() | |
.cpu() | |
) | |
binary_dice_scores.append(float(cell_dice)) | |
# binary aji | |
cell_jaccard = ( | |
binary_jaccard_index( | |
preds=pred_binary_map, | |
target=target_binary_map, | |
) | |
.detach() | |
.cpu() | |
) | |
binary_jaccard_scores.append(float(cell_jaccard)) | |
# pq values | |
if len(np.unique(instance_maps_gt[i])) == 1: | |
dq, sq, pq = np.nan, np.nan, np.nan | |
else: | |
remapped_instance_pred = binarize( | |
predictions["instance_types_nuclei"][i][1:].transpose(1, 2, 0) | |
) #(256,6) | |
remapped_gt = remap_label(instance_maps_gt[i]) #(256,256) | |
# remapped_instance_pred = binarize(predictions["instance_types_nuclei"][i].transpose(2,1,0)[1:]) #[64,256,256,6] | |
[dq, sq, pq], _ = get_fast_pq( | |
true=remapped_gt, pred=remapped_instance_pred | |
) #(256,256) (256,256) true是instance map,在这里true的形状应该是真实的实例图,pred是预测的实例图,形状应该相等,都为(256,256) | |
pq_scores.append(pq) | |
dq_scores.append(dq) | |
sq_scores.append(sq) | |
scores.append( | |
[ | |
cell_dice.detach().cpu().numpy(), | |
cell_jaccard.detach().cpu().numpy(), | |
pq, | |
] | |
) | |
# pq values per class (with class 0 beeing background -> should be skipped in the future) | |
nuclei_type_pq = [] | |
nuclei_type_dq = [] | |
nuclei_type_sq = [] | |
for j in range(0, self.num_classes): | |
pred_nuclei_instance_class = remap_label( | |
predictions["instance_types_nuclei"][i][j, ...] | |
) | |
target_nuclei_instance_class = remap_label( | |
gt["instance_types_nuclei"][i][j, ...] | |
) | |
# if ground truth is empty, skip from calculation | |
if len(np.unique(target_nuclei_instance_class)) == 1: | |
pq_tmp = np.nan | |
dq_tmp = np.nan | |
sq_tmp = np.nan | |
else: | |
[dq_tmp, sq_tmp, pq_tmp], _ = get_fast_pq( | |
pred_nuclei_instance_class, | |
target_nuclei_instance_class, | |
match_iou=0.5, | |
) | |
nuclei_type_pq.append(pq_tmp) | |
nuclei_type_dq.append(dq_tmp) | |
nuclei_type_sq.append(sq_tmp) | |
# detection scores | |
true_centroids = np.array( | |
[v["centroid"] for k, v in gt["instance_types"][i].items()] | |
) | |
true_instance_type = np.array( | |
[v["type"] for k, v in gt["instance_types"][i].items()] | |
) | |
pred_centroids = np.array( | |
[v["centroid"] for k, v in predictions["instance_types"][i].items()] | |
) | |
pred_instance_type = np.array( | |
[v["type"] for k, v in predictions["instance_types"][i].items()] | |
) | |
if true_centroids.shape[0] == 0: | |
true_centroids = np.array([[0, 0]]) | |
true_instance_type = np.array([0]) | |
if pred_centroids.shape[0] == 0: | |
pred_centroids = np.array([[0, 0]]) | |
pred_instance_type = np.array([0]) | |
if self.magnification == 40: | |
pairing_radius = 12 | |
else: | |
pairing_radius = 6 | |
paired, unpaired_true, unpaired_pred = pair_coordinates( | |
true_centroids, pred_centroids, pairing_radius | |
) | |
true_idx_offset = ( | |
true_idx_offset + true_inst_type_all[-1].shape[0] if i != 0 else 0 | |
) | |
pred_idx_offset = ( | |
pred_idx_offset + pred_inst_type_all[-1].shape[0] if i != 0 else 0 | |
) | |
true_inst_type_all.append(true_instance_type) | |
pred_inst_type_all.append(pred_instance_type) | |
# increment the pairing index statistic | |
if paired.shape[0] != 0: # ! sanity | |
paired[:, 0] += true_idx_offset | |
paired[:, 1] += pred_idx_offset | |
paired_all.append(paired) | |
unpaired_true += true_idx_offset | |
unpaired_pred += pred_idx_offset | |
unpaired_true_all.append(unpaired_true) | |
unpaired_pred_all.append(unpaired_pred) | |
cell_type_pq_scores.append(nuclei_type_pq) | |
cell_type_dq_scores.append(nuclei_type_dq) | |
cell_type_sq_scores.append(nuclei_type_sq) | |
paired_all = np.concatenate(paired_all, axis=0) | |
unpaired_true_all = np.concatenate(unpaired_true_all, axis=0) | |
unpaired_pred_all = np.concatenate(unpaired_pred_all, axis=0) | |
true_inst_type_all = np.concatenate(true_inst_type_all, axis=0) | |
pred_inst_type_all = np.concatenate(pred_inst_type_all, axis=0) | |
batch_metrics = { | |
"image_names": image_names, | |
"binary_dice_scores": binary_dice_scores, | |
"binary_jaccard_scores": binary_jaccard_scores, | |
"pq_scores": pq_scores, | |
"dq_scores": dq_scores, | |
"sq_scores": sq_scores, | |
"cell_type_pq_scores": cell_type_pq_scores, | |
"cell_type_dq_scores": cell_type_dq_scores, | |
"cell_type_sq_scores": cell_type_sq_scores, | |
"tissue_pred": pred_tissue, | |
"tissue_gt": gt["tissue_types"], | |
"paired_all": paired_all, | |
"unpaired_true_all": unpaired_true_all, | |
"unpaired_pred_all": unpaired_pred_all, | |
"true_inst_type_all": true_inst_type_all, | |
"pred_inst_type_all": pred_inst_type_all, | |
} | |
return batch_metrics, scores | |
def plot_results( | |
self, | |
imgs: Union[torch.Tensor, np.ndarray], | |
predictions: dict, | |
num_nuclei_classes: int, | |
outdir: Union[Path, str], | |
) -> None: | |
# TODO: Adapt Docstring and function, currently not working with our shape | |
"""Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth | |
Args: | |
imgs (Union[torch.Tensor, np.ndarray]): Images to process, a random number (num_images) is selected from this stack | |
Shape: (batch_size, 3, H', W') | |
predictions (dict): Predictions of models. Keys: | |
"nuclei_type_map": Shape: (batch_size, H', W', num_nuclei) | |
"nuclei_binary_map": Shape: (batch_size, H', W', 2) | |
"hv_map": Shape: (batch_size, H', W', 2) | |
"instance_map": Shape: (batch_size, H', W') | |
ground_truth (dict): Ground truth values. Keys: | |
"nuclei_type_map": Shape: (batch_size, H', W', num_nuclei) | |
"nuclei_binary_map": Shape: (batch_size, H', W', 2) | |
"hv_map": Shape: (batch_size, H', W', 2) | |
"instance_map": Shape: (batch_size, H', W') | |
img_names (List): Names of images as list | |
num_nuclei_classes (int): Number of total nuclei classes including background | |
outdir (Union[Path, str]): Output directory where images should be stored | |
scores (List[List[float]], optional): List with scores for each image. | |
Each list entry is a list with 3 scores: Dice, Jaccard and bPQ for the image. | |
Defaults to None. | |
""" | |
outdir = Path(outdir) | |
outdir.mkdir(exist_ok=True, parents=True) | |
# permute for gt and predictions | |
predictions.hv_map = predictions.hv_map.permute(0, 2, 3, 1) | |
predictions.nuclei_binary_map = predictions.nuclei_binary_map.permute(0, 2, 3, 1) | |
predictions.nuclei_type_map = predictions.nuclei_type_map.permute(0, 2, 3, 1) | |
h = predictions.hv_map.shape[1] | |
w = predictions.hv_map.shape[2] | |
# convert to rgb and crop to selection | |
sample_images = ( | |
imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() | |
) # convert to rgb | |
sample_images = cropping_center(sample_images, (h, w), True) | |
pred_sample_binary_map = ( | |
predictions.nuclei_binary_map[:, :, :, 1].detach().cpu().numpy() | |
) | |
pred_sample_hv_map = predictions.hv_map.detach().cpu().numpy() | |
pred_sample_instance_maps = predictions.instance_map.detach().cpu().numpy() | |
pred_sample_type_maps = ( | |
torch.argmax(predictions.nuclei_type_map, dim=-1).detach().cpu().numpy() | |
) | |
# create colormaps | |
hv_cmap = plt.get_cmap("jet") | |
binary_cmap = plt.get_cmap("jet") | |
instance_map = plt.get_cmap("viridis") | |
cell_colors = ["#ffffff", "#ff0000", "#00ff00", "#1e00ff", "#feff00", "#ffbf00"] | |
# invert the normalization of the sample images | |
transform_settings = self.run_conf["transformations"] | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
inv_normalize = transforms.Normalize( | |
mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]], | |
std=[1 / std[0], 1 / std[1], 1 / std[2]], | |
) | |
inv_samples = inv_normalize(torch.tensor(sample_images).permute(0, 3, 1, 2)) | |
sample_images = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy() | |
for i in range(len(imgs)): | |
fig, axs = plt.subplots(figsize=(6, 2), dpi=300) | |
placeholder = np.zeros((h, 7 * w, 3)) | |
# orig image | |
placeholder[:h, :w, :3] = sample_images[i] | |
# binary prediction | |
placeholder[: h, w : 2 * w, :3] = rgba2rgb( | |
binary_cmap(pred_sample_binary_map[i]) | |
) # *255? | |
# hv maps | |
placeholder[: h, 2 * w : 3 * w, :3] = rgba2rgb( | |
hv_cmap((pred_sample_hv_map[i, :, :, 0] + 1) / 2) | |
) | |
placeholder[: h, 3 * w : 4 * w, :3] = rgba2rgb( | |
hv_cmap((pred_sample_hv_map[i, :, :, 1] + 1) / 2) | |
) | |
# instance_predictions | |
placeholder[: h, 4 * w : 5 * w, :3] = rgba2rgb( | |
instance_map( | |
( | |
pred_sample_instance_maps[i] | |
- np.min(pred_sample_instance_maps[i]) | |
) | |
/ ( | |
np.max(pred_sample_instance_maps[i]) | |
- np.min(pred_sample_instance_maps[i] + 1e-10) | |
) | |
) | |
) | |
# type_predictions | |
placeholder[: h, 5 * w : 6 * w, :3] = rgba2rgb( | |
binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes) | |
) | |
# contours | |
# pred | |
pred_contours_polygon = [ | |
v["contour"] for v in predictions.instance_types[i].values() | |
] | |
pred_contours_polygon = [ | |
list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon | |
] | |
pred_contour_colors_polygon = [ | |
cell_colors[v["type"]] | |
for v in predictions.instance_types[i].values() | |
] | |
pred_cell_image = Image.fromarray( | |
(sample_images[i] * 255).astype(np.uint8) | |
).convert("RGB") | |
pred_drawing = ImageDraw.Draw(pred_cell_image) | |
add_patch = lambda poly, color: pred_drawing.polygon( | |
poly, outline=color, width=2 | |
) | |
[ | |
add_patch(poly, c) | |
for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon) | |
] | |
pred_cell_image.save("raw_pred.png") | |
placeholder[: h, 6 * w : 7 * w, :3] = ( | |
np.asarray(pred_cell_image) / 255 | |
) | |
# plotting | |
axs.imshow(placeholder) | |
axs.set_xticks(np.arange(w / 2, 7 * w, w)) | |
axs.set_xticklabels( | |
[ | |
"Image", | |
"Binary-Cells", | |
"HV-Map-0", | |
"HV-Map-1", | |
"Instances", | |
"Nuclei-Pred", | |
"Countours", | |
], | |
fontsize=6, | |
) | |
axs.xaxis.tick_top() | |
axs.set_yticks([ h /2 ]) | |
axs.set_yticklabels(["Pred."], fontsize=6) | |
axs.tick_params(axis="both", which="both", length=0) | |
grid_x = np.arange(w, 6 * w, w) | |
grid_y = np.arange(h, 2 * h, h) | |
for x_seg in grid_x: | |
axs.axvline(x_seg, color="black") | |
for y_seg in grid_y: | |
axs.axhline(y_seg, color="black") | |
fig.suptitle(f"All Predictions for input image") | |
fig.tight_layout() | |
fig.savefig("pred_img.png") | |
plt.close() | |
# CLI | |
class InferenceCellViTParser: | |
def __init__(self) -> None: | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
description="Perform CellViT inference for given run-directory with model checkpoints and logs", | |
) | |
parser.add_argument( | |
"--run_dir", | |
type=str, | |
help="Logging directory of a training run.", | |
default="./", | |
) | |
parser.add_argument( | |
"--checkpoint_name", | |
type=str, | |
help="Name of the checkpoint. Either select 'best_checkpoint.pth'," | |
"'latest_checkpoint.pth' or one of the intermediate checkpoint names," | |
"e.g., 'checkpoint_100.pth'", | |
default="model_best.pth", | |
) | |
parser.add_argument( | |
"--gpu", type=int, help="Cuda-GPU ID for inference", default=0 | |
) | |
parser.add_argument( | |
"--magnification", | |
type=int, | |
help="Dataset Magnification. Either 20 or 40. Default: 40", | |
choices=[20, 40], | |
default=40, | |
) | |
parser.add_argument( | |
"--plots", | |
action="store_true", | |
help="Generate inference plots in run_dir", | |
default=True, | |
) | |
self.parser = parser | |
def parse_arguments(self) -> dict: | |
opt = self.parser.parse_args() | |
return vars(opt) | |
if __name__ == "__main__": | |
configuration_parser = InferenceCellViTParser() | |
configuration = configuration_parser.parse_arguments() | |
print(configuration) | |
inf = InferenceCellViT( | |
run_dir=configuration["run_dir"], | |
checkpoint_name=configuration["checkpoint_name"], | |
gpu=configuration["gpu"], | |
magnification=configuration["magnification"], | |
) | |
model, dataloader, conf = inf.setup_patch_inference() | |
inf.run_patch_inference( | |
model, dataloader, conf, generate_plots=configuration["plots"] | |
) | |