import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )

ckpt = torch.load(hf_hub_download('tennant/MUG', 'laion_mug_vit_base_5ep.pth'), map_location='cpu')

new_dict = OrderedDict()
for k, v in ckpt.items():
    k = k[len('image_encoder.model.'):]
    new_dict.update({k: v})

model = mae_vit_base_patch16(uni_dim=768, uni_heads=12, less_u=True)

msg = model.load_state_dict(new_dict, strict=False)
print(msg)
if torch.cuda.is_available():
    model.cuda()
model.eval()

@torch.no_grad()
def visual_recon(x, model, mask_ratio=0.75):
    target = model.patchify(x)
    mean = target.mean(dim=-1, keepdim=True)
    var = target.var(dim=-1, keepdim=True)

    latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio)
    y, _ = model.forward_decoder(latent, ids_restore)
    y = y * (var + 1.e-6)**.5 + mean
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y)
    
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask)
    
    x = torch.einsum('nchw->nhwc', x)
    
    return x * (1 - mask), x * (1 - mask) + y * mask, y, latent

@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
    assert latent.shape[0] == 1, 'can only caption one image at a time'
    
    x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
    seq = x_l.shape[1]
    if torch.cuda.is_available():
        x_l = x_l.cuda()

    cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

    x_l = model.embed_text(x_l)

    for cross_attn1, cross_attn2 in model.multimodal_layers:
        x_l = cross_attn1(x_l, latent)
        x_l = cross_attn2(x_l, latent)

    pred = model.to_logits(x_l)
    pred[:, :, 103] = -100
    pred[:, :, 101] = -100
    pred[:, :, 100] = -100
    pred[:, :, 0] = -100
    next_word = pred.argmax(dim=-1)[0, -1]
    next_word = tokenizer.decode(next_word)
    
    return next_word

def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
    words = prefix.split()
    while len(words) < max_len:
        next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
        words.append(next_word)
        if next_word == '[SEP]':
            break
    return ' '.join(words)


def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'):
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    x = np.array(x) / 255.
    x = x - imagenet_mean
    x = x / imagenet_std

    x = torch.tensor(x).float()
    x = x.unsqueeze(0)
    x = torch.einsum('nhwc->nchw', x)
    if torch.cuda.is_available():
        x = x.cuda()
        
    def unnorm_pix(img):
        img = img.squeeze(0).cpu().detach().numpy()
        img = img * imagenet_std + imagenet_mean
        return np.clip(img, a_min=0., a_max=1.)

    masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio)
    caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix)
    
    masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
    return_img = np.concatenate([masked, masked_recon, recon], axis=1)
    
    return return_img, caption_from_model

import gradio as gr

demo = gr.Interface(gr_caption, 
                    inputs=[gr.Image(value='cat.jpeg', shape=(224, 224)),
                            gr.Number(value=0.75, label='mask ratio'),
                            gr.Number(value=20, label='max length'),
                            gr.Textbox(value='a photo of a', label='caption prefix')], 
                    outputs=[gr.Image(shape=(224, 224 * 3)), 
                             'text'],
                    # examples=[['cat.jpeg', 0.75, 20, 'a photo of a']],
                )
demo.launch()