################################################################################
# Copyright (C) 2023 Xingqian Xu - All Rights Reserved                         #
#                                                                              #
# Please visit Prompt-Free-Diffusion's arXiv paper for more details, link at   #
# arxiv.org/abs/2305.16223                                                     #
#                                                                              #
################################################################################

import gradio as gr
import os.path as osp
from PIL import Image
import numpy as np
import time

import torch
import torchvision.transforms as tvtrans
from lib.cfg_helper import model_cfg_bank
from lib.model_zoo import get_model

from collections import OrderedDict
from lib.model_zoo.ddim import DDIMSampler

from huggingface_hub import hf_hub_download

n_sample_image = 1

# controlnet_path = OrderedDict([
#     ['canny'             , ('canny'   , 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors')],
#     ['canny_v11p'        , ('canny'   , 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors')],
#     ['depth'             , ('depth'   , 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors')],
#     ['hed'               , ('hed'     , 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors')],
#     ['mlsd'              , ('mlsd'    , 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors')],
#     ['mlsd_v11p'         , ('mlsd'    , 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors')],
#     ['normal'            , ('normal'  , 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors')],
#     ['openpose'          , ('openpose', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors')],
#     ['openpose_v11p'     , ('openpose', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors')],
#     ['scribble'          , ('scribble', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors')],
#     ['softedge_v11p'     , ('scribble', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors')],
#     ['seg'               , ('none'    , 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors')],
#     ['lineart_v11p'      , ('none'    , 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors')],
#     ['lineart_anime_v11p', ('none'    , 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors')],
# ])

