import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

BLOCKS = {
    'content': ['down_blocks'],
    'style': ["up_blocks"],

}

controlnet_BLOCKS = {
    'content': [],
    'style': ["down_blocks"],
}


def resize_width_height(width, height, min_short_side=512, max_long_side=1024):

    if width < height:

        if width < min_short_side:
            scale_factor = min_short_side / width
            new_width = min_short_side
            new_height = int(height * scale_factor)
        else:
            new_width, new_height = width, height
    else:

        if height < min_short_side:
            scale_factor = min_short_side / height
            new_width = int(width * scale_factor)
            new_height = min_short_side
        else:
            new_width, new_height = width, height

    if max(new_width, new_height) > max_long_side:
        scale_factor = max_long_side / max(new_width, new_height)
        new_width = int(new_width * scale_factor)
        new_height = int(new_height * scale_factor)
    return new_width, new_height

def resize_content(content_image):
    max_long_side = 1024
    min_short_side = 1024

    new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
                                                min_short_side=min_short_side, max_long_side=max_long_side)
    height = new_height // 16 * 16
    width = new_width // 16 * 16
    content_image = content_image.resize((width, height))

    return width,height,content_image

attn_maps = {}
def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook

def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split('.')[-1].startswith('attn2'):
            module.register_forward_hook(hook_fn(name))

    return unet

def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1,0)
    temp_size = None

    for i in range(0,5):
        scale = 2 ** i
        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode='bilinear',
        align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size) 
        net_attn_maps.append(attn_map) 

    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

    return net_attn_maps

def attnmaps2images(net_attn_maps):

    #total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        #total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        #print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        #image = fix_save_attn_map(attn_map)
        images.append(image)

    #print(total_attn_scores)
    return images
def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator