from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import warnings
import os


import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.modeling_utils import ModelMixin
import PIL
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn.init as init
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from diffusers.utils.torch_utils import is_compiled_module


class ControlnetPredictor(object):
    def __init__(self, controlnet_model_path: str, *args, **kwargs):
        """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取
            Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training
        Args:
            controlnet_model_path (str): controlnet 模型路径. controlnet model path.
        """
        super(ControlnetPredictor, self).__init__(*args, **kwargs)
        self.controlnet = ControlNetModel.from_pretrained(
            controlnet_model_path,
        )

    def prepare_image(
        self,
        image,  # b c t h w
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        if height is None:
            height = image.shape[-2]
        if width is None:
            width = image.shape[-1]
        width, height = (
            x - x % self.control_image_processor.vae_scale_factor
            for x in (width, height)
        )
        image = rearrange(image, "b c t h w-> (b t) c h w")
        image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0

        image = (
            torch.nn.functional.interpolate(
                image,
                size=(height, width),
                mode="bilinear",
            ),
        )

        do_normalize = self.control_image_processor.config.do_normalize
        if image.min() < 0:
            warnings.warn(
                "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
                f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
                FutureWarning,
            )
            do_normalize = False

        if do_normalize:
            image = self.control_image_processor.normalize(image)

        image_batch_size = image.shape[0]

        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image

    @torch.no_grad()
    def __call__(
        self,
        batch_size: int,
        device: str,
        dtype: torch.dtype,
        timesteps: List[float],
        i: int,
        scheduler: KarrasDiffusionSchedulers,
        prompt_embeds: torch.Tensor,
        do_classifier_free_guidance: bool = False,
        # 2b co t ho wo
        latent_model_input: torch.Tensor = None,
        # b co t ho wo
        latents: torch.Tensor = None,
        # b c t h w
        image: Union[
            torch.FloatTensor,
            PIL.Image.Image,
            np.ndarray,
            List[torch.FloatTensor],
            List[PIL.Image.Image],
            List[np.ndarray],
        ] = None,
        # b c t(1) hi wi
        controlnet_condition_frames: Optional[torch.FloatTensor] = None,
        # b c t ho wo
        controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
        # b c t(1) ho wo
        controlnet_condition_latents: Optional[torch.FloatTensor] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_videos_per_prompt: Optional[int] = 1,
        return_dict: bool = True,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
        latent_index: torch.LongTensor = None,
        vision_condition_latent_index: torch.LongTensor = None,
        **kwargs,
    ):
        assert (
            image is None and controlnet_latents is None
        ), "should set one of image and controlnet_latents"

        controlnet = (
            self.controlnet._orig_mod
            if is_compiled_module(self.controlnet)
            else self.controlnet
        )

        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(
            control_guidance_end, list
        ):
            control_guidance_start = len(control_guidance_end) * [
                control_guidance_start
            ]
        elif not isinstance(control_guidance_end, list) and isinstance(
            control_guidance_start, list
        ):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(
            control_guidance_end, list
        ):
            mult = (
                len(controlnet.nets)
                if isinstance(controlnet, MultiControlNetModel)
                else 1
            )
            control_guidance_start, control_guidance_end = mult * [
                control_guidance_start
            ], mult * [control_guidance_end]

        if isinstance(controlnet, MultiControlNetModel) and isinstance(
            controlnet_conditioning_scale, float
        ):
            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
                controlnet.nets
            )

        global_pool_conditions = (
            controlnet.config.global_pool_conditions
            if isinstance(controlnet, ControlNetModel)
            else controlnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions

        # 4. Prepare image
        if isinstance(controlnet, ControlNetModel):
            if (
                controlnet_latents is not None
                and controlnet_condition_latents is not None
            ):
                if isinstance(controlnet_latents, np.ndarray):
                    controlnet_latents = torch.from_numpy(controlnet_latents)
                if isinstance(controlnet_condition_latents, np.ndarray):
                    controlnet_condition_latents = torch.from_numpy(
                        controlnet_condition_latents
                    )
                # TODO:使用index进行concat
                controlnet_latents = torch.concat(
                    [controlnet_condition_latents, controlnet_latents], dim=2
                )
                if not guess_mode and do_classifier_free_guidance:
                    controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0)
                controlnet_latents = rearrange(
                    controlnet_latents, "b c t h w->(b t) c h w"
                )
                controlnet_latents = controlnet_latents.to(device=device, dtype=dtype)
            else:
                # TODO:使用index进行concat
                # TODO: concat with index
                if controlnet_condition_frames is not None:
                    if isinstance(controlnet_condition_frames, np.ndarray):
                        image = np.concatenate(
                            [controlnet_condition_frames, image], axis=2
                        )
                image = self.prepare_image(
                    image=image,
                    width=width,
                    height=height,
                    batch_size=batch_size * num_videos_per_prompt,
                    num_images_per_prompt=num_videos_per_prompt,
                    device=device,
                    dtype=controlnet.dtype,
                    do_classifier_free_guidance=do_classifier_free_guidance,
                    guess_mode=guess_mode,
                )
                height, width = image.shape[-2:]
        elif isinstance(controlnet, MultiControlNetModel):
            images = []
            # TODO: 支持直接使用controlnet_latent而不是frames
            # TODO: support using controlnet_latent directly instead of frames
            if controlnet_latents is not None:
                raise NotImplementedError
            else:
                for i, image_ in enumerate(image):
                    if controlnet_condition_frames is not None and isinstance(
                        controlnet_condition_frames, list
                    ):
                        if isinstance(controlnet_condition_frames[i], np.ndarray):
                            image_ = np.concatenate(
                                [controlnet_condition_frames[i], image_], axis=2
                            )
                    image_ = self.prepare_image(
                        image=image_,
                        width=width,
                        height=height,
                        batch_size=batch_size * num_videos_per_prompt,
                        num_images_per_prompt=num_videos_per_prompt,
                        device=device,
                        dtype=controlnet.dtype,
                        do_classifier_free_guidance=do_classifier_free_guidance,
                        guess_mode=guess_mode,
                    )

                    images.append(image_)

                image = images
                height, width = image[0].shape[-2:]
        else:
            assert False

        # 7.1 Create tensor stating which controlnets to keep
        controlnet_keep = []
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            controlnet_keep.append(
                keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
            )

        t = timesteps[i]

        # controlnet(s) inference
        if guess_mode and do_classifier_free_guidance:
            # Infer ControlNet only for the conditional batch.
            control_model_input = latents
            control_model_input = scheduler.scale_model_input(control_model_input, t)
            controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
        else:
            control_model_input = latent_model_input
            controlnet_prompt_embeds = prompt_embeds
        if isinstance(controlnet_keep[i], list):
            cond_scale = [
                c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
            ]
        else:
            cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
        control_model_input_reshape = rearrange(
            control_model_input, "b c t h w -> (b t) c h w"
        )
        encoder_hidden_states_repeat = repeat(
            controlnet_prompt_embeds,
            "b n q->(b t) n q",
            t=control_model_input.shape[2],
        )

        down_block_res_samples, mid_block_res_sample = self.controlnet(
            control_model_input_reshape,
            t,
            encoder_hidden_states_repeat,
            controlnet_cond=image,
            controlnet_cond_latents=controlnet_latents,
            conditioning_scale=cond_scale,
            guess_mode=guess_mode,
            return_dict=False,
        )

        return down_block_res_samples, mid_block_res_sample


