# %%
import argparse, os


import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import (
    StableDiffusionPipeline,
    DDIMScheduler,
    DiffusionPipeline,
    StableDiffusionXLPipeline,
)
from diffusers.image_processor import VaeImageProcessor
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import argparse
import PIL.Image as Image
from torchvision.utils import make_grid
import numpy
from diffusers.schedulers import DDIMScheduler
import torch.nn.functional as F
from models import attn_injection
from omegaconf import OmegaConf
from typing import List, Tuple

import omegaconf
import utils.exp_utils
import json

device = torch.device("cuda")


def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
    # Tokenize text and get embeddings
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    with torch.no_grad():
        prompt_embeds = text_encoder(
            text_input_ids.to(device),
            output_hidden_states=True,
        )

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    if prompt == "":
        negative_prompt_embeds = torch.zeros_like(prompt_embeds)
        negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
        return negative_prompt_embeds, negative_pooled_prompt_embeds
    return prompt_embeds, pooled_prompt_embeds


def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str):
    device = model._execution_device
    (
        prompt_embeds,
        pooled_prompt_embeds,
    ) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
    (
        prompt_embeds_2,
        pooled_prompt_embeds_2,
    ) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device)
    prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
    text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
    add_time_ids = model._get_add_time_ids(
        (1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim
    ).to(device)
    # repeat the time ids for each prompt
    add_time_ids = add_time_ids.repeat(len(prompt), 1)
    added_cond_kwargs = {
        "text_embeds": pooled_prompt_embeds_2,
        "time_ids": add_time_ids,
    }
    return added_cond_kwargs, prompt_embeds


def _encode_text_sdxl_with_negative(
    model: StableDiffusionXLPipeline, prompt: List[str]
):

    B = len(prompt)
    added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
    added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(
        model, ["" for _ in range(B)]
    )
    prompt_embeds = torch.cat(
        (
            prompt_embeds_uncond,
            prompt_embeds,
        )
    )
    added_cond_kwargs = {
        "text_embeds": torch.cat(
            (added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])
        ),
        "time_ids": torch.cat(
            (added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])
        ),
    }
    return added_cond_kwargs, prompt_embeds


# Sample function (regular DDIM)
@torch.no_grad()
def sample(
    pipe,
    prompt,
    start_step=0,
    start_latents=None,
    intermediate_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):
    negative_prompt = [""] * len(prompt)
    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, prompt
        )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    latents = latents.repeat(len(prompt), 1, 1, 1)
    # assume that the first latent is used for reconstruction
    for i in tqdm(range(start_step, num_inference_steps)):
        latents[0] = intermediate_latents[(-i + 1)]
        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
        latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

    # Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images


# Sample function (regular DDIM), but disentangle the content and style
@torch.no_grad()
def sample_disentangled(
    pipe,
    prompt,
    start_step=0,
    start_latents=None,
    intermediate_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    use_content_anchor=True,
    negative_prompt="",
    device=device,
):
    negative_prompt = [""] * len(prompt)
    vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor)
    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, prompt
        )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)
    # save

    latent_shape = (
        (1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64)
    )
    generative_latent = torch.randn(latent_shape, device=device)
    generative_latent *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    latents = latents.repeat(len(prompt), 1, 1, 1)
    # randomly initalize the 1st lantent for generation

    latents[1] = generative_latent
    # assume that the first latent is used for reconstruction
    for i in tqdm(range(start_step, num_inference_steps), desc="Stylizing"):

        if use_content_anchor:
            latents[0] = intermediate_latents[(-i + 1)]
        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Post-processing
        # images = vae_decoder.postprocess(latents)
    pipe.vae.to(dtype=torch.float32)
    latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
    latents = 1 / pipe.vae.config.scaling_factor * latents
    images = pipe.vae.decode(latents, return_dict=False)[0]
    images = (images / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    images = pipe.numpy_to_pil(images)
    if isinstance(pipe, StableDiffusionXLPipeline):
        pipe.vae.to(dtype=torch.float16)

    return images


## Inversion
@torch.no_grad()
def invert(
    pipe,
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=50,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
        latents = start_latents.clone().detach()
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, [prompt]
        )  # Latents are now the specified start latents
        latents = start_latents.clone().detach().half()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(
        range(1, num_inference_steps),
        total=num_inference_steps - 1,
        desc="DDIM Inversion",
    ):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (
            alpha_t_next.sqrt() / alpha_t.sqrt()
        ) + (1 - alpha_t_next).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)


