VIVEK JAYARAM commited on
Commit
235a140
·
1 Parent(s): 22a317f
Files changed (1) hide show
  1. app.py +28 -63
app.py CHANGED
@@ -5,7 +5,6 @@ import yaml
5
  import os
6
  import numpy as np
7
  from PIL import Image
8
- import time
9
  from cdim.noise import get_noise
10
  from cdim.operators import get_operator
11
  from cdim.image_utils import save_to_image
@@ -15,13 +14,11 @@ from cdim.diffusion.diffusion_pipeline import run_diffusion
15
  from cdim.eta_scheduler import EtaScheduler
16
  from diffusers import DiffusionPipeline
17
 
18
-
19
- # Global variables for model and scheduler
20
  model = None
21
  ddim_scheduler = None
22
  model_type = None
23
 
24
-
25
  def load_image(image_path):
26
  """Process input image to tensor format."""
27
  image = Image.open(image_path)
@@ -29,23 +26,30 @@ def load_image(image_path):
29
  original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
30
  return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
31
 
32
-
33
  def load_yaml(file_path: str) -> dict:
34
- """Load configurations from a YAML file."""
35
  with open(file_path) as f:
36
  config = yaml.load(f, Loader=yaml.FullLoader)
37
  return config
38
 
39
-
40
  def convert_to_np(torch_image):
41
  return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
42
 
43
-
44
  @spaces.GPU
