import random
from typing import Optional

import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

DP_AXIS, SP_AXIS = 0, 1


class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
    def __init__(
        self,
        sp_size: int = 1,
        stage: int = 2,
        precision: str = "fp16",
        initial_scale: float = 2**32,
        min_scale: float = 1,
        growth_factor: float = 2,
        backoff_factor: float = 0.5,
        growth_interval: int = 1000,
        hysteresis: int = 2,
        max_scale: float = 2**32,
        max_norm: float = 0.0,
        norm_type: float = 2.0,
        reduce_bucket_size_in_m: int = 12,
        communication_dtype: Optional[torch.dtype] = None,
        overlap_communication: bool = True,
        cpu_offload: bool = False,
        master_weights: bool = True,
        verbose: bool = False,
    ) -> None:
        super().__init__(
            stage=stage,
            precision=precision,
            initial_scale=initial_scale,
            min_scale=min_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            hysteresis=hysteresis,
            max_scale=max_scale,
            max_norm=max_norm,
            norm_type=norm_type,
            reduce_bucket_size_in_m=reduce_bucket_size_in_m,
            communication_dtype=communication_dtype,
            overlap_communication=overlap_communication,
            cpu_offload=cpu_offload,
            master_weights=master_weights,
            verbose=verbose,
        )
        self.sp_size = sp_size
        assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
        self.dp_size = self.world_size // sp_size
        self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
        self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
        self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
        self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
        self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)

    def __del__(self):
        """Destroy the prcess groups in ProcessGroupMesh"""
        self.pg_mesh.destroy_mesh_process_groups()

    def prepare_dataloader(
        self,
        dataset,
        batch_size,
        shuffle=False,
        seed=1024,
        drop_last=False,
        pin_memory=False,
        num_workers=0,
        distributed_sampler_cls=None,
        **kwargs,
    ):
        _kwargs = kwargs.copy()
        distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
        sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)

        # Deterministic dataloader
        def seed_worker(worker_id):
            worker_seed = seed
            np.random.seed(worker_seed)
            torch.manual_seed(worker_seed)
            random.seed(worker_seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            worker_init_fn=seed_worker,
            drop_last=drop_last,
            pin_memory=pin_memory,
            num_workers=num_workers,
            **_kwargs,
        )