|
import os |
|
import random |
|
|
|
import torch |
|
from PIL import Image |
|
from torchstain.base.normalizers.he_normalizer import HENormalizer |
|
from torchstain.torch.utils import cov, percentile |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
|
|
def preprocessor(pretrained=False, normalizer=None): |
|
if pretrained: |
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
else: |
|
mean = (0.5, 0.5, 0.5) |
|
std = (0.5, 0.5, 0.5) |
|
|
|
preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.Lambda(lambda x: x) if normalizer == None else normalizer, |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=mean, std=std), |
|
] |
|
) |
|
|
|
return preprocess |
|
|
|
|
|
""" |
|
Source code ported from: https://github.com/schaugf/HEnorm_python |
|
Original implementation: https://github.com/mitkovetta/staining-normalization |
|
""" |
|
|
|
|
|
class TorchMacenkoNormalizer(HENormalizer): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.HERef = torch.tensor( |
|
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]] |
|
) |
|
self.maxCRef = torch.tensor([1.9705, 1.0308]) |
|
|
|
|
|
self.updated_lstsq = hasattr(torch.linalg, "lstsq") |
|
|
|
def __convert_rgb2od(self, I, Io, beta): |
|
I = I.permute(1, 2, 0) |
|
|
|
|
|
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) |
|
|
|
|
|
ODhat = OD[~torch.any(OD < beta, dim=1)] |
|
|
|
return OD, ODhat |
|
|
|
def __find_HE(self, ODhat, eigvecs, alpha): |
|
|
|
|
|
That = torch.matmul(ODhat, eigvecs) |
|
phi = torch.atan2(That[:, 1], That[:, 0]) |
|
|
|
|
|
minPhi = percentile(phi, alpha) |
|
maxPhi = percentile(phi, 100 - alpha) |
|
|
|
vMin = torch.matmul( |
|
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) |
|
).unsqueeze(1) |
|
vMax = torch.matmul( |
|
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) |
|
).unsqueeze(1) |
|
|
|
|
|
|
|
HE = torch.where( |
|
vMin[0] > vMax[0], |
|
torch.cat((vMin, vMax), dim=1), |
|
torch.cat((vMax, vMin), dim=1), |
|
) |
|
|
|
return HE |
|
|
|
def __find_concentration(self, OD, HE): |
|
|
|
Y = OD.T |
|
|
|
|
|
if not self.updated_lstsq: |
|
return torch.lstsq(Y, HE)[0][:2] |
|
|
|
return torch.linalg.lstsq(HE, Y)[0] |
|
|
|
def __compute_matrices(self, I, Io, alpha, beta): |
|
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) |
|
|
|
|
|
_, eigvecs = torch.linalg.eigh(cov(ODhat.T)) |
|
eigvecs = eigvecs[:, [1, 2]] |
|
|
|
HE = self.__find_HE(ODhat, eigvecs, alpha) |
|
|
|
C = self.__find_concentration(OD, HE) |
|
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) |
|
|
|
return HE, C, maxC |
|
|
|
def fit(self, I, Io=240, alpha=1, beta=0.15): |
|
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) |
|
|
|
self.HERef = HE |
|
self.maxCRef = maxC |
|
|
|
def normalize( |
|
self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int" |
|
): |
|
"""Normalize staining appearence of H&E stained images |
|
|
|
Example use: |
|
see test.py |
|
|
|
Input: |
|
I: RGB input image: tensor of shape [C, H, W] and type uint8 |
|
Io: (optional) transmitted light intensity |
|
alpha: percentile |
|
beta: transparency threshold |
|
stains: if true, return also H & E components |
|
|
|
Output: |
|
Inorm: normalized image |
|
H: hematoxylin image |
|
E: eosin image |
|
|
|
Reference: |
|
A method for normalizing histology slides for quantitative analysis. M. |
|
Macenko et al., ISBI 2009 |
|
""" |
|
|
|
c, h, w = I.shape |
|
|
|
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) |
|
|
|
|
|
C *= (self.maxCRef / maxC).unsqueeze(-1) |
|
|
|
|
|
Inorm = Io * torch.exp(-torch.matmul(self.HERef, C)) |
|
Inorm = torch.clip(Inorm, 0, 255) |
|
|
|
Inorm = Inorm.reshape(c, h, w).float() / 255.0 |
|
Inorm = torch.clip(Inorm, 0.0, 1.0) |
|
|
|
H, E = None, None |
|
|
|
if stains: |
|
H = torch.mul( |
|
Io, |
|
torch.exp( |
|
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) |
|
), |
|
) |
|
H[H > 255] = 255 |
|
H = H.T.reshape(h, w, c).int() |
|
|
|
E = torch.mul( |
|
Io, |
|
torch.exp( |
|
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) |
|
), |
|
) |
|
E[E > 255] = 255 |
|
E = E.T.reshape(h, w, c).int() |
|
|
|
return Inorm, H, E |
|
|
|
|
|
class MacenkoNormalizer: |
|
def __init__(self, target_path=None, prob=1): |
|
self.transform_before_macenko = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)] |
|
) |
|
self.normalizer = TorchMacenkoNormalizer() |
|
|
|
ext = os.path.splitext(target_path)[1].lower() |
|
if ext in [".jpg", ".jpeg", ".png"]: |
|
target = Image.open(target_path) |
|
self.normalizer.fit(self.transform_before_macenko(target)) |
|
elif ext in [".pt"]: |
|
target = torch.load(target_path) |
|
self.normalizer.HERef = target["HERef"] |
|
self.normalizer.maxCRef = target["maxCRef"] |
|
|
|
else: |
|
raise ValueError(f"Invalid extension: {ext}") |
|
self.prob = prob |
|
|
|
def __call__(self, image): |
|
t_to_transform = self.transform_before_macenko(image) |
|
try: |
|
image_macenko, _, _ = self.normalizer.normalize( |
|
I=t_to_transform, stains=False, form="chw", dtype="float" |
|
) |
|
if torch.any(torch.isnan(image_macenko)): |
|
return image |
|
else: |
|
image_macenko = to_pil_image(image_macenko) |
|
return image_macenko |
|
except Exception as e: |
|
if "kthvalue()" in str(e) or "linalg.eigh" in str(e): |
|
pass |
|
else: |
|
print(str(e)) |
|
return image |
|
|