import json
from typing import Optional, Sequence

import numpy as np
import torch
import torch.distributed as ptdist
from monai.data import (
    CacheDataset,
    PersistentDataset,
    partition_dataset,
)
from monai.data.utils import pad_list_data_collate
from monai.transforms import (
    Compose,
    CropForegroundd,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandSpatialCropSamplesd,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
    Transform,
)


class PermuteImage(Transform):
    """Permute the dimensions of the image"""

    def __call__(self, data):
        data["image"] = data["image"].permute(
            3, 0, 1, 2
        )  # Adjust permutation order as needed
        return data


class CTDataset:
    def __init__(
        self,
        json_path: str,
        img_size: int,
        depth: int,
        mask_patch_size: int,
        patch_size: int,
        downsample_ratio: Sequence[float],
        cache_dir: str,
        batch_size: int = 1,
        val_batch_size: int = 1,
        num_workers: int = 4,
        cache_num: int = 0,
        cache_rate: float = 0.0,
        dist: bool = False,
    ):
        super().__init__()
        self.json_path = json_path
        self.img_size = img_size
        self.depth = depth
        self.mask_patch_size = mask_patch_size
        self.patch_size = patch_size
        self.cache_dir = cache_dir
        self.downsample_ratio = downsample_ratio
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.cache_num = cache_num
        self.cache_rate = cache_rate
        self.dist = dist

        data_list = json.load(open(json_path, "r"))

        if "train" in data_list.keys():
            self.train_list = data_list["train"]
        if "validation" in data_list.keys():
            self.val_list = data_list["validation"]

    def val_transforms(
        self,
    ):
        return self.train_transforms()

    def train_transforms(
        self,
    ):
        transforms = Compose(
            [
                LoadImaged(keys=["image"]),
                EnsureChannelFirstd(keys=["image"]),
                Orientationd(keys=["image"], axcodes="RAS"),
                Spacingd(
                    keys=["image"],
                    pixdim=self.downsample_ratio,
                    mode=("bilinear"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-175,
                    a_max=250,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image"], source_key="image"),
                RandSpatialCropSamplesd(
                    keys=["image"],
                    roi_size=(self.img_size, self.img_size, self.depth),
                    random_size=False,
                    num_samples=1,
                ),
                SpatialPadd(
                    keys=["image"],
                    spatial_size=(self.img_size, self.img_size, self.depth),
                ),
                # RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
                # RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
                ToTensord(keys=["image"]),
                PermuteImage(),
            ]
        )

        return transforms

    def setup(self, stage: Optional[str] = None):
        # Assign Train split(s) for use in Dataloaders
        if stage in [None, "train"]:
            if self.dist:
                train_partition = partition_dataset(
                    data=self.train_list,
                    num_partitions=ptdist.get_world_size(),
                    shuffle=True,
                    even_divisible=True,
                    drop_last=False,
                )[ptdist.get_rank()]
                valid_partition = partition_dataset(
                    data=self.val_list,
                    num_partitions=ptdist.get_world_size(),
                    shuffle=False,
                    even_divisible=True,
                    drop_last=False,
                )[ptdist.get_rank()]
                # self.cache_num //= ptdist.get_world_size()
            else:
                train_partition = self.train_list
                valid_partition = self.val_list

            if any([self.cache_num, self.cache_rate]) > 0:
                train_ds = CacheDataset(
                    train_partition,
                    cache_num=self.cache_num,
                    cache_rate=self.cache_rate,
                    num_workers=self.num_workers,
                    transform=self.train_transforms(),
                )
                valid_ds = CacheDataset(
                    valid_partition,
                    cache_num=self.cache_num // 4,
                    cache_rate=self.cache_rate,
                    num_workers=self.num_workers,
                    transform=self.val_transforms(),
                )
            else:
                train_ds = PersistentDataset(
                    train_partition,
                    transform=self.train_transforms(),
                    cache_dir=self.cache_dir,
                )
                valid_ds = PersistentDataset(
                    valid_partition,
                    transform=self.val_transforms(),
                    cache_dir=self.cache_dir,
                )

            return {"train": train_ds, "validation": valid_ds}

        if stage in [None, "test"]:
            if any([self.cache_num, self.cache_rate]) > 0:
                test_ds = CacheDataset(
                    self.val_list,
                    cache_num=self.cache_num // 4,
                    cache_rate=self.cache_rate,
                    num_workers=self.num_workers,
                    transform=self.val_transforms(),
                )
            else:
                test_ds = PersistentDataset(
                    self.val_list,
                    transform=self.val_transforms(),
                    cache_dir=self.cache_dir,
                )

            return {"test": test_ds}

        return {"train": None, "validation": None}

    def train_dataloader(self, train_ds):
        # def collate_fn(examples):
        #     pixel_values = torch.stack([example["image"] for example in examples])
        #     mask = torch.stack([example["mask"] for example in examples])
        #     return {"pixel_values": pixel_values, "bool_masked_pos": mask}

        return torch.utils.data.DataLoader(
            train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=True,
            collate_fn=pad_list_data_collate,
            # collate_fn=collate_fn
            # drop_last=False,
            # prefetch_factor=4,
        )

    def val_dataloader(self, valid_ds):
        return torch.utils.data.DataLoader(
            valid_ds,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=False,
            # drop_last=False,
            collate_fn=pad_list_data_collate,
            # prefetch_factor=4,
        )

    def test_dataloader(self, test_ds):
        return torch.utils.data.DataLoader(
            test_ds,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=False,
            # drop_last=False,
            collate_fn=pad_list_data_collate,
            # prefetch_factor=4,
        )