VIVEK JAYARAM commited on
Commit
1f460ce
·
1 Parent(s): d8bc485

Gradio demo

Browse files
Files changed (2) hide show
  1. gradio_demo.py +155 -0
  2. requirements.txt +1 -0
gradio_demo.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import yaml
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image
7
+ import time
8
+ from cdim.noise import get_noise
9
+ from cdim.operators import get_operator
10
+ from cdim.image_utils import save_to_image
11
+ from cdim.dps_model.dps_unet import create_model
12
+ from cdim.diffusion.scheduling_ddim import DDIMScheduler
13
+ from cdim.diffusion.diffusion_pipeline import run_diffusion
14
+ from cdim.eta_scheduler import EtaScheduler
15
+ from diffusers import DiffusionPipeline
16
+
17
+
18
+ # Global variables for model and scheduler
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model = None
21
+ ddim_scheduler = None
22
+ model_type = None
23
+
24
+
25
+ def load_image(image_path):
26
+ """Process input image to tensor format."""
27
+ image = Image.open(image_path)
28
+ original_image = np.array(image.resize((256, 256), Image.BICUBIC))
29
+ original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
30
+ return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
31
+
32
+
33
+ def load_yaml(file_path: str) -> dict:
34
+ """Load configurations from a YAML file."""
35
+ with open(file_path) as f:
36
+ config = yaml.load(f, Loader=yaml.FullLoader)
37
+ return config
38
+
39
+
40
+ def convert_to_np(torch_image):
41
+ return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
42
+
43
+
44
+ def generate_noisy_image(image_choice, noise_sigma, operator_key):
45
+ """Generate the noisy image and store necessary data for restoration."""
46
+ # Map image choice to path
47
+ image_paths = {
48
+ "CelebA HQ 1": "sample_images/celebhq_29999.jpg",
49
+ "CelebA HQ 2": "sample_images/celebhq_00001.jpg",
50
+ "CelebA HQ 3": "sample_images/celebhq_00000.jpg"
51
+ }
52
+
53
+ config_paths = {
54
+ "Box Inpainting": "operator_configs/box_inpainting_config.yaml",
55
+ "Random Inpainting": "operator_configs/random_inpainting_config.yaml",
56
+ "Super Resolution": "operator_configs/super_resolution_config.yaml",
57
+ "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
58
+ }
59
+
60
+ image_path = image_paths[image_choice]
61
+
62
+ # Load image and get noisy version
63
+ original_image = load_image(image_path).to(device)
64
+ noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
65
+ noise_config["sigma"] = noise_sigma
66
+ noise_function = get_noise(**noise_config)
67
+ operator_config = load_yaml(config_paths[operator_key])
68
+ operator_config["device"] = device
69
+ operator = get_operator(**operator_config)
70
+
71
+ noisy_measurement = noise_function(operator(original_image))
72
+ noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
73
+
74
+ # Store necessary data for restoration
75
+ data = {
76
+ 'noisy_measurement': noisy_measurement,
77
+ 'operator': operator,
78
+ 'noise_function': noise_function
79
+ }
80
+
81
+ return noisy_image, data # Return the noisy image and data for restoration
82
+
83
+
84
+ def run_restoration(data, T, K):
85
+ """Run the restoration process and return the restored image."""
86
+ global model, ddim_scheduler, model_type
87
+
88
+ # Extract stored data
89
+ noisy_measurement = data['noisy_measurement']
90
+ operator = data['operator']
91
+ noise_function = data['noise_function']
92
+
93
+ # Initialize model if not already done
94
+ if model is None:
95
+ model_type = "diffusers"
96
+ model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to("cuda").unet
97
+
98
+ ddim_scheduler = DDIMScheduler(
99
+ num_train_timesteps=1000,
100
+ beta_start=0.0001,
101
+ beta_end=0.02,
102
+ beta_schedule="linear"
103
+ )
104
+
105
+ # Run restoration
106
+ eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
107
+ output_image = run_diffusion(
108
+ model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
109
+ eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
110
+ )
111
+
112
+ # Convert output image for display
113
+ output_image = Image.fromarray(convert_to_np(output_image[0]))
114
+ return output_image
115
+
116
+
117
+ with gr.Blocks() as demo:
118
+ gr.Markdown("# Noisy Image Restoration with Diffusion Models")
119
+
120
+ with gr.Row():
121
+ T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)")
122
+ K = gr.Slider(1, 10, value=3, step=1, label="K Value")
123
+ noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
124
+
125
+ image_select = gr.Dropdown(
126
+ choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3"],
127
+ value="CelebA HQ 1",
128
+ label="Select Input Image"
129
+ )
130
+
131
+ operator_select = gr.Dropdown(
132
+ choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"],
133
+ value="Box Inpainting",
134
+ label="Select Task"
135
+ )
136
+
137
+ run_button = gr.Button("Run Inference")
138
+ noisy_image = gr.Image(label="Noisy Image")
139
+ restored_image = gr.Image(label="Restored Image")
140
+ state = gr.State() # To store intermediate data
141
+
142
+ # First function generates the noisy image and stores data
143
+ run_button.click(
144
+ fn=generate_noisy_image,
145
+ inputs=[image_select, noise_sigma, operator_select],
146
+ outputs=[noisy_image, state],
147
+ ).then(
148
+ fn=run_restoration,
149
+ inputs=[state, T, K],
150
+ outputs=restored_image
151
+ )
152
+
153
+
154
+ if __name__ == "__main__":
155
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -4,3 +4,4 @@ Pillow==11.0.0
4
  PyYAML==6.0.2
5
  scipy==1.14.1
6
  tqdm==4.66.5
 
 
4
  PyYAML==6.0.2
5
  scipy==1.14.1
6
  tqdm==4.66.5
7
+ graio==5.3.0