controlnet_path = OrderedDict([
    ['canny'             , ('canny'   , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors'))],
    ['canny_v11p'        , ('canny'   , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors'))],
    ['depth'             , ('depth'   , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors'))],
    ['hed'               , ('hed'     , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors'))],
    ['mlsd'              , ('mlsd'    , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors'))],
    ['mlsd_v11p'         , ('mlsd'    , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors'))],
    ['normal'            , ('normal'  , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors'))],
    ['openpose'          , ('openpose', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors'))],
    ['openpose_v11p'     , ('openpose', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors'))],
    ['scribble'          , ('scribble', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors'))],
    ['softedge_v11p'     , ('scribble', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors'))],
    ['seg'               , ('none'    , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors'))],
    ['lineart_v11p'      , ('none'    , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors'))],
    ['lineart_anime_v11p', ('none'    , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors'))],
])

preprocess_method = [
    'canny'                ,
    'depth'                ,
    'hed'                  ,
    'mlsd'                 ,
    'normal'               ,
    'openpose'             ,
    'openpose_withface'    ,
    'openpose_withfacehand',
    'scribble'             ,
    'none'                 ,
]

# diffuser_path = OrderedDict([
#     ['SD-v1.5'             , 'pretrained/pfd/diffuser/SD-v1-5.safetensors'],
#     ['OpenJouney-v4'       , 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors'],
#     ['Deliberate-v2.0'     , 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors'],
#     ['RealisticVision-v2.0', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors'],
#     ['Anything-v4'         , 'pretrained/pfd/diffuser/Anything-v4.safetensors'],
#     ['Oam-v3'              , 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors'],
#     ['Oam-v2'              , 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors'],
# ])

# ctxencoder_path = OrderedDict([
#     ['SeeCoder'      , 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors'],
#     ['SeeCoder-PA'   , 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors'],
#     ['SeeCoder-Anime', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors'],
# ])

diffuser_path = OrderedDict([
    ['SD-v1.5'             , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/SD-v1-5.safetensors')],
    ['OpenJouney-v4'       , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors')],
    ['Deliberate-v2.0'     , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors')],
    ['RealisticVision-v2.0', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors')],
    ['Anything-v4'         , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/Anything-v4.safetensors')],
    ['Oam-v3'              , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors')],
    ['Oam-v2'              , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors')],
])

ctxencoder_path = OrderedDict([
    ['SeeCoder'      , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors')],
    ['SeeCoder-PA'   , hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors')],
    ['SeeCoder-Anime', hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors')],
])

##########
# helper #
##########

def highlight_print(info):
    print('')
    print(''.join(['#']*(len(info)+4)))
    print('# '+info+' #')
    print(''.join(['#']*(len(info)+4)))
    print('')

def load_sd_from_file(target):
    if osp.splitext(target)[-1] == '.ckpt':
        sd = torch.load(target, map_location='cpu')['state_dict']
    elif osp.splitext(target)[-1] == '.pth':
        sd = torch.load(target, map_location='cpu')
    elif osp.splitext(target)[-1] == '.safetensors':
        from safetensors.torch import load_file as stload
        sd = OrderedDict(stload(target, device='cpu'))
    else:
        assert False, "File type must be .ckpt or .pth or .safetensors"
    return sd

########
# main #
########

class prompt_free_diffusion(object):
    def __init__(self, 
                 fp16=False, 
                 tag_ctx=None,
                 tag_diffuser=None,
                 tag_ctl=None,):

        self.tag_ctx = tag_ctx
        self.tag_diffuser = tag_diffuser
        self.tag_ctl = tag_ctl
        self.strict_sd = True

        cfgm = model_cfg_bank()('pfd_seecoder_with_controlnet')
        self.net = get_model()(cfgm)
        sdvae = hf_hub_download('shi-labs/prompt-free-diffusion', 'pretrained/pfd/vae/sd-v2-0-base-autokl.pth')
        sdvae = torch.load(sdvae)
        self.net.vae.load_state_dict(sdvae)
        
        self.action_load_ctx(tag_ctx)
        self.action_load_diffuser(tag_diffuser)
        self.action_load_ctl(tag_ctl)
 
        if fp16:
            highlight_print('Running in FP16')
            self.net.ctx['image'].fp16 = True
            self.net = self.net.half()
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.net.to('cuda')

        self.net.eval()
        self.sampler = DDIMSampler(self.net)

        self.n_sample_image = n_sample_image
        self.ddim_steps = 50
        self.ddim_eta = 0.0
        self.image_latent_dim = 4

    def load_ctx(self, pretrained):
        sd = load_sd_from_file(pretrained)
        sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
            if ki.find('ctx.')!=0]
        sd.update(OrderedDict(sd_extra))

        self.net.load_state_dict(sd, strict=True)
        print('Load context encoder from [{}] strict [{}].'.format(pretrained, True))

    def load_diffuser(self, pretrained):
        sd = load_sd_from_file(pretrained)
        if len([ki for ki in sd.keys() if ki.find('diffuser.image.context_blocks.')==0]) == 0:
            sd = [(
                ki.replace('diffuser.text.context_blocks.', 'diffuser.image.context_blocks.'), vi) 
                    for ki, vi in sd.items()]
            sd = OrderedDict(sd)
        sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
            if ki.find('diffuser.')!=0]
        sd.update(OrderedDict(sd_extra))
        self.net.load_state_dict(sd, strict=True)
        print('Load diffuser from [{}] strict [{}].'.format(pretrained, True))

    def load_ctl(self, pretrained):
        sd = load_sd_from_file(pretrained)
        self.net.ctl.load_state_dict(sd, strict=True)
        print('Load controlnet from [{}] strict [{}].'.format(pretrained, True))

    def action_load_ctx(self, tag):
        pretrained = ctxencoder_path[tag]
        if tag == 'SeeCoder-PA':
            from lib.model_zoo.seecoder import PPE_MLP
            pe_layer = \
                PPE_MLP(freq_num=20, freq_max=None, out_channel=768, mlp_layer=3)
            if self.dtype == torch.float16:
                pe_layer = pe_layer.half()
            if self.use_cuda:
                pe_layer.to('cuda')
            pe_layer.eval()
            self.net.ctx['image'].qtransformer.pe_layer = pe_layer
        else:
            self.net.ctx['image'].qtransformer.pe_layer = None
        if pretrained is not None:
            self.load_ctx(pretrained)
        self.tag_ctx = tag
        return tag

    def action_load_diffuser(self, tag):
        pretrained = diffuser_path[tag]
        if pretrained is not None:
            self.load_diffuser(pretrained)
        self.tag_diffuser = tag
        return tag

    def action_load_ctl(self, tag):
        pretrained = controlnet_path[tag][1]
        if pretrained is not None:
            self.load_ctl(pretrained)
        self.tag_ctl = tag
        return tag

    def action_autoset_hw(self, imctl):
        if imctl is None:
            return 512, 512
        w, h = imctl.size
        w = w//64 * 64
        h = h//64 * 64
        w = w if w >=512 else 512
        w = w if w <=1536 else 1536
        h = h if h >=512 else 512
        h = h if h <=1536 else 1536
        return h, w

    def action_autoset_method(self, tag):
        return controlnet_path[tag][0]

    def action_inference(
            self, im, imctl, ctl_method, do_preprocess, 
            h, w, ugscale, seed, 
            tag_ctx, tag_diffuser, tag_ctl,):

        if tag_ctx != self.tag_ctx:
            self.action_load_ctx(tag_ctx)
        if tag_diffuser != self.tag_diffuser:
            self.action_load_diffuser(tag_diffuser)
        if tag_ctl != self.tag_ctl:
            self.action_load_ctl(tag_ctl)

        n_samples = self.n_sample_image

        sampler = self.sampler
        device = self.net.device

        w = w//64 * 64
        h = h//64 * 64
        if imctl is not None:
            imctl = imctl.resize([w, h], Image.Resampling.BICUBIC)

        craw = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
        c = self.net.ctx_encode(craw, which='image').repeat(n_samples, 1, 1)
        u = torch.zeros_like(c)

        if tag_ctx in ["SeeCoder-Anime"]:
            u = torch.load('assets/anime_ug.pth')[None].to(device).to(self.dtype)
            pad = c.size(1) - u.size(1)
            u = torch.cat([u, torch.zeros_like(u[:, 0:1].repeat(1, pad, 1))], axis=1)

        if tag_ctl != 'none':
            ccraw = tvtrans.ToTensor()(imctl)[None].to(device).to(self.dtype)
            if do_preprocess:
                cc = self.net.ctl.preprocess(ccraw, type=ctl_method, size=[h, w])
                cc = cc.to(self.dtype)
            else:
                cc = ccraw
        else:
            cc = None

        shape = [n_samples, self.image_latent_dim, h//8, w//8]

        if seed < 0:
            np.random.seed(int(time.time()))
            torch.manual_seed(-seed + 100)
        else:
            np.random.seed(seed + 100)
            torch.manual_seed(seed)

        x, _ = sampler.sample(
            steps=self.ddim_steps,
            x_info={'type':'image',},
            c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, 
                    'unconditional_guidance_scale':ugscale,
                    'control':cc,},
            shape=shape,
            verbose=False,
            eta=self.ddim_eta)

        ccout = [tvtrans.ToPILImage()(i) for i in cc] if cc is not None else []
        imout = self.net.vae_decode(x, which='image')
        imout = [tvtrans.ToPILImage()(i) for i in imout]
        return imout + ccout

pfd_inference = prompt_free_diffusion(
    fp16=True, tag_ctx = 'SeeCoder', tag_diffuser = 'Deliberate-v2.0', tag_ctl = 'canny',)

#################
# sub interface #
#################

cache_examples = True

def get_example():
    case = [
        [
            'assets/examples/ghibli-input.jpg', 
            'assets/examples/ghibli-canny.png', 
            'canny', False, 
            768, 1024, 1.8, 23, 
            'SeeCoder', 'Deliberate-v2.0', 'canny', ],
        [
            'assets/examples/astronautridinghouse-input.jpg', 
            'assets/examples/astronautridinghouse-canny.png', 
            'canny', False, 
            512, 768, 2.0, 21, 
            'SeeCoder', 'Deliberate-v2.0', 'canny', ],
        [
            'assets/examples/grassland-input.jpg', 
            'assets/examples/grassland-scribble.png', 
            'scribble', False, 
            768, 512, 2.0, 41, 
            'SeeCoder', 'Deliberate-v2.0', 'scribble', ],
        [
            'assets/examples/jeep-input.jpg', 
            'assets/examples/jeep-depth.png', 
            'depth', False, 
            512, 768, 2.0, 30, 
            'SeeCoder', 'Deliberate-v2.0', 'depth', ],
        [
            'assets/examples/bedroom-input.jpg', 
            'assets/examples/bedroom-mlsd.png', 
            'mlsd', False, 
            512, 512, 2.0, 31, 
            'SeeCoder', 'Deliberate-v2.0', 'mlsd', ],
        [
            'assets/examples/nightstreet-input.jpg', 
            'assets/examples/nightstreet-canny.png', 
            'canny', False, 
            768, 512, 2.3, 20, 
            'SeeCoder', 'Deliberate-v2.0', 'canny', ],
        [
            'assets/examples/woodcar-input.jpg', 
            'assets/examples/woodcar-depth.png', 
            'depth', False, 
            768, 512, 2.0, 20, 
            'SeeCoder', 'Deliberate-v2.0', 'depth', ],
        [
            'assets/examples-anime/miku.jpg', 
            'assets/examples-anime/miku-canny.png', 
            'canny', False, 
            768, 576, 1.5, 22, 
            'SeeCoder-Anime', 'Anything-v4', 'canny', ],
        [
            'assets/examples-anime/random0.jpg', 
            'assets/examples-anime/pose.png', 
            'openpose', False, 
            768, 1536, 2.0, 41, 
            'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ],
        [
            'assets/examples-anime/random1.jpg', 
            'assets/examples-anime/pose.png', 
            'openpose', False, 
            768, 1536, 2.5, 28, 
            'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ], 
        [
            'assets/examples-anime/camping.jpg', 
            'assets/examples-anime/pose.png', 
            'openpose', False, 
            768, 1536, 2.0, 35, 
            'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
        [
            'assets/examples-anime/hanfu_girl.jpg', 
            'assets/examples-anime/pose.png', 
            'openpose', False, 
            768, 1536, 2.0, 20, 
            'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
    ]
    return case

def interface():
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
            with gr.Row():
                out_width  = gr.Slider(label="Width" , minimum=512, maximum=1536, value=512, step=64, visible=True)
                out_height = gr.Slider(label="Height", minimum=512, maximum=1536, value=512, step=64, visible=True)
            with gr.Row():
                scl_lvl = gr.Slider(label="CFGScale", minimum=0, maximum=10, value=2, step=0.01, visible=True)
                seed = gr.Number(20, label="Seed", precision=0)
            with gr.Row():
                tag_ctx = gr.Dropdown(label='Context Encoder', choices=[pi for pi in ctxencoder_path.keys()], value='SeeCoder')
                tag_diffuser = gr.Dropdown(label='Diffuser', choices=[pi for pi in diffuser_path.keys()], value='Deliberate-v2.0')
            button = gr.Button("Run")
        with gr.Column():
            ctl_input = gr.Image(label='Control Input', type='pil', elem_id='customized_imbox')
            do_preprocess = gr.Checkbox(label='Preprocess', value=False)
            with gr.Row():
                ctl_method = gr.Dropdown(label='Preprocess Type', choices=preprocess_method, value='canny')
                tag_ctl    = gr.Dropdown(label='ControlNet',      choices=[pi for pi in controlnet_path.keys()], value='canny')
        with gr.Column():
            img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image+1)

    tag_ctl.change(
        pfd_inference.action_autoset_method,
        inputs = [tag_ctl],
        outputs = [ctl_method],)

    ctl_input.change(
        pfd_inference.action_autoset_hw,
        inputs = [ctl_input],
        outputs = [out_height, out_width],)

    # tag_ctx.change(
    #     pfd_inference.action_load_ctx,
    #     inputs = [tag_ctx],
    #     outputs = [tag_ctx],)

    # tag_diffuser.change(
    #     pfd_inference.action_load_diffuser,
    #     inputs = [tag_diffuser],
    #     outputs = [tag_diffuser],)

    # tag_ctl.change(
    #     pfd_inference.action_load_ctl,
    #     inputs = [tag_ctl],
    #     outputs = [tag_ctl],)

    button.click(
        pfd_inference.action_inference,
        inputs=[img_input, ctl_input, ctl_method, do_preprocess, 
                out_height, out_width, scl_lvl, seed, 
                tag_ctx, tag_diffuser, tag_ctl, ],
        outputs=[img_output])
    
    gr.Examples(
        label='Examples', 
        examples=get_example(), 
        fn=pfd_inference.action_inference,
        inputs=[img_input, ctl_input, ctl_method, do_preprocess,
                out_height, out_width, scl_lvl, seed, 
                tag_ctx, tag_diffuser, tag_ctl, ],
        outputs=[img_output],
        cache_examples=cache_examples,)

#############
# Interface #
#############

css = """
    #customized_imbox {
        min-height: 450px;
    }
    #customized_imbox>div[data-testid="image"] {
        min-height: 450px;
    }
    #customized_imbox>div[data-testid="image"]>div {
        min-height: 450px;
    }
    #customized_imbox>div[data-testid="image"]>iframe {
        min-height: 450px;
    }
    #customized_imbox>div.unpadded_box {
        min-height: 450px;
    }
    #myinst {
        font-size: 0.8rem; 
        margin: 0rem;
        color: #6B7280;
    }
    #maskinst {
        text-align: justify;
        min-width: 1200px;
    }
    #maskinst>img {
        min-width:399px;
        max-width:450px;
        vertical-align: top;
        display: inline-block;
    }
    #maskinst:after {
        content: "";
        width: 100%;
        display: inline-block;
    }
"""

if True:
    with gr.Blocks(css=css) as demo:
        gr.HTML(
            """
            <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
            <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
                Prompt-Free Diffusion
            </h1>
            </div>
            """)

        interface()

        # gr.HTML(
        #     """
        #     <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
        #     <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
        #     <b>Version</b>: {}
        #     </h3>
        #     </div>
        #     """.format(' '+str(pfd_inference.pretrained)))

    # demo.launch(server_name="0.0.0.0", server_port=7992)
    # demo.launch()
    demo.launch(debug=True)