45
- def generate_noisy_image(image_choice, noise_sigma, operator_key):
46
- """Generate the noisy image and store necessary data for restoration."""
47
- # Map image choice to path
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  image_paths = {
51
  "CelebA HQ 1": "sample_images/celebhq_29999.jpg",
@@ -55,70 +59,37 @@ def generate_noisy_image(image_choice, noise_sigma, operator_key):
55
 
56
  config_paths = {
57
  "Box Inpainting": "operator_configs/box_inpainting_config.yaml",
58
- "Random Inpainting": "operator_configs/random_inpainting_config.yaml",
59
  "Super Resolution": "operator_configs/super_resolution_config.yaml",
60
  "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
61
  }
62
 
 
63
  image_path = image_paths[image_choice]
64
-
65
- # Load image and get noisy version
66
  original_image = load_image(image_path).to(device)
 
67
  noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
68
  noise_config["sigma"] = noise_sigma
69
  noise_function = get_noise(**noise_config)
 
70
  operator_config = load_yaml(config_paths[operator_key])
71
  operator_config["device"] = device
72
  operator = get_operator(**operator_config)
73
-
74
  noisy_measurement = noise_function(operator(original_image))
75
  noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
76
 
77
- # Store necessary data for restoration
78
- data = {
79
- 'noisy_measurement': noisy_measurement.cpu(),
80
- 'operator': operator,
81
- 'noise_function': noise_function
82
- }
83
-
84
- return noisy_image, data # Return the noisy image and data for restoration
85
-
86
- @spaces.GPU
87
- def run_restoration(data, T, K):
88
- """Run the restoration process and return the restored image."""
89
- global model, ddim_scheduler, model_type
90
-
91
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
-
93
- # Extract stored data
94
- noisy_measurement = data['noisy_measurement'].to(device)
95
- operator = data['operator']
96
- noise_function = data['noise_function']
97
-
98
- # Initialize model if not already done
99
- if model is None:
100
- model_type = "diffusers"
101
- model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to(device).unet
102
-
103
- ddim_scheduler = DDIMScheduler(
104
- num_train_timesteps=1000,
105
- beta_start=0.0001,
106
- beta_end=0.02,
107
- beta_schedule="linear"
108
- )
109
-
110
  # Run restoration
111
  eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
112
  output_image = run_diffusion(
113
  model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
114
  eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
115
  )
116
-
117
- # Convert output image for display
118
  output_image = Image.fromarray(convert_to_np(output_image[0]))
119
- return output_image
120
-
121
 
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown("# Noisy Image Restoration with Diffusion Models")
124
 
@@ -142,19 +113,13 @@ with gr.Blocks() as demo:
142
  run_button = gr.Button("Run Inference")
143
  noisy_image = gr.Image(label="Noisy Image")
144
  restored_image = gr.Image(label="Restored Image")
145
- state = gr.State() # To store intermediate data
146
 
147
- # First function generates the noisy image and stores data
148
  run_button.click(
149
- fn=generate_noisy_image,
150
- inputs=[image_select, noise_sigma, operator_select],
151
- outputs=[noisy_image, state],
152
- ).then(
153
- fn=run_restoration,
154
- inputs=[state, T, K],
155
- outputs=restored_image
156
  )
157
 
158
-
159
  if __name__ == "__main__":
160
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import os
6
  import numpy as np
7
  from PIL import Image
 
8
  from cdim.noise import get_noise
9
  from cdim.operators import get_operator
10
  from cdim.image_utils import save_to_image
 
14
  from cdim.eta_scheduler import EtaScheduler
15
  from diffusers import DiffusionPipeline
16
 
17
+ # Global variables moved inside GPU-decorated functions
 
18
  model = None
19
  ddim_scheduler = None
20
  model_type = None
21
 
 
22
  def load_image(image_path):
23
  """Process input image to tensor format."""
24
  image = Image.open(image_path)
 
26
  original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
27
  return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
28
 
 
29
  def load_yaml(file_path: str) -> dict:
 
30
  with open(file_path) as f:
31
  config = yaml.load(f, Loader=yaml.FullLoader)
32
  return config
33
 
 
34
  def convert_to_np(torch_image):
35
  return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
36
 
 
37
  @spaces.GPU
38
+ def process_image(image_choice, noise_sigma, operator_key, T, K):
39
+ """Combined function to handle both generation and restoration"""
 
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ # Initialize model inside GPU-decorated function
43
+ global model, ddim_scheduler, model_type
44
+ if model is None:
45
+ model_type = "diffusers"
46
+ model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to(device).unet
47
+ ddim_scheduler = DDIMScheduler(
48
+ num_train_timesteps=1000,
49
+ beta_start=0.0001,
50
+ beta_end=0.02,
51
+ beta_schedule="linear"
52
+ )
53
 
54
  image_paths = {
55
  "CelebA HQ 1": "sample_images/celebhq_29999.jpg",
 
59
 
60
  config_paths = {
61
  "Box Inpainting": "operator_configs/box_inpainting_config.yaml",
62
+ "Random Inpainting": "operator_configs/random_inpainting_config.yaml",
63
  "Super Resolution": "operator_configs/super_resolution_config.yaml",
64
  "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
65
  }
66
 
67
+ # Generate noisy image
68
  image_path = image_paths[image_choice]
 
 
69
  original_image = load_image(image_path).to(device)
70
+
71
  noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
72
  noise_config["sigma"] = noise_sigma
73
  noise_function = get_noise(**noise_config)
74
+
75
  operator_config = load_yaml(config_paths[operator_key])
76
  operator_config["device"] = device
77
  operator = get_operator(**operator_config)
78
+
79
  noisy_measurement = noise_function(operator(original_image))
80
  noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Run restoration
83
  eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
84
  output_image = run_diffusion(
85
  model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
86
  eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
87
  )
88
+
 
89
  output_image = Image.fromarray(convert_to_np(output_image[0]))
90
+ return noisy_image, output_image
 
91
 
92
+ # Gradio interface
93
  with gr.Blocks() as demo:
94
  gr.Markdown("# Noisy Image Restoration with Diffusion Models")
95
 
 
113
  run_button = gr.Button("Run Inference")
114
  noisy_image = gr.Image(label="Noisy Image")
115
  restored_image = gr.Image(label="Restored Image")
 
116
 
117
+ # Single function call instead of chaining
118
  run_button.click(
119
+ fn=process_image,
120
+ inputs=[image_select, noise_sigma, operator_select, T, K],
121
+ outputs=[noisy_image, restored_image]
 
 
 
 
122
  )
123
 
 
124
  if __name__ == "__main__":
125
+ demo.launch(server_name="0.0.0.0", server_port=7860)