Spaces:
Running
Running
import torch | |
import numpy as np | |
import os, sys | |
from monai.inferers import sliding_window_inference | |
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../"))) | |
from BasePredictor import BasePredictor | |
from utils import compute_masks | |
__all__ = ["Predictor"] | |
class Predictor(BasePredictor): | |
def __init__( | |
self, | |
model, | |
device, | |
input_path, | |
output_path, | |
make_submission=False, | |
exp_name=None, | |
algo_params=None, | |
): | |
super(Predictor, self).__init__( | |
model, | |
device, | |
input_path, | |
output_path, | |
make_submission, | |
exp_name, | |
algo_params, | |
) | |
self.hflip_tta = HorizontalFlip() | |
self.vflip_tta = VerticalFlip() | |
def _inference(self, img_data): | |
"""Conduct model prediction""" | |
img_data = img_data.to(self.device) | |
img_base = img_data | |
outputs_base = self._window_inference(img_base) | |
outputs_base = outputs_base.cpu().squeeze() | |
img_base.cpu() | |
if not self.use_tta: | |
pred_mask = outputs_base | |
return pred_mask | |
else: | |
# HorizontalFlip TTA | |
img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True) | |
outputs_hflip = self._window_inference(img_hflip) | |
outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True) | |
outputs_hflip = outputs_hflip.cpu().squeeze() | |
img_hflip = img_hflip.cpu() | |
# VertricalFlip TTA | |
img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True) | |
outputs_vflip = self._window_inference(img_vflip) | |
outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True) | |
outputs_vflip = outputs_vflip.cpu().squeeze() | |
img_vflip = img_vflip.cpu() | |
# Merge Results | |
pred_mask = torch.zeros_like(outputs_base) | |
pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3 | |
pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3 | |
pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3 | |
return pred_mask | |
def _window_inference(self, img_data, aux=False): | |
"""Inference on RoI-sized window""" | |
outputs = sliding_window_inference( | |
img_data, | |
roi_size=512, | |
sw_batch_size=4, | |
predictor=self.model if not aux else self.model_aux, | |
padding_mode="constant", | |
mode="gaussian", | |
overlap=0.6, | |
) | |
return outputs | |
def _post_process(self, pred_mask): | |
"""Generate cell instance masks.""" | |
dP, cellprob = pred_mask[:2], self._sigmoid(pred_mask[-1]) | |
H, W = pred_mask.shape[-2], pred_mask.shape[-1] | |
if np.prod(H * W) < (5000 * 5000): | |
pred_mask = compute_masks( | |
dP, | |
cellprob, | |
use_gpu=True, | |
flow_threshold=0.4, | |
device=self.device, | |
cellprob_threshold=0.5, | |
)[0] | |
else: | |
print("\n[Whole Slide] Grid Prediction starting...") | |
roi_size = 2000 | |
# Get patch grid by roi_size | |
if H % roi_size != 0: | |
n_H = H // roi_size + 1 | |
new_H = roi_size * n_H | |
else: | |
n_H = H // roi_size | |
new_H = H | |
if W % roi_size != 0: | |
n_W = W // roi_size + 1 | |
new_W = roi_size * n_W | |
else: | |
n_W = W // roi_size | |
new_W = W | |
# Allocate values on the grid | |
pred_pad = np.zeros((new_H, new_W), dtype=np.uint32) | |
dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32) | |
cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32) | |
dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob | |
for i in range(n_H): | |
for j in range(n_W): | |
print("Pred on Grid (%d, %d) processing..." % (i, j)) | |
dP_roi = dP_pad[ | |
:, | |
roi_size * i : roi_size * (i + 1), | |
roi_size * j : roi_size * (j + 1), | |
] | |
cellprob_roi = cellprob_pad[ | |
roi_size * i : roi_size * (i + 1), | |
roi_size * j : roi_size * (j + 1), | |
] | |
pred_mask = compute_masks( | |
dP_roi, | |
cellprob_roi, | |
use_gpu=True, | |
flow_threshold=0.4, | |
device=self.device, | |
cellprob_threshold=0.5, | |
)[0] | |
pred_pad[ | |
roi_size * i : roi_size * (i + 1), | |
roi_size * j : roi_size * (j + 1), | |
] = pred_mask | |
pred_mask = pred_pad[:H, :W] | |
return pred_mask | |
def _sigmoid(self, z): | |
return 1 / (1 + np.exp(-z)) | |
""" | |
Adapted from the following references: | |
[1] https://github.com/qubvel/ttach/blob/master/ttach/transforms.py | |
""" | |
def hflip(x): | |
"""flip batch of images horizontally""" | |
return x.flip(3) | |
def vflip(x): | |
"""flip batch of images vertically""" | |
return x.flip(2) | |
class DualTransform: | |
identity_param = None | |
def __init__( | |
self, name: str, params, | |
): | |
self.params = params | |
self.pname = name | |
def apply_aug_image(self, image, *args, **params): | |
raise NotImplementedError | |
def apply_deaug_mask(self, mask, *args, **params): | |
raise NotImplementedError | |
class HorizontalFlip(DualTransform): | |
"""Flip images horizontally (left -> right)""" | |
identity_param = False | |
def __init__(self): | |
super().__init__("apply", [False, True]) | |
def apply_aug_image(self, image, apply=False, **kwargs): | |
if apply: | |
image = hflip(image) | |
return image | |
def apply_deaug_mask(self, mask, apply=False, **kwargs): | |
if apply: | |
mask = hflip(mask) | |
return mask | |
class VerticalFlip(DualTransform): | |
"""Flip images vertically (up -> down)""" | |
identity_param = False | |
def __init__(self): | |
super().__init__("apply", [False, True]) | |
def apply_aug_image(self, image, apply=False, **kwargs): | |
if apply: | |
image = vflip(image) | |
return image | |
def apply_deaug_mask(self, mask, apply=False, **kwargs): | |
if apply: | |
mask = vflip(mask) | |
return mask | |