| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| import torch | |
| def gauss_kernel(size=5, channels=3): | |
| kernel = torch.tensor( | |
| [ | |
| [1.0, 4.0, 6.0, 4.0, 1], | |
| [4.0, 16.0, 24.0, 16.0, 4.0], | |
| [6.0, 24.0, 36.0, 24.0, 6.0], | |
| [4.0, 16.0, 24.0, 16.0, 4.0], | |
| [1.0, 4.0, 6.0, 4.0, 1.0], | |
| ] | |
| ) | |
| kernel /= 256.0 | |
| kernel = kernel.repeat(channels, 1, 1, 1) | |
| kernel = kernel.to(device) | |
| return kernel | |
| def downsample(x): | |
| return x[:, :, ::2, ::2] | |
| def upsample(x): | |
| cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) | |
| cc = cc.permute(0, 1, 3, 2) | |
| cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) | |
| x_up = cc.permute(0, 1, 3, 2) | |
| return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1])) | |
| def conv_gauss(img, kernel): | |
| img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect") | |
| out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) | |
| return out | |
| def laplacian_pyramid(img, kernel, max_levels=3): | |
| current = img | |
| pyr = [] | |
| for level in range(max_levels): | |
| filtered = conv_gauss(current, kernel) | |
| down = downsample(filtered) | |
| up = upsample(down) | |
| diff = current - up | |
| pyr.append(diff) | |
| current = down | |
| return pyr | |
| class LapLoss(torch.nn.Module): | |
| def __init__(self, max_levels=5, channels=3): | |
| super(LapLoss, self).__init__() | |
| self.max_levels = max_levels | |
| self.gauss_kernel = gauss_kernel(channels=channels) | |
| def forward(self, input, target): | |
| pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) | |
| pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) | |
| return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) | |