import spaces
import gradio as gr
import numpy as np
import random
from PIL import Image
import torch
from diffusers import (
    ControlNetModel,
    DiffusionPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionXLControlNetPipeline,
    UniPCMultistepScheduler,
    EulerDiscreteScheduler,
    AutoencoderKL
)
from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor
from transformers import CLIPImageProcessor
from diffusers.utils import load_image
from gradio_imageslider import ImageSlider
import boto3
from io import BytesIO
from datetime import datetime
import json

device = "cuda"
base_model_id = "SG161222/RealVisXL_V4.0"
controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0"
vae_model_id = "madebyollin/sdxl-vae-fp16-fix"


if torch.cuda.is_available():

    # load pipe
    controlnet = ControlNetModel.from_pretrained(
        controlnet_model_id, 
        # variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.float32
    )
    # vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.float16)
    pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        base_model_id, 
        controlnet=controlnet, 
        # vae=vae,
        # variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.float32,
    )
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe.to(device)

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

USE_TORCH_COMPILE = 0
ENABLE_CPU_OFFLOAD = 0


def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


def get_depth_map(image, progress):
    original_size = (image.size[1], image.size[0]) 
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth
    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=original_size,
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)
    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image


def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name, progress):
    print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
    connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"

    s3 = boto3.client(
        's3',
        endpoint_url=connectionUrl,
        region_name='auto',
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key
    )

    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
    buffer = BytesIO()
    image.save(buffer, "PNG")
    buffer.seek(0)
    s3.upload_fileobj(buffer, bucket_name, image_file)
    print("upload finish", image_file)
    return image_file



@spaces.GPU(enable_queue=True)
def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
    print("process start")
    if image_url:
        print(image_url)
        orginal_image = load_image(image_url)
    else:
        orginal_image = Image.fromarray(image)

    size = (orginal_image.size[0], orginal_image.size[1])
    print("image size", size)
    depth_image = get_depth_map(orginal_image, progress)
    generator = torch.Generator().manual_seed(seed)
    print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
    generated_image = pipe(
        prompt=prompt,
        negative_prompt=n_prompt,
        width=size[0],
        height=size[1],
        guidance_scale=guidance_scale,
        num_inference_steps=num_steps,
        strength=control_strength,
        generator=generator,
        image=depth_image
    ).images[0]

    if upload_to_s3:
        url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket, progress)
        result = {"status": "success", "url": url}
    else:
        result = {"status": "success", "message": "Image generated but not uploaded"}
    
    return [orginal_image, generated_image], json.dumps(result)

with gr.Blocks() as demo:
    
    with gr.Row():
        with gr.Column():
            image = gr.Image()
            image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button("Run")
            
            with gr.Accordion("Advanced options", open=True):
                
                num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1)
                guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
                control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1)
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                n_prompt = gr.Textbox(
                    label="Negative prompt",
                    value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
                )

                upload_to_s3 = gr.Checkbox(label="Upload to R2", value=False)
                account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
                access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
                secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
                bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
        

        with gr.Column():
            images = ImageSlider(label="Generate images", type="pil", slider_color="pink")
            logs = gr.Textbox(label="logs")
            
    inputs = [
        image,
        image_url,
        prompt,
        n_prompt,
        num_steps,
        guidance_scale,
        control_strength,
        seed,
        upload_to_s3,
        account_id,
        access_key,
        secret_key,
        bucket
    ]
    run_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=process,
            inputs=inputs,
            outputs=[images, logs],
            api_name=False
        )

demo.queue().launch()