class InflatedConv3d(nn.Conv2d):
    def forward(self, x):
        video_length = x.shape[2]

        x = rearrange(x, "b c f h w -> (b f) c h w")
        x = super().forward(x)
        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)

        return x


def zero_module(module):
    # Zero out the parameters of a module and return it.
    for p in module.parameters():
        p.detach().zero_()
    return module


class PoseGuider(ModelMixin):
    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int] = (16, 32, 64, 128),
    ):
        super().__init__()
        self.conv_in = InflatedConv3d(
            conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
        )

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(
                InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
            )
            self.blocks.append(
                InflatedConv3d(
                    channel_in, channel_out, kernel_size=3, padding=1, stride=2
                )
            )

        self.conv_out = zero_module(
            InflatedConv3d(
                block_out_channels[-1],
                conditioning_embedding_channels,
                kernel_size=3,
                padding=1,
            )
        )

    def forward(self, conditioning):
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_path,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int] = (16, 32, 64, 128),
    ):
        if not os.path.exists(pretrained_model_path):
            print(f"There is no model file in {pretrained_model_path}")
        print(
            f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..."
        )

        state_dict = torch.load(pretrained_model_path, map_location="cpu")
        model = PoseGuider(
            conditioning_embedding_channels=conditioning_embedding_channels,
            conditioning_channels=conditioning_channels,
            block_out_channels=block_out_channels,
        )

        m, u = model.load_state_dict(state_dict, strict=False)
        # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
        params = [p.numel() for n, p in model.named_parameters()]
        print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")

        return model