import spaces
import argparse, os

import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import (
    StableDiffusionPipeline,
    DDIMScheduler,
    DiffusionPipeline,
    StableDiffusionXLPipeline,
)
from diffusers.image_processor import VaeImageProcessor
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import argparse
import PIL.Image as Image
from torchvision.utils import make_grid
import numpy
from diffusers.schedulers import DDIMScheduler
import torch.nn.functional as F
from models import attn_injection
from omegaconf import OmegaConf
from typing import List, Tuple

import omegaconf
import utils.exp_utils
import json

device = "cuda"


def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
    # Tokenize text and get embeddings
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    with torch.no_grad():
        prompt_embeds = text_encoder(
            text_input_ids.to(device),
            output_hidden_states=True,
        )

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    if prompt == "":
        negative_prompt_embeds = torch.zeros_like(prompt_embeds)
        negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
        return negative_prompt_embeds, negative_pooled_prompt_embeds
    return prompt_embeds, pooled_prompt_embeds


def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str):
    device = model._execution_device
    (
        prompt_embeds,
        pooled_prompt_embeds,
    ) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
    (
        prompt_embeds_2,
        pooled_prompt_embeds_2,
    ) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device)
    prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
    text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
    add_time_ids = model._get_add_time_ids(
        (1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim
    ).to(device)
    # repeat the time ids for each prompt
    add_time_ids = add_time_ids.repeat(len(prompt), 1)
    added_cond_kwargs = {
        "text_embeds": pooled_prompt_embeds_2,
        "time_ids": add_time_ids,
    }
    return added_cond_kwargs, prompt_embeds


def _encode_text_sdxl_with_negative(
    model: StableDiffusionXLPipeline, prompt: List[str]
):

    B = len(prompt)
    added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
    added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(
        model, ["" for _ in range(B)]
    )
    prompt_embeds = torch.cat(
        (
            prompt_embeds_uncond,
            prompt_embeds,
        )
    )
    added_cond_kwargs = {
        "text_embeds": torch.cat(
            (added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])
        ),
        "time_ids": torch.cat(
            (added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])
        ),
    }
    return added_cond_kwargs, prompt_embeds


# Sample function (regular DDIM)
@torch.no_grad()
def sample(
    pipe,
    prompt,
    start_step=0,
    start_latents=None,
    intermediate_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):
    negative_prompt = [""] * len(prompt)
    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, prompt
        )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    latents = latents.repeat(len(prompt), 1, 1, 1)
    # assume that the first latent is used for reconstruction
    for i in tqdm(range(start_step, num_inference_steps)):
        latents[0] = intermediate_latents[(-i + 1)]
        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
        latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

    # Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images


# Sample function (regular DDIM), but disentangle the content and style
@torch.no_grad()
def sample_disentangled(
    pipe,
    prompt,
    start_step=0,
    start_latents=None,
    intermediate_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    use_content_anchor=True,
    negative_prompt="",
    device=device,
):
    negative_prompt = [""] * len(prompt)
    vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor)
    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, prompt
        )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)
    # save

    latent_shape = (
        (1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64)
    )
    generative_latent = torch.randn(latent_shape, device=device)
    generative_latent *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()
    latents = latents.repeat(len(prompt), 1, 1, 1)
    # randomly initialize the 1st latent for generation
    latents[1] = generative_latent

    num_intermediate_latents = len(intermediate_latents) if intermediate_latents is not None else 0

    for i in range(start_step, num_inference_steps):
        if use_content_anchor and intermediate_latents is not None:
            # Ensure that the index is within bounds
            if -i >= -num_intermediate_latents:
                latents[0] = intermediate_latents[-i]
            else:
                # Handle case when the index is out of bounds
                # You could use a default latent or skip this step
                latents[0] = intermediate_latents[0]  # Example: use the first latent

        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Post-processing
        # images = vae_decoder.postprocess(latents)
    pipe.vae.to(dtype=torch.float32)
    latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
    latents = 1 / pipe.vae.config.scaling_factor * latents
    images = pipe.vae.decode(latents, return_dict=False)[0]
    images = (images / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    images = pipe.numpy_to_pil(images)
    if isinstance(pipe, StableDiffusionXLPipeline):
        pipe.vae.to(dtype=torch.float16)

    return images



## Inversion
@torch.no_grad()
def invert(
    pipe,
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=50,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    if isinstance(pipe, StableDiffusionPipeline):
        text_embeddings = pipe._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
        )
        added_cond_kwargs = None
        latents = start_latents.clone().detach()
    elif isinstance(pipe, StableDiffusionXLPipeline):
        added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
            pipe, [prompt]
        )  # Latents are now the specified start latents
        latents = start_latents.clone().detach().half()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = list(reversed(pipe.scheduler.timesteps))

    for i in range(num_inference_steps):
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000 // num_inference_steps)) # t+1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (
            alpha_t_next.sqrt() / alpha_t.sqrt()
        ) + (1 - alpha_t_next).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)







