# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from re import U

import numpy as np

from einops import rearrange

from .masactrl_utils import AttentionBase

from torchvision.utils import save_image

import sys

import torch
import torch.nn.functional as F
from torch import nn
import torch.fft as fft

from einops import rearrange, repeat
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
# from masactrl.masactrl import MutualSelfAttentionControl

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None



class AttentionBase:
    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

    def after_step(self):
        pass

    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            # after step
            self.after_step()
        return out

    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
        return out

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0


class MaskPromptedStyleAttentionControl(AttentionBase):
    def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1, 
                 only_masked_region=False, guidance=0.0, 
                 style_mask=None, source_mask=None, de_bug=False):
        """
        MaskPromptedSAC
        Args:
            start_step: the step to start mutual self-attention control
            start_layer: the layer to start mutual self-attention control
            layer_idx: list of the layers to apply mutual self-attention control
            step_idx: list the steps to apply mutual self-attention control
            total_steps: the total number of steps
            thres: the thereshold for mask thresholding
            ref_token_idx: the token index list for cross-attention map aggregation
            cur_token_idx: the token index list for cross-attention map aggregation
            mask_save_dir: the path to save the mask image
        """

        super().__init__()
        self.total_steps = total_steps
        self.total_layers = 16
        self.start_step = start_step
        self.start_layer = start_layer
        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
        self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
        print("using MaskPromptStyleAttentionControl")
        print("MaskedSAC at denoising steps: ", self.step_idx)
        print("MaskedSAC at U-Net layers: ", self.layer_idx)
        
        self.de_bug = de_bug
        self.style_guidance = style_guidance
        self.only_masked_region = only_masked_region
        self.style_attn_step = style_attn_step
        self.self_attns = []
        self.cross_attns = []
        self.guidance = guidance
        self.style_mask = style_mask
        self.source_mask = source_mask


    def after_step(self):
        self.self_attns = []
        self.cross_attns = []

    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
        B = q.shape[0] // num_heads
        H = W = int(np.sqrt(q.shape[1]))
        q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
        k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
        v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)

        sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
        
        if q_mask is not None:
            sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
            
        if k_mask is not None:
            sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
        
        attn = sim.softmax(-1) if attn is None else attn

        if len(attn) == 2 * len(v):
            v = torch.cat([v] * 2)
        out = torch.einsum("h i j, h j d -> h i d", attn, v)
        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
        return out
    
    def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
        B = q.shape[0] // num_heads
        H = W = int(np.sqrt(q.shape[1]))
        q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
        k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
        v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
        sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
        if q_mask is not None:
            sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
            sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
        if k_mask is not None:
            sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
            sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
        sim = torch.cat([sim_fg, sim_bg])
        attn = sim.softmax(-1)

        if len(attn) == 2 * len(v):
            v = torch.cat([v] * 2)
        out = torch.einsum("h i j, h j d -> h i d", attn, v)
        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
        return out
   
    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):

        """
        Attention forward function
        """
        
        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)

        B = q.shape[0] // num_heads // 2
        H = W = int(np.sqrt(q.shape[1]))
        
        if self.style_mask is not None and self.source_mask is not None:
            #mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx)  # (4, H, W)
            heigh, width = self.style_mask.shape[-2:]
            mask_style = self.style_mask# (H, W)
            mask_source = self.source_mask# (H, W)
            scale = int(np.sqrt(heigh * width / q.shape[1]))
            # res = int(np.sqrt(q.shape[1]))
            spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1)
            spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1)
            
        else:
            spatial_mask_source=None
            spatial_mask_style=None

        if spatial_mask_style is None or spatial_mask_source is None:
            
            out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
        
        else:
            if self.only_masked_region:
                out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
            else:
                out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)

        out = torch.cat([out_s,out_c,out_t],dim=0)  
        return out
    

    def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
        if self.de_bug:
            import pdb; pdb.set_trace()
        
        qs, qc, qt = q.chunk(3)

        out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
        out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)

        if self.cur_step < self.style_attn_step:
            out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
        else:
            out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
            if self.style_guidance>=0:
                out_t = out_c + (out_t - out_c) * self.style_guidance
        return out_s,out_c,out_t

    def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
        qs, qc, qt = q.chunk(3)
        
        out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
        out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
        out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
        
        if self.de_bug:
            import pdb; pdb.set_trace()

        if self.cur_step < self.style_attn_step:
            out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
        else:
            out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
            out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
            if self.style_guidance>=0:
                out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance 
            
            out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source)

        if self.de_bug:
            import pdb; pdb.set_trace()
        
        # print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source)))
        return out_s,out_c,out_t

    def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
        
        if self.de_bug:
            import pdb; pdb.set_trace()
        # To prevent query confusion, render fg and bg according to mask.
        qs, qc, qt = q.chunk(3)
        out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
        if self.cur_step < self.style_attn_step: 
            
            out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
            out_c_fg,out_c_bg = out_c.chunk(2)
            out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source)

        else:
            out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
            out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
            out_t_fg,out_t_bg = out_t.chunk(2)
            out_c_fg,out_c_bg = out_c.chunk(2)
            if self.style_guidance>=0:
                out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance 
                out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance 
            out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source)
        
        return out_s,out_t,out_t