'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.''' from abc import ABC, abstractmethod from functools import partial import yaml from torch.nn import functional as F from torchvision import torch from util.resizer import Resizer from util.img_utils import Blurkernel, fft2_m # ================= # Operation classes # ================= __OPERATOR__ = {} def register_operator(name: str): def wrapper(cls): if __OPERATOR__.get(name, None): raise NameError(f"Name {name} is already registered!") __OPERATOR__[name] = cls return cls return wrapper def get_operator(name: str, **kwargs): if __OPERATOR__.get(name, None) is None: raise NameError(f"Name {name} is not defined.") return __OPERATOR__[name](**kwargs) class LinearOperator(ABC): @abstractmethod def forward(self, data, **kwargs): # calculate A * X pass @abstractmethod def transpose(self, data, **kwargs): # calculate A^T * X pass def ortho_project(self, data, **kwargs): # calculate (I - A^T * A)X return data - self.transpose(self.forward(data, **kwargs), **kwargs) def project(self, data, measurement, **kwargs): # calculate (I - A^T * A)Y - AX return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs) @register_operator(name='noise') class DenoiseOperator(LinearOperator): def __init__(self, device): self.device = device def forward(self, data): return data def transpose(self, data): return data def ortho_project(self, data): return data def project(self, data): return data @register_operator(name='super_resolution') class SuperResolutionOperator(LinearOperator): def __init__(self, in_shape, scale_factor, device): self.device = device self.up_sample = partial(F.interpolate, scale_factor=scale_factor) self.down_sample = Resizer(in_shape, 1/scale_factor).to(device) def forward(self, data, **kwargs): return self.down_sample(data) def transpose(self, data, **kwargs): return self.up_sample(data) def project(self, data, measurement, **kwargs): return data - self.transpose(self.forward(data)) + self.transpose(measurement) @register_operator(name='motion_blur') class MotionBlurOperator(LinearOperator): def __init__(self, kernel_size, intensity, device): self.device = device self.kernel_size = kernel_size self.conv = Blurkernel(blur_type='motion', kernel_size=kernel_size, std=intensity, device=device).to(device) # should we keep this device term? self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity) kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) self.conv.update_weights(kernel) def forward(self, data, **kwargs): # A^T * A return self.conv(data) def transpose(self, data, **kwargs): return data def get_kernel(self): kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device) return kernel.view(1, 1, self.kernel_size, self.kernel_size) @register_operator(name='colorization') class ColorizationOperator(LinearOperator): def __init__(self, device): self.device = device def forward(self, data, **kwargs): return (1/3) * torch.sum(data, dim=1, keepdim=True) def transpose(self, data, **kwargs): return data @register_operator(name='gaussian_blur') class GaussialBlurOperator(LinearOperator): def __init__(self, kernel_size, intensity, device): self.device = device self.kernel_size = kernel_size self.conv = Blurkernel(blur_type='gaussian', kernel_size=kernel_size, std=intensity, device=device).to(device) self.kernel = self.conv.get_kernel() self.conv.update_weights(self.kernel.type(torch.float32)) def forward(self, data, **kwargs): return self.conv(data) def transpose(self, data, **kwargs): return data def get_kernel(self): return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) def project(self, data, measurement, **kwargs): # calculate (I - A^T * A)Y - AX return data - self.forward(data, **kwargs) + measurement @register_operator(name='inpainting') class InpaintingOperator(LinearOperator): '''This operator get pre-defined mask and return masked image.''' def __init__(self, device): self.device = device def set_mask(self, mask): self.mask = mask def forward(self, data, **kwargs): try: return data * self.mask.to(self.device) except: raise ValueError("Require mask") def transpose(self, data, **kwargs): return data def ortho_project(self, data, **kwargs): return data - self.forward(data, **kwargs) def project(self, data, measurement, **kwargs): return data - self.forward(data, **kwargs) + measurement class NonLinearOperator(ABC): @abstractmethod def forward(self, data, **kwargs): pass def project(self, data, measurement, **kwargs): return data + measurement - self.forward(data) @register_operator(name='phase_retrieval') class PhaseRetrievalOperator(NonLinearOperator): def __init__(self, oversample, device): self.pad = int((oversample / 8.0) * 256) self.device = device def forward(self, data, **kwargs): padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad)) amplitude = fft2_m(padded).abs() return amplitude @register_operator(name='nonlinear_blur') class NonlinearBlurOperator(NonLinearOperator): def __init__(self, opt_yml_path, device): self.device = device self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path) def prepare_nonlinear_blur_model(self, opt_yml_path): ''' Nonlinear deblur requires external codes (bkse). ''' from bkse.models.kernel_encoding.kernel_wizard import KernelWizard with open(opt_yml_path, "r") as f: opt = yaml.safe_load(f)["KernelWizard"] model_path = opt["pretrained"] blur_model = KernelWizard(opt) blur_model.eval() blur_model.load_state_dict(torch.load(model_path)) blur_model = blur_model.to(self.device) return blur_model def forward(self, data, **kwargs): random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2 data = (data + 1.0) / 2.0 #[-1, 1] -> [0, 1] blurred = self.blur_model.adaptKernel(data, kernel=random_kernel) blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1] return blurred # ============= # Noise classes # ============= __NOISE__ = {} def register_noise(name: str): def wrapper(cls): if __NOISE__.get(name, None): raise NameError(f"Name {name} is already defined!") __NOISE__[name] = cls return cls return wrapper def get_noise(name: str, **kwargs): if __NOISE__.get(name, None) is None: raise NameError(f"Name {name} is not defined.") noiser = __NOISE__[name](**kwargs) noiser.__name__ = name return noiser class Noise(ABC): def __call__(self, data): return self.forward(data) @abstractmethod def forward(self, data): pass @register_noise(name='clean') class Clean(Noise): def forward(self, data): return data @register_noise(name='gaussian') class GaussianNoise(Noise): def __init__(self, sigma): self.sigma = sigma def forward(self, data): return data + torch.randn_like(data, device=data.device) * self.sigma * 2 @register_noise(name='poisson') class PoissonNoise(Noise): def __init__(self, rate): self.rate = rate def forward(self, data): ''' Follow skimage.util.random_noise. ''' # TODO: set one version of poisson # version 3 (stack-overflow) import numpy as np data = (data + 1.0) / 2.0 data = data.clamp(0, 1) device = data.device data = data.detach().cpu() data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate) data = data * 2.0 - 1.0 data = data.clamp(-1, 1) return data.to(device) # version 2 (skimage) # if data.min() < 0: # low_clip = -1 # else: # low_clip = 0 # # Determine unique values in iamge & calculate the next power of two # vals = torch.Tensor([len(torch.unique(data))]) # vals = 2 ** torch.ceil(torch.log2(vals)) # vals = vals.to(data.device) # if low_clip == -1: # old_max = data.max() # data = (data + 1.0) / (old_max + 1.0) # data = torch.poisson(data * vals) / float(vals) # if low_clip == -1: # data = data * (old_max + 1.0) - 1.0 # return data.clamp(low_clip, 1.0)