vivjay30 commited on
Commit
89d5874
·
1 Parent(s): e1abd0f

Chi squared method

Browse files
README.md CHANGED
@@ -1,11 +1,66 @@
1
- ---
2
- title: CDIM
3
- emoji: 😃
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.1.0
8
- app_file: app.py
9
- pinned: true
10
- arxiv: https://arxiv.org/abs/2411.00359
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Linearly Constrained Diffusion Implicit Models
2
+ ![alt text](Teaser.jpg)
3
+
4
+ ### Authors
5
+ [Vivek Jayaram](http://www.vivekjayaram.com/), [John Thickstun](https://johnthickstun.com/), [Ira Kemelmacher-Shlizerman](https://homes.cs.washington.edu/~kemelmi/), and [Steve Seitz](https://homes.cs.washington.edu/~seitz/)
6
+
7
+ ### Links
8
+ [[Gradio Demo]](https://huggingface.co/spaces/vivjay30/cdim) [[Project Page]](https://grail.cs.washington.edu/projects/cdim/) [[Paper]](https://arxiv.org/abs/2411.00359)
9
+
10
+ ### Summary
11
+ We solve noisy linear inverse problems with diffusion models. The method is fast and addresses many problems like inpainting, super-resolution, gaussian deblur, and poisson noise.
12
+
13
+
14
+ ## Getting started
15
+
16
+ Recommended environment: Python 3.11, Cuda 12, Conda. For lower verions please adjust the dependencies below.
17
+
18
+ ### 1) Clone the repository
19
+
20
+ ```
21
+ git clone https://github.com/vivjay30/cdim
22
+
23
+ cd cdim
24
+ ```
25
+
26
+ ### 2) Install dependencies
27
+
28
+ ```
29
+ conda create -n cdim python=3.11
30
+
31
+ conda activate cdim
32
+
33
+ pip install -r requirements.txt
34
+
35
+ pip install torch==2.4.1+cu124 torchvision-0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
36
+ ```
37
+
38
+ ## Inference Examples
39
+
40
+ (The underlying diffusion models will be automatically downloaded on the first run).
41
+
42
+ #### CelebHQ Inpainting Example (T'=25 Denoising Steps)
43
+
44
+ `python inference.py sample_images/celebhq/00001.jpg 25 operator_configs/box_inpainting_config.yaml noise_configs/gaussian_noise_config.yaml google/ddpm-celebahq-256`
45
+
46
+ #### LSUN Churches Gaussian Deblur Example (T'=25 Denoising Steps)
47
+ `python inference.py sample_images/lsun_church.png 25 operator_configs/gaussian_blur_config.yaml noise_configs/gaussian_noise_config.yaml google/ddpm-church-256`
48
+
49
+
50
+ ## FFHQ and Imagenet Models
51
+ These models are generally not as strong as the google ddpm models, but are used for comparisons with baseline methods.
52
+
53
+ From [this link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh?usp=sharing), download the checkpoints "ffhq_10m.pt" and "imagenet_256.pt" to models/
54
+
55
+ #### Imagenet Super Resolution Example
56
+ Here we set T'=50 to show the algorithm running slower
57
+ `python inference.py sample_images/imagenet_val_00002.png 50 operator_configs/super_resolution_config.yaml noise_configs/gaussian_noise_config.yaml models/imagenet_model_config.yaml`
58
+
59
+ #### FFHQ Random Inpainting (Faster)
60
+ Here we set T'=10 to show the algorithm running faster
61
+ `python inference.py sample_images/ffhq_00010.png 10 operator_configs/random_inpainting_config.yaml noise_configs/gaussian_noise_config.yaml models/ffhq_model_config.yaml`
62
+
63
+ #### A Note on Exact Recovery
64
+ If you set the measurement noise to 0 in gaussian_noise_config.yaml, then the recovered image should match the the observation y exactly (e.g. inpainting doesn't chance observed pixels). In practice, this doesn't happen because the diffusion schedule sets $\overline{\alpha}_0 = 0.999$ for numeric stability, meaning a tiny amount of noise is injected even at t=0.
65
+
66
+
app.py CHANGED
@@ -2,24 +2,21 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  import yaml
5
- import os
6
  import numpy as np
7
  from PIL import Image
8
  from cdim.noise import get_noise
9
  from cdim.operators import get_operator
10
- from cdim.image_utils import save_to_image
11
- from cdim.dps_model.dps_unet import create_model
12
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
13
  from cdim.diffusion.diffusion_pipeline import run_diffusion
14
- from cdim.eta_scheduler import EtaScheduler
15
  from diffusers import DiffusionPipeline
16
 
17
- # Global variables moved inside GPU-decorated functions
18
  model = None
19
  ddim_scheduler = None
20
  model_type = None
21
  curr_model_name = None
22
 
 
23
  def load_image(image_path):
24
  """Process input image to tensor format."""
25
  image = Image.open(image_path)
@@ -27,22 +24,26 @@ def load_image(image_path):
27
  original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
28
  return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
29
 
 
30
  def load_yaml(file_path: str) -> dict:
 
31
  with open(file_path) as f:
32
  config = yaml.load(f, Loader=yaml.FullLoader)
33
  return config
34
 
 
35
  def convert_to_np(torch_image):
36
  return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
37
 
 
38
  @spaces.GPU
39
- def process_image(image_choice, noise_sigma, operator_key, T, K):
40
- """Combined function to handle both generation and restoration"""
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize model inside GPU-decorated function
44
  global model, curr_model_name, ddim_scheduler, model_type
45
- model_name = "google/ddpm-celebahq-256" if "Celeb" in image_choice else "google/ddpm-church-256"
46
 
47
  if model is None or curr_model_name != model_name:
48
  model_type = "diffusers"
@@ -85,22 +86,22 @@ def process_image(image_choice, noise_sigma, operator_key, T, K):
85
  noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
86
 
87
  # Run restoration
88
- eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
89
  output_image = run_diffusion(
90
  model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
91
- eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
92
  )
93
 
94
  output_image = Image.fromarray(convert_to_np(output_image[0]))
95
  return noisy_image, output_image
96
 
 
97
  # Gradio interface
98
  with gr.Blocks() as demo:
99
  gr.Markdown("# Noisy Image Restoration with Diffusion Models")
100
 
101
  with gr.Row():
102
- T = gr.Slider(10, 200, value=25, step=1, label="Number of Inference Steps (T)")
103
- K = gr.Slider(1, 10, value=2, step=1, label="K Value")
104
  noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
105
 
106
  image_select = gr.Dropdown(
@@ -119,12 +120,11 @@ with gr.Blocks() as demo:
119
  noisy_image = gr.Image(label="Noisy Image")
120
  restored_image = gr.Image(label="Restored Image")
121
 
122
- # Single function call instead of chaining
123
  run_button.click(
124
  fn=process_image,
125
- inputs=[image_select, noise_sigma, operator_select, T, K],
126
  outputs=[noisy_image, restored_image]
127
  )
128
 
129
  if __name__ == "__main__":
130
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import spaces
3
  import torch
4
  import yaml
 
5
  import numpy as np
6
  from PIL import Image
7
  from cdim.noise import get_noise
8
  from cdim.operators import get_operator
 
 
9
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
10
  from cdim.diffusion.diffusion_pipeline import run_diffusion
 
11
  from diffusers import DiffusionPipeline
12
 
13
+ # Global variables for model and scheduler (initialized inside GPU-decorated function)
14
  model = None
15
  ddim_scheduler = None
16
  model_type = None
17
  curr_model_name = None
18
 
19
+
20
  def load_image(image_path):
21
  """Process input image to tensor format."""
22
  image = Image.open(image_path)
 
24
  original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
25
  return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
26
 
27
+
28
  def load_yaml(file_path: str) -> dict:
29
+ """Load configurations from a YAML file."""
30
  with open(file_path) as f:
31
  config = yaml.load(f, Loader=yaml.FullLoader)
32
  return config
33
 
34
+
35
  def convert_to_np(torch_image):
36
  return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
37
 
38
+
39
  @spaces.GPU
40
+ def process_image(image_choice, noise_sigma, operator_key, T, stopping_sigma):
41
+ """Combined function to handle both generation and restoration."""
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
  # Initialize model inside GPU-decorated function
45
  global model, curr_model_name, ddim_scheduler, model_type
46
+ model_name = "google/ddpm-celebahq-256" if "CelebA" in image_choice else "google/ddpm-church-256"
47
 
48
  if model is None or curr_model_name != model_name:
49
  model_type = "diffusers"
 
86
  noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
87
 
88
  # Run restoration
 
89
  output_image = run_diffusion(
90
  model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
91
+ stopping_sigma, num_inference_steps=T, model_type=model_type
92
  )
93
 
94
  output_image = Image.fromarray(convert_to_np(output_image[0]))
95
  return noisy_image, output_image
96
 
97
+
98
  # Gradio interface
99
  with gr.Blocks() as demo:
100
  gr.Markdown("# Noisy Image Restoration with Diffusion Models")
101
 
102
  with gr.Row():
103
+ T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)")
104
+ stopping_sigma = gr.Slider(0.1, 5.0, value=0.1, step=0.1, label="Stopping Sigma (c)")
105
  noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
106
 
107
  image_select = gr.Dropdown(
 
120
  noisy_image = gr.Image(label="Noisy Image")
121
  restored_image = gr.Image(label="Restored Image")
122
 
 
123
  run_button.click(
124
  fn=process_image,
125
+ inputs=[image_select, noise_sigma, operator_select, T, stopping_sigma],
126
  outputs=[noisy_image, restored_image]
127
  )
128
 
129
  if __name__ == "__main__":
130
+ demo.launch(server_name="0.0.0.0", server_port=7860)
cdim/diffusion/diffusion_pipeline.py CHANGED
@@ -1,8 +1,10 @@
1
  import torch
2
  from tqdm import tqdm
3
 
4
- from cdim.image_utils import randn_tensor
5
  from cdim.discrete_kl_loss import discrete_kl_loss
 
 
6
 
7
  def compute_kl_gaussian(residuals, sigma):
8
  # Only 0 centered for now
@@ -23,13 +25,13 @@ def run_diffusion(
23
  operator,
24
  noise_function,
25
  device,
26
- eta_scheduler,
27
  num_inference_steps: int = 1000,
28
- K=5,
29
  image_dim=256,
30
  image_channels=3,
31
  model_type="diffusers",
32
- loss_type="l2"
33
  ):
34
  batch_size = noisy_observation.shape[0]
35
  image_shape = (batch_size, image_channels, image_dim, image_dim)
@@ -38,18 +40,60 @@ def run_diffusion(
38
  scheduler.set_timesteps(num_inference_steps, device=device)
39
  t_skip = scheduler.timesteps[0] - scheduler.timesteps[1]
40
 
 
 
 
41
  for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), desc="Processing timesteps"):
 
 
 
42
  # 1. predict noise model_output
43
  model_output = model(image, t.unsqueeze(0).to(device))
44
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
45
 
 
 
 
46
  # 2. compute previous image: x_t -> x_t-1
47
  image = scheduler.step(model_output, t, image).prev_sample
48
  image.requires_grad_()
49
  alpha_prod_t_prev = scheduler.alphas_cumprod[t-t_skip] if t-t_skip >= 0 else 1
50
  beta_prod_t_prev = 1 - alpha_prod_t_prev
51
- for j in range(K):
 
 
52
  if t <= 0: break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  with torch.enable_grad():
55
  # Calculate x^hat_0
@@ -57,42 +101,52 @@ def run_diffusion(
57
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
58
  x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
59
 
60
- if loss_type == "l2" and noise_function.name == "gaussian":
61
- distance = operator(x_0) - noisy_observation
62
- if (distance ** 2).mean() < noise_function.sigma ** 2:
63
- break
64
- loss = ((distance) ** 2).mean()
65
- print(f"L2 loss {loss}")
66
- loss.backward()
67
-
68
- elif loss_type == "kl" and noise_function.name == "gaussian":
69
- diff = (operator(x_0) - noisy_observation) # Residuals
70
- kl_div = compute_kl_gaussian(diff, noise_function.sigma)
71
- kl_div.backward()
72
-
73
- elif loss_type == "kl" and noise_function.name == "poisson":
74
- residuals = (operator(x_0) * noise_function.rate - noisy_observation * noise_function.rate) * 127.5 # Residuals
75
- x_0_pixel = operator((x_0 + 1) * 127.5)
76
- mask = x_0_pixel > 2 # Avoid numeric issues with pixel values near 0
77
- pearson = residuals[mask] / torch.sqrt(x_0_pixel[mask] * noise_function.rate)
78
- pearson_flat = pearson.view(-1)
79
- kl_div = compute_kl_gaussian(pearson_flat, 1.0)
80
- kl_div.backward()
81
-
82
- elif loss_type == "categorical_kl" and noise_function.name == "bimodal":
83
- diff = (operator(x_0) - noisy_observation)
84
- indices = operator(torch.ones(image.shape).to(device))
85
- diff = diff[indices > 0] # Don't consider masked out pixels in the distribution
86
- empirical_distribution = noise_function.sample_noise_distribution(image).to(device).view(-1)
87
- loss = discrete_kl_loss(diff, empirical_distribution, num_bins=15)
88
- print(f"Categorical KL {loss}")
89
- loss.backward()
90
-
91
- else:
92
- raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
93
-
94
- step_size = eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
95
  image -= step_size * image.grad
 
 
96
  image = image.detach().requires_grad_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
98
  return image
 
1
  import torch
2
  from tqdm import tqdm
3
 
4
+ from cdim.image_utils import randn_tensor, trace_AAt, estimate_variance, save_to_image, compute_operator_distance
5
  from cdim.discrete_kl_loss import discrete_kl_loss
6
+ from cdim.eta_utils import calculate_best_step_size, initial_guess_step_size
7
+
8
 
9
  def compute_kl_gaussian(residuals, sigma):
10
  # Only 0 centered for now
 
25
  operator,
26
  noise_function,
27
  device,
28
+ stopping_sigma,
29
  num_inference_steps: int = 1000,
30
+ K=20,
31
  image_dim=256,
32
  image_channels=3,
33
  model_type="diffusers",
34
+ original_image=None
35
  ):
36
  batch_size = noisy_observation.shape[0]
37
  image_shape = (batch_size, image_channels, image_dim, image_dim)
 
40
  scheduler.set_timesteps(num_inference_steps, device=device)
41
  t_skip = scheduler.timesteps[0] - scheduler.timesteps[1]
42
 
43
+ data = []
44
+ TOTAL_UPDATE_STEPS = 0
45
+ trace = trace_AAt(operator)
46
  for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), desc="Processing timesteps"):
47
+ # Using GT image noised up if you want to debug anything
48
+ # image = original_image * scheduler.alphas_cumprod[t] ** 0.5 + torch.randn_like(original_image) * (1 - scheduler.alphas_cumprod[t]) ** 0.5
49
+
50
  # 1. predict noise model_output
51
  model_output = model(image, t.unsqueeze(0).to(device))
52
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
53
 
54
+ # Save image for debugging
55
+ # save_to_image(image, f"intermediates/{t}_xt.png")
56
+
57
  # 2. compute previous image: x_t -> x_t-1
58
  image = scheduler.step(model_output, t, image).prev_sample
59
  image.requires_grad_()
60
  alpha_prod_t_prev = scheduler.alphas_cumprod[t-t_skip] if t-t_skip >= 0 else 1
61
  beta_prod_t_prev = 1 - alpha_prod_t_prev
62
+
63
+ k = 0
64
+ while k < K:
65
  if t <= 0: break
66
+ a = scheduler.alphas_cumprod[t-t_skip]**0.5 - 1
67
+ # For inpainting, use the number of observed pixels
68
+ num_elements = operator.get_num_observed() if hasattr(operator, 'get_num_observed') else noisy_observation.numel()
69
+
70
+ # mu_{t-delta}(y) e.q. 14
71
+ target_distance = (a**2 * torch.linalg.norm(noisy_observation)**2 + (1 - scheduler.alphas_cumprod[t-t_skip]) * trace).item()
72
+ target_distance += num_elements * noise_function.sigma**2*(1-a**2)
73
+
74
+ # ||Ax_{t-delta} - y||^2
75
+ actual_distance = compute_operator_distance(operator, image, noisy_observation, squared=True).item()
76
+
77
+ # sigma^2_{t-delta}(y) e.q. 15
78
+ variance = estimate_variance(
79
+ operator,
80
+ noisy_observation,
81
+ scheduler.alphas_cumprod[t-t_skip],
82
+ image.shape,
83
+ trace=trace,
84
+ sigma_y=noise_function.sigma,
85
+ n_trace_samples=64,
86
+ n_y_samples=64,
87
+ device=image.device)
88
+
89
+ # c * sigma_{t-delta}(y)
90
+ threshold = stopping_sigma * variance**0.5
91
+ # print(f"Target Distance mean {target_distance} max {target_distance + threshold} actual distance {actual_distance}")
92
+
93
+ # R_{t-delta} is within rho_{t-delta} e.q. 16
94
+ if actual_distance <= target_distance + threshold:
95
+ break
96
+
97
 
98
  with torch.enable_grad():
99
  # Calculate x^hat_0
 
101
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
102
  x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
103
 
104
+ # Save Tweedie's estimate for debugging
105
+ # save_to_image(x_0, f"intermediates/{t}_x0.png")
106
+
107
+ loss = compute_operator_distance(operator, x_0, noisy_observation, squared=True).mean()
108
+
109
+ # print(f"L2 loss {compute_operator_distance(operator, x_0, noisy_observation, squared=False)}")
110
+ data.append((t.item(), compute_operator_distance(operator, image, noisy_observation, squared=False).item()))
111
+ loss.backward()
112
+
113
+ initial_step_size = initial_guess_step_size(t.item(), torch.linalg.norm(image.grad)) # eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
114
+ with torch.no_grad():
115
+ # Set debug=True to see detailed step size search information
116
+ step_size = calculate_best_step_size(image, noisy_observation, operator, image.grad, target_distance, threshold, initial_step_size, debug=False)
117
+
118
+ # print(f"Step Size {step_size:.6e} initial guess {initial_step_size:.6e}")
119
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  image -= step_size * image.grad
121
+ new_distance = compute_operator_distance(operator, image, noisy_observation, squared=True).item()
122
+ # print(f"New distance {new_distance}")
123
  image = image.detach().requires_grad_()
124
+ TOTAL_UPDATE_STEPS += 1
125
+
126
+ if step_size <= 1e-12: break
127
+
128
+ k += 1
129
+
130
+ # Check here because threshold is stochastic and can change from iteration to iteration
131
+ if new_distance <= target_distance + threshold:
132
+ break
133
+
134
+ # print("Step", t.item())
135
+ # Use num_elements for proper normalization with inpainting
136
+ num_elements = operator.get_num_observed() if hasattr(operator, 'get_num_observed') else noisy_observation.numel()
137
+ # print("Distance", 1 / num_elements * compute_operator_distance(operator, image, noisy_observation, squared=True).item())
138
+
139
+ # Print MAE if you want to track contraint error
140
+ if hasattr(operator, 'select'):
141
+ # Compute MAE over observed pixels only
142
+ Ax = operator.select(image).flatten()
143
+ y_selected = operator.select(noisy_observation).flatten()
144
+ # print("MAE", (torch.abs(Ax - y_selected).mean().item()))
145
+ else:
146
+ pass
147
+ # print("MAE", (torch.abs(operator(image) - noisy_observation).mean().item()))
148
 
149
+ print(f"Total Denoising {len(scheduler.timesteps)}")
150
+ print(f"Total Projection Steps {TOTAL_UPDATE_STEPS}")
151
+ print(f"Total NFEs {TOTAL_UPDATE_STEPS + len(scheduler.timesteps)}")
152
  return image
cdim/eta_scheduler.py DELETED
@@ -1,61 +0,0 @@
1
- import json
2
-
3
- class EtaScheduler:
4
- def __init__(self, method, task, T, K, loss_type,
5
- noise_function, lambda_val=None):
6
- self.task = task
7
- self.T = T
8
- self.K = K
9
- self.loss_type = loss_type
10
- self.lambda_val = lambda_val
11
- self.method = method
12
-
13
- self.precomputed_etas = self._load_precomputed_etas()
14
-
15
- # Couldn't find expected gradnorm
16
- if not self.precomputed_etas and method == "expected_gradnorm":
17
- self.method = "gradnorm"
18
- print("Etas for this configuration not found. Switching to gradnorm.")
19
-
20
-
21
- # Precomputed gradients are only for gaussian noise
22
- if noise_function.name != "gaussian" and method == "expected_gradnorm":
23
- self.method = "gradnorm"
24
- print("Precomputed gradients are only for gaussian noise. Switching to gradnorm.")
25
-
26
-
27
- # Get the best lambda_val if it's not passed
28
- if self.lambda_val is None:
29
- if self.method == "expected_gradnorm":
30
- self.lambda_val = self.precomputed_etas["lambda"]
31
- else:
32
- self.lambda_val = self.best_guess_lambda()
33
- print(f"Using lambda {self.lambda_val}")
34
-
35
- def _load_precomputed_etas(self):
36
- steps_key = f"T{self.T}_K{self.K}"
37
- with open("cdim/etas.json") as f:
38
- all_etas = json.load(f)
39
-
40
- return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
41
-
42
- def get_step_size(self, t, grad_norm):
43
- """Use either precomputed expected gradnorm or gradnorm."""
44
- if self.method == "expected_gradnorm":
45
- step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
46
- else:
47
- step_size = self.lambda_val * 1 / grad_norm
48
- return step_size
49
-
50
- def best_guess_lambda(self):
51
- """Guess a lambda value if not provided. Based on trial and error"""
52
- total_steps = self.T * self.K
53
-
54
- # L2 tends to over optimize too aggressively, so the default lr is lower
55
- if self.loss_type == "kl":
56
- return 350 / total_steps
57
- elif self.loss_type == "l2":
58
- return 220 / total_steps
59
- else:
60
- raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdim/eta_utils.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+
4
+ class EtaScheduler:
5
+ def __init__(self, method, task, T, K, loss_type,
6
+ noise_function, lambda_val=None):
7
+ self.task = task
8
+ self.T = T
9
+ self.K = K
10
+ self.loss_type = loss_type
11
+ self.lambda_val = lambda_val
12
+ self.method = method
13
+
14
+ self.precomputed_etas = self._load_precomputed_etas()
15
+
16
+ # Couldn't find expected gradnorm
17
+ if not self.precomputed_etas and method == "expected_gradnorm":
18
+ self.method = "gradnorm"
19
+ print("Etas for this configuration not found. Switching to gradnorm.")
20
+
21
+
22
+ # Precomputed gradients are only for gaussian noise
23
+ if noise_function.name != "gaussian" and method == "expected_gradnorm":
24
+ self.method = "gradnorm"
25
+ print("Precomputed gradients are only for gaussian noise. Switching to gradnorm.")
26
+
27
+
28
+ # Get the best lambda_val if it's not passed
29
+ if self.lambda_val is None:
30
+ if self.method == "expected_gradnorm":
31
+ self.lambda_val = self.precomputed_etas["lambda"]
32
+ else:
33
+ self.lambda_val = self.best_guess_lambda()
34
+ print(f"Using lambda {self.lambda_val}")
35
+
36
+ def _load_precomputed_etas(self):
37
+ steps_key = f"T{self.T}_K{self.K}"
38
+ with open("cdim/etas.json") as f:
39
+ all_etas = json.load(f)
40
+
41
+ return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
42
+
43
+ def get_step_size(self, t, grad_norm):
44
+ """Use either precomputed expected gradnorm or gradnorm."""
45
+ if self.method == "expected_gradnorm":
46
+ step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
47
+ else:
48
+ step_size = self.lambda_val * 1 / grad_norm
49
+ return step_size
50
+
51
+ def best_guess_lambda(self):
52
+ """Guess a lambda value if not provided. Based on trial and error"""
53
+ total_steps = self.T * self.K
54
+
55
+ # L2 tends to over optimize too aggressively, so the default lr is lower
56
+ if self.loss_type == "kl":
57
+ return 350 / total_steps
58
+ elif self.loss_type == "l2":
59
+ return 220 / total_steps
60
+ else:
61
+ raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
62
+
63
+ def initial_guess_step_size(T, grad_norm):
64
+ best_guess_lambda = 220 / T
65
+ return best_guess_lambda / grad_norm
66
+
67
+
68
+ # def calculate_best_step_size(image, y, operator, gradient, target_distance, initial_guess,
69
+ # max_iters=20, tol=1e-4, bracket_factor=1.4):
70
+ # def compute_distance(eta):
71
+ # x_new = image - eta * gradient
72
+ # diff = operator(x_new) - y
73
+ # return torch.linalg.norm(diff)**2
74
+
75
+ # def objective(eta):
76
+ # return torch.abs(compute_distance(eta) - target_distance)
77
+
78
+ # # Try to bracket the root
79
+ # eta_low = initial_guess / bracket_factor
80
+ # eta_high = initial_guess * bracket_factor
81
+
82
+ # for _ in range(10):
83
+ # import pdb
84
+ # pdb.set_trace()
85
+ # dist_low = compute_distance(eta_low)
86
+ # dist_high = compute_distance(eta_high)
87
+ # if (dist_low - target_distance) * (dist_high - target_distance) < 0:
88
+ # break
89
+ # eta_low /= bracket_factor
90
+ # eta_high *= bracket_factor
91
+ # else:
92
+ # # Fallback: brute-force line search over eta to minimize distance
93
+ # best_eta = None
94
+ # best_val = float('inf')
95
+ # for eta in torch.linspace(0, initial_guess * 5, steps=100, device=image.device):
96
+ # val = objective(eta)
97
+ # # print(f"ETA {eta} distance {compute_distance(eta)}")
98
+ # if val < best_val:
99
+ # best_val = val
100
+ # best_eta = eta
101
+ # return best_eta.item()
102
+
103
+ # # Binary search
104
+ # for _ in range(max_iters):
105
+ # eta_mid = (eta_low + eta_high) / 2
106
+ # dist_mid = compute_distance(eta_mid)
107
+ # error = dist_mid - target_distance
108
+
109
+ # if abs(error) < tol:
110
+ # return eta_mid
111
+
112
+ # if (compute_distance(eta_low) - target_distance) * error < 0:
113
+ # eta_high = eta_mid
114
+ # else:
115
+ # eta_low = eta_mid
116
+
117
+ # return eta_mid
118
+
119
+
120
+ import torch
121
+ from cdim.image_utils import compute_operator_distance
122
+
123
+ def calculate_best_step_size(
124
+ image: torch.Tensor,
125
+ y: torch.Tensor,
126
+ operator,
127
+ gradient: torch.Tensor,
128
+ target_distance: float,
129
+ threshold: float,
130
+ initial_guess: float,
131
+ *,
132
+ tol: float = 1e-4,
133
+ max_iters: int = 50,
134
+ debug: bool = False,
135
+ ):
136
+ """
137
+ Find the smallest η ≥ 0 that makes ||A(x − η g) − y||² ≈ target_distance + threshold.
138
+
139
+ Uses a robust grid search followed by golden section search for fine-grained optimization.
140
+
141
+ Note: For inpainting operators with 'select' method, distances are computed
142
+ over only the observed pixels.
143
+
144
+ Args:
145
+ debug: If True, prints detailed search information
146
+ """
147
+ target_boundary = target_distance + threshold
148
+
149
+ def distance(η: torch.Tensor) -> torch.Tensor:
150
+ return compute_operator_distance(operator, image - η * gradient, y, squared=True)
151
+
152
+ def error(η):
153
+ return distance(η) - target_boundary
154
+
155
+ # Phase 1: Coarse grid search to find promising regions
156
+ # Search from very small to larger step sizes
157
+ # Allow step sizes larger than 1 if needed
158
+ max_eta = initial_guess * 200.0 if initial_guess > 0 else 10.0
159
+
160
+ # Create a logarithmically-spaced grid for better coverage of small values
161
+ # This ensures we search finely near 0 and coarser at larger values
162
+ # Start from 1e-12 to handle cases with very large gradients
163
+ n_coarse = 100 # Increased for better resolution
164
+ eta_grid = torch.cat([
165
+ torch.tensor([0.0]),
166
+ torch.logspace(-12, torch.log10(torch.tensor(max_eta)), n_coarse - 1)
167
+ ]).to(image.device)
168
+
169
+ if debug:
170
+ print(f"[Step Size] Searching from {eta_grid[1]:.2e} to {eta_grid[-1]:.2e} ({len(eta_grid)} points)")
171
+ print(f"[Step Size] Sample grid points: {[f'{x:.2e}' for x in eta_grid[1:11].tolist()]}")
172
+
173
+ # Evaluate distances at all grid points
174
+ distances = torch.tensor([distance(eta).item() for eta in eta_grid])
175
+ errors = distances - target_boundary
176
+
177
+ dist_at_zero = distances[0].item()
178
+
179
+ # Strategy: Find the SMALLEST eta that gets us AT OR BELOW target_boundary
180
+ # Only consider non-zero etas
181
+ below_target_mask = distances[1:] <= target_boundary
182
+
183
+ if below_target_mask.any():
184
+ # Found etas that reach target - pick the SMALLEST one (most conservative)
185
+ below_indices = torch.where(below_target_mask)[0] + 1 # +1 because we excluded index 0
186
+ best_idx = below_indices[0].item() # Smallest eta that reaches target
187
+ best_eta = eta_grid[best_idx].item()
188
+ best_distance = distances[best_idx].item()
189
+
190
+ if debug:
191
+ print(f"[Step Size] Coarse grid: found eta={best_eta:.2e} that reaches target")
192
+ print(f"[Step Size] Distance: {best_distance:.2f} (target: {target_boundary:.2f}, under by {target_boundary - best_distance:.2f})")
193
+ else:
194
+ # No eta reaches target - find the one that gets closest (minimize distance to target)
195
+ non_zero_distances = distances[1:]
196
+ closest_idx = torch.argmin(torch.abs(non_zero_distances - target_boundary)) + 1
197
+ best_idx = closest_idx
198
+ best_eta = eta_grid[best_idx].item()
199
+ best_distance = distances[best_idx].item()
200
+
201
+ if debug:
202
+ print(f"[Step Size] Coarse grid: cannot reach target, best eta={best_eta:.2e}")
203
+ print(f"[Step Size] Distance: {best_distance:.2f} (target: {target_boundary:.2f}, over by {best_distance - target_boundary:.2f})")
204
+
205
+ # Check if eta=0 is better (already at target)
206
+ if dist_at_zero <= target_boundary:
207
+ if debug:
208
+ print(f"[Step Size] Distance at eta=0: {dist_at_zero:.2f} - already at/below target")
209
+ return 0.0
210
+
211
+ if debug:
212
+ print(f"[Step Size] Distance at eta=0: {dist_at_zero:.2f} (need to step)")
213
+
214
+ # Phase 1.5: Fine search around the best point found
215
+ # If best_eta is not at the boundaries, do a fine search around it
216
+ if best_idx > 0 and best_idx < len(eta_grid) - 1:
217
+ eta_low_bound = eta_grid[best_idx - 1].item()
218
+ eta_high_bound = eta_grid[best_idx + 1].item()
219
+
220
+ # Create a very fine linear grid between the neighboring points
221
+ fine_grid = torch.linspace(eta_low_bound, eta_high_bound, 50).to(image.device)
222
+ fine_distances = torch.tensor([distance(eta).item() for eta in fine_grid])
223
+
224
+ # Find the SMALLEST eta in fine grid that gets us AT OR BELOW target
225
+ fine_below_mask = fine_distances <= target_boundary
226
+
227
+ if fine_below_mask.any():
228
+ # Found fine etas that reach target - pick the SMALLEST
229
+ fine_below_indices = torch.where(fine_below_mask)[0]
230
+ fine_best_idx = fine_below_indices[0].item()
231
+ fine_best_eta = fine_grid[fine_best_idx].item()
232
+ fine_best_distance = fine_distances[fine_best_idx].item()
233
+
234
+ # Only update if this is better (smaller eta that still reaches target, or gets closer)
235
+ if fine_best_distance <= target_boundary and (best_distance > target_boundary or fine_best_eta < best_eta):
236
+ best_eta = fine_best_eta
237
+ best_distance = fine_best_distance
238
+ best_idx = len(eta_grid) + fine_best_idx
239
+
240
+ if debug:
241
+ print(f"[Step Size] Fine grid: improved to eta={best_eta:.2e}, distance={best_distance:.2f} (under by {target_boundary - best_distance:.2f})")
242
+ else:
243
+ # No fine eta reaches target - find closest
244
+ fine_best_idx = torch.argmin(torch.abs(fine_distances - target_boundary))
245
+ fine_best_eta = fine_grid[fine_best_idx].item()
246
+ fine_best_distance = fine_distances[fine_best_idx].item()
247
+
248
+ # Only update if closer to target than current best
249
+ if abs(fine_best_distance - target_boundary) < abs(best_distance - target_boundary):
250
+ best_eta = fine_best_eta
251
+ best_distance = fine_best_distance
252
+ best_idx = len(eta_grid) + fine_best_idx
253
+
254
+ if debug:
255
+ print(f"[Step Size] Fine grid: improved to eta={best_eta:.2e}, distance={best_distance:.2f} (over by {best_distance - target_boundary:.2f})")
256
+
257
+ # Always update the grid for potential bracketing
258
+ distances = torch.cat([distances, fine_distances])
259
+ errors = torch.cat([errors, fine_distances - target_boundary])
260
+ eta_grid = torch.cat([eta_grid, fine_grid])
261
+
262
+ # If best_eta is 0 and we're already at or below target, return 0
263
+ if best_eta == 0.0 and dist_at_zero <= target_boundary:
264
+ if debug:
265
+ print(f"[Step Size] Already at target, no step needed")
266
+ return 0.0
267
+
268
+ # If we've reached target, we can return (no need for golden section)
269
+ if best_distance <= target_boundary:
270
+ if debug:
271
+ print(f"[Step Size] Reached target with eta={best_eta:.2e}, returning")
272
+ return best_eta
273
+
274
+ # Phase 2: Check for bracketing around the best point
275
+ # Look for a sign change (crossing the target boundary)
276
+ bracket_found = False
277
+ eta_lo, eta_hi = None, None
278
+
279
+ # Check neighbors of best point
280
+ for i in range(len(eta_grid) - 1):
281
+ if errors[i] * errors[i + 1] < 0: # Sign change
282
+ eta_lo, eta_hi = eta_grid[i].item(), eta_grid[i + 1].item()
283
+ bracket_found = True
284
+ break
285
+
286
+ # Phase 3: Refine using golden section search
287
+ # Only refine if we haven't reached target yet and have a valid bracket
288
+ if best_eta > 0 and best_distance > target_boundary and best_idx > 0 and best_idx < len(eta_grid) - 1:
289
+ # Golden section search to find the smallest eta that reaches target_boundary
290
+ phi = (1 + 5**0.5) / 2 # Golden ratio
291
+ resphi = 2 - phi
292
+
293
+ # Search in a small window around best_eta
294
+ a = eta_grid[max(0, best_idx - 1)].item()
295
+ b = eta_grid[min(len(eta_grid) - 1, best_idx + 1)].item()
296
+
297
+ # Make sure we have a valid interval
298
+ if b - a < 1e-20:
299
+ if debug:
300
+ print(f"[Step Size] Interval too small for refinement, returning eta={best_eta:.2e}")
301
+ return best_eta
302
+
303
+ dist_a = distance(torch.tensor(a)).item()
304
+ dist_b = distance(torch.tensor(b)).item()
305
+
306
+ for _ in range(max_iters):
307
+ if abs(b - a) < 1e-20: # Extremely tight tolerance
308
+ break
309
+
310
+ # Golden section points
311
+ x1 = a + resphi * (b - a)
312
+ x2 = b - resphi * (b - a)
313
+
314
+ dist_x1 = distance(torch.tensor(x1)).item()
315
+ dist_x2 = distance(torch.tensor(x2)).item()
316
+
317
+ # Priority: prefer points that reach target (dist <= target_boundary)
318
+ # Among those, prefer smaller eta
319
+ # If neither reaches, prefer closer to target
320
+
321
+ x1_reaches = dist_x1 <= target_boundary
322
+ x2_reaches = dist_x2 <= target_boundary
323
+
324
+ if x1_reaches and not x2_reaches:
325
+ # x1 reaches target, x2 doesn't -> prefer x1's half
326
+ b = x2
327
+ dist_b = dist_x2
328
+ if x1 < best_eta or not (best_distance <= target_boundary):
329
+ best_eta = x1
330
+ best_distance = dist_x1
331
+ elif x2_reaches and not x1_reaches:
332
+ # x2 reaches target, x1 doesn't -> prefer x2's half
333
+ a = x1
334
+ dist_a = dist_x1
335
+ if x2 < best_eta or not (best_distance <= target_boundary):
336
+ best_eta = x2
337
+ best_distance = dist_x2
338
+ elif x1_reaches and x2_reaches:
339
+ # Both reach target -> prefer smaller eta (which is x1)
340
+ b = x2
341
+ dist_b = dist_x2
342
+ best_eta = x1
343
+ best_distance = dist_x1
344
+ else:
345
+ # Neither reaches target -> prefer closer to target
346
+ if abs(dist_x1 - target_boundary) < abs(dist_x2 - target_boundary):
347
+ b = x2
348
+ dist_b = dist_x2
349
+ if abs(dist_x1 - target_boundary) < abs(best_distance - target_boundary):
350
+ best_eta = x1
351
+ best_distance = dist_x1
352
+ else:
353
+ a = x1
354
+ dist_a = dist_x1
355
+ if abs(dist_x2 - target_boundary) < abs(best_distance - target_boundary):
356
+ best_eta = x2
357
+ best_distance = dist_x2
358
+
359
+ if debug:
360
+ if best_distance <= target_boundary:
361
+ print(f"[Step Size] Final: eta={best_eta:.2e}, distance={best_distance:.2f} (under by {target_boundary - best_distance:.2f})")
362
+ else:
363
+ print(f"[Step Size] Final: eta={best_eta:.2e}, distance={best_distance:.2f} (over by {best_distance - target_boundary:.2f})")
364
+ else:
365
+ if debug:
366
+ print(f"[Step Size] No refinement needed, returning best: eta={best_eta:.2e}")
367
+
368
+ return best_eta
369
+
cdim/image_utils.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import List, Optional, Tuple, Union
2
 
3
  import torch
 
4
  from torchvision.transforms import ToPILImage
 
5
 
6
  def save_to_image(tensor, filename):
7
  """
@@ -66,3 +68,209 @@ def randn_tensor(
66
  latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
67
 
68
  return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Optional, Tuple, Union
2
 
3
  import torch
4
+ from torch import Tensor
5
  from torchvision.transforms import ToPILImage
6
+ from typing import Callable
7
 
8
  def save_to_image(tensor, filename):
9
  """
 
68
  latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
69
 
70
  return latents
71
+
72
+
73
+ @torch.no_grad()
74
+ def estimate_variance(
75
+ operator: Callable[[Tensor], Tensor],
76
+ y: Tensor, # Ax_0 + noise (shape (m,))
77
+ alphabar_t: float,
78
+ in_shape: tuple[int, ...], # e.g. (1, 3, 256, 256)
79
+ trace: float, # tr(AA^T)
80
+ sigma_y: float,
81
+ n_trace_samples: int = 64,
82
+ n_y_samples: int = 64,
83
+ device: torch.device | str = "cuda",
84
+ dtype: torch.dtype = torch.float32,
85
+ ) -> float:
86
+ """
87
+ Monte-Carlo estimator of Var(||A x_t – y||^2) without access to A^T.
88
+
89
+ Note: For inpainting operators with a 'select' method, this computes variance
90
+ over only the observed pixels.
91
+ """
92
+ use_select = hasattr(operator, 'select')
93
+
94
+ # For inpainting, select only observed pixels from y
95
+ if use_select:
96
+ y_selected = operator.select(y).flatten()
97
+ y = y_selected.to(device=device, dtype=dtype)
98
+ else:
99
+ y = y.to(device=device, dtype=dtype).flatten()
100
+ m = y.numel()
101
+
102
+ # ---------------- tr((AA^T)^2)
103
+ t1_acc = torch.zeros((), device=device, dtype=dtype)
104
+ for _ in range(n_trace_samples):
105
+ v = torch.randn(in_shape, device=device, dtype=dtype)
106
+ w = torch.randn(in_shape, device=device, dtype=dtype)
107
+ if use_select:
108
+ Av = operator.select(v).flatten()
109
+ Aw = operator.select(w).flatten()
110
+ else:
111
+ Av = operator(v).flatten()
112
+ Aw = operator(w).flatten()
113
+ s = torch.dot(Av, Aw)
114
+ t1_acc += s * s
115
+ T1 = t1_acc / n_trace_samples
116
+
117
+ # ---------------- y^T AA^T y
118
+ t2_acc = torch.zeros((), device=device, dtype=dtype)
119
+ for _ in range(n_y_samples):
120
+ v = torch.randn(in_shape, device=device, dtype=dtype)
121
+ if use_select:
122
+ Av = operator.select(v).flatten()
123
+ else:
124
+ Av = operator(v).flatten()
125
+ s = torch.dot(y, Av)
126
+ t2_acc += s * s
127
+ T2 = t2_acc / n_y_samples
128
+
129
+ # ---------------- assemble variance
130
+ alpha_bar = torch.as_tensor(alphabar_t, dtype=dtype, device=device)
131
+ sigma2 = torch.as_tensor(sigma_y**2, dtype=dtype, device=device)
132
+
133
+ a2 = (torch.sqrt(alpha_bar) - 1.0).pow(2) # (1-√ᾱ)^2
134
+ b = 1.0 - alpha_bar # (1-ᾱ)
135
+
136
+ var = 2 * (b*b * T1 + 2 * b * sigma2 * trace + m * sigma2.pow(2)) \
137
+ + 4 * a2 * (b * (T2 - sigma2 * trace) + sigma2 * (y.pow(2).sum() - m * sigma2))
138
+ return var.item()
139
+
140
+
141
+
142
+ def compute_operator_distance(
143
+ operator: Callable[[Tensor], Tensor],
144
+ x: Tensor,
145
+ y: Tensor,
146
+ squared: bool = True
147
+ ) -> Tensor:
148
+ """
149
+ Compute ||Ax - y||^2 (or ||Ax - y|| if squared=False).
150
+
151
+ For inpainting operators with a 'select' method, this computes the distance
152
+ over only the observed pixels. Otherwise uses the standard operator call.
153
+
154
+ Args:
155
+ operator: The forward operator A
156
+ x: Input tensor (e.g., image)
157
+ y: Measurement tensor (for inpainting, this should be the full masked measurement)
158
+ squared: If True, returns squared L2 norm. If False, returns L2 norm.
159
+
160
+ Returns:
161
+ Scalar tensor representing the distance
162
+ """
163
+ if hasattr(operator, 'select'):
164
+ # Use select method for inpainting operators
165
+ # Both x and y need to be selected to extract only observed pixels
166
+ Ax = operator.select(x).flatten()
167
+ y_selected = operator.select(y).flatten()
168
+ else:
169
+ # Standard operator application
170
+ Ax = operator(x).flatten()
171
+ y_selected = y.flatten()
172
+
173
+ diff = Ax - y_selected
174
+ if squared:
175
+ return (diff ** 2).sum()
176
+ else:
177
+ return torch.sqrt((diff ** 2).sum())
178
+
179
+
180
+ def trace_AAt(
181
+ operator: Callable[[torch.Tensor], torch.Tensor],
182
+ input_shape = (1, 3, 256, 256),
183
+ num_samples: int = 256,
184
+ device: str = "cuda" # or "cpu"
185
+ ) -> float:
186
+ """
187
+ Unbiased Monte-Carlo estimate of tr(A Aᵀ) for a black-box linear operator.
188
+
189
+ operator : function that maps a (1,C,H,W) tensor → down-sampled tensor
190
+ input_shape : shape expected by the operator
191
+ num_samples : more samples → lower variance (error ≈ O(1/√num_samples))
192
+
193
+ Note: For inpainting operators with a 'select' method, this computes the trace
194
+ over only the observed pixels, not the full tensor with zeros.
195
+ """
196
+ total = 0.0
197
+ use_select = hasattr(operator, 'select')
198
+
199
+ for _ in range(num_samples):
200
+ # Rademacher noise (±1). Use torch.randn for Gaussian instead.
201
+ z = torch.empty(input_shape, device=device).bernoulli_().mul_(2).sub_(1)
202
+ if use_select:
203
+ Az = operator.select(z).flatten() # only observed pixels
204
+ else:
205
+ Az = operator(z).flatten() # output can have any shape
206
+ total += torch.dot(Az, Az).item() # ||Az||²
207
+ return total / num_samples
208
+
209
+
210
+ # def trace_AAt_squared(
211
+ # operator: Callable[[torch.Tensor], torch.Tensor],
212
+ # input_shape: tuple = (1, 3, 256, 256),
213
+ # num_samples: int = 32,
214
+ # device: str = "cuda") -> float:
215
+ # """
216
+ # Estimates tr((A Aᵀ)^2) using Hutchinson's method and autograd for Aᵀ.
217
+ # """
218
+ # total = 0.0
219
+ # for _ in range(num_samples):
220
+ # # Sample z ~ N(0, I) (same shape as operator's *output*)
221
+ # z = torch.randn(operator(torch.zeros(input_shape, device=device)).shape, device=device)
222
+
223
+ # # Compute Aᵀz via gradient: ∇_w [⟨operator(w), z⟩] = Aᵀz
224
+ # w = torch.randn(input_shape, device=device, requires_grad=True)
225
+ # Az = operator(w).flatten()
226
+ # loss = torch.dot(Az, z.flatten()) # ⟨Az, z⟩ = ⟨w, Aᵀz⟩
227
+ # A_adj_z = torch.autograd.grad(loss, w, retain_graph=False)[0]
228
+
229
+ # # Compute AAᵀz = operator(Aᵀz)
230
+ # AA_adj_z = operator(A_adj_z.detach()).flatten()
231
+ # total += torch.dot(AA_adj_z, AA_adj_z).item() # ||AAᵀz||²
232
+ # return total / num_samples
233
+
234
+
235
+ # def compute_yAAy(
236
+ # operator: Callable[[torch.Tensor], torch.Tensor],
237
+ # y: torch.Tensor,
238
+ # input_shape: tuple = (1, 3, 256, 256),
239
+ # device: str = "cuda") -> float:
240
+ # """
241
+ # Computes yᵀ (A Aᵀ) y using autograd to get Aᵀy.
242
+ # """
243
+ # # Compute Aᵀy via gradient: ∇_w [⟨operator(w), y⟩] = Aᵀy
244
+ # w = torch.randn(input_shape, device=device, requires_grad=True)
245
+ # Az = operator(w).flatten()
246
+ # loss = torch.dot(Az, y.flatten())
247
+ # A_adj_y = torch.autograd.grad(loss, w, retain_graph=False)[0]
248
+
249
+ # # Compute A Aᵀ y = operator(Aᵀy)
250
+ # AA_adj_y = operator(A_adj_y.detach()).flatten()
251
+ # return torch.dot(AA_adj_y, y.flatten()).item()
252
+
253
+ # def variance_Axt_minus_y_sq(
254
+ # operator: Callable[[torch.Tensor], torch.Tensor],
255
+ # y: torch.Tensor,
256
+ # alphabar_t: float,
257
+ # input_shape: tuple = (1, 3, 256, 256),
258
+ # num_samples_trace: int = 32,
259
+ # device: str = "cuda"
260
+ # ) -> float:
261
+ # """
262
+ # Computes Var(||A𝐱ₜ - y||²) = 2(1-ᾱₜ)² tr((AAᵀ)²) + 4(1-ᾱₜ)(√ᾱₜ -1)² yᵀAAᵀy.
263
+ # """
264
+ # # Term 1: 2(1-ᾱₜ)^2 * tr((AAᵀ)^2)
265
+ # tr_AAt_sq = trace_AAt_squared(operator, input_shape, num_samples_trace, device)
266
+ # term1 = 2 * (1 - alphabar_t)**2 * tr_AAt_sq
267
+
268
+ # # Term 2: 4(1-ᾱₜ)(√ᾱₜ -1)^2 * yᵀAAᵀy
269
+ # yAAy = compute_yAAy(operator, y, input_shape, device)
270
+ # term2 = 4 * (1 - alphabar_t) * (torch.sqrt(torch.tensor(alphabar_t)) - 1)**2 * yAAy
271
+
272
+ # return term1 + term2
273
+
274
+
275
+
276
+
cdim/operators/gaussian_blur_operator.py CHANGED
@@ -4,7 +4,7 @@ from cdim.operators.blur_kernel import BlurKernel
4
 
5
  @register_operator(name='gaussian_blur')
6
  class GaussianBlurOperator:
7
- def __init__(self, kernel_size, intensity, device='cpu'):
8
  self.device = device
9
  self.kernel_size = kernel_size
10
  self.conv = BlurKernel(blur_type='gaussian',
 
4
 
5
  @register_operator(name='gaussian_blur')
6
  class GaussianBlurOperator:
7
+ def __init__(self, kernel_size, intensity, device='cuda'):
8
  self.device = device
9
  self.kernel_size = kernel_size
10
  self.conv = BlurKernel(blur_type='gaussian',
cdim/operators/random_box_masker.py CHANGED
@@ -54,3 +54,31 @@ class RandomBoxMasker:
54
 
55
  # Apply the mask to the input tensor
56
  return tensor * self.mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Apply the mask to the input tensor
56
  return tensor * self.mask
57
+
58
+ def select(self, tensor):
59
+ """
60
+ Extract only the observed pixels from the tensor (pixels outside the box).
61
+
62
+ Args:
63
+ tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
64
+
65
+ Returns:
66
+ torch.Tensor: Flattened tensor containing only observed pixels (b, num_observed)
67
+ """
68
+ b, c, h, w = tensor.shape
69
+ assert c == self.channels and h == self.height and w == self.width, \
70
+ f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
71
+
72
+ # Move the mask to the same device as the input tensor if necessary
73
+ if tensor.device != self.mask.device:
74
+ self.mask = self.mask.to(tensor.device)
75
+
76
+ # Extract only observed pixels (where mask is 1, outside the box)
77
+ observed = (tensor * self.mask).flatten(1) # (b, c*h*w)
78
+ # Keep only non-zero elements
79
+ mask_flat = self.mask.flatten(1) # (1, c*h*w)
80
+ return observed[:, mask_flat[0] > 0] # (b, num_observed)
81
+
82
+ def get_num_observed(self):
83
+ """Return the number of observed elements."""
84
+ return int(self.mask.sum().item())
cdim/operators/random_pixel_masker.py CHANGED
@@ -56,3 +56,32 @@ class RandomPixelMasker:
56
 
57
  # Apply the mask to the input tensor
58
  return tensor * self.mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Apply the mask to the input tensor
58
  return tensor * self.mask
59
+
60
+ def select(self, tensor):
61
+ """
62
+ Extract only the observed pixels from the tensor.
63
+
64
+ Args:
65
+ tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
66
+
67
+ Returns:
68
+ torch.Tensor: Flattened tensor containing only observed pixels (b, num_observed)
69
+ """
70
+ b, c, h, w = tensor.shape
71
+ assert c == self.channels and h == self.height and w == self.width, \
72
+ f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
73
+
74
+ # Move the mask to the same device as the input tensor if necessary
75
+ if tensor.device != self.mask.device:
76
+ self.mask = self.mask.to(tensor.device)
77
+
78
+ # Extract only observed pixels (where mask is 1)
79
+ # mask is (1, c, h, w), we want to select pixels across all channels
80
+ observed = (tensor * self.mask).flatten(1) # (b, c*h*w)
81
+ # Keep only non-zero elements
82
+ mask_flat = self.mask.flatten(1) # (1, c*h*w)
83
+ return observed[:, mask_flat[0] > 0] # (b, num_observed)
84
+
85
+ def get_num_observed(self):
86
+ """Return the number of observed elements."""
87
+ return int(self.mask.sum().item())
inference.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import os
3
  import yaml
4
  import time
 
5
 
6
  from PIL import Image
7
  import numpy as np
@@ -15,7 +16,6 @@ from cdim.image_utils import save_to_image
15
  from cdim.dps_model.dps_unet import create_model
16
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
17
  from cdim.diffusion.diffusion_pipeline import run_diffusion
18
- from cdim.eta_scheduler import EtaScheduler
19
 
20
 
21
  def load_image(path):
@@ -36,13 +36,39 @@ def load_yaml(file_path: str) -> dict:
36
  return config
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def main(args):
40
  device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
41
  print(f"Using device {device_str}")
42
  device = torch.device(device_str)
43
 
44
  os.makedirs(args.output_dir, exist_ok=True)
45
- original_image = load_image(args.input_image).to(device)
46
 
47
  # Load the noise function
48
  noise_config = load_yaml(args.noise_config)
@@ -78,43 +104,52 @@ def main(args):
78
  steps_offset=0,
79
  )
80
 
81
- noisy_measurement = noise_function(operator(original_image))
82
- save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
83
-
84
- eta_scheduler = EtaScheduler(args.eta_type, operator.name, args.T,
85
- args.K, args.loss, noise_function, args.lambda_val)
86
-
87
- t0 = time.time()
88
- output_image = run_diffusion(
89
- model, ddim_scheduler,
90
- noisy_measurement, operator, noise_function, device,
91
- eta_scheduler,
92
- num_inference_steps=args.T,
93
- K=args.K,
94
- model_type=model_type,
95
- loss_type=args.loss)
96
- print(f"total time {time.time() - t0}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
99
 
100
  if __name__ == '__main__':
101
  parser = argparse.ArgumentParser()
102
- parser.add_argument("input_image", type=str)
103
  parser.add_argument("T", type=int)
104
- parser.add_argument("K", type=int)
105
  parser.add_argument("operator_config", type=str)
106
  parser.add_argument("noise_config", type=str)
107
  parser.add_argument("model_config", type=str)
108
- parser.add_argument("--eta-type", type=str,
109
- choices=['gradnorm', 'expected_gradnorm'],
110
- default='expected_gradnorm')
111
  parser.add_argument("--lambda-val", type=float,
112
  default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.")
113
- parser.add_argument("--output-dir", default=".", type=str)
114
- parser.add_argument("--loss", type=str,
115
- choices=['l2', 'kl', 'categorical_kl'], default='l2',
116
- help="Algorithm to use. Options: 'l2', 'kl', 'categorical_kl'. Default is 'l2'."
117
- )
118
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
 
 
 
119
 
120
  main(parser.parse_args())
 
2
  import os
3
  import yaml
4
  import time
5
+ from pathlib import Path
6
 
7
  from PIL import Image
8
  import numpy as np
 
16
  from cdim.dps_model.dps_unet import create_model
17
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
18
  from cdim.diffusion.diffusion_pipeline import run_diffusion
 
19
 
20
 
21
  def load_image(path):
 
36
  return config
37
 
38
 
39
+ def process_image(image_path, output_dir, model, ddim_scheduler, operator, noise_function,
40
+ device, args, model_type):
41
+ """
42
+ Process a single image with the given model and parameters
43
+ """
44
+ original_image = load_image(image_path).to(device)
45
+
46
+ # Get the base filename without extension
47
+ base_name = Path(image_path).stem
48
+
49
+ noisy_measurement = noise_function(operator(original_image))
50
+ save_to_image(noisy_measurement, os.path.join(output_dir, f"{base_name}_noisy_measurement.png"))
51
+
52
+ t0 = time.time()
53
+ output_image = run_diffusion(
54
+ model, ddim_scheduler,
55
+ noisy_measurement, operator, noise_function, device,
56
+ args.stopping_sigma,
57
+ num_inference_steps=args.T,
58
+ K=args.K,
59
+ model_type=model_type,
60
+ original_image=original_image)
61
+ print(f"Processing time for {base_name}: {time.time() - t0:.2f}s")
62
+
63
+ save_to_image(output_image, os.path.join(output_dir, f"{base_name}_output.png"))
64
+
65
+
66
  def main(args):
67
  device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
68
  print(f"Using device {device_str}")
69
  device = torch.device(device_str)
70
 
71
  os.makedirs(args.output_dir, exist_ok=True)
 
72
 
73
  # Load the noise function
74
  noise_config = load_yaml(args.noise_config)
 
104
  steps_offset=0,
105
  )
106
 
107
+ # Process input (either a single image or all images in a directory)
108
+ input_path = Path(args.input)
109
+
110
+ if input_path.is_file():
111
+ # Process a single image
112
+ print(f"Processing single image: {input_path.name}")
113
+ process_image(
114
+ str(input_path), args.output_dir, model, ddim_scheduler,
115
+ operator, noise_function, device, args, model_type
116
+ )
117
+ elif input_path.is_dir():
118
+ # Process all images in the directory
119
+ image_files = [
120
+ f for f in input_path.iterdir()
121
+ if not f.name.startswith('.') and f.suffix.lower() in ['.png', '.jpg', '.jpeg']
122
+ ]
123
+ image_files = sorted(image_files)
124
+
125
+ print(f"Found {len(image_files)} images to process")
126
+
127
+ for image_file in image_files:
128
+ print(f"Processing {image_file.name}...")
129
+ # Optional, recreate operator (uncomment to use same operator)
130
+ operator = get_operator(**operator_config)
131
+ process_image(
132
+ str(image_file), args.output_dir, model, ddim_scheduler,
133
+ operator, noise_function, device, args, model_type
134
+ )
135
+ else:
136
+ raise ValueError(f"Input path '{input_path}' is neither a file nor a directory")
137
 
 
138
 
139
  if __name__ == '__main__':
140
  parser = argparse.ArgumentParser()
141
+ parser.add_argument("input", type=str, help="Path to input image or folder containing input images")
142
  parser.add_argument("T", type=int)
 
143
  parser.add_argument("operator_config", type=str)
144
  parser.add_argument("noise_config", type=str)
145
  parser.add_argument("model_config", type=str)
146
+ parser.add_argument("--stopping-sigma", type=float, default=0.1, help="How many std deviations away to stop")
 
 
147
  parser.add_argument("--lambda-val", type=float,
148
  default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.")
149
+ parser.add_argument("--output-dir", default="output", type=str)
 
 
 
 
150
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
151
+ parser.add_argument("--K", type=int, default=20,
152
+ help="Cap the number of steps K at any iteration. Helps avoid edge cases or cap NFEs.")
153
+
154
 
155
  main(parser.parse_args())
requirements.txt CHANGED
@@ -1,11 +1,8 @@
1
- accelerate
2
- diffusers
3
- gradio
4
- numpy
5
  Pillow
6
- PyYAML
7
- scipy
8
- --extra-index-url https://download.pytorch.org/whl/cu113
9
- torch
10
- torchvision
11
- tqdm
 
1
+ diffusers==0.30.3
2
+ gradio==5.3.0
3
+ numpy==2.1.2
 
4
  Pillow
5
+ PyYAML==6.0.2
6
+ scipy==1.14.1
7
+ tqdm==4.66.5
8
+ accelerate