Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC, abstractmethod | |
import torch | |
__CONDITIONING_METHOD__ = {} | |
def register_conditioning_method(name: str): | |
def wrapper(cls): | |
if __CONDITIONING_METHOD__.get(name, None): | |
raise NameError(f"Name {name} is already registered!") | |
__CONDITIONING_METHOD__[name] = cls | |
return cls | |
return wrapper | |
def get_conditioning_method(name: str, operator, noiser, **kwargs): | |
if __CONDITIONING_METHOD__.get(name, None) is None: | |
raise NameError(f"Name {name} is not defined!") | |
return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs) | |
class ConditioningMethod(ABC): | |
def __init__(self, operator, noiser, **kwargs): | |
self.operator = operator | |
self.noiser = noiser | |
def project(self, data, noisy_measurement, **kwargs): | |
return self.operator.project(data=data, measurement=noisy_measurement, **kwargs) | |
def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs): | |
if self.noiser.__name__ == 'gaussian': | |
difference = measurement - self.operator.forward(x_0_hat, **kwargs) | |
norm = torch.linalg.norm(difference) | |
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] | |
elif self.noiser.__name__ == 'poisson': | |
Ax = self.operator.forward(x_0_hat, **kwargs) | |
difference = measurement-Ax | |
norm = torch.linalg.norm(difference) / measurement.abs() | |
norm = norm.mean() | |
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] | |
else: | |
raise NotImplementedError | |
return norm_grad, norm | |
def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs): | |
pass | |
class Identity(ConditioningMethod): | |
# just pass the input without conditioning | |
def conditioning(self, x_t): | |
return x_t | |
class Projection(ConditioningMethod): | |
def conditioning(self, x_t, noisy_measurement, **kwargs): | |
x_t = self.project(data=x_t, noisy_measurement=noisy_measurement) | |
return x_t | |
class ManifoldConstraintGradient(ConditioningMethod): | |
def __init__(self, operator, noiser, **kwargs): | |
super().__init__(operator, noiser) | |
self.scale = kwargs.get('scale', 1.0) | |
def conditioning(self, x_prev, x_t, x_0_hat, measurement, noisy_measurement, **kwargs): | |
# posterior sampling | |
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs) | |
x_t -= norm_grad * self.scale | |
# projection | |
x_t = self.project(data=x_t, noisy_measurement=noisy_measurement, **kwargs) | |
return x_t, norm | |
class PosteriorSampling(ConditioningMethod): | |
def __init__(self, operator, noiser, **kwargs): | |
super().__init__(operator, noiser) | |
self.scale = kwargs.get('scale', 1.0) | |
def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs): | |
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs) | |
x_t -= norm_grad * self.scale | |
return x_t, norm | |
class PosteriorSamplingPlus(ConditioningMethod): | |
def __init__(self, operator, noiser, **kwargs): | |
super().__init__(operator, noiser) | |
self.num_sampling = kwargs.get('num_sampling', 5) | |
self.scale = kwargs.get('scale', 1.0) | |
def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs): | |
norm = 0 | |
for _ in range(self.num_sampling): | |
# TODO: use noiser? | |
x_0_hat_noise = x_0_hat + 0.05 * torch.rand_like(x_0_hat) | |
difference = measurement - self.operator.forward(x_0_hat_noise) | |
norm += torch.linalg.norm(difference) / self.num_sampling | |
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] | |
x_t -= norm_grad * self.scale | |
return x_t, norm | |