import gradio as gr
import spaces
import torch
import yaml
import os
import numpy as np
from PIL import Image
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
from cdim.eta_scheduler import EtaScheduler
from diffusers import DiffusionPipeline

# Global variables moved inside GPU-decorated functions
model = None
ddim_scheduler = None
model_type = None

def load_image(image_path):
    """Process input image to tensor format."""
    image = Image.open(image_path)
    original_image = np.array(image.resize((256, 256), Image.BICUBIC))
    original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
    return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]

def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

def convert_to_np(torch_image):
    return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)

@spaces.GPU
def process_image(image_choice, noise_sigma, operator_key, T, K):
    """Combined function to handle both generation and restoration"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model inside GPU-decorated function
    global model, ddim_scheduler, model_type
    if model is None:
        model_type = "diffusers"
        model_name = "google/ddpm-celebahq-256" if "Celeb" in image_choice else "google/ddpm-curch-256"
        model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to(device).unet
        ddim_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.0001,
            beta_end=0.02,
            beta_schedule="linear"
        )

    image_paths = {
        "CelebA HQ 1": "sample_images/celebhq_29999.jpg",
        "CelebA HQ 2": "sample_images/celebhq_00001.jpg",
        "CelebA HQ 3": "sample_images/celebhq_00000.jpg",
        "LSUN Church": "sample_images/lsun_church.png"
    }

    config_paths = {
        "Box Inpainting": "operator_configs/box_inpainting_config.yaml",
        "Random Inpainting": "operator_configs/random_inpainting_config.yaml",
        "Super Resolution": "operator_configs/super_resolution_config.yaml",
        "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
    }

    # Generate noisy image
    image_path = image_paths[image_choice]
    original_image = load_image(image_path).to(device)
    
    noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
    noise_config["sigma"] = noise_sigma
    noise_function = get_noise(**noise_config)
    
    operator_config = load_yaml(config_paths[operator_key])
    operator_config["device"] = device
    operator = get_operator(**operator_config)
    
    noisy_measurement = noise_function(operator(original_image))
    noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))

    # Run restoration
    eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
    output_image = run_diffusion(
        model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
        eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
    )
    
    output_image = Image.fromarray(convert_to_np(output_image[0]))
    return noisy_image, output_image

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Noisy Image Restoration with Diffusion Models")
    
    with gr.Row():
        T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)")
        K = gr.Slider(1, 10, value=3, step=1, label="K Value")
        noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
    
    image_select = gr.Dropdown(
        choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3", "LSUN Church"],
        value="CelebA HQ 1",
        label="Select Input Image"
    )
    
    operator_select = gr.Dropdown(
        choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"],
        value="Box Inpainting",
        label="Select Task"
    )
    
    run_button = gr.Button("Run Inference")
    noisy_image = gr.Image(label="Noisy Image")
    restored_image = gr.Image(label="Restored Image")

    # Single function call instead of chaining
    run_button.click(
        fn=process_image,
        inputs=[image_select, noise_sigma, operator_select, T, K],
        outputs=[noisy_image, restored_image]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)