# Currently only sd15

import functools
import torch
import einops

from comfy import model_management, utils
from comfy.ldm.modules.attention import optimized_attention


module_mapping_sd15 = {
    0: "input_blocks.1.1.transformer_blocks.0.attn1",
    1: "input_blocks.1.1.transformer_blocks.0.attn2",
    2: "input_blocks.2.1.transformer_blocks.0.attn1",
    3: "input_blocks.2.1.transformer_blocks.0.attn2",
    4: "input_blocks.4.1.transformer_blocks.0.attn1",
    5: "input_blocks.4.1.transformer_blocks.0.attn2",
    6: "input_blocks.5.1.transformer_blocks.0.attn1",
    7: "input_blocks.5.1.transformer_blocks.0.attn2",
    8: "input_blocks.7.1.transformer_blocks.0.attn1",
    9: "input_blocks.7.1.transformer_blocks.0.attn2",
    10: "input_blocks.8.1.transformer_blocks.0.attn1",
    11: "input_blocks.8.1.transformer_blocks.0.attn2",
    12: "output_blocks.3.1.transformer_blocks.0.attn1",
    13: "output_blocks.3.1.transformer_blocks.0.attn2",
    14: "output_blocks.4.1.transformer_blocks.0.attn1",
    15: "output_blocks.4.1.transformer_blocks.0.attn2",
    16: "output_blocks.5.1.transformer_blocks.0.attn1",
    17: "output_blocks.5.1.transformer_blocks.0.attn2",
    18: "output_blocks.6.1.transformer_blocks.0.attn1",
    19: "output_blocks.6.1.transformer_blocks.0.attn2",
    20: "output_blocks.7.1.transformer_blocks.0.attn1",
    21: "output_blocks.7.1.transformer_blocks.0.attn2",
    22: "output_blocks.8.1.transformer_blocks.0.attn1",
    23: "output_blocks.8.1.transformer_blocks.0.attn2",
    24: "output_blocks.9.1.transformer_blocks.0.attn1",
    25: "output_blocks.9.1.transformer_blocks.0.attn2",
    26: "output_blocks.10.1.transformer_blocks.0.attn1",
    27: "output_blocks.10.1.transformer_blocks.0.attn2",
    28: "output_blocks.11.1.transformer_blocks.0.attn1",
    29: "output_blocks.11.1.transformer_blocks.0.attn2",
    30: "middle_block.1.transformer_blocks.0.attn1",
    31: "middle_block.1.transformer_blocks.0.attn2",
}


def compute_cond_mark(cond_or_uncond, sigmas):
    cond_or_uncond_size = int(sigmas.shape[0])

    cond_mark = []
    for cx in cond_or_uncond:
        cond_mark += [cx] * cond_or_uncond_size

    cond_mark = torch.Tensor(cond_mark).to(sigmas)
    return cond_mark


class LoRALinearLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
        super().__init__()
        self.down = torch.nn.Linear(in_features, rank, bias=False)
        self.up = torch.nn.Linear(rank, out_features, bias=False)
        self.org = [org]

    def forward(self, h):
        org_weight = self.org[0].weight.to(h)
        org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
        down_weight = self.down.weight
        up_weight = self.up.weight
        final_weight = org_weight + torch.mm(up_weight, down_weight)
        return torch.nn.functional.linear(h, final_weight, org_bias)


