Chi squared method
Browse files- README.md +66 -11
- app.py +15 -15
- cdim/diffusion/diffusion_pipeline.py +94 -40
- cdim/eta_scheduler.py +0 -61
- cdim/eta_utils.py +369 -0
- cdim/image_utils.py +208 -0
- cdim/operators/gaussian_blur_operator.py +1 -1
- cdim/operators/random_box_masker.py +28 -0
- cdim/operators/random_pixel_masker.py +29 -0
- inference.py +64 -29
- requirements.txt +7 -10
README.md
CHANGED
|
@@ -1,11 +1,66 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Linearly Constrained Diffusion Implicit Models
|
| 2 |
+

|
| 3 |
+
|
| 4 |
+
### Authors
|
| 5 |
+
[Vivek Jayaram](http://www.vivekjayaram.com/), [John Thickstun](https://johnthickstun.com/), [Ira Kemelmacher-Shlizerman](https://homes.cs.washington.edu/~kemelmi/), and [Steve Seitz](https://homes.cs.washington.edu/~seitz/)
|
| 6 |
+
|
| 7 |
+
### Links
|
| 8 |
+
[[Gradio Demo]](https://huggingface.co/spaces/vivjay30/cdim) [[Project Page]](https://grail.cs.washington.edu/projects/cdim/) [[Paper]](https://arxiv.org/abs/2411.00359)
|
| 9 |
+
|
| 10 |
+
### Summary
|
| 11 |
+
We solve noisy linear inverse problems with diffusion models. The method is fast and addresses many problems like inpainting, super-resolution, gaussian deblur, and poisson noise.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## Getting started
|
| 15 |
+
|
| 16 |
+
Recommended environment: Python 3.11, Cuda 12, Conda. For lower verions please adjust the dependencies below.
|
| 17 |
+
|
| 18 |
+
### 1) Clone the repository
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
git clone https://github.com/vivjay30/cdim
|
| 22 |
+
|
| 23 |
+
cd cdim
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 2) Install dependencies
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
conda create -n cdim python=3.11
|
| 30 |
+
|
| 31 |
+
conda activate cdim
|
| 32 |
+
|
| 33 |
+
pip install -r requirements.txt
|
| 34 |
+
|
| 35 |
+
pip install torch==2.4.1+cu124 torchvision-0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Inference Examples
|
| 39 |
+
|
| 40 |
+
(The underlying diffusion models will be automatically downloaded on the first run).
|
| 41 |
+
|
| 42 |
+
#### CelebHQ Inpainting Example (T'=25 Denoising Steps)
|
| 43 |
+
|
| 44 |
+
`python inference.py sample_images/celebhq/00001.jpg 25 operator_configs/box_inpainting_config.yaml noise_configs/gaussian_noise_config.yaml google/ddpm-celebahq-256`
|
| 45 |
+
|
| 46 |
+
#### LSUN Churches Gaussian Deblur Example (T'=25 Denoising Steps)
|
| 47 |
+
`python inference.py sample_images/lsun_church.png 25 operator_configs/gaussian_blur_config.yaml noise_configs/gaussian_noise_config.yaml google/ddpm-church-256`
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## FFHQ and Imagenet Models
|
| 51 |
+
These models are generally not as strong as the google ddpm models, but are used for comparisons with baseline methods.
|
| 52 |
+
|
| 53 |
+
From [this link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh?usp=sharing), download the checkpoints "ffhq_10m.pt" and "imagenet_256.pt" to models/
|
| 54 |
+
|
| 55 |
+
#### Imagenet Super Resolution Example
|
| 56 |
+
Here we set T'=50 to show the algorithm running slower
|
| 57 |
+
`python inference.py sample_images/imagenet_val_00002.png 50 operator_configs/super_resolution_config.yaml noise_configs/gaussian_noise_config.yaml models/imagenet_model_config.yaml`
|
| 58 |
+
|
| 59 |
+
#### FFHQ Random Inpainting (Faster)
|
| 60 |
+
Here we set T'=10 to show the algorithm running faster
|
| 61 |
+
`python inference.py sample_images/ffhq_00010.png 10 operator_configs/random_inpainting_config.yaml noise_configs/gaussian_noise_config.yaml models/ffhq_model_config.yaml`
|
| 62 |
+
|
| 63 |
+
#### A Note on Exact Recovery
|
| 64 |
+
If you set the measurement noise to 0 in gaussian_noise_config.yaml, then the recovered image should match the the observation y exactly (e.g. inpainting doesn't chance observed pixels). In practice, this doesn't happen because the diffusion schedule sets $\overline{\alpha}_0 = 0.999$ for numeric stability, meaning a tiny amount of noise is injected even at t=0.
|
| 65 |
+
|
| 66 |
+
|
app.py
CHANGED
|
@@ -2,24 +2,21 @@ import gradio as gr
|
|
| 2 |
import spaces
|
| 3 |
import torch
|
| 4 |
import yaml
|
| 5 |
-
import os
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
from cdim.noise import get_noise
|
| 9 |
from cdim.operators import get_operator
|
| 10 |
-
from cdim.image_utils import save_to_image
|
| 11 |
-
from cdim.dps_model.dps_unet import create_model
|
| 12 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
| 13 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
| 14 |
-
from cdim.eta_scheduler import EtaScheduler
|
| 15 |
from diffusers import DiffusionPipeline
|
| 16 |
|
| 17 |
-
# Global variables
|
| 18 |
model = None
|
| 19 |
ddim_scheduler = None
|
| 20 |
model_type = None
|
| 21 |
curr_model_name = None
|
| 22 |
|
|
|
|
| 23 |
def load_image(image_path):
|
| 24 |
"""Process input image to tensor format."""
|
| 25 |
image = Image.open(image_path)
|
|
@@ -27,22 +24,26 @@ def load_image(image_path):
|
|
| 27 |
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
|
| 28 |
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
|
| 29 |
|
|
|
|
| 30 |
def load_yaml(file_path: str) -> dict:
|
|
|
|
| 31 |
with open(file_path) as f:
|
| 32 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 33 |
return config
|
| 34 |
|
|
|
|
| 35 |
def convert_to_np(torch_image):
|
| 36 |
return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
|
| 37 |
|
|
|
|
| 38 |
@spaces.GPU
|
| 39 |
-
def process_image(image_choice, noise_sigma, operator_key, T,
|
| 40 |
-
"""Combined function to handle both generation and restoration"""
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
|
| 43 |
# Initialize model inside GPU-decorated function
|
| 44 |
global model, curr_model_name, ddim_scheduler, model_type
|
| 45 |
-
model_name = "google/ddpm-celebahq-256" if "
|
| 46 |
|
| 47 |
if model is None or curr_model_name != model_name:
|
| 48 |
model_type = "diffusers"
|
|
@@ -85,22 +86,22 @@ def process_image(image_choice, noise_sigma, operator_key, T, K):
|
|
| 85 |
noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
|
| 86 |
|
| 87 |
# Run restoration
|
| 88 |
-
eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
|
| 89 |
output_image = run_diffusion(
|
| 90 |
model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
|
| 91 |
-
|
| 92 |
)
|
| 93 |
|
| 94 |
output_image = Image.fromarray(convert_to_np(output_image[0]))
|
| 95 |
return noisy_image, output_image
|
| 96 |
|
|
|
|
| 97 |
# Gradio interface
|
| 98 |
with gr.Blocks() as demo:
|
| 99 |
gr.Markdown("# Noisy Image Restoration with Diffusion Models")
|
| 100 |
|
| 101 |
with gr.Row():
|
| 102 |
-
T = gr.Slider(10, 200, value=
|
| 103 |
-
|
| 104 |
noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
|
| 105 |
|
| 106 |
image_select = gr.Dropdown(
|
|
@@ -119,12 +120,11 @@ with gr.Blocks() as demo:
|
|
| 119 |
noisy_image = gr.Image(label="Noisy Image")
|
| 120 |
restored_image = gr.Image(label="Restored Image")
|
| 121 |
|
| 122 |
-
# Single function call instead of chaining
|
| 123 |
run_button.click(
|
| 124 |
fn=process_image,
|
| 125 |
-
inputs=[image_select, noise_sigma, operator_select, T,
|
| 126 |
outputs=[noisy_image, restored_image]
|
| 127 |
)
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 2 |
import spaces
|
| 3 |
import torch
|
| 4 |
import yaml
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
from cdim.noise import get_noise
|
| 8 |
from cdim.operators import get_operator
|
|
|
|
|
|
|
| 9 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
| 10 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
|
|
|
| 11 |
from diffusers import DiffusionPipeline
|
| 12 |
|
| 13 |
+
# Global variables for model and scheduler (initialized inside GPU-decorated function)
|
| 14 |
model = None
|
| 15 |
ddim_scheduler = None
|
| 16 |
model_type = None
|
| 17 |
curr_model_name = None
|
| 18 |
|
| 19 |
+
|
| 20 |
def load_image(image_path):
|
| 21 |
"""Process input image to tensor format."""
|
| 22 |
image = Image.open(image_path)
|
|
|
|
| 24 |
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
|
| 25 |
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
|
| 26 |
|
| 27 |
+
|
| 28 |
def load_yaml(file_path: str) -> dict:
|
| 29 |
+
"""Load configurations from a YAML file."""
|
| 30 |
with open(file_path) as f:
|
| 31 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 32 |
return config
|
| 33 |
|
| 34 |
+
|
| 35 |
def convert_to_np(torch_image):
|
| 36 |
return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
|
| 37 |
|
| 38 |
+
|
| 39 |
@spaces.GPU
|
| 40 |
+
def process_image(image_choice, noise_sigma, operator_key, T, stopping_sigma):
|
| 41 |
+
"""Combined function to handle both generation and restoration."""
|
| 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
|
| 44 |
# Initialize model inside GPU-decorated function
|
| 45 |
global model, curr_model_name, ddim_scheduler, model_type
|
| 46 |
+
model_name = "google/ddpm-celebahq-256" if "CelebA" in image_choice else "google/ddpm-church-256"
|
| 47 |
|
| 48 |
if model is None or curr_model_name != model_name:
|
| 49 |
model_type = "diffusers"
|
|
|
|
| 86 |
noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))
|
| 87 |
|
| 88 |
# Run restoration
|
|
|
|
| 89 |
output_image = run_diffusion(
|
| 90 |
model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
|
| 91 |
+
stopping_sigma, num_inference_steps=T, model_type=model_type
|
| 92 |
)
|
| 93 |
|
| 94 |
output_image = Image.fromarray(convert_to_np(output_image[0]))
|
| 95 |
return noisy_image, output_image
|
| 96 |
|
| 97 |
+
|
| 98 |
# Gradio interface
|
| 99 |
with gr.Blocks() as demo:
|
| 100 |
gr.Markdown("# Noisy Image Restoration with Diffusion Models")
|
| 101 |
|
| 102 |
with gr.Row():
|
| 103 |
+
T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)")
|
| 104 |
+
stopping_sigma = gr.Slider(0.1, 5.0, value=0.1, step=0.1, label="Stopping Sigma (c)")
|
| 105 |
noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
|
| 106 |
|
| 107 |
image_select = gr.Dropdown(
|
|
|
|
| 120 |
noisy_image = gr.Image(label="Noisy Image")
|
| 121 |
restored_image = gr.Image(label="Restored Image")
|
| 122 |
|
|
|
|
| 123 |
run_button.click(
|
| 124 |
fn=process_image,
|
| 125 |
+
inputs=[image_select, noise_sigma, operator_select, T, stopping_sigma],
|
| 126 |
outputs=[noisy_image, restored_image]
|
| 127 |
)
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
cdim/diffusion/diffusion_pipeline.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import torch
|
| 2 |
from tqdm import tqdm
|
| 3 |
|
| 4 |
-
from cdim.image_utils import randn_tensor
|
| 5 |
from cdim.discrete_kl_loss import discrete_kl_loss
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def compute_kl_gaussian(residuals, sigma):
|
| 8 |
# Only 0 centered for now
|
|
@@ -23,13 +25,13 @@ def run_diffusion(
|
|
| 23 |
operator,
|
| 24 |
noise_function,
|
| 25 |
device,
|
| 26 |
-
|
| 27 |
num_inference_steps: int = 1000,
|
| 28 |
-
K=
|
| 29 |
image_dim=256,
|
| 30 |
image_channels=3,
|
| 31 |
model_type="diffusers",
|
| 32 |
-
|
| 33 |
):
|
| 34 |
batch_size = noisy_observation.shape[0]
|
| 35 |
image_shape = (batch_size, image_channels, image_dim, image_dim)
|
|
@@ -38,18 +40,60 @@ def run_diffusion(
|
|
| 38 |
scheduler.set_timesteps(num_inference_steps, device=device)
|
| 39 |
t_skip = scheduler.timesteps[0] - scheduler.timesteps[1]
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), desc="Processing timesteps"):
|
|
|
|
|
|
|
|
|
|
| 42 |
# 1. predict noise model_output
|
| 43 |
model_output = model(image, t.unsqueeze(0).to(device))
|
| 44 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
# 2. compute previous image: x_t -> x_t-1
|
| 47 |
image = scheduler.step(model_output, t, image).prev_sample
|
| 48 |
image.requires_grad_()
|
| 49 |
alpha_prod_t_prev = scheduler.alphas_cumprod[t-t_skip] if t-t_skip >= 0 else 1
|
| 50 |
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
if t <= 0: break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
with torch.enable_grad():
|
| 55 |
# Calculate x^hat_0
|
|
@@ -57,42 +101,52 @@ def run_diffusion(
|
|
| 57 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
| 58 |
x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
mask = x_0_pixel > 2 # Avoid numeric issues with pixel values near 0
|
| 77 |
-
pearson = residuals[mask] / torch.sqrt(x_0_pixel[mask] * noise_function.rate)
|
| 78 |
-
pearson_flat = pearson.view(-1)
|
| 79 |
-
kl_div = compute_kl_gaussian(pearson_flat, 1.0)
|
| 80 |
-
kl_div.backward()
|
| 81 |
-
|
| 82 |
-
elif loss_type == "categorical_kl" and noise_function.name == "bimodal":
|
| 83 |
-
diff = (operator(x_0) - noisy_observation)
|
| 84 |
-
indices = operator(torch.ones(image.shape).to(device))
|
| 85 |
-
diff = diff[indices > 0] # Don't consider masked out pixels in the distribution
|
| 86 |
-
empirical_distribution = noise_function.sample_noise_distribution(image).to(device).view(-1)
|
| 87 |
-
loss = discrete_kl_loss(diff, empirical_distribution, num_bins=15)
|
| 88 |
-
print(f"Categorical KL {loss}")
|
| 89 |
-
loss.backward()
|
| 90 |
-
|
| 91 |
-
else:
|
| 92 |
-
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
| 93 |
-
|
| 94 |
-
step_size = eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
|
| 95 |
image -= step_size * image.grad
|
|
|
|
|
|
|
| 96 |
image = image.detach().requires_grad_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
return image
|
|
|
|
| 1 |
import torch
|
| 2 |
from tqdm import tqdm
|
| 3 |
|
| 4 |
+
from cdim.image_utils import randn_tensor, trace_AAt, estimate_variance, save_to_image, compute_operator_distance
|
| 5 |
from cdim.discrete_kl_loss import discrete_kl_loss
|
| 6 |
+
from cdim.eta_utils import calculate_best_step_size, initial_guess_step_size
|
| 7 |
+
|
| 8 |
|
| 9 |
def compute_kl_gaussian(residuals, sigma):
|
| 10 |
# Only 0 centered for now
|
|
|
|
| 25 |
operator,
|
| 26 |
noise_function,
|
| 27 |
device,
|
| 28 |
+
stopping_sigma,
|
| 29 |
num_inference_steps: int = 1000,
|
| 30 |
+
K=20,
|
| 31 |
image_dim=256,
|
| 32 |
image_channels=3,
|
| 33 |
model_type="diffusers",
|
| 34 |
+
original_image=None
|
| 35 |
):
|
| 36 |
batch_size = noisy_observation.shape[0]
|
| 37 |
image_shape = (batch_size, image_channels, image_dim, image_dim)
|
|
|
|
| 40 |
scheduler.set_timesteps(num_inference_steps, device=device)
|
| 41 |
t_skip = scheduler.timesteps[0] - scheduler.timesteps[1]
|
| 42 |
|
| 43 |
+
data = []
|
| 44 |
+
TOTAL_UPDATE_STEPS = 0
|
| 45 |
+
trace = trace_AAt(operator)
|
| 46 |
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), desc="Processing timesteps"):
|
| 47 |
+
# Using GT image noised up if you want to debug anything
|
| 48 |
+
# image = original_image * scheduler.alphas_cumprod[t] ** 0.5 + torch.randn_like(original_image) * (1 - scheduler.alphas_cumprod[t]) ** 0.5
|
| 49 |
+
|
| 50 |
# 1. predict noise model_output
|
| 51 |
model_output = model(image, t.unsqueeze(0).to(device))
|
| 52 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
| 53 |
|
| 54 |
+
# Save image for debugging
|
| 55 |
+
# save_to_image(image, f"intermediates/{t}_xt.png")
|
| 56 |
+
|
| 57 |
# 2. compute previous image: x_t -> x_t-1
|
| 58 |
image = scheduler.step(model_output, t, image).prev_sample
|
| 59 |
image.requires_grad_()
|
| 60 |
alpha_prod_t_prev = scheduler.alphas_cumprod[t-t_skip] if t-t_skip >= 0 else 1
|
| 61 |
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 62 |
+
|
| 63 |
+
k = 0
|
| 64 |
+
while k < K:
|
| 65 |
if t <= 0: break
|
| 66 |
+
a = scheduler.alphas_cumprod[t-t_skip]**0.5 - 1
|
| 67 |
+
# For inpainting, use the number of observed pixels
|
| 68 |
+
num_elements = operator.get_num_observed() if hasattr(operator, 'get_num_observed') else noisy_observation.numel()
|
| 69 |
+
|
| 70 |
+
# mu_{t-delta}(y) e.q. 14
|
| 71 |
+
target_distance = (a**2 * torch.linalg.norm(noisy_observation)**2 + (1 - scheduler.alphas_cumprod[t-t_skip]) * trace).item()
|
| 72 |
+
target_distance += num_elements * noise_function.sigma**2*(1-a**2)
|
| 73 |
+
|
| 74 |
+
# ||Ax_{t-delta} - y||^2
|
| 75 |
+
actual_distance = compute_operator_distance(operator, image, noisy_observation, squared=True).item()
|
| 76 |
+
|
| 77 |
+
# sigma^2_{t-delta}(y) e.q. 15
|
| 78 |
+
variance = estimate_variance(
|
| 79 |
+
operator,
|
| 80 |
+
noisy_observation,
|
| 81 |
+
scheduler.alphas_cumprod[t-t_skip],
|
| 82 |
+
image.shape,
|
| 83 |
+
trace=trace,
|
| 84 |
+
sigma_y=noise_function.sigma,
|
| 85 |
+
n_trace_samples=64,
|
| 86 |
+
n_y_samples=64,
|
| 87 |
+
device=image.device)
|
| 88 |
+
|
| 89 |
+
# c * sigma_{t-delta}(y)
|
| 90 |
+
threshold = stopping_sigma * variance**0.5
|
| 91 |
+
# print(f"Target Distance mean {target_distance} max {target_distance + threshold} actual distance {actual_distance}")
|
| 92 |
+
|
| 93 |
+
# R_{t-delta} is within rho_{t-delta} e.q. 16
|
| 94 |
+
if actual_distance <= target_distance + threshold:
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
|
| 98 |
with torch.enable_grad():
|
| 99 |
# Calculate x^hat_0
|
|
|
|
| 101 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
| 102 |
x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
|
| 103 |
|
| 104 |
+
# Save Tweedie's estimate for debugging
|
| 105 |
+
# save_to_image(x_0, f"intermediates/{t}_x0.png")
|
| 106 |
+
|
| 107 |
+
loss = compute_operator_distance(operator, x_0, noisy_observation, squared=True).mean()
|
| 108 |
+
|
| 109 |
+
# print(f"L2 loss {compute_operator_distance(operator, x_0, noisy_observation, squared=False)}")
|
| 110 |
+
data.append((t.item(), compute_operator_distance(operator, image, noisy_observation, squared=False).item()))
|
| 111 |
+
loss.backward()
|
| 112 |
+
|
| 113 |
+
initial_step_size = initial_guess_step_size(t.item(), torch.linalg.norm(image.grad)) # eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
# Set debug=True to see detailed step size search information
|
| 116 |
+
step_size = calculate_best_step_size(image, noisy_observation, operator, image.grad, target_distance, threshold, initial_step_size, debug=False)
|
| 117 |
+
|
| 118 |
+
# print(f"Step Size {step_size:.6e} initial guess {initial_step_size:.6e}")
|
| 119 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
image -= step_size * image.grad
|
| 121 |
+
new_distance = compute_operator_distance(operator, image, noisy_observation, squared=True).item()
|
| 122 |
+
# print(f"New distance {new_distance}")
|
| 123 |
image = image.detach().requires_grad_()
|
| 124 |
+
TOTAL_UPDATE_STEPS += 1
|
| 125 |
+
|
| 126 |
+
if step_size <= 1e-12: break
|
| 127 |
+
|
| 128 |
+
k += 1
|
| 129 |
+
|
| 130 |
+
# Check here because threshold is stochastic and can change from iteration to iteration
|
| 131 |
+
if new_distance <= target_distance + threshold:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
# print("Step", t.item())
|
| 135 |
+
# Use num_elements for proper normalization with inpainting
|
| 136 |
+
num_elements = operator.get_num_observed() if hasattr(operator, 'get_num_observed') else noisy_observation.numel()
|
| 137 |
+
# print("Distance", 1 / num_elements * compute_operator_distance(operator, image, noisy_observation, squared=True).item())
|
| 138 |
+
|
| 139 |
+
# Print MAE if you want to track contraint error
|
| 140 |
+
if hasattr(operator, 'select'):
|
| 141 |
+
# Compute MAE over observed pixels only
|
| 142 |
+
Ax = operator.select(image).flatten()
|
| 143 |
+
y_selected = operator.select(noisy_observation).flatten()
|
| 144 |
+
# print("MAE", (torch.abs(Ax - y_selected).mean().item()))
|
| 145 |
+
else:
|
| 146 |
+
pass
|
| 147 |
+
# print("MAE", (torch.abs(operator(image) - noisy_observation).mean().item()))
|
| 148 |
|
| 149 |
+
print(f"Total Denoising {len(scheduler.timesteps)}")
|
| 150 |
+
print(f"Total Projection Steps {TOTAL_UPDATE_STEPS}")
|
| 151 |
+
print(f"Total NFEs {TOTAL_UPDATE_STEPS + len(scheduler.timesteps)}")
|
| 152 |
return image
|
cdim/eta_scheduler.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
|
| 3 |
-
class EtaScheduler:
|
| 4 |
-
def __init__(self, method, task, T, K, loss_type,
|
| 5 |
-
noise_function, lambda_val=None):
|
| 6 |
-
self.task = task
|
| 7 |
-
self.T = T
|
| 8 |
-
self.K = K
|
| 9 |
-
self.loss_type = loss_type
|
| 10 |
-
self.lambda_val = lambda_val
|
| 11 |
-
self.method = method
|
| 12 |
-
|
| 13 |
-
self.precomputed_etas = self._load_precomputed_etas()
|
| 14 |
-
|
| 15 |
-
# Couldn't find expected gradnorm
|
| 16 |
-
if not self.precomputed_etas and method == "expected_gradnorm":
|
| 17 |
-
self.method = "gradnorm"
|
| 18 |
-
print("Etas for this configuration not found. Switching to gradnorm.")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# Precomputed gradients are only for gaussian noise
|
| 22 |
-
if noise_function.name != "gaussian" and method == "expected_gradnorm":
|
| 23 |
-
self.method = "gradnorm"
|
| 24 |
-
print("Precomputed gradients are only for gaussian noise. Switching to gradnorm.")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# Get the best lambda_val if it's not passed
|
| 28 |
-
if self.lambda_val is None:
|
| 29 |
-
if self.method == "expected_gradnorm":
|
| 30 |
-
self.lambda_val = self.precomputed_etas["lambda"]
|
| 31 |
-
else:
|
| 32 |
-
self.lambda_val = self.best_guess_lambda()
|
| 33 |
-
print(f"Using lambda {self.lambda_val}")
|
| 34 |
-
|
| 35 |
-
def _load_precomputed_etas(self):
|
| 36 |
-
steps_key = f"T{self.T}_K{self.K}"
|
| 37 |
-
with open("cdim/etas.json") as f:
|
| 38 |
-
all_etas = json.load(f)
|
| 39 |
-
|
| 40 |
-
return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
|
| 41 |
-
|
| 42 |
-
def get_step_size(self, t, grad_norm):
|
| 43 |
-
"""Use either precomputed expected gradnorm or gradnorm."""
|
| 44 |
-
if self.method == "expected_gradnorm":
|
| 45 |
-
step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
|
| 46 |
-
else:
|
| 47 |
-
step_size = self.lambda_val * 1 / grad_norm
|
| 48 |
-
return step_size
|
| 49 |
-
|
| 50 |
-
def best_guess_lambda(self):
|
| 51 |
-
"""Guess a lambda value if not provided. Based on trial and error"""
|
| 52 |
-
total_steps = self.T * self.K
|
| 53 |
-
|
| 54 |
-
# L2 tends to over optimize too aggressively, so the default lr is lower
|
| 55 |
-
if self.loss_type == "kl":
|
| 56 |
-
return 350 / total_steps
|
| 57 |
-
elif self.loss_type == "l2":
|
| 58 |
-
return 220 / total_steps
|
| 59 |
-
else:
|
| 60 |
-
raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cdim/eta_utils.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class EtaScheduler:
|
| 5 |
+
def __init__(self, method, task, T, K, loss_type,
|
| 6 |
+
noise_function, lambda_val=None):
|
| 7 |
+
self.task = task
|
| 8 |
+
self.T = T
|
| 9 |
+
self.K = K
|
| 10 |
+
self.loss_type = loss_type
|
| 11 |
+
self.lambda_val = lambda_val
|
| 12 |
+
self.method = method
|
| 13 |
+
|
| 14 |
+
self.precomputed_etas = self._load_precomputed_etas()
|
| 15 |
+
|
| 16 |
+
# Couldn't find expected gradnorm
|
| 17 |
+
if not self.precomputed_etas and method == "expected_gradnorm":
|
| 18 |
+
self.method = "gradnorm"
|
| 19 |
+
print("Etas for this configuration not found. Switching to gradnorm.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Precomputed gradients are only for gaussian noise
|
| 23 |
+
if noise_function.name != "gaussian" and method == "expected_gradnorm":
|
| 24 |
+
self.method = "gradnorm"
|
| 25 |
+
print("Precomputed gradients are only for gaussian noise. Switching to gradnorm.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Get the best lambda_val if it's not passed
|
| 29 |
+
if self.lambda_val is None:
|
| 30 |
+
if self.method == "expected_gradnorm":
|
| 31 |
+
self.lambda_val = self.precomputed_etas["lambda"]
|
| 32 |
+
else:
|
| 33 |
+
self.lambda_val = self.best_guess_lambda()
|
| 34 |
+
print(f"Using lambda {self.lambda_val}")
|
| 35 |
+
|
| 36 |
+
def _load_precomputed_etas(self):
|
| 37 |
+
steps_key = f"T{self.T}_K{self.K}"
|
| 38 |
+
with open("cdim/etas.json") as f:
|
| 39 |
+
all_etas = json.load(f)
|
| 40 |
+
|
| 41 |
+
return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
|
| 42 |
+
|
| 43 |
+
def get_step_size(self, t, grad_norm):
|
| 44 |
+
"""Use either precomputed expected gradnorm or gradnorm."""
|
| 45 |
+
if self.method == "expected_gradnorm":
|
| 46 |
+
step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
|
| 47 |
+
else:
|
| 48 |
+
step_size = self.lambda_val * 1 / grad_norm
|
| 49 |
+
return step_size
|
| 50 |
+
|
| 51 |
+
def best_guess_lambda(self):
|
| 52 |
+
"""Guess a lambda value if not provided. Based on trial and error"""
|
| 53 |
+
total_steps = self.T * self.K
|
| 54 |
+
|
| 55 |
+
# L2 tends to over optimize too aggressively, so the default lr is lower
|
| 56 |
+
if self.loss_type == "kl":
|
| 57 |
+
return 350 / total_steps
|
| 58 |
+
elif self.loss_type == "l2":
|
| 59 |
+
return 220 / total_steps
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
|
| 62 |
+
|
| 63 |
+
def initial_guess_step_size(T, grad_norm):
|
| 64 |
+
best_guess_lambda = 220 / T
|
| 65 |
+
return best_guess_lambda / grad_norm
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# def calculate_best_step_size(image, y, operator, gradient, target_distance, initial_guess,
|
| 69 |
+
# max_iters=20, tol=1e-4, bracket_factor=1.4):
|
| 70 |
+
# def compute_distance(eta):
|
| 71 |
+
# x_new = image - eta * gradient
|
| 72 |
+
# diff = operator(x_new) - y
|
| 73 |
+
# return torch.linalg.norm(diff)**2
|
| 74 |
+
|
| 75 |
+
# def objective(eta):
|
| 76 |
+
# return torch.abs(compute_distance(eta) - target_distance)
|
| 77 |
+
|
| 78 |
+
# # Try to bracket the root
|
| 79 |
+
# eta_low = initial_guess / bracket_factor
|
| 80 |
+
# eta_high = initial_guess * bracket_factor
|
| 81 |
+
|
| 82 |
+
# for _ in range(10):
|
| 83 |
+
# import pdb
|
| 84 |
+
# pdb.set_trace()
|
| 85 |
+
# dist_low = compute_distance(eta_low)
|
| 86 |
+
# dist_high = compute_distance(eta_high)
|
| 87 |
+
# if (dist_low - target_distance) * (dist_high - target_distance) < 0:
|
| 88 |
+
# break
|
| 89 |
+
# eta_low /= bracket_factor
|
| 90 |
+
# eta_high *= bracket_factor
|
| 91 |
+
# else:
|
| 92 |
+
# # Fallback: brute-force line search over eta to minimize distance
|
| 93 |
+
# best_eta = None
|
| 94 |
+
# best_val = float('inf')
|
| 95 |
+
# for eta in torch.linspace(0, initial_guess * 5, steps=100, device=image.device):
|
| 96 |
+
# val = objective(eta)
|
| 97 |
+
# # print(f"ETA {eta} distance {compute_distance(eta)}")
|
| 98 |
+
# if val < best_val:
|
| 99 |
+
# best_val = val
|
| 100 |
+
# best_eta = eta
|
| 101 |
+
# return best_eta.item()
|
| 102 |
+
|
| 103 |
+
# # Binary search
|
| 104 |
+
# for _ in range(max_iters):
|
| 105 |
+
# eta_mid = (eta_low + eta_high) / 2
|
| 106 |
+
# dist_mid = compute_distance(eta_mid)
|
| 107 |
+
# error = dist_mid - target_distance
|
| 108 |
+
|
| 109 |
+
# if abs(error) < tol:
|
| 110 |
+
# return eta_mid
|
| 111 |
+
|
| 112 |
+
# if (compute_distance(eta_low) - target_distance) * error < 0:
|
| 113 |
+
# eta_high = eta_mid
|
| 114 |
+
# else:
|
| 115 |
+
# eta_low = eta_mid
|
| 116 |
+
|
| 117 |
+
# return eta_mid
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
import torch
|
| 121 |
+
from cdim.image_utils import compute_operator_distance
|
| 122 |
+
|
| 123 |
+
def calculate_best_step_size(
|
| 124 |
+
image: torch.Tensor,
|
| 125 |
+
y: torch.Tensor,
|
| 126 |
+
operator,
|
| 127 |
+
gradient: torch.Tensor,
|
| 128 |
+
target_distance: float,
|
| 129 |
+
threshold: float,
|
| 130 |
+
initial_guess: float,
|
| 131 |
+
*,
|
| 132 |
+
tol: float = 1e-4,
|
| 133 |
+
max_iters: int = 50,
|
| 134 |
+
debug: bool = False,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Find the smallest η ≥ 0 that makes ||A(x − η g) − y||² ≈ target_distance + threshold.
|
| 138 |
+
|
| 139 |
+
Uses a robust grid search followed by golden section search for fine-grained optimization.
|
| 140 |
+
|
| 141 |
+
Note: For inpainting operators with 'select' method, distances are computed
|
| 142 |
+
over only the observed pixels.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
debug: If True, prints detailed search information
|
| 146 |
+
"""
|
| 147 |
+
target_boundary = target_distance + threshold
|
| 148 |
+
|
| 149 |
+
def distance(η: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
return compute_operator_distance(operator, image - η * gradient, y, squared=True)
|
| 151 |
+
|
| 152 |
+
def error(η):
|
| 153 |
+
return distance(η) - target_boundary
|
| 154 |
+
|
| 155 |
+
# Phase 1: Coarse grid search to find promising regions
|
| 156 |
+
# Search from very small to larger step sizes
|
| 157 |
+
# Allow step sizes larger than 1 if needed
|
| 158 |
+
max_eta = initial_guess * 200.0 if initial_guess > 0 else 10.0
|
| 159 |
+
|
| 160 |
+
# Create a logarithmically-spaced grid for better coverage of small values
|
| 161 |
+
# This ensures we search finely near 0 and coarser at larger values
|
| 162 |
+
# Start from 1e-12 to handle cases with very large gradients
|
| 163 |
+
n_coarse = 100 # Increased for better resolution
|
| 164 |
+
eta_grid = torch.cat([
|
| 165 |
+
torch.tensor([0.0]),
|
| 166 |
+
torch.logspace(-12, torch.log10(torch.tensor(max_eta)), n_coarse - 1)
|
| 167 |
+
]).to(image.device)
|
| 168 |
+
|
| 169 |
+
if debug:
|
| 170 |
+
print(f"[Step Size] Searching from {eta_grid[1]:.2e} to {eta_grid[-1]:.2e} ({len(eta_grid)} points)")
|
| 171 |
+
print(f"[Step Size] Sample grid points: {[f'{x:.2e}' for x in eta_grid[1:11].tolist()]}")
|
| 172 |
+
|
| 173 |
+
# Evaluate distances at all grid points
|
| 174 |
+
distances = torch.tensor([distance(eta).item() for eta in eta_grid])
|
| 175 |
+
errors = distances - target_boundary
|
| 176 |
+
|
| 177 |
+
dist_at_zero = distances[0].item()
|
| 178 |
+
|
| 179 |
+
# Strategy: Find the SMALLEST eta that gets us AT OR BELOW target_boundary
|
| 180 |
+
# Only consider non-zero etas
|
| 181 |
+
below_target_mask = distances[1:] <= target_boundary
|
| 182 |
+
|
| 183 |
+
if below_target_mask.any():
|
| 184 |
+
# Found etas that reach target - pick the SMALLEST one (most conservative)
|
| 185 |
+
below_indices = torch.where(below_target_mask)[0] + 1 # +1 because we excluded index 0
|
| 186 |
+
best_idx = below_indices[0].item() # Smallest eta that reaches target
|
| 187 |
+
best_eta = eta_grid[best_idx].item()
|
| 188 |
+
best_distance = distances[best_idx].item()
|
| 189 |
+
|
| 190 |
+
if debug:
|
| 191 |
+
print(f"[Step Size] Coarse grid: found eta={best_eta:.2e} that reaches target")
|
| 192 |
+
print(f"[Step Size] Distance: {best_distance:.2f} (target: {target_boundary:.2f}, under by {target_boundary - best_distance:.2f})")
|
| 193 |
+
else:
|
| 194 |
+
# No eta reaches target - find the one that gets closest (minimize distance to target)
|
| 195 |
+
non_zero_distances = distances[1:]
|
| 196 |
+
closest_idx = torch.argmin(torch.abs(non_zero_distances - target_boundary)) + 1
|
| 197 |
+
best_idx = closest_idx
|
| 198 |
+
best_eta = eta_grid[best_idx].item()
|
| 199 |
+
best_distance = distances[best_idx].item()
|
| 200 |
+
|
| 201 |
+
if debug:
|
| 202 |
+
print(f"[Step Size] Coarse grid: cannot reach target, best eta={best_eta:.2e}")
|
| 203 |
+
print(f"[Step Size] Distance: {best_distance:.2f} (target: {target_boundary:.2f}, over by {best_distance - target_boundary:.2f})")
|
| 204 |
+
|
| 205 |
+
# Check if eta=0 is better (already at target)
|
| 206 |
+
if dist_at_zero <= target_boundary:
|
| 207 |
+
if debug:
|
| 208 |
+
print(f"[Step Size] Distance at eta=0: {dist_at_zero:.2f} - already at/below target")
|
| 209 |
+
return 0.0
|
| 210 |
+
|
| 211 |
+
if debug:
|
| 212 |
+
print(f"[Step Size] Distance at eta=0: {dist_at_zero:.2f} (need to step)")
|
| 213 |
+
|
| 214 |
+
# Phase 1.5: Fine search around the best point found
|
| 215 |
+
# If best_eta is not at the boundaries, do a fine search around it
|
| 216 |
+
if best_idx > 0 and best_idx < len(eta_grid) - 1:
|
| 217 |
+
eta_low_bound = eta_grid[best_idx - 1].item()
|
| 218 |
+
eta_high_bound = eta_grid[best_idx + 1].item()
|
| 219 |
+
|
| 220 |
+
# Create a very fine linear grid between the neighboring points
|
| 221 |
+
fine_grid = torch.linspace(eta_low_bound, eta_high_bound, 50).to(image.device)
|
| 222 |
+
fine_distances = torch.tensor([distance(eta).item() for eta in fine_grid])
|
| 223 |
+
|
| 224 |
+
# Find the SMALLEST eta in fine grid that gets us AT OR BELOW target
|
| 225 |
+
fine_below_mask = fine_distances <= target_boundary
|
| 226 |
+
|
| 227 |
+
if fine_below_mask.any():
|
| 228 |
+
# Found fine etas that reach target - pick the SMALLEST
|
| 229 |
+
fine_below_indices = torch.where(fine_below_mask)[0]
|
| 230 |
+
fine_best_idx = fine_below_indices[0].item()
|
| 231 |
+
fine_best_eta = fine_grid[fine_best_idx].item()
|
| 232 |
+
fine_best_distance = fine_distances[fine_best_idx].item()
|
| 233 |
+
|
| 234 |
+
# Only update if this is better (smaller eta that still reaches target, or gets closer)
|
| 235 |
+
if fine_best_distance <= target_boundary and (best_distance > target_boundary or fine_best_eta < best_eta):
|
| 236 |
+
best_eta = fine_best_eta
|
| 237 |
+
best_distance = fine_best_distance
|
| 238 |
+
best_idx = len(eta_grid) + fine_best_idx
|
| 239 |
+
|
| 240 |
+
if debug:
|
| 241 |
+
print(f"[Step Size] Fine grid: improved to eta={best_eta:.2e}, distance={best_distance:.2f} (under by {target_boundary - best_distance:.2f})")
|
| 242 |
+
else:
|
| 243 |
+
# No fine eta reaches target - find closest
|
| 244 |
+
fine_best_idx = torch.argmin(torch.abs(fine_distances - target_boundary))
|
| 245 |
+
fine_best_eta = fine_grid[fine_best_idx].item()
|
| 246 |
+
fine_best_distance = fine_distances[fine_best_idx].item()
|
| 247 |
+
|
| 248 |
+
# Only update if closer to target than current best
|
| 249 |
+
if abs(fine_best_distance - target_boundary) < abs(best_distance - target_boundary):
|
| 250 |
+
best_eta = fine_best_eta
|
| 251 |
+
best_distance = fine_best_distance
|
| 252 |
+
best_idx = len(eta_grid) + fine_best_idx
|
| 253 |
+
|
| 254 |
+
if debug:
|
| 255 |
+
print(f"[Step Size] Fine grid: improved to eta={best_eta:.2e}, distance={best_distance:.2f} (over by {best_distance - target_boundary:.2f})")
|
| 256 |
+
|
| 257 |
+
# Always update the grid for potential bracketing
|
| 258 |
+
distances = torch.cat([distances, fine_distances])
|
| 259 |
+
errors = torch.cat([errors, fine_distances - target_boundary])
|
| 260 |
+
eta_grid = torch.cat([eta_grid, fine_grid])
|
| 261 |
+
|
| 262 |
+
# If best_eta is 0 and we're already at or below target, return 0
|
| 263 |
+
if best_eta == 0.0 and dist_at_zero <= target_boundary:
|
| 264 |
+
if debug:
|
| 265 |
+
print(f"[Step Size] Already at target, no step needed")
|
| 266 |
+
return 0.0
|
| 267 |
+
|
| 268 |
+
# If we've reached target, we can return (no need for golden section)
|
| 269 |
+
if best_distance <= target_boundary:
|
| 270 |
+
if debug:
|
| 271 |
+
print(f"[Step Size] Reached target with eta={best_eta:.2e}, returning")
|
| 272 |
+
return best_eta
|
| 273 |
+
|
| 274 |
+
# Phase 2: Check for bracketing around the best point
|
| 275 |
+
# Look for a sign change (crossing the target boundary)
|
| 276 |
+
bracket_found = False
|
| 277 |
+
eta_lo, eta_hi = None, None
|
| 278 |
+
|
| 279 |
+
# Check neighbors of best point
|
| 280 |
+
for i in range(len(eta_grid) - 1):
|
| 281 |
+
if errors[i] * errors[i + 1] < 0: # Sign change
|
| 282 |
+
eta_lo, eta_hi = eta_grid[i].item(), eta_grid[i + 1].item()
|
| 283 |
+
bracket_found = True
|
| 284 |
+
break
|
| 285 |
+
|
| 286 |
+
# Phase 3: Refine using golden section search
|
| 287 |
+
# Only refine if we haven't reached target yet and have a valid bracket
|
| 288 |
+
if best_eta > 0 and best_distance > target_boundary and best_idx > 0 and best_idx < len(eta_grid) - 1:
|
| 289 |
+
# Golden section search to find the smallest eta that reaches target_boundary
|
| 290 |
+
phi = (1 + 5**0.5) / 2 # Golden ratio
|
| 291 |
+
resphi = 2 - phi
|
| 292 |
+
|
| 293 |
+
# Search in a small window around best_eta
|
| 294 |
+
a = eta_grid[max(0, best_idx - 1)].item()
|
| 295 |
+
b = eta_grid[min(len(eta_grid) - 1, best_idx + 1)].item()
|
| 296 |
+
|
| 297 |
+
# Make sure we have a valid interval
|
| 298 |
+
if b - a < 1e-20:
|
| 299 |
+
if debug:
|
| 300 |
+
print(f"[Step Size] Interval too small for refinement, returning eta={best_eta:.2e}")
|
| 301 |
+
return best_eta
|
| 302 |
+
|
| 303 |
+
dist_a = distance(torch.tensor(a)).item()
|
| 304 |
+
dist_b = distance(torch.tensor(b)).item()
|
| 305 |
+
|
| 306 |
+
for _ in range(max_iters):
|
| 307 |
+
if abs(b - a) < 1e-20: # Extremely tight tolerance
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
# Golden section points
|
| 311 |
+
x1 = a + resphi * (b - a)
|
| 312 |
+
x2 = b - resphi * (b - a)
|
| 313 |
+
|
| 314 |
+
dist_x1 = distance(torch.tensor(x1)).item()
|
| 315 |
+
dist_x2 = distance(torch.tensor(x2)).item()
|
| 316 |
+
|
| 317 |
+
# Priority: prefer points that reach target (dist <= target_boundary)
|
| 318 |
+
# Among those, prefer smaller eta
|
| 319 |
+
# If neither reaches, prefer closer to target
|
| 320 |
+
|
| 321 |
+
x1_reaches = dist_x1 <= target_boundary
|
| 322 |
+
x2_reaches = dist_x2 <= target_boundary
|
| 323 |
+
|
| 324 |
+
if x1_reaches and not x2_reaches:
|
| 325 |
+
# x1 reaches target, x2 doesn't -> prefer x1's half
|
| 326 |
+
b = x2
|
| 327 |
+
dist_b = dist_x2
|
| 328 |
+
if x1 < best_eta or not (best_distance <= target_boundary):
|
| 329 |
+
best_eta = x1
|
| 330 |
+
best_distance = dist_x1
|
| 331 |
+
elif x2_reaches and not x1_reaches:
|
| 332 |
+
# x2 reaches target, x1 doesn't -> prefer x2's half
|
| 333 |
+
a = x1
|
| 334 |
+
dist_a = dist_x1
|
| 335 |
+
if x2 < best_eta or not (best_distance <= target_boundary):
|
| 336 |
+
best_eta = x2
|
| 337 |
+
best_distance = dist_x2
|
| 338 |
+
elif x1_reaches and x2_reaches:
|
| 339 |
+
# Both reach target -> prefer smaller eta (which is x1)
|
| 340 |
+
b = x2
|
| 341 |
+
dist_b = dist_x2
|
| 342 |
+
best_eta = x1
|
| 343 |
+
best_distance = dist_x1
|
| 344 |
+
else:
|
| 345 |
+
# Neither reaches target -> prefer closer to target
|
| 346 |
+
if abs(dist_x1 - target_boundary) < abs(dist_x2 - target_boundary):
|
| 347 |
+
b = x2
|
| 348 |
+
dist_b = dist_x2
|
| 349 |
+
if abs(dist_x1 - target_boundary) < abs(best_distance - target_boundary):
|
| 350 |
+
best_eta = x1
|
| 351 |
+
best_distance = dist_x1
|
| 352 |
+
else:
|
| 353 |
+
a = x1
|
| 354 |
+
dist_a = dist_x1
|
| 355 |
+
if abs(dist_x2 - target_boundary) < abs(best_distance - target_boundary):
|
| 356 |
+
best_eta = x2
|
| 357 |
+
best_distance = dist_x2
|
| 358 |
+
|
| 359 |
+
if debug:
|
| 360 |
+
if best_distance <= target_boundary:
|
| 361 |
+
print(f"[Step Size] Final: eta={best_eta:.2e}, distance={best_distance:.2f} (under by {target_boundary - best_distance:.2f})")
|
| 362 |
+
else:
|
| 363 |
+
print(f"[Step Size] Final: eta={best_eta:.2e}, distance={best_distance:.2f} (over by {best_distance - target_boundary:.2f})")
|
| 364 |
+
else:
|
| 365 |
+
if debug:
|
| 366 |
+
print(f"[Step Size] No refinement needed, returning best: eta={best_eta:.2e}")
|
| 367 |
+
|
| 368 |
+
return best_eta
|
| 369 |
+
|
cdim/image_utils.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
from torchvision.transforms import ToPILImage
|
|
|
|
| 5 |
|
| 6 |
def save_to_image(tensor, filename):
|
| 7 |
"""
|
|
@@ -66,3 +68,209 @@ def randn_tensor(
|
|
| 66 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
| 67 |
|
| 68 |
return latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
from torchvision.transforms import ToPILImage
|
| 6 |
+
from typing import Callable
|
| 7 |
|
| 8 |
def save_to_image(tensor, filename):
|
| 9 |
"""
|
|
|
|
| 68 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
| 69 |
|
| 70 |
return latents
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def estimate_variance(
|
| 75 |
+
operator: Callable[[Tensor], Tensor],
|
| 76 |
+
y: Tensor, # Ax_0 + noise (shape (m,))
|
| 77 |
+
alphabar_t: float,
|
| 78 |
+
in_shape: tuple[int, ...], # e.g. (1, 3, 256, 256)
|
| 79 |
+
trace: float, # tr(AA^T)
|
| 80 |
+
sigma_y: float,
|
| 81 |
+
n_trace_samples: int = 64,
|
| 82 |
+
n_y_samples: int = 64,
|
| 83 |
+
device: torch.device | str = "cuda",
|
| 84 |
+
dtype: torch.dtype = torch.float32,
|
| 85 |
+
) -> float:
|
| 86 |
+
"""
|
| 87 |
+
Monte-Carlo estimator of Var(||A x_t – y||^2) without access to A^T.
|
| 88 |
+
|
| 89 |
+
Note: For inpainting operators with a 'select' method, this computes variance
|
| 90 |
+
over only the observed pixels.
|
| 91 |
+
"""
|
| 92 |
+
use_select = hasattr(operator, 'select')
|
| 93 |
+
|
| 94 |
+
# For inpainting, select only observed pixels from y
|
| 95 |
+
if use_select:
|
| 96 |
+
y_selected = operator.select(y).flatten()
|
| 97 |
+
y = y_selected.to(device=device, dtype=dtype)
|
| 98 |
+
else:
|
| 99 |
+
y = y.to(device=device, dtype=dtype).flatten()
|
| 100 |
+
m = y.numel()
|
| 101 |
+
|
| 102 |
+
# ---------------- tr((AA^T)^2)
|
| 103 |
+
t1_acc = torch.zeros((), device=device, dtype=dtype)
|
| 104 |
+
for _ in range(n_trace_samples):
|
| 105 |
+
v = torch.randn(in_shape, device=device, dtype=dtype)
|
| 106 |
+
w = torch.randn(in_shape, device=device, dtype=dtype)
|
| 107 |
+
if use_select:
|
| 108 |
+
Av = operator.select(v).flatten()
|
| 109 |
+
Aw = operator.select(w).flatten()
|
| 110 |
+
else:
|
| 111 |
+
Av = operator(v).flatten()
|
| 112 |
+
Aw = operator(w).flatten()
|
| 113 |
+
s = torch.dot(Av, Aw)
|
| 114 |
+
t1_acc += s * s
|
| 115 |
+
T1 = t1_acc / n_trace_samples
|
| 116 |
+
|
| 117 |
+
# ---------------- y^T AA^T y
|
| 118 |
+
t2_acc = torch.zeros((), device=device, dtype=dtype)
|
| 119 |
+
for _ in range(n_y_samples):
|
| 120 |
+
v = torch.randn(in_shape, device=device, dtype=dtype)
|
| 121 |
+
if use_select:
|
| 122 |
+
Av = operator.select(v).flatten()
|
| 123 |
+
else:
|
| 124 |
+
Av = operator(v).flatten()
|
| 125 |
+
s = torch.dot(y, Av)
|
| 126 |
+
t2_acc += s * s
|
| 127 |
+
T2 = t2_acc / n_y_samples
|
| 128 |
+
|
| 129 |
+
# ---------------- assemble variance
|
| 130 |
+
alpha_bar = torch.as_tensor(alphabar_t, dtype=dtype, device=device)
|
| 131 |
+
sigma2 = torch.as_tensor(sigma_y**2, dtype=dtype, device=device)
|
| 132 |
+
|
| 133 |
+
a2 = (torch.sqrt(alpha_bar) - 1.0).pow(2) # (1-√ᾱ)^2
|
| 134 |
+
b = 1.0 - alpha_bar # (1-ᾱ)
|
| 135 |
+
|
| 136 |
+
var = 2 * (b*b * T1 + 2 * b * sigma2 * trace + m * sigma2.pow(2)) \
|
| 137 |
+
+ 4 * a2 * (b * (T2 - sigma2 * trace) + sigma2 * (y.pow(2).sum() - m * sigma2))
|
| 138 |
+
return var.item()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def compute_operator_distance(
|
| 143 |
+
operator: Callable[[Tensor], Tensor],
|
| 144 |
+
x: Tensor,
|
| 145 |
+
y: Tensor,
|
| 146 |
+
squared: bool = True
|
| 147 |
+
) -> Tensor:
|
| 148 |
+
"""
|
| 149 |
+
Compute ||Ax - y||^2 (or ||Ax - y|| if squared=False).
|
| 150 |
+
|
| 151 |
+
For inpainting operators with a 'select' method, this computes the distance
|
| 152 |
+
over only the observed pixels. Otherwise uses the standard operator call.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
operator: The forward operator A
|
| 156 |
+
x: Input tensor (e.g., image)
|
| 157 |
+
y: Measurement tensor (for inpainting, this should be the full masked measurement)
|
| 158 |
+
squared: If True, returns squared L2 norm. If False, returns L2 norm.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Scalar tensor representing the distance
|
| 162 |
+
"""
|
| 163 |
+
if hasattr(operator, 'select'):
|
| 164 |
+
# Use select method for inpainting operators
|
| 165 |
+
# Both x and y need to be selected to extract only observed pixels
|
| 166 |
+
Ax = operator.select(x).flatten()
|
| 167 |
+
y_selected = operator.select(y).flatten()
|
| 168 |
+
else:
|
| 169 |
+
# Standard operator application
|
| 170 |
+
Ax = operator(x).flatten()
|
| 171 |
+
y_selected = y.flatten()
|
| 172 |
+
|
| 173 |
+
diff = Ax - y_selected
|
| 174 |
+
if squared:
|
| 175 |
+
return (diff ** 2).sum()
|
| 176 |
+
else:
|
| 177 |
+
return torch.sqrt((diff ** 2).sum())
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def trace_AAt(
|
| 181 |
+
operator: Callable[[torch.Tensor], torch.Tensor],
|
| 182 |
+
input_shape = (1, 3, 256, 256),
|
| 183 |
+
num_samples: int = 256,
|
| 184 |
+
device: str = "cuda" # or "cpu"
|
| 185 |
+
) -> float:
|
| 186 |
+
"""
|
| 187 |
+
Unbiased Monte-Carlo estimate of tr(A Aᵀ) for a black-box linear operator.
|
| 188 |
+
|
| 189 |
+
operator : function that maps a (1,C,H,W) tensor → down-sampled tensor
|
| 190 |
+
input_shape : shape expected by the operator
|
| 191 |
+
num_samples : more samples → lower variance (error ≈ O(1/√num_samples))
|
| 192 |
+
|
| 193 |
+
Note: For inpainting operators with a 'select' method, this computes the trace
|
| 194 |
+
over only the observed pixels, not the full tensor with zeros.
|
| 195 |
+
"""
|
| 196 |
+
total = 0.0
|
| 197 |
+
use_select = hasattr(operator, 'select')
|
| 198 |
+
|
| 199 |
+
for _ in range(num_samples):
|
| 200 |
+
# Rademacher noise (±1). Use torch.randn for Gaussian instead.
|
| 201 |
+
z = torch.empty(input_shape, device=device).bernoulli_().mul_(2).sub_(1)
|
| 202 |
+
if use_select:
|
| 203 |
+
Az = operator.select(z).flatten() # only observed pixels
|
| 204 |
+
else:
|
| 205 |
+
Az = operator(z).flatten() # output can have any shape
|
| 206 |
+
total += torch.dot(Az, Az).item() # ||Az||²
|
| 207 |
+
return total / num_samples
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# def trace_AAt_squared(
|
| 211 |
+
# operator: Callable[[torch.Tensor], torch.Tensor],
|
| 212 |
+
# input_shape: tuple = (1, 3, 256, 256),
|
| 213 |
+
# num_samples: int = 32,
|
| 214 |
+
# device: str = "cuda") -> float:
|
| 215 |
+
# """
|
| 216 |
+
# Estimates tr((A Aᵀ)^2) using Hutchinson's method and autograd for Aᵀ.
|
| 217 |
+
# """
|
| 218 |
+
# total = 0.0
|
| 219 |
+
# for _ in range(num_samples):
|
| 220 |
+
# # Sample z ~ N(0, I) (same shape as operator's *output*)
|
| 221 |
+
# z = torch.randn(operator(torch.zeros(input_shape, device=device)).shape, device=device)
|
| 222 |
+
|
| 223 |
+
# # Compute Aᵀz via gradient: ∇_w [⟨operator(w), z⟩] = Aᵀz
|
| 224 |
+
# w = torch.randn(input_shape, device=device, requires_grad=True)
|
| 225 |
+
# Az = operator(w).flatten()
|
| 226 |
+
# loss = torch.dot(Az, z.flatten()) # ⟨Az, z⟩ = ⟨w, Aᵀz⟩
|
| 227 |
+
# A_adj_z = torch.autograd.grad(loss, w, retain_graph=False)[0]
|
| 228 |
+
|
| 229 |
+
# # Compute AAᵀz = operator(Aᵀz)
|
| 230 |
+
# AA_adj_z = operator(A_adj_z.detach()).flatten()
|
| 231 |
+
# total += torch.dot(AA_adj_z, AA_adj_z).item() # ||AAᵀz||²
|
| 232 |
+
# return total / num_samples
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# def compute_yAAy(
|
| 236 |
+
# operator: Callable[[torch.Tensor], torch.Tensor],
|
| 237 |
+
# y: torch.Tensor,
|
| 238 |
+
# input_shape: tuple = (1, 3, 256, 256),
|
| 239 |
+
# device: str = "cuda") -> float:
|
| 240 |
+
# """
|
| 241 |
+
# Computes yᵀ (A Aᵀ) y using autograd to get Aᵀy.
|
| 242 |
+
# """
|
| 243 |
+
# # Compute Aᵀy via gradient: ∇_w [⟨operator(w), y⟩] = Aᵀy
|
| 244 |
+
# w = torch.randn(input_shape, device=device, requires_grad=True)
|
| 245 |
+
# Az = operator(w).flatten()
|
| 246 |
+
# loss = torch.dot(Az, y.flatten())
|
| 247 |
+
# A_adj_y = torch.autograd.grad(loss, w, retain_graph=False)[0]
|
| 248 |
+
|
| 249 |
+
# # Compute A Aᵀ y = operator(Aᵀy)
|
| 250 |
+
# AA_adj_y = operator(A_adj_y.detach()).flatten()
|
| 251 |
+
# return torch.dot(AA_adj_y, y.flatten()).item()
|
| 252 |
+
|
| 253 |
+
# def variance_Axt_minus_y_sq(
|
| 254 |
+
# operator: Callable[[torch.Tensor], torch.Tensor],
|
| 255 |
+
# y: torch.Tensor,
|
| 256 |
+
# alphabar_t: float,
|
| 257 |
+
# input_shape: tuple = (1, 3, 256, 256),
|
| 258 |
+
# num_samples_trace: int = 32,
|
| 259 |
+
# device: str = "cuda"
|
| 260 |
+
# ) -> float:
|
| 261 |
+
# """
|
| 262 |
+
# Computes Var(||A𝐱ₜ - y||²) = 2(1-ᾱₜ)² tr((AAᵀ)²) + 4(1-ᾱₜ)(√ᾱₜ -1)² yᵀAAᵀy.
|
| 263 |
+
# """
|
| 264 |
+
# # Term 1: 2(1-ᾱₜ)^2 * tr((AAᵀ)^2)
|
| 265 |
+
# tr_AAt_sq = trace_AAt_squared(operator, input_shape, num_samples_trace, device)
|
| 266 |
+
# term1 = 2 * (1 - alphabar_t)**2 * tr_AAt_sq
|
| 267 |
+
|
| 268 |
+
# # Term 2: 4(1-ᾱₜ)(√ᾱₜ -1)^2 * yᵀAAᵀy
|
| 269 |
+
# yAAy = compute_yAAy(operator, y, input_shape, device)
|
| 270 |
+
# term2 = 4 * (1 - alphabar_t) * (torch.sqrt(torch.tensor(alphabar_t)) - 1)**2 * yAAy
|
| 271 |
+
|
| 272 |
+
# return term1 + term2
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
cdim/operators/gaussian_blur_operator.py
CHANGED
|
@@ -4,7 +4,7 @@ from cdim.operators.blur_kernel import BlurKernel
|
|
| 4 |
|
| 5 |
@register_operator(name='gaussian_blur')
|
| 6 |
class GaussianBlurOperator:
|
| 7 |
-
def __init__(self, kernel_size, intensity, device='
|
| 8 |
self.device = device
|
| 9 |
self.kernel_size = kernel_size
|
| 10 |
self.conv = BlurKernel(blur_type='gaussian',
|
|
|
|
| 4 |
|
| 5 |
@register_operator(name='gaussian_blur')
|
| 6 |
class GaussianBlurOperator:
|
| 7 |
+
def __init__(self, kernel_size, intensity, device='cuda'):
|
| 8 |
self.device = device
|
| 9 |
self.kernel_size = kernel_size
|
| 10 |
self.conv = BlurKernel(blur_type='gaussian',
|
cdim/operators/random_box_masker.py
CHANGED
|
@@ -54,3 +54,31 @@ class RandomBoxMasker:
|
|
| 54 |
|
| 55 |
# Apply the mask to the input tensor
|
| 56 |
return tensor * self.mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Apply the mask to the input tensor
|
| 56 |
return tensor * self.mask
|
| 57 |
+
|
| 58 |
+
def select(self, tensor):
|
| 59 |
+
"""
|
| 60 |
+
Extract only the observed pixels from the tensor (pixels outside the box).
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
torch.Tensor: Flattened tensor containing only observed pixels (b, num_observed)
|
| 67 |
+
"""
|
| 68 |
+
b, c, h, w = tensor.shape
|
| 69 |
+
assert c == self.channels and h == self.height and w == self.width, \
|
| 70 |
+
f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
|
| 71 |
+
|
| 72 |
+
# Move the mask to the same device as the input tensor if necessary
|
| 73 |
+
if tensor.device != self.mask.device:
|
| 74 |
+
self.mask = self.mask.to(tensor.device)
|
| 75 |
+
|
| 76 |
+
# Extract only observed pixels (where mask is 1, outside the box)
|
| 77 |
+
observed = (tensor * self.mask).flatten(1) # (b, c*h*w)
|
| 78 |
+
# Keep only non-zero elements
|
| 79 |
+
mask_flat = self.mask.flatten(1) # (1, c*h*w)
|
| 80 |
+
return observed[:, mask_flat[0] > 0] # (b, num_observed)
|
| 81 |
+
|
| 82 |
+
def get_num_observed(self):
|
| 83 |
+
"""Return the number of observed elements."""
|
| 84 |
+
return int(self.mask.sum().item())
|
cdim/operators/random_pixel_masker.py
CHANGED
|
@@ -56,3 +56,32 @@ class RandomPixelMasker:
|
|
| 56 |
|
| 57 |
# Apply the mask to the input tensor
|
| 58 |
return tensor * self.mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Apply the mask to the input tensor
|
| 58 |
return tensor * self.mask
|
| 59 |
+
|
| 60 |
+
def select(self, tensor):
|
| 61 |
+
"""
|
| 62 |
+
Extract only the observed pixels from the tensor.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.Tensor: Flattened tensor containing only observed pixels (b, num_observed)
|
| 69 |
+
"""
|
| 70 |
+
b, c, h, w = tensor.shape
|
| 71 |
+
assert c == self.channels and h == self.height and w == self.width, \
|
| 72 |
+
f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
|
| 73 |
+
|
| 74 |
+
# Move the mask to the same device as the input tensor if necessary
|
| 75 |
+
if tensor.device != self.mask.device:
|
| 76 |
+
self.mask = self.mask.to(tensor.device)
|
| 77 |
+
|
| 78 |
+
# Extract only observed pixels (where mask is 1)
|
| 79 |
+
# mask is (1, c, h, w), we want to select pixels across all channels
|
| 80 |
+
observed = (tensor * self.mask).flatten(1) # (b, c*h*w)
|
| 81 |
+
# Keep only non-zero elements
|
| 82 |
+
mask_flat = self.mask.flatten(1) # (1, c*h*w)
|
| 83 |
+
return observed[:, mask_flat[0] > 0] # (b, num_observed)
|
| 84 |
+
|
| 85 |
+
def get_num_observed(self):
|
| 86 |
+
"""Return the number of observed elements."""
|
| 87 |
+
return int(self.mask.sum().item())
|
inference.py
CHANGED
|
@@ -2,6 +2,7 @@ import argparse
|
|
| 2 |
import os
|
| 3 |
import yaml
|
| 4 |
import time
|
|
|
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
import numpy as np
|
|
@@ -15,7 +16,6 @@ from cdim.image_utils import save_to_image
|
|
| 15 |
from cdim.dps_model.dps_unet import create_model
|
| 16 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
| 17 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
| 18 |
-
from cdim.eta_scheduler import EtaScheduler
|
| 19 |
|
| 20 |
|
| 21 |
def load_image(path):
|
|
@@ -36,13 +36,39 @@ def load_yaml(file_path: str) -> dict:
|
|
| 36 |
return config
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def main(args):
|
| 40 |
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
|
| 41 |
print(f"Using device {device_str}")
|
| 42 |
device = torch.device(device_str)
|
| 43 |
|
| 44 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 45 |
-
original_image = load_image(args.input_image).to(device)
|
| 46 |
|
| 47 |
# Load the noise function
|
| 48 |
noise_config = load_yaml(args.noise_config)
|
|
@@ -78,43 +104,52 @@ def main(args):
|
|
| 78 |
steps_offset=0,
|
| 79 |
)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
|
| 99 |
|
| 100 |
if __name__ == '__main__':
|
| 101 |
parser = argparse.ArgumentParser()
|
| 102 |
-
parser.add_argument("
|
| 103 |
parser.add_argument("T", type=int)
|
| 104 |
-
parser.add_argument("K", type=int)
|
| 105 |
parser.add_argument("operator_config", type=str)
|
| 106 |
parser.add_argument("noise_config", type=str)
|
| 107 |
parser.add_argument("model_config", type=str)
|
| 108 |
-
parser.add_argument("--
|
| 109 |
-
choices=['gradnorm', 'expected_gradnorm'],
|
| 110 |
-
default='expected_gradnorm')
|
| 111 |
parser.add_argument("--lambda-val", type=float,
|
| 112 |
default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.")
|
| 113 |
-
parser.add_argument("--output-dir", default="
|
| 114 |
-
parser.add_argument("--loss", type=str,
|
| 115 |
-
choices=['l2', 'kl', 'categorical_kl'], default='l2',
|
| 116 |
-
help="Algorithm to use. Options: 'l2', 'kl', 'categorical_kl'. Default is 'l2'."
|
| 117 |
-
)
|
| 118 |
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
main(parser.parse_args())
|
|
|
|
| 2 |
import os
|
| 3 |
import yaml
|
| 4 |
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
| 7 |
from PIL import Image
|
| 8 |
import numpy as np
|
|
|
|
| 16 |
from cdim.dps_model.dps_unet import create_model
|
| 17 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
| 18 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def load_image(path):
|
|
|
|
| 36 |
return config
|
| 37 |
|
| 38 |
|
| 39 |
+
def process_image(image_path, output_dir, model, ddim_scheduler, operator, noise_function,
|
| 40 |
+
device, args, model_type):
|
| 41 |
+
"""
|
| 42 |
+
Process a single image with the given model and parameters
|
| 43 |
+
"""
|
| 44 |
+
original_image = load_image(image_path).to(device)
|
| 45 |
+
|
| 46 |
+
# Get the base filename without extension
|
| 47 |
+
base_name = Path(image_path).stem
|
| 48 |
+
|
| 49 |
+
noisy_measurement = noise_function(operator(original_image))
|
| 50 |
+
save_to_image(noisy_measurement, os.path.join(output_dir, f"{base_name}_noisy_measurement.png"))
|
| 51 |
+
|
| 52 |
+
t0 = time.time()
|
| 53 |
+
output_image = run_diffusion(
|
| 54 |
+
model, ddim_scheduler,
|
| 55 |
+
noisy_measurement, operator, noise_function, device,
|
| 56 |
+
args.stopping_sigma,
|
| 57 |
+
num_inference_steps=args.T,
|
| 58 |
+
K=args.K,
|
| 59 |
+
model_type=model_type,
|
| 60 |
+
original_image=original_image)
|
| 61 |
+
print(f"Processing time for {base_name}: {time.time() - t0:.2f}s")
|
| 62 |
+
|
| 63 |
+
save_to_image(output_image, os.path.join(output_dir, f"{base_name}_output.png"))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
def main(args):
|
| 67 |
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
|
| 68 |
print(f"Using device {device_str}")
|
| 69 |
device = torch.device(device_str)
|
| 70 |
|
| 71 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
| 72 |
|
| 73 |
# Load the noise function
|
| 74 |
noise_config = load_yaml(args.noise_config)
|
|
|
|
| 104 |
steps_offset=0,
|
| 105 |
)
|
| 106 |
|
| 107 |
+
# Process input (either a single image or all images in a directory)
|
| 108 |
+
input_path = Path(args.input)
|
| 109 |
+
|
| 110 |
+
if input_path.is_file():
|
| 111 |
+
# Process a single image
|
| 112 |
+
print(f"Processing single image: {input_path.name}")
|
| 113 |
+
process_image(
|
| 114 |
+
str(input_path), args.output_dir, model, ddim_scheduler,
|
| 115 |
+
operator, noise_function, device, args, model_type
|
| 116 |
+
)
|
| 117 |
+
elif input_path.is_dir():
|
| 118 |
+
# Process all images in the directory
|
| 119 |
+
image_files = [
|
| 120 |
+
f for f in input_path.iterdir()
|
| 121 |
+
if not f.name.startswith('.') and f.suffix.lower() in ['.png', '.jpg', '.jpeg']
|
| 122 |
+
]
|
| 123 |
+
image_files = sorted(image_files)
|
| 124 |
+
|
| 125 |
+
print(f"Found {len(image_files)} images to process")
|
| 126 |
+
|
| 127 |
+
for image_file in image_files:
|
| 128 |
+
print(f"Processing {image_file.name}...")
|
| 129 |
+
# Optional, recreate operator (uncomment to use same operator)
|
| 130 |
+
operator = get_operator(**operator_config)
|
| 131 |
+
process_image(
|
| 132 |
+
str(image_file), args.output_dir, model, ddim_scheduler,
|
| 133 |
+
operator, noise_function, device, args, model_type
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(f"Input path '{input_path}' is neither a file nor a directory")
|
| 137 |
|
|
|
|
| 138 |
|
| 139 |
if __name__ == '__main__':
|
| 140 |
parser = argparse.ArgumentParser()
|
| 141 |
+
parser.add_argument("input", type=str, help="Path to input image or folder containing input images")
|
| 142 |
parser.add_argument("T", type=int)
|
|
|
|
| 143 |
parser.add_argument("operator_config", type=str)
|
| 144 |
parser.add_argument("noise_config", type=str)
|
| 145 |
parser.add_argument("model_config", type=str)
|
| 146 |
+
parser.add_argument("--stopping-sigma", type=float, default=0.1, help="How many std deviations away to stop")
|
|
|
|
|
|
|
| 147 |
parser.add_argument("--lambda-val", type=float,
|
| 148 |
default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.")
|
| 149 |
+
parser.add_argument("--output-dir", default="output", type=str)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
|
| 151 |
+
parser.add_argument("--K", type=int, default=20,
|
| 152 |
+
help="Cap the number of steps K at any iteration. Helps avoid edge cases or cap NFEs.")
|
| 153 |
+
|
| 154 |
|
| 155 |
main(parser.parse_args())
|
requirements.txt
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
numpy
|
| 5 |
Pillow
|
| 6 |
-
PyYAML
|
| 7 |
-
scipy
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
torchvision
|
| 11 |
-
tqdm
|
|
|
|
| 1 |
+
diffusers==0.30.3
|
| 2 |
+
gradio==5.3.0
|
| 3 |
+
numpy==2.1.2
|
|
|
|
| 4 |
Pillow
|
| 5 |
+
PyYAML==6.0.2
|
| 6 |
+
scipy==1.14.1
|
| 7 |
+
tqdm==4.66.5
|
| 8 |
+
accelerate
|
|
|
|
|
|