def style_image_with_inversion(
    pipe,
    input_image,
    input_image_prompt,
    style_prompt,
    num_steps=100,
    start_step=30,
    guidance_scale=3.5,
    disentangle=False,
    share_attn=False,
    share_cross_attn=False,
    share_resnet_layers=[0, 1],
    share_attn_layers=[],
    c2s_layers=[0, 1],
    share_key=True,
    share_query=True,
    share_value=False,
    use_adain=True,
    use_content_anchor=True,
    output_dir: str = None,
    resnet_mode: str = None,
    return_intermediate=False,
    intermediate_latents=None,
):
    with torch.no_grad():
        pipe.vae.to(dtype=torch.float32)
        latent = pipe.vae.encode(input_image.to(device) * 2 - 1)
        # latent = pipe.vae.encode(input_image.to(device))
        l = pipe.vae.config.scaling_factor * latent.latent_dist.sample()
        if isinstance(pipe, StableDiffusionXLPipeline):
            pipe.vae.to(dtype=torch.float16)
    if intermediate_latents is None:
        inverted_latents = invert(
            pipe, l, input_image_prompt, num_inference_steps=num_steps
        )
    else:
        inverted_latents = intermediate_latents

    attn_injection.register_attention_processors(
        pipe,
        base_dir=output_dir,
        resnet_mode=resnet_mode,
        attn_mode="artist" if disentangle else "pnp",
        disentangle=disentangle,
        share_resblock=True,
        share_attn=share_attn,
        share_cross_attn=share_cross_attn,
        share_resnet_layers=share_resnet_layers,
        share_attn_layers=share_attn_layers,
        share_key=share_key,
        share_query=share_query,
        share_value=share_value,
        use_adain=use_adain,
        c2s_layers=c2s_layers,
    )

    if disentangle:
        final_im = sample_disentangled(
            pipe,
            style_prompt,
            start_latents=inverted_latents[-(start_step + 1)][None],
            intermediate_latents=inverted_latents,
            start_step=start_step,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            use_content_anchor=use_content_anchor,
        )
    else:
        final_im = sample(
            pipe,
            style_prompt,
            start_latents=inverted_latents[-(start_step + 1)][None],
            intermediate_latents=inverted_latents,
            start_step=start_step,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
        )

    # unset the attention processors
    attn_injection.unset_attention_processors(
        pipe,
        unset_share_attn=True,
        unset_share_resblock=True,
    )
    if return_intermediate:
        return final_im, inverted_latents
    return final_im


