편명장/님/(myeongjang.pyeon)
initial commit
287a683
import os
import random
import torch
from PIL import Image
from torchstain.base.normalizers.he_normalizer import HENormalizer
from torchstain.torch.utils import cov, percentile
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
def preprocessor(pretrained=False, normalizer=None):
if pretrained:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
return preprocess
"""
Source code ported from: https://github.com/schaugf/HEnorm_python
Original implementation: https://github.com/mitkovetta/staining-normalization
"""
class TorchMacenkoNormalizer(HENormalizer):
def __init__(self):
super().__init__()
self.HERef = torch.tensor(
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
)
self.maxCRef = torch.tensor([1.9705, 1.0308])
# Avoid using deprecated torch.lstsq (since 1.9.0)
self.updated_lstsq = hasattr(torch.linalg, "lstsq")
def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)
# calculate optical density
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
# remove transparent pixels
ODhat = OD[~torch.any(OD < beta, dim=1)]
return OD, ODhat
def __find_HE(self, ODhat, eigvecs, alpha):
# project on the plane spanned by the eigenvectors corresponding to the two
# largest eigenvalues
That = torch.matmul(ODhat, eigvecs)
phi = torch.atan2(That[:, 1], That[:, 0])
# print(phi.size())
minPhi = percentile(phi, alpha)
maxPhi = percentile(phi, 100 - alpha)
vMin = torch.matmul(
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
).unsqueeze(1)
vMax = torch.matmul(
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
).unsqueeze(1)
# a heuristic to make the vector corresponding to hematoxylin first and the
# one corresponding to eosin second
HE = torch.where(
vMin[0] > vMax[0],
torch.cat((vMin, vMax), dim=1),
torch.cat((vMax, vMin), dim=1),
)
return HE
def __find_concentration(self, OD, HE):
# rows correspond to channels (RGB), columns to OD values
Y = OD.T
# determine concentrations of the individual stains
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]
return torch.linalg.lstsq(HE, Y)[0]
def __compute_matrices(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
# compute eigenvectors
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
eigvecs = eigvecs[:, [1, 2]]
HE = self.__find_HE(ODhat, eigvecs, alpha)
C = self.__find_concentration(OD, HE)
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
return HE, C, maxC
def fit(self, I, Io=240, alpha=1, beta=0.15):
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
self.HERef = HE
self.maxCRef = maxC
def normalize(
self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
):
"""Normalize staining appearence of H&E stained images
Example use:
see test.py
Input:
I: RGB input image: tensor of shape [C, H, W] and type uint8
Io: (optional) transmitted light intensity
alpha: percentile
beta: transparency threshold
stains: if true, return also H & E components
Output:
Inorm: normalized image
H: hematoxylin image
E: eosin image
Reference:
A method for normalizing histology slides for quantitative analysis. M.
Macenko et al., ISBI 2009
"""
c, h, w = I.shape
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
# normalize stain concentrations
C *= (self.maxCRef / maxC).unsqueeze(-1)
# recreate the image using reference mixing matrix
Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
Inorm = torch.clip(Inorm, 0, 255)
Inorm = Inorm.reshape(c, h, w).float() / 255.0
Inorm = torch.clip(Inorm, 0.0, 1.0)
H, E = None, None
if stains:
H = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
),
)
H[H > 255] = 255
H = H.T.reshape(h, w, c).int()
E = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
),
)
E[E > 255] = 255
E = E.T.reshape(h, w, c).int()
return Inorm, H, E
class MacenkoNormalizer:
def __init__(self, target_path=None, prob=1):
self.transform_before_macenko = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
)
self.normalizer = TorchMacenkoNormalizer()
ext = os.path.splitext(target_path)[1].lower()
if ext in [".jpg", ".jpeg", ".png"]:
target = Image.open(target_path)
self.normalizer.fit(self.transform_before_macenko(target))
elif ext in [".pt"]:
target = torch.load(target_path)
self.normalizer.HERef = target["HERef"]
self.normalizer.maxCRef = target["maxCRef"]
else:
raise ValueError(f"Invalid extension: {ext}")
self.prob = prob
def __call__(self, image):
t_to_transform = self.transform_before_macenko(image)
try:
image_macenko, _, _ = self.normalizer.normalize(
I=t_to_transform, stains=False, form="chw", dtype="float"
)
if torch.any(torch.isnan(image_macenko)):
return image
else:
image_macenko = to_pil_image(image_macenko)
return image_macenko
except Exception as e:
if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
pass
else:
print(str(e))
return image