import os import random import torch from PIL import Image from torchstain.base.normalizers.he_normalizer import HENormalizer from torchstain.torch.utils import cov, percentile from torchvision import transforms from torchvision.transforms.functional import to_pil_image def preprocessor(pretrained=False, normalizer=None): if pretrained: mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) else: mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.Lambda(lambda x: x) if normalizer == None else normalizer, transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] ) return preprocess """ Source code ported from: https://github.com/schaugf/HEnorm_python Original implementation: https://github.com/mitkovetta/staining-normalization """ class TorchMacenkoNormalizer(HENormalizer): def __init__(self): super().__init__() self.HERef = torch.tensor( [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]] ) self.maxCRef = torch.tensor([1.9705, 1.0308]) # Avoid using deprecated torch.lstsq (since 1.9.0) self.updated_lstsq = hasattr(torch.linalg, "lstsq") def __convert_rgb2od(self, I, Io, beta): I = I.permute(1, 2, 0) # calculate optical density OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) # remove transparent pixels ODhat = OD[~torch.any(OD < beta, dim=1)] return OD, ODhat def __find_HE(self, ODhat, eigvecs, alpha): # project on the plane spanned by the eigenvectors corresponding to the two # largest eigenvalues That = torch.matmul(ODhat, eigvecs) phi = torch.atan2(That[:, 1], That[:, 0]) # print(phi.size()) minPhi = percentile(phi, alpha) maxPhi = percentile(phi, 100 - alpha) vMin = torch.matmul( eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) ).unsqueeze(1) vMax = torch.matmul( eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) ).unsqueeze(1) # a heuristic to make the vector corresponding to hematoxylin first and the # one corresponding to eosin second HE = torch.where( vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1), ) return HE def __find_concentration(self, OD, HE): # rows correspond to channels (RGB), columns to OD values Y = OD.T # determine concentrations of the individual stains if not self.updated_lstsq: return torch.lstsq(Y, HE)[0][:2] return torch.linalg.lstsq(HE, Y)[0] def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) C = self.__find_concentration(OD, HE) maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) return HE, C, maxC def fit(self, I, Io=240, alpha=1, beta=0.15): HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) self.HERef = HE self.maxCRef = maxC def normalize( self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int" ): """Normalize staining appearence of H&E stained images Example use: see test.py Input: I: RGB input image: tensor of shape [C, H, W] and type uint8 Io: (optional) transmitted light intensity alpha: percentile beta: transparency threshold stains: if true, return also H & E components Output: Inorm: normalized image H: hematoxylin image E: eosin image Reference: A method for normalizing histology slides for quantitative analysis. M. Macenko et al., ISBI 2009 """ c, h, w = I.shape HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) # normalize stain concentrations C *= (self.maxCRef / maxC).unsqueeze(-1) # recreate the image using reference mixing matrix Inorm = Io * torch.exp(-torch.matmul(self.HERef, C)) Inorm = torch.clip(Inorm, 0, 255) Inorm = Inorm.reshape(c, h, w).float() / 255.0 Inorm = torch.clip(Inorm, 0.0, 1.0) H, E = None, None if stains: H = torch.mul( Io, torch.exp( torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) ), ) H[H > 255] = 255 H = H.T.reshape(h, w, c).int() E = torch.mul( Io, torch.exp( torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) ), ) E[E > 255] = 255 E = E.T.reshape(h, w, c).int() return Inorm, H, E class MacenkoNormalizer: def __init__(self, target_path=None, prob=1): self.transform_before_macenko = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)] ) self.normalizer = TorchMacenkoNormalizer() ext = os.path.splitext(target_path)[1].lower() if ext in [".jpg", ".jpeg", ".png"]: target = Image.open(target_path) self.normalizer.fit(self.transform_before_macenko(target)) elif ext in [".pt"]: target = torch.load(target_path) self.normalizer.HERef = target["HERef"] self.normalizer.maxCRef = target["maxCRef"] else: raise ValueError(f"Invalid extension: {ext}") self.prob = prob def __call__(self, image): t_to_transform = self.transform_before_macenko(image) try: image_macenko, _, _ = self.normalizer.normalize( I=t_to_transform, stains=False, form="chw", dtype="float" ) if torch.any(torch.isnan(image_macenko)): return image else: image_macenko = to_pil_image(image_macenko) return image_macenko except Exception as e: if "kthvalue()" in str(e) or "linalg.eigh" in str(e): pass else: print(str(e)) return image