if __name__ == "__main__":

    # Load a pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base"
    ).to(device)

    # pipe = DiffusionPipeline.from_pretrained(
    #     # "playgroundai/playground-v2-1024px-aesthetic",
    #     torch_dtype=torch.float16,
    #     use_safetensors=True,
    #     add_watermarker=False,
    #     variant="fp16",
    # )
    # pipe.to("cuda")

    # Set up a DDIM scheduler
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf")
    parser.add_argument(
        "--config", type=str, default="config.yaml", help="Path to the config file"
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="dataset",
        choices=["dataset", "cli", "app"],
        help="Path to the config file",
    )
    parser.add_argument(
        "--image_dir", type=str, default="test.png", help="Path to the image"
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="an impressionist painting",
        help="Stylization prompt",
    )
    # mode = "single_control_content"
    args = parser.parse_args()
    config_dir = args.config
    mode = args.mode
    # mode = "dataset"
    out_name = ["content_delegation", "style_delegation", "style_out"]

    if mode == "dataset":
        cfg = OmegaConf.load(config_dir)

        base_output_path = cfg.out_path
        if not os.path.exists(cfg.out_path):
            os.makedirs(cfg.out_path)
        base_output_path = os.path.join(base_output_path, cfg.exp_name)

        experiment_output_path = utils.exp_utils.make_unique_experiment_path(
            base_output_path
        )

        # Save the experiment configuration
        config_file_path = os.path.join(experiment_output_path, "config.yaml")
        omegaconf.OmegaConf.save(cfg, config_file_path)

        # Seed all

        annotation = json.load(open(cfg.annotation))
        with open(os.path.join(experiment_output_path, "annotation.json"), "w") as f:
            json.dump(annotation, f)
        for i, entry in enumerate(annotation):
            utils.exp_utils.seed_all(cfg.seed)
            image_path = entry["image_path"]
            src_prompt = entry["source_prompt"]
            tgt_prompt = entry["target_prompt"]
            resolution = 512 if isinstance(pipe, StableDiffusionXLPipeline) else 512
            input_image = utils.exp_utils.get_processed_image(
                image_path, device, resolution
            )

            prompt_in = [
                src_prompt,  # reconstruction
                tgt_prompt,  # uncontrolled style
                "",  # controlled style
            ]

            imgs = style_image_with_inversion(
                pipe,
                input_image,
                src_prompt,
                style_prompt=prompt_in,
                num_steps=cfg.num_steps,
                start_step=cfg.start_step,
                guidance_scale=cfg.style_cfg_scale,
                disentangle=cfg.disentangle,
                resnet_mode=cfg.resnet_mode,
                share_attn=cfg.share_attn,
                share_cross_attn=cfg.share_cross_attn,
                share_resnet_layers=cfg.share_resnet_layers,
                share_attn_layers=cfg.share_attn_layers,
                share_key=cfg.share_key,
                share_query=cfg.share_query,
                share_value=cfg.share_value,
                use_content_anchor=cfg.use_content_anchor,
                use_adain=cfg.use_adain,
                output_dir=experiment_output_path,
            )

            for j, img in enumerate(imgs):
                img.save(f"{experiment_output_path}/out_{i}_{out_name[j]}.png")
                print(
                    f"Image saved as {experiment_output_path}/out_{i}_{out_name[j]}.png"
                )
    elif mode == "cli":
        cfg = OmegaConf.load(config_dir)
        utils.exp_utils.seed_all(cfg.seed)
        image = utils.exp_utils.get_processed_image(args.image_dir, device, 512)
        tgt_prompt = args.prompt
        src_prompt = ""
        prompt_in = [
            "",  # reconstruction
            tgt_prompt,  # uncontrolled style
            "",  # controlled style
        ]
        out_dir = "./out"
        os.makedirs(out_dir, exist_ok=True)
        imgs = style_image_with_inversion(
            pipe,
            image,
            src_prompt,
            style_prompt=prompt_in,
            num_steps=cfg.num_steps,
            start_step=cfg.start_step,
            guidance_scale=cfg.style_cfg_scale,
            disentangle=cfg.disentangle,
            resnet_mode=cfg.resnet_mode,
            share_attn=cfg.share_attn,
            share_cross_attn=cfg.share_cross_attn,
            share_resnet_layers=cfg.share_resnet_layers,
            share_attn_layers=cfg.share_attn_layers,
            share_key=cfg.share_key,
            share_query=cfg.share_query,
            share_value=cfg.share_value,
            use_content_anchor=cfg.use_content_anchor,
            use_adain=cfg.use_adain,
            output_dir=out_dir,
        )
        image_base_name = os.path.basename(args.image_dir).split(".")[0]
        for j, img in enumerate(imgs):
            img.save(f"{out_dir}/{image_base_name}_out_{out_name[j]}.png")
            print(f"Image saved as {out_dir}/{image_base_name}_out_{out_name[j]}.png")
    elif mode == "app":
        # gradio
        import gradio as gr

        def style_transfer_app(
            prompt,
            image,
            cfg_scale=7.5,
            num_content_layers=4,
            num_style_layers=9,
            seed=0,
            progress=gr.Progress(track_tqdm=True),
        ):
            utils.exp_utils.seed_all(seed)
            image = utils.exp_utils.process_image(image, device, 512)

            tgt_prompt = prompt
            src_prompt = ""
            prompt_in = [
                "",  # reconstruction
                tgt_prompt,  # uncontrolled style
                "",  # controlled style
            ]

            share_resnet_layers = (
                list(range(num_content_layers)) if num_content_layers != 0 else None
            )
            share_attn_layers = (
                list(range(num_style_layers)) if num_style_layers != 0 else None
            )
            imgs = style_image_with_inversion(
                pipe,
                image,
                src_prompt,
                style_prompt=prompt_in,
                num_steps=50,
                start_step=0,
                guidance_scale=cfg_scale,
                disentangle=True,
                resnet_mode="hidden",
                share_attn=True,
                share_cross_attn=True,
                share_resnet_layers=share_resnet_layers,
                share_attn_layers=share_attn_layers,
                share_key=True,
                share_query=True,
                share_value=False,
                use_content_anchor=True,
                use_adain=True,
                output_dir="./",
            )

            return imgs[2]

        # load examples
        examples = []
        annotation = json.load(open("data/example/annotation.json"))
        for entry in annotation:
            image = utils.exp_utils.get_processed_image(
                entry["image_path"], device, 512
            )
            image = transforms.ToPILImage()(image[0])

            examples.append([entry["target_prompt"], image, None, None, None])

        text_input = gr.Textbox(
            value="An impressionist painting",
            label="Text Prompt",
            info="Describe the style you want to apply to the image, do not include the description of the image content itself",
            lines=2,
            placeholder="Enter a text prompt",
        )
        image_input = gr.Image(
            height="80%",
            width="80%",
            label="Content image (will be resized to 512x512)",
            interactive=True,
        )
        cfg_slider = gr.Slider(
            0,
            15,
            value=7.5,
            label="Classifier Free Guidance (CFG) Scale",
            info="higher values give more style, 7.5 should be good for most cases",
        )
        content_slider = gr.Slider(
            0,
            9,
            value=4,
            step=1,
            label="Number of content control layer",
            info="higher values make it more similar to original image. Default to control first 4 layers",
        )
        style_slider = gr.Slider(
            0,
            9,
            value=9,
            step=1,
            label="Number of style control layer",
            info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.",
        )
        seed_slider = gr.Slider(
            0,
            100,
            value=0,
            step=1,
            label="Seed",
            info="Random seed for the model",
        )
        app = gr.Interface(
            fn=style_transfer_app,
            inputs=[
                text_input,
                image_input,
                cfg_slider,
                content_slider,
                style_slider,
                seed_slider,
            ],
            outputs=["image"],
            title="Artist Interactive Demo",
            examples=examples,
        )
        app.launch()