import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr

# Download checkpoints
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")

mask_predictor = AutoMasker(
    densepose_path="./ckpts/densepose",
    schp_path="./ckpts/schp",
)

densepose_predictor = DensePosePredictor(
    config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
    weights_path="./ckpts/densepose/model_final_162be9.pkl",
)

vt_model = LeffaModel(
    pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
    pretrained_model="./ckpts/virtual_tryon.pth",
)
vt_inference = LeffaInference(model=vt_model)

pt_model = LeffaModel(
    pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
    pretrained_model="./ckpts/pose_transfer.pth",
)
pt_inference = LeffaInference(model=pt_model)

translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
base_model = "black-forest-labs/FLUX.1-dev"
model_lora_repo = "Motas/Flux_Fashion_Photography_Style"
clothes_lora_repo = "prithivMLmods/Canopus-Clothing-Flux-LoRA"

fashion_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
fashion_pipe.to("cuda")

@spaces.GPU()
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
    # 한글 감지 및 번역
    def contains_korean(text):
        return any(ord('가') <= ord(char) <= ord('힣') for char in text)
    
    if contains_korean(prompt):
        translated = translator(prompt)[0]['translation_text']
        actual_prompt = translated
    else:
        actual_prompt = prompt

    # 모드에 따른 LoRA 및 트리거워드 설정
    if mode == "Generate Model":
        pipe.load_lora_weights(model_lora_repo)
        trigger_word = "fashion photography, professional model"
    else:
        pipe.load_lora_weights(clothes_lora_repo)
        trigger_word = "upper clothing, fashion item"

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)

    progress(0, "Starting fashion generation...")

    for i in range(1, steps + 1):
        if i % (steps // 10) == 0:
            progress(i / steps * 100, f"Processing step {i} of {steps}...")

    image = pipe(
        prompt=f"{actual_prompt} {trigger_word}",
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
    ).images[0]

    progress(100, "Completed!")
    return image, seed
    
def leffa_predict(src_image_path, ref_image_path, control_type):
    assert control_type in [
        "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
    src_image = Image.open(src_image_path)
    ref_image = Image.open(ref_image_path)
    src_image = resize_and_center(src_image, 768, 1024)
    ref_image = resize_and_center(ref_image, 768, 1024)

    src_image_array = np.array(src_image)
    ref_image_array = np.array(ref_image)

    # Mask
    if control_type == "virtual_tryon":
        src_image = src_image.convert("RGB")
        mask = mask_predictor(src_image, "upper")["mask"]
    elif control_type == "pose_transfer":
        mask = Image.fromarray(np.ones_like(src_image_array) * 255)

    # DensePose
    src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
    src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
    src_image_iuv = Image.fromarray(src_image_iuv_array)
    src_image_seg = Image.fromarray(src_image_seg_array)
    if control_type == "virtual_tryon":
        densepose = src_image_seg
    elif control_type == "pose_transfer":
        densepose = src_image_iuv

    # Leffa
    transform = LeffaTransform()

    data = {
        "src_image": [src_image],
        "ref_image": [ref_image],
        "mask": [mask],
        "densepose": [densepose],
    }
    data = transform(data)
    if control_type == "virtual_tryon":
        inference = vt_inference
    elif control_type == "pose_transfer":
        inference = pt_inference
    output = inference(data)
    gen_image = output["generated_image"][0]
    # gen_image.save("gen_image.png")
    return np.array(gen_image)


def leffa_predict_vt(src_image_path, ref_image_path):
    return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")


def leffa_predict_pt(src_image_path, ref_image_path):
    return leffa_predict(src_image_path, ref_image_path, "pose_transfer")



with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
    gr.Markdown("# 🎭 Fashion Studio & Virtual Try-on")
    
    with gr.Tabs():
        # 패션 생성 탭
        with gr.Tab("Fashion Generation"):
            with gr.Column():
                mode = gr.Radio(
                    choices=["Generate Model", "Generate Clothes"],
                    label="Generation Mode",
                    value="Generate Model"
                )
                
                prompt = gr.TextArea(
                    label="Fashion Description (한글 또는 영어)",
                    placeholder="패션 모델이나 의류를 설명하세요..."
                )
                
                with gr.Row():
                    with gr.Column():
                        result = gr.Image(label="Generated Result")
                        generate_button = gr.Button("Generate Fashion")
                
                with gr.Accordion("Advanced Options", open=False):
                    with gr.Group():
                        with gr.Row():
                            with gr.Column():
                                cfg_scale = gr.Slider(
                                    label="CFG Scale",
                                    minimum=1,
                                    maximum=20,
                                    step=0.5,
                                    value=7.0
                                )
                                steps = gr.Slider(
                                    label="Steps",
                                    minimum=1,
                                    maximum=100,
                                    step=1,
                                    value=30
                                )
                                lora_scale = gr.Slider(
                                    label="LoRA Scale",
                                    minimum=0,
                                    maximum=1,
                                    step=0.01,
                                    value=0.85
                                )
                        
                        with gr.Row():
                            width = gr.Slider(
                                label="Width",
                                minimum=256,
                                maximum=1536,
                                step=64,
                                value=512
                            )
                            height = gr.Slider(
                                label="Height",
                                minimum=256,
                                maximum=1536,
                                step=64,
                                value=768
                            )
                        
                        with gr.Row():
                            randomize_seed = gr.Checkbox(
                                True,
                                label="Randomize seed"
                            )
                            seed = gr.Slider(
                                label="Seed",
                                minimum=0,
                                maximum=MAX_SEED,
                                step=1,
                                value=42
                            )

        # 가상 피팅 탭
        with gr.Tab("Virtual Try-on"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("#### Person Image")
                    vt_src_image = gr.Image(
                        sources=["upload"],
                        type="filepath",
                        label="Person Image",
                        width=512,
                        height=512,
                    )
                    gr.Examples(
                        inputs=vt_src_image,
                        examples_per_page=5,
                        examples=["./ckpts/examples/person1/01350_00.jpg",
                                "./ckpts/examples/person1/01376_00.jpg",
                                "./ckpts/examples/person1/01416_00.jpg",
                                "./ckpts/examples/person1/05976_00.jpg",
                                "./ckpts/examples/person1/06094_00.jpg"]
                    )

                with gr.Column():
                    gr.Markdown("#### Garment Image")
                    vt_ref_image = gr.Image(
                        sources=["upload"],
                        type="filepath",
                        label="Garment Image",
                        width=512,
                        height=512,
                    )
                    gr.Examples(
                        inputs=vt_ref_image,
                        examples_per_page=5,
                        examples=["./ckpts/examples/garment/01449_00.jpg",
                                "./ckpts/examples/garment/01486_00.jpg",
                                "./ckpts/examples/garment/01853_00.jpg",
                                "./ckpts/examples/garment/02070_00.jpg",
                                "./ckpts/examples/garment/03553_00.jpg"]
                    )

                with gr.Column():
                    gr.Markdown("#### Generated Image")
                    vt_gen_image = gr.Image(
                        label="Generated Image",
                        width=512,
                        height=512,
                    )
                    vt_gen_button = gr.Button("Try-on")

        # 포즈 전송 탭
        with gr.Tab("Pose Transfer"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("#### Person Image")
                    pt_ref_image = gr.Image(
                        sources=["upload"],
                        type="filepath",
                        label="Person Image",
                        width=512,
                        height=512,
                    )
                    gr.Examples(
                        inputs=pt_ref_image,
                        examples_per_page=5,
                        examples=["./ckpts/examples/person1/01350_00.jpg",
                                "./ckpts/examples/person1/01376_00.jpg",
                                "./ckpts/examples/person1/01416_00.jpg",
                                "./ckpts/examples/person1/05976_00.jpg",
                                "./ckpts/examples/person1/06094_00.jpg"]
                    )

                with gr.Column():
                    gr.Markdown("#### Target Pose Person Image")
                    pt_src_image = gr.Image(
                        sources=["upload"],
                        type="filepath",
                        label="Target Pose Person Image",
                        width=512,
                        height=512,
                    )
                    gr.Examples(
                        inputs=pt_src_image,
                        examples_per_page=5,
                        examples=["./ckpts/examples/person2/01850_00.jpg",
                                "./ckpts/examples/person2/01875_00.jpg",
                                "./ckpts/examples/person2/02532_00.jpg",
                                "./ckpts/examples/person2/02902_00.jpg",
                                "./ckpts/examples/person2/05346_00.jpg"]
                    )

                with gr.Column():
                    gr.Markdown("#### Generated Image")
                    pt_gen_image = gr.Image(
                        label="Generated Image",
                        width=512,
                        height=512,
                    )
                    pose_transfer_gen_button = gr.Button("Generate")

    gr.Markdown(note)

    # 이벤트 핸들러
    generate_button.click(
        generate_fashion,
        inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
        outputs=[result, seed]
    )
    
    vt_gen_button.click(
        fn=leffa_predict_vt,
        inputs=[vt_src_image, vt_ref_image],
        outputs=[vt_gen_image]
    )
    
    pose_transfer_gen_button.click(
        fn=leffa_predict_pt,
        inputs=[pt_src_image, pt_ref_image],
        outputs=[pt_gen_image]
    )

demo.launch(share=True, server_port=7860)