import os
import random

import torch
from PIL import Image
from torchstain.base.normalizers.he_normalizer import HENormalizer
from torchstain.torch.utils import cov, percentile
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image


def preprocessor(pretrained=False, normalizer=None):
    if pretrained:
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)

    preprocess = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )

    return preprocess


"""
Source code ported from: https://github.com/schaugf/HEnorm_python
Original implementation: https://github.com/mitkovetta/staining-normalization
"""


class TorchMacenkoNormalizer(HENormalizer):
    def __init__(self):
        super().__init__()

        self.HERef = torch.tensor(
            [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
        )
        self.maxCRef = torch.tensor([1.9705, 1.0308])

        # Avoid using deprecated torch.lstsq (since 1.9.0)
        self.updated_lstsq = hasattr(torch.linalg, "lstsq")

    def __convert_rgb2od(self, I, Io, beta):
        I = I.permute(1, 2, 0)

        # calculate optical density
        OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)

        # remove transparent pixels
        ODhat = OD[~torch.any(OD < beta, dim=1)]

        return OD, ODhat

    def __find_HE(self, ODhat, eigvecs, alpha):
        # project on the plane spanned by the eigenvectors corresponding to the two
        # largest eigenvalues
        That = torch.matmul(ODhat, eigvecs)
        phi = torch.atan2(That[:, 1], That[:, 0])
        # print(phi.size())

        minPhi = percentile(phi, alpha)
        maxPhi = percentile(phi, 100 - alpha)

        vMin = torch.matmul(
            eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
        ).unsqueeze(1)
        vMax = torch.matmul(
            eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
        ).unsqueeze(1)

        # a heuristic to make the vector corresponding to hematoxylin first and the
        # one corresponding to eosin second
        HE = torch.where(
            vMin[0] > vMax[0],
            torch.cat((vMin, vMax), dim=1),
            torch.cat((vMax, vMin), dim=1),
        )

        return HE

    def __find_concentration(self, OD, HE):
        # rows correspond to channels (RGB), columns to OD values
        Y = OD.T

        # determine concentrations of the individual stains
        if not self.updated_lstsq:
            return torch.lstsq(Y, HE)[0][:2]

        return torch.linalg.lstsq(HE, Y)[0]

    def __compute_matrices(self, I, Io, alpha, beta):
        OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

        # compute eigenvectors
        _, eigvecs = torch.linalg.eigh(cov(ODhat.T))
        eigvecs = eigvecs[:, [1, 2]]

        HE = self.__find_HE(ODhat, eigvecs, alpha)

        C = self.__find_concentration(OD, HE)
        maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])

        return HE, C, maxC

    def fit(self, I, Io=240, alpha=1, beta=0.15):
        HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)

        self.HERef = HE
        self.maxCRef = maxC

    def normalize(
        self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
    ):
        """Normalize staining appearence of H&E stained images

        Example use:
            see test.py

        Input:
            I: RGB input image: tensor of shape [C, H, W] and type uint8
            Io: (optional) transmitted light intensity
            alpha: percentile
            beta: transparency threshold
            stains: if true, return also H & E components

        Output:
            Inorm: normalized image
            H: hematoxylin image
            E: eosin image

        Reference:
            A method for normalizing histology slides for quantitative analysis. M.
            Macenko et al., ISBI 2009
        """

        c, h, w = I.shape

        HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)

        # normalize stain concentrations
        C *= (self.maxCRef / maxC).unsqueeze(-1)

        # recreate the image using reference mixing matrix
        Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
        Inorm = torch.clip(Inorm, 0, 255)

        Inorm = Inorm.reshape(c, h, w).float() / 255.0
        Inorm = torch.clip(Inorm, 0.0, 1.0)

        H, E = None, None

        if stains:
            H = torch.mul(
                Io,
                torch.exp(
                    torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
                ),
            )
            H[H > 255] = 255
            H = H.T.reshape(h, w, c).int()

            E = torch.mul(
                Io,
                torch.exp(
                    torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
                ),
            )
            E[E > 255] = 255
            E = E.T.reshape(h, w, c).int()

        return Inorm, H, E


class MacenkoNormalizer:
    def __init__(self, target_path=None, prob=1):
        self.transform_before_macenko = transforms.Compose(
            [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
        )
        self.normalizer = TorchMacenkoNormalizer()

        ext = os.path.splitext(target_path)[1].lower()
        if ext in [".jpg", ".jpeg", ".png"]:
            target = Image.open(target_path)
            self.normalizer.fit(self.transform_before_macenko(target))
        elif ext in [".pt"]:
            target = torch.load(target_path)
            self.normalizer.HERef = target["HERef"]
            self.normalizer.maxCRef = target["maxCRef"]

        else:
            raise ValueError(f"Invalid extension: {ext}")
        self.prob = prob

    def __call__(self, image):
        t_to_transform = self.transform_before_macenko(image)
        try:
            image_macenko, _, _ = self.normalizer.normalize(
                I=t_to_transform, stains=False, form="chw", dtype="float"
            )
            if torch.any(torch.isnan(image_macenko)):
                return image
            else:
                image_macenko = to_pil_image(image_macenko)
                return image_macenko
        except Exception as e:
            if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
                pass
            else:
                print(str(e))
            return image