class AttentionSharingUnit(torch.nn.Module):
    # `transformer_options` passed to the most recent BasicTransformerBlock.forward
    # call.
    transformer_options: dict = {}

    def __init__(self, module, frames=2, use_control=True, rank=256):
        super().__init__()

        self.heads = module.heads
        self.frames = frames
        self.original_module = [module]
        q_in_channels, q_out_channels = (
            module.to_q.in_features,
            module.to_q.out_features,
        )
        k_in_channels, k_out_channels = (
            module.to_k.in_features,
            module.to_k.out_features,
        )
        v_in_channels, v_out_channels = (
            module.to_v.in_features,
            module.to_v.out_features,
        )
        o_in_channels, o_out_channels = (
            module.to_out[0].in_features,
            module.to_out[0].out_features,
        )

        hidden_size = k_out_channels

        self.to_q_lora = [
            LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
            for _ in range(self.frames)
        ]
        self.to_k_lora = [
            LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
            for _ in range(self.frames)
        ]
        self.to_v_lora = [
            LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
            for _ in range(self.frames)
        ]
        self.to_out_lora = [
            LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
            for _ in range(self.frames)
        ]

        self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
        self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
        self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
        self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)

        self.temporal_i = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_n = torch.nn.LayerNorm(
            hidden_size, elementwise_affine=True, eps=1e-6
        )
        self.temporal_q = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_k = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_v = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_o = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )

        self.control_convs = None

        if use_control:
            self.control_convs = [
                torch.nn.Sequential(
                    torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(256, hidden_size, kernel_size=1),
                )
                for _ in range(self.frames)
            ]
            self.control_convs = torch.nn.ModuleList(self.control_convs)

        self.control_signals = None

    def forward(self, h, context=None, value=None):
        transformer_options = self.transformer_options

        modified_hidden_states = einops.rearrange(
            h, "(b f) d c -> f b d c", f=self.frames
        )

        if self.control_convs is not None:
            context_dim = int(modified_hidden_states.shape[2])
            control_outs = []
            for f in range(self.frames):
                control_signal = self.control_signals[context_dim].to(
                    modified_hidden_states
                )
                control = self.control_convs[f](control_signal)
                control = einops.rearrange(control, "b c h w -> b (h w) c")
                control_outs.append(control)
            control_outs = torch.stack(control_outs, dim=0)
            modified_hidden_states = modified_hidden_states + control_outs.to(
                modified_hidden_states
            )

        if context is None:
            framed_context = modified_hidden_states
        else:
            framed_context = einops.rearrange(
                context, "(b f) d c -> f b d c", f=self.frames
            )

        framed_cond_mark = einops.rearrange(
            compute_cond_mark(
                transformer_options["cond_or_uncond"],
                transformer_options["sigmas"],
            ),
            "(b f) -> f b",
            f=self.frames,
        ).to(modified_hidden_states)

        attn_outs = []
        for f in range(self.frames):
            fcf = framed_context[f]

            if context is not None:
                cond_overwrite = transformer_options.get("cond_overwrite", [])
                if len(cond_overwrite) > f:
                    cond_overwrite = cond_overwrite[f]
                else:
                    cond_overwrite = None
                if cond_overwrite is not None:
                    cond_mark = framed_cond_mark[f][:, None, None]
                    fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark

            q = self.to_q_lora[f](modified_hidden_states[f])
            k = self.to_k_lora[f](fcf)
            v = self.to_v_lora[f](fcf)
            o = optimized_attention(q, k, v, self.heads)
            o = self.to_out_lora[f](o)
            o = self.original_module[0].to_out[1](o)
            attn_outs.append(o)

        attn_outs = torch.stack(attn_outs, dim=0)
        modified_hidden_states = modified_hidden_states + attn_outs.to(
            modified_hidden_states
        )
        modified_hidden_states = einops.rearrange(
            modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
        )

        x = modified_hidden_states
        x = self.temporal_n(x)
        x = self.temporal_i(x)
        d = x.shape[1]

        x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)

        q = self.temporal_q(x)
        k = self.temporal_k(x)
        v = self.temporal_v(x)

        x = optimized_attention(q, k, v, self.heads)
        x = self.temporal_o(x)
        x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)

        modified_hidden_states = modified_hidden_states + x

        return modified_hidden_states - h

    @classmethod
    def hijack_transformer_block(cls):
        def register_get_transformer_options(func):
            @functools.wraps(func)
            def forward(self, x, context=None, transformer_options={}):
                cls.transformer_options = transformer_options
                return func(self, x, context, transformer_options)

            return forward

        from comfy.ldm.modules.attention import BasicTransformerBlock

        BasicTransformerBlock.forward = register_get_transformer_options(
            BasicTransformerBlock.forward
        )


AttentionSharingUnit.hijack_transformer_block()


class AdditionalAttentionCondsEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.blocks_0 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 64*64*256

        self.blocks_1 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 32*32*256

        self.blocks_2 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 16*16*256

        self.blocks_3 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 8*8*256

        self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]

    def __call__(self, h):
        results = {}
        for b in self.blks:
            h = b(h)
            results[int(h.shape[2]) * int(h.shape[3])] = h
        return results


class HookerLayers(torch.nn.Module):
    def __init__(self, layer_list):
        super().__init__()
        self.layers = torch.nn.ModuleList(layer_list)


class AttentionSharingPatcher(torch.nn.Module):
    def __init__(self, unet, frames=2, use_control=True, rank=256):
        super().__init__()
        model_management.unload_model_clones(unet)

        units = []
        for i in range(32):
            real_key = module_mapping_sd15[i]
            attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
            u = AttentionSharingUnit(
                attn_module, frames=frames, use_control=use_control, rank=rank
            )
            units.append(u)
            unet.add_object_patch("diffusion_model." + real_key, u)

        self.hookers = HookerLayers(units)

        if use_control:
            self.kwargs_encoder = AdditionalAttentionCondsEncoder()
        else:
            self.kwargs_encoder = None

        self.dtype = torch.float32
        if model_management.should_use_fp16(model_management.get_torch_device()):
            self.dtype = torch.float16
            self.hookers.half()
        return

    def set_control(self, img):
        img = img.cpu().float() * 2.0 - 1.0
        signals = self.kwargs_encoder(img)
        for m in self.hookers.layers:
            m.control_signals = signals
        return