import gradio as gr
import torch
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel
import random
import numpy as np

import os
from huggingface_hub import login

login(os.getenv("hfapikey"))

# Initialize models
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch_dtype)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch_dtype)
pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max

def infer(
    prompt,
    control_image_path,
    controlnet_conditioning_scale,
    guidance_scale,
    num_inference_steps,
    seed,
    randomize_seed,
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.manual_seed(seed)
    control_image = load_image(control_image_path) if control_image_path else None

    # Generate image
    result = pipe(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
    ).images[0]

    return result, seed

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")

        with gr.Row():
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Enter your prompt",
                max_lines=1,
            )
            run_button = gr.Button("Generate", variant="primary")

        with gr.Accordion("Advanced Settings", open=True):
            control_image = gr.Image(
                sources=['upload', 'webcam', 'clipboard'],
                type="filepath",
                label="Control Image (Line Art)"
            )
            controlnet_conditioning_scale = gr.Slider(
                label="ControlNet Conditioning Scale",
                minimum=0.0,
                maximum=1.0,
                value=0.6,
                step=0.1
            )
            guidance_scale = gr.Slider(
                label="Guidance Scale",
                minimum=1.0,
                maximum=10.0,
                value=3.5,
                step=0.1
            )
            num_inference_steps = gr.Slider(
                label="Number of Inference Steps",
                minimum=1,
                maximum=100,
                value=28,
                step=1
            )
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0
            )
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)

        
        result = gr.Image(label="Result", show_label=False)

        gr.Examples(
            examples=[
                "Shiba Inu wearing dinosaur costume riding skateboard",
                "Victorian style mansion interior with candlelight"
            ],
            inputs=[prompt]
        )

    run_button.click(
        infer,
        inputs=[
            prompt,
            control_image,
            controlnet_conditioning_scale,
            guidance_scale,
            num_inference_steps,
            seed,
            randomize_seed
        ],
        outputs=[result, seed]
    )

if __name__ == "__main__":
    demo.launch()