def style_image_with_inversion(
    pipe,
    input_image,
    input_image_prompt,
    style_prompt,
    num_steps=100,
    start_step=30,
    guidance_scale=3.5,
    disentangle=False,
    share_attn=False,
    share_cross_attn=False,
    share_resnet_layers=[0, 1],
    share_attn_layers=[],
    c2s_layers=[0, 1],
    share_key=True,
    share_query=True,
    share_value=False,
    use_adain=True,
    use_content_anchor=True,
    output_dir: str = None,
    resnet_mode: str = None,
    return_intermediate=False,
    intermediate_latents=None,
):
    with torch.no_grad():
        pipe.vae.to(dtype=torch.float32)
        latent = pipe.vae.encode(input_image.to(device) * 2 - 1)
        # latent = pipe.vae.encode(input_image.to(device))
        l = pipe.vae.config.scaling_factor * latent.latent_dist.sample()
        if isinstance(pipe, StableDiffusionXLPipeline):
            pipe.vae.to(dtype=torch.float16)
    if intermediate_latents is None:
        inverted_latents = invert(
            pipe, l, input_image_prompt, num_inference_steps=num_steps
        )
    else:
        inverted_latents = intermediate_latents

    attn_injection.register_attention_processors(
        pipe,
        base_dir=output_dir,
        resnet_mode=resnet_mode,
        attn_mode="artist" if disentangle else "pnp",
        disentangle=disentangle,
        share_resblock=True,
        share_attn=share_attn,
        share_cross_attn=share_cross_attn,
        share_resnet_layers=share_resnet_layers,
        share_attn_layers=share_attn_layers,
        share_key=share_key,
        share_query=share_query,
        share_value=share_value,
        use_adain=use_adain,
        c2s_layers=c2s_layers,
    )

    if disentangle:
        final_im = sample_disentangled(
            pipe,
            style_prompt,
            start_latents=inverted_latents[-(start_step + 1)][None],
            intermediate_latents=inverted_latents,
            start_step=start_step,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            use_content_anchor=use_content_anchor,
        )
    else:
        final_im = sample(
            pipe,
            style_prompt,
            start_latents=inverted_latents[-(start_step + 1)][None],
            intermediate_latents=inverted_latents,
            start_step=start_step,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
        )

    # unset the attention processors
    attn_injection.unset_attention_processors(
        pipe,
        unset_share_attn=True,
        unset_share_resblock=True,
    )
    if return_intermediate:
        return final_im, inverted_latents
    return final_im


if __name__ == "__main__":

    

    # pipe = DiffusionPipeline.from_pretrained(
    #     # "playgroundai/playground-v2-1024px-aesthetic",
    #     torch_dtype=torch.float16,
    #     use_safetensors=True,
    #     add_watermarker=False,
    #     variant="fp16",
    # )
    # pipe.to("cuda")

    parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf")
    parser.add_argument(
        "--config", type=str, default="config.yaml", help="Path to the config file"
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="dataset",
        choices=["dataset", "cli", "app"],
        help="Path to the config file",
    )
    parser.add_argument(
        "--image_dir", type=str, default="test.png", help="Path to the image"
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="an impressionist painting",
        help="Stylization prompt",
    )
    # mode = "single_control_content"
    args = parser.parse_args()
    config_dir = args.config
    mode = args.mode
    # mode = "dataset"
    out_name = ["content_delegation", "style_delegation", "style_out"]

    if mode == "app":
        # gradio
        import gradio as gr

        # Load a pipeline
        pipe = StableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1-base"
        ).to(device)

        # Set up a DDIM scheduler
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

        @spaces.GPU()
        def style_transfer_app(
            prompt,
            image,
            cfg_scale=7.5,
            num_content_layers=4,
            num_style_layers=9,
            seed=0,
            progress=gr.Progress(track_tqdm=True),
        ):
            utils.exp_utils.seed_all(seed)
            image = utils.exp_utils.process_image(image, device, 512)

            tgt_prompt = prompt
            src_prompt = ""
            prompt_in = [
                "",  # reconstruction
                tgt_prompt,  # uncontrolled style
                "",  # controlled style
            ]

            share_resnet_layers = (
                list(range(num_content_layers)) if num_content_layers != 0 else None
            )
            share_attn_layers = (
                list(range(num_style_layers)) if num_style_layers != 0 else None
            )
            imgs = style_image_with_inversion(
                pipe,
                image,
                src_prompt,
                style_prompt=prompt_in,
                num_steps=50,
                start_step=0,
                guidance_scale=cfg_scale,
                disentangle=True,
                resnet_mode="hidden",
                share_attn=True,
                share_cross_attn=True,
                share_resnet_layers=share_resnet_layers,
                share_attn_layers=share_attn_layers,
                share_key=True,
                share_query=True,
                share_value=False,
                use_content_anchor=True,
                use_adain=True,
                output_dir="./",
            )

            return imgs[2]

        # load examples
        examples = []
        annotation = json.load(open("data/example/annotation.json"))
        for entry in annotation:
            image = utils.exp_utils.get_processed_image(
                entry["image_path"], device, 512
            )
            image = transforms.ToPILImage()(image[0])

            examples.append([entry["target_prompt"], image, None, None, None])
 
        with gr.Blocks() as app:
        
            with gr.Column():
                gr.Markdown("# Artist: Aesthetically Controllable Text-Driven Stylization without Training")
                gr.Markdown("## Interactive Demo, HF space version")
                gr.HTML("""
                <div style="display:flex;column-gap:4px;">
                    <a href='https://diffusionartist.github.io/'>
                        <img src='https://img.shields.io/badge/Project-Page-green'>
                    </a>
                    <a href='https://github.com/songrise/Artist'>
                        <img src='https://img.shields.io/badge/Code-github-blue'>
                    </a>
                    <a href='https://arxiv.org/abs/2407.15842'>
                        <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
                    </a>
                    <a href='https://huggingface.co/papers/2407.15842'>
                        <img src='https://img.shields.io/badge/Papers-HF-ffd21f'>
                    </a>
                    
                </div>
                """)
                with gr.Row():
                    with gr.Column():
                        image_input = gr.Image(
                            label="Content image (will be resized to 512x512)",
                            interactive=True,
                        )
                        text_input = gr.Textbox(
                            value="An impressionist painting",
                            label="Text Prompt",
                            info="Describe the style you want to apply to the image, do not include the description of the image content itself",
                            lines=2,
                            placeholder="Enter a text prompt",
                        )
                        with gr.Accordion("Advanced settings"):
                            with gr.Column():
                                cfg_slider = gr.Slider(
                                    0,
                                    15,
                                    value=7.5,
                                    label="Classifier Free Guidance (CFG) Scale",
                                    info="higher values give more style, 7.5 should be good for most cases",
                                )
                                content_slider = gr.Slider(
                                    0,
                                    9,
                                    value=4,
                                    step=1,
                                    label="Number of content control layer",
                                    info="higher values make it more similar to original image. Default to control first 4 layers",
                                )
                                style_slider = gr.Slider(
                                    0,
                                    9,
                                    value=9,
                                    step=1,
                                    label="Number of style control layer",
                                    info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.",
                                )
                                seed_slider = gr.Slider(
                                    0,
                                    100,
                                    value=0,
                                    step=1,
                                    label="Seed",
                                    info="Random seed for the model",
                                )
                        submit_btn = gr.Button("Submit")
                    with gr.Column():
                        image_output= gr.Image(format="png")
                        gr.Examples(
                            examples = examples,
                            fn = style_transfer_app,
                            inputs = [text_input, image_input],
                            outputs = [image_output],
                            cache_examples=False
                        )
            submit_btn.click(
                fn=style_transfer_app,
                inputs=[
                    text_input,
                    image_input,
                    cfg_slider,
                    content_slider,
                    style_slider,
                    seed_slider,
                ],
                outputs=[image_output],
                show_api=False
            )
            
        
        app.launch(show_api=False, show_error=True)