cdim / app.py
vivjay30's picture
Update app.py
c5833bf verified
import gradio as gr
import spaces
import torch
import yaml
import os
import numpy as np
from PIL import Image
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
from cdim.eta_scheduler import EtaScheduler
from diffusers import DiffusionPipeline
# Global variables moved inside GPU-decorated functions
model = None
ddim_scheduler = None
model_type = None
curr_model_name = None
def load_image(image_path):
"""Process input image to tensor format."""
image = Image.open(image_path)
original_image = np.array(image.resize((256, 256), Image.BICUBIC))
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def convert_to_np(torch_image):
return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
@spaces.GPU
def process_image(image_choice, noise_sigma, operator_key, T, K):
"""Combined function to handle both generation and restoration"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model inside GPU-decorated function
global model, curr_model_name, ddim_scheduler, model_type
model_name = "google/ddpm-celebahq-256" if "Celeb" in image_choice else "google/ddpm-church-256"
if model is None or curr_model_name != model_name:
model_type = "diffusers"
model = DiffusionPipeline.from_pretrained(model_name).to(device).unet
curr_model_name = model_name
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear"
)
image_paths = {
"CelebA HQ 1": "sample_images/celebhq_29999.jpg",
"CelebA HQ 2": "sample_images/celebhq_00001.jpg",
"CelebA HQ 3": "sample_images/celebhq_00000.jpg",
"LSUN Church": "sample_images/lsun_church.png"
}
config_paths = {
"Box Inpainting": "operator_configs/box_inpainting_config.yaml",
"Random Inpainting": "operator_configs/random_inpainting_config.yaml",
"Super Resolution": "operator_configs/super_resolution_config.yaml",
"Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
}
# Generate noisy image
image_path = image_paths[image_choice]
original_image = load_image(image_path).to(device)
noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
noise_config["sigma"] = noise_sigma
noise_function = get_noise(**noise_config)
operator_config = load_yaml(config_paths[operator_key])
operator_config["device"] = device
operator = get_operator(**operator_config)
noisy_measurement = noise_function(operator(original_image))
noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
# Run restoration
eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
output_image = run_diffusion(
model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
)
output_image = Image.fromarray(convert_to_np(output_image[0]))
return noisy_image, output_image
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Noisy Image Restoration with Diffusion Models")
with gr.Row():
T = gr.Slider(10, 200, value=25, step=1, label="Number of Inference Steps (T)")
K = gr.Slider(1, 10, value=2, step=1, label="K Value")
noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
image_select = gr.Dropdown(
choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3", "LSUN Church"],
value="CelebA HQ 1",
label="Select Input Image"
)
operator_select = gr.Dropdown(
choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"],
value="Box Inpainting",
label="Select Task"
)
run_button = gr.Button("Run Inference")
noisy_image = gr.Image(label="Noisy Image")
restored_image = gr.Image(label="Restored Image")
# Single function call instead of chaining
run_button.click(
fn=process_image,
inputs=[image_select, noise_sigma, operator_select, T, K],
outputs=[noisy_image, restored_image]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)