#!/usr/bin/env python
"""
Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)

The code in this repo is partly adapted from the following repositories:
https://huggingface.co/spaces/hysts/LoRA-SD-training
https://huggingface.co/spaces/multimodalart/dreambooth-training
"""
from __future__ import annotations

import os
import pathlib

import gradio as gr
import torch
from typing import List

from inference import InferencePipeline
from trainer import Trainer
from uploader import upload


TITLE = "# LoRA + Dreambooth Training and Inference Demo 🎨"
DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)."


ORIGINAL_SPACE_ID = "smangrul/peft-lora-sd-dreambooth"

SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
"""
if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
    SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'

else:
    SETTINGS = "Settings"
CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
<center>
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
"T4 small" is sufficient to run this demo.
</center>
"""


def show_warning(warning_text: str) -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Box():
            gr.Markdown(warning_text)
    return demo


def update_output_files() -> dict:
    paths = sorted(pathlib.Path("results").glob("*.pt"))
    config_paths = sorted(pathlib.Path("results").glob("*.json"))
    paths = paths + config_paths
    paths = [path.as_posix() for path in paths]  # type: ignore
    return gr.update(value=paths or None)


def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks:
    with gr.Blocks() as demo:
        base_model = gr.Dropdown(
            choices=[
                "CompVis/stable-diffusion-v1-4",
                "runwayml/stable-diffusion-v1-5",
                "stabilityai/stable-diffusion-2-1-base",
            ],
            value="runwayml/stable-diffusion-v1-5",
            label="Base Model",
            visible=True,
        )
        resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False)

        with gr.Row():
            with gr.Box():
                gr.Markdown("Training Data")
                concept_images = gr.Files(label="Images for your concept")
                concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
                gr.Markdown(
                    """
                    - Upload images of the style you are planning on training on.
                    - For a concept prompt, use a unique, made up word to avoid collisions.
                    - Guidelines for getting good results:
                        - Dreambooth for an `object` or `style`:
                            - 5-10 images of the object from different angles
                            - 500-800 iterations should be good enough. 
                            - Prior preservation is recommended.
                            - `class_prompt`:
                                - `a photo of object`
                                - `style`
                            - `concept_prompt`:
                                - `<concept prompt> object`
                                - `<concept prompt> style`
                                - `a photo of <concept prompt> object`
                                - `a photo of <concept prompt> style`
                        - Dreambooth for a `Person/Face`:
                            - 15-50 images of the person from different angles, lighting, and expressions. 
                            Have considerable photos with close up faces.
                            - 800-1200 iterations should be good enough.
                            - good defaults for hyperparams
                                - Model - `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1-base`
                                - Use/check Prior preservation.
                                - Number of class images to use - 200
                                - Prior Loss Weight - 1
                                - LoRA Rank for unet - 16
                                - LoRA Alpha for unet - 20
                                - lora dropout - 0
                                - LoRA Bias for unet - `all`
                                - LoRA Rank for CLIP - 16
                                - LoRA Alpha for CLIP - 17
                                - LoRA Bias for CLIP - `all`
                                - lora dropout for CLIP - 0
                                - Uncheck `FP16` and `8bit-Adam` (don't use them for faces)
                            - `class_prompt`: Use the gender related word of the person
                                - `man`
                                - `woman`
                                - `boy`
                                - `girl`
                            - `concept_prompt`: just the unique, made up word, e.g., `srm`
                            - Choose `all` for `lora_bias` and `text_encode_lora_bias`
                        - Dreambooth for a `Scene`:
                            - 15-50 images of the scene from different angles, lighting, and expressions.
                            - 800-1200 iterations should be good enough.
                            - Prior preservation is recommended.
                            - `class_prompt`:
                                - `scene`
                                - `landscape`
                                - `city`
                                - `beach`
                                - `mountain`
                            - `concept_prompt`:
                                - `<concept prompt> scene`
                                - `<concept prompt> landscape`
                        - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam
                    """
                )
            with gr.Box():
                gr.Markdown("Training Parameters")
                num_training_steps = gr.Number(label="Number of Training Steps", value=1000, precision=0)
                learning_rate = gr.Number(label="Learning Rate", value=0.0001)
                gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True)
                train_text_encoder = gr.Checkbox(label="Train Text Encoder", value=True)
                with_prior_preservation = gr.Checkbox(label="Prior Preservation", value=True)
                class_prompt = gr.Textbox(
                    label="Class Prompt", max_lines=1, placeholder='Example: "a photo of object"'
                )
                num_class_images = gr.Number(label="Number of class images to use", value=50, precision=0)
                prior_loss_weight = gr.Number(label="Prior Loss Weight", value=1.0, precision=1)
                # use_lora = gr.Checkbox(label="Whether to use LoRA", value=True)
                lora_r = gr.Number(label="LoRA Rank for unet", value=4, precision=0)
                lora_alpha = gr.Number(
                    label="LoRA Alpha for unet. scaling factor = lora_r/lora_alpha", value=4, precision=0
                )
                lora_dropout = gr.Number(label="lora dropout", value=0.00)
                lora_bias = gr.Dropdown(
                    choices=["none", "all", "lora_only"],
                    value="none",
                    label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type",
                    visible=True,
                )
                lora_text_encoder_r = gr.Number(label="LoRA Rank for CLIP", value=4, precision=0)
                lora_text_encoder_alpha = gr.Number(
                    label="LoRA Alpha for CLIP. scaling factor = lora_r/lora_alpha", value=4, precision=0
                )
                lora_text_encoder_dropout = gr.Number(label="lora dropout for CLIP", value=0.00)
                lora_text_encoder_bias = gr.Dropdown(
                    choices=["none", "all", "lora_only"],
                    value="none",
                    label="LoRA Bias for CLIP. This enables bias params to be trainable based on the bias type",
                    visible=True,
                )
                gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0)
                fp16 = gr.Checkbox(label="FP16", value=True)
                use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True)
                gr.Markdown(
                    """
                    - It will take about 20-30 minutes to train for 1000 steps with a T4 GPU.
                    - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
                    - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
                    """
                )

        run_button = gr.Button("Start Training")
        with gr.Box():
            with gr.Row():
                check_status_button = gr.Button("Check Training Status")
                with gr.Column():
                    with gr.Box():
                        gr.Markdown("Message")
                        training_status = gr.Markdown()
                    output_files = gr.Files(label="Trained Weight Files and Configs")

        run_button.click(fn=pipe.clear)

        run_button.click(
            fn=trainer.run,
            inputs=[
                base_model,
                resolution,
                num_training_steps,
                concept_images,
                concept_prompt,
                learning_rate,
                gradient_accumulation,
                fp16,
                use_8bit_adam,
                gradient_checkpointing,
                train_text_encoder,
                with_prior_preservation,
                prior_loss_weight,
                class_prompt,
                num_class_images,
                lora_r,
                lora_alpha,
                lora_bias,
                lora_dropout,
                lora_text_encoder_r,
                lora_text_encoder_alpha,
                lora_text_encoder_bias,
                lora_text_encoder_dropout,
            ],
            outputs=[
                training_status,
                output_files,
            ],
            queue=False,
        )
        check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False)
        check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False)
    return demo


def find_weight_files() -> List[str]:
    curr_dir = pathlib.Path(__file__).parent
    paths = sorted(curr_dir.rglob("*.pt"))
    return [path.relative_to(curr_dir).as_posix() for path in paths]


def reload_lora_weight_list() -> dict:
    return gr.update(choices=find_weight_files())


def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                base_model = gr.Dropdown(
                    choices=[
                        "CompVis/stable-diffusion-v1-4",
                        "runwayml/stable-diffusion-v1-5",
                        "stabilityai/stable-diffusion-2-1-base",
                    ],
                    value="runwayml/stable-diffusion-v1-5",
                    label="Base Model",
                    visible=True,
                )
                reload_button = gr.Button("Reload Weight List")
                lora_weight_name = gr.Dropdown(
                    choices=find_weight_files(), value="lora/lora_disney.pt", label="LoRA Weight File"
                )
                prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "style of sks, baby lion"')
                negative_prompt = gr.Textbox(
                    label="Negative Prompt", max_lines=1, placeholder='Example: "blurry, botched, low quality"'
                )
                seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1)
                with gr.Accordion("Other Parameters", open=False):
                    num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50)
                    guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7)

                run_button = gr.Button("Generate")

                gr.Markdown(
                    """
                - After training, you can press "Reload Weight List" button to load your trained model names.
                - Few repos to refer for ideas:
                    - https://huggingface.co/smangrul/smangrul
                    - https://huggingface.co/smangrul/painting-in-the-style-of-smangrul
                    - https://huggingface.co/smangrul/erenyeager
                """
                )
            with gr.Column():
                result = gr.Image(label="Result")

        reload_button.click(fn=reload_lora_weight_list, inputs=None, outputs=lora_weight_name)
        prompt.submit(
            fn=pipe.run,
            inputs=[
                base_model,
                lora_weight_name,
                prompt,
                negative_prompt,
                seed,
                num_steps,
                guidance_scale,
            ],
            outputs=result,
            queue=False,
        )
        run_button.click(
            fn=pipe.run,
            inputs=[
                base_model,
                lora_weight_name,
                prompt,
                negative_prompt,
                seed,
                num_steps,
                guidance_scale,
            ],
            outputs=result,
            queue=False,
        )
        seed.change(
            fn=pipe.run,
            inputs=[
                base_model,
                lora_weight_name,
                prompt,
                negative_prompt,
                seed,
                num_steps,
                guidance_scale,
            ],
            outputs=result,
            queue=False,
        )
    return demo


def create_upload_demo() -> gr.Blocks:
    with gr.Blocks() as demo:
        model_name = gr.Textbox(label="Model Name")
        hf_token = gr.Textbox(label="Hugging Face Token (with write permission)")
        upload_button = gr.Button("Upload")
        with gr.Box():
            gr.Markdown("Message")
            result = gr.Markdown()
        gr.Markdown(
            """
            - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
            - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
            """
        )

    upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result)

    return demo


pipe = InferencePipeline()
trainer = Trainer()

with gr.Blocks(css="style.css") as demo:
    if os.getenv("IS_SHARED_UI"):
        show_warning(SHARED_UI_WARNING)
    if not torch.cuda.is_available():
        show_warning(CUDA_NOT_AVAILABLE_WARNING)

    gr.Markdown(TITLE)
    gr.Markdown(DESCRIPTION)

    with gr.Tabs():
        with gr.TabItem("Train"):
            create_training_demo(trainer, pipe)
        with gr.TabItem("Test"):
            create_inference_demo(pipe)
        with gr.TabItem("Upload"):
            create_upload_demo()

demo.queue(default_enabled=False).launch(share=True)