Spaces:
Running
on
Zero
Running
on
Zero
'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.''' | |
from abc import ABC, abstractmethod | |
from functools import partial | |
import yaml | |
from torch.nn import functional as F | |
from torchvision import torch | |
from util.resizer import Resizer | |
from util.img_utils import Blurkernel, fft2_m | |
# ================= | |
# Operation classes | |
# ================= | |
__OPERATOR__ = {} | |
def register_operator(name: str): | |
def wrapper(cls): | |
if __OPERATOR__.get(name, None): | |
raise NameError(f"Name {name} is already registered!") | |
__OPERATOR__[name] = cls | |
return cls | |
return wrapper | |
def get_operator(name: str, **kwargs): | |
if __OPERATOR__.get(name, None) is None: | |
raise NameError(f"Name {name} is not defined.") | |
return __OPERATOR__[name](**kwargs) | |
class LinearOperator(ABC): | |
def forward(self, data, **kwargs): | |
# calculate A * X | |
pass | |
def transpose(self, data, **kwargs): | |
# calculate A^T * X | |
pass | |
def ortho_project(self, data, **kwargs): | |
# calculate (I - A^T * A)X | |
return data - self.transpose(self.forward(data, **kwargs), **kwargs) | |
def project(self, data, measurement, **kwargs): | |
# calculate (I - A^T * A)Y - AX | |
return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs) | |
class DenoiseOperator(LinearOperator): | |
def __init__(self, device): | |
self.device = device | |
def forward(self, data): | |
return data | |
def transpose(self, data): | |
return data | |
def ortho_project(self, data): | |
return data | |
def project(self, data): | |
return data | |
class SuperResolutionOperator(LinearOperator): | |
def __init__(self, in_shape, scale_factor, device): | |
self.device = device | |
self.up_sample = partial(F.interpolate, scale_factor=scale_factor) | |
self.down_sample = Resizer(in_shape, 1/scale_factor).to(device) | |
def forward(self, data, **kwargs): | |
return self.down_sample(data) | |
def transpose(self, data, **kwargs): | |
return self.up_sample(data) | |
def project(self, data, measurement, **kwargs): | |
return data - self.transpose(self.forward(data)) + self.transpose(measurement) | |
class MotionBlurOperator(LinearOperator): | |
def __init__(self, kernel_size, intensity, device): | |
self.device = device | |
self.kernel_size = kernel_size | |
self.conv = Blurkernel(blur_type='motion', | |
kernel_size=kernel_size, | |
std=intensity, | |
device=device).to(device) # should we keep this device term? | |
self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity) | |
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) | |
self.conv.update_weights(kernel) | |
def forward(self, data, **kwargs): | |
# A^T * A | |
return self.conv(data) | |
def transpose(self, data, **kwargs): | |
return data | |
def get_kernel(self): | |
kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device) | |
return kernel.view(1, 1, self.kernel_size, self.kernel_size) | |
class ColorizationOperator(LinearOperator): | |
def __init__(self, device): | |
self.device = device | |
def forward(self, data, **kwargs): | |
return (1/3) * torch.sum(data, dim=1, keepdim=True) | |
def transpose(self, data, **kwargs): | |
return data | |
class GaussialBlurOperator(LinearOperator): | |
def __init__(self, kernel_size, intensity, device): | |
self.device = device | |
self.kernel_size = kernel_size | |
self.conv = Blurkernel(blur_type='gaussian', | |
kernel_size=kernel_size, | |
std=intensity, | |
device=device).to(device) | |
self.kernel = self.conv.get_kernel() | |
self.conv.update_weights(self.kernel.type(torch.float32)) | |
def forward(self, data, **kwargs): | |
return self.conv(data) | |
def transpose(self, data, **kwargs): | |
return data | |
def get_kernel(self): | |
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) | |
def project(self, data, measurement, **kwargs): | |
# calculate (I - A^T * A)Y - AX | |
return data - self.forward(data, **kwargs) + measurement | |
class InpaintingOperator(LinearOperator): | |
'''This operator get pre-defined mask and return masked image.''' | |
def __init__(self, device): | |
self.device = device | |
def set_mask(self, mask): | |
self.mask = mask | |
def forward(self, data, **kwargs): | |
try: | |
return data * self.mask.to(self.device) | |
except: | |
raise ValueError("Require mask") | |
def transpose(self, data, **kwargs): | |
return data | |
def ortho_project(self, data, **kwargs): | |
return data - self.forward(data, **kwargs) | |
def project(self, data, measurement, **kwargs): | |
return data - self.forward(data, **kwargs) + measurement | |
class NonLinearOperator(ABC): | |
def forward(self, data, **kwargs): | |
pass | |
def project(self, data, measurement, **kwargs): | |
return data + measurement - self.forward(data) | |
class PhaseRetrievalOperator(NonLinearOperator): | |
def __init__(self, oversample, device): | |
self.pad = int((oversample / 8.0) * 256) | |
self.device = device | |
def forward(self, data, **kwargs): | |
padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad)) | |
amplitude = fft2_m(padded).abs() | |
return amplitude | |
class NonlinearBlurOperator(NonLinearOperator): | |
def __init__(self, opt_yml_path, device): | |
self.device = device | |
self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path) | |
def prepare_nonlinear_blur_model(self, opt_yml_path): | |
''' | |
Nonlinear deblur requires external codes (bkse). | |
''' | |
from bkse.models.kernel_encoding.kernel_wizard import KernelWizard | |
with open(opt_yml_path, "r") as f: | |
opt = yaml.safe_load(f)["KernelWizard"] | |
model_path = opt["pretrained"] | |
blur_model = KernelWizard(opt) | |
blur_model.eval() | |
blur_model.load_state_dict(torch.load(model_path)) | |
blur_model = blur_model.to(self.device) | |
return blur_model | |
def forward(self, data, **kwargs): | |
random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2 | |
data = (data + 1.0) / 2.0 #[-1, 1] -> [0, 1] | |
blurred = self.blur_model.adaptKernel(data, kernel=random_kernel) | |
blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1] | |
return blurred | |
# ============= | |
# Noise classes | |
# ============= | |
__NOISE__ = {} | |
def register_noise(name: str): | |
def wrapper(cls): | |
if __NOISE__.get(name, None): | |
raise NameError(f"Name {name} is already defined!") | |
__NOISE__[name] = cls | |
return cls | |
return wrapper | |
def get_noise(name: str, **kwargs): | |
if __NOISE__.get(name, None) is None: | |
raise NameError(f"Name {name} is not defined.") | |
noiser = __NOISE__[name](**kwargs) | |
noiser.__name__ = name | |
return noiser | |
class Noise(ABC): | |
def __call__(self, data): | |
return self.forward(data) | |
def forward(self, data): | |
pass | |
class Clean(Noise): | |
def forward(self, data): | |
return data | |
class GaussianNoise(Noise): | |
def __init__(self, sigma): | |
self.sigma = sigma | |
def forward(self, data): | |
return data + torch.randn_like(data, device=data.device) * self.sigma * 2 | |
class PoissonNoise(Noise): | |
def __init__(self, rate): | |
self.rate = rate | |
def forward(self, data): | |
''' | |
Follow skimage.util.random_noise. | |
''' | |
# TODO: set one version of poisson | |
# version 3 (stack-overflow) | |
import numpy as np | |
data = (data + 1.0) / 2.0 | |
data = data.clamp(0, 1) | |
device = data.device | |
data = data.detach().cpu() | |
data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate) | |
data = data * 2.0 - 1.0 | |
data = data.clamp(-1, 1) | |
return data.to(device) | |
# version 2 (skimage) | |
# if data.min() < 0: | |
# low_clip = -1 | |
# else: | |
# low_clip = 0 | |
# # Determine unique values in iamge & calculate the next power of two | |
# vals = torch.Tensor([len(torch.unique(data))]) | |
# vals = 2 ** torch.ceil(torch.log2(vals)) | |
# vals = vals.to(data.device) | |
# if low_clip == -1: | |
# old_max = data.max() | |
# data = (data + 1.0) / (old_max + 1.0) | |
# data = torch.poisson(data * vals) / float(vals) | |
# if low_clip == -1: | |
# data = data * (old_max + 1.0) - 1.0 | |
# return data.clamp(low_clip, 1.0) |