LambdaSuperRes / KAIR /data /degradations.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
5.28 kB
from typing import Tuple
import numpy as np
import random
import torch
from numpy.typing import NDArray
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from torch import Tensor
from torch.nn import functional as F
def blur(img: Tensor, kernel: NDArray) -> Tensor:
return filter2D(img, kernel)
def random_resize(
img: Tensor,
resize_prob: float,
resize_range: Tuple[int, int],
output_scale: float = 1
) -> Tensor:
updown_type = random.choices(['up', 'down', 'keep'], resize_prob)[0]
if updown_type == 'up':
random_scale = np.random.uniform(1, resize_range[1])
elif updown_type == 'down':
random_scale = np.random.uniform(resize_range[0], 1)
else:
random_scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(img, scale_factor=output_scale * random_scale, mode=mode)
return out
def add_noise(
img: Tensor,
gray_noise_prob: float,
gaussian_noise_prob: float,
noise_range: Tuple[float, float],
poisson_scale_range: Tuple[float, float]
) -> Tensor:
if np.random.uniform() < gaussian_noise_prob:
img = random_add_gaussian_noise_pt(
img, sigma_range=noise_range, clip=True, rounds=False,
gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img, scale_range=poisson_scale_range,
gray_prob=gray_noise_prob, clip=True, rounds=False)
return img
def jpeg_compression_simulation(
img: Tensor,
jpeg_range: Tuple[float, float],
jpeg_simulator: DiffJPEG
) -> Tensor:
jpeg_p = img.new_zeros(img.size(0)).uniform_(*jpeg_range)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img = torch.clamp(img, 0, 1)
return jpeg_simulator(img, quality=jpeg_p)
@torch.no_grad()
def apply_real_esrgan_degradations(
gt: Tensor,
blur_kernel1: NDArray,
blur_kernel2: NDArray,
second_blur_prob: float,
sinc_kernel: NDArray,
resize_prob1: float,
resize_prob2: float,
resize_range1: Tuple[int, int],
resize_range2: Tuple[int, int],
gray_noise_prob1: float,
gray_noise_prob2: float,
gaussian_noise_prob1: float,
gaussian_noise_prob2: float,
noise_range: Tuple[float, float],
poisson_scale_range: Tuple[float, float],
jpeg_compression_range1: Tuple[float, float],
jpeg_compression_range2: Tuple[float, float],
jpeg_simulator: DiffJPEG,
random_crop_gt_size: 512,
sr_upsample_scale: float,
usm_sharpener: USMSharp
):
"""
Accept batch from batchloader, and then add two-order degradations
to obtain LQ images.
gt: Tensor of shape (B x C x H x W)
"""
gt_usm = usm_sharpener(gt)
# from PIL import Image
# Image.fromarray((gt_usm[0].permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)).save(
# "/home/cll/Desktop/GT_USM_orig.png")
orig_h, orig_w = gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
out = blur(gt_usm, blur_kernel1)
out = random_resize(out, resize_prob1, resize_range1)
out = add_noise(out, gray_noise_prob1, gaussian_noise_prob1, noise_range, poisson_scale_range)
out = jpeg_compression_simulation(out, jpeg_compression_range1, jpeg_simulator)
# ----------------------- The second degradation process ----------------------- #
if np.random.uniform() < second_blur_prob:
out = blur(out, blur_kernel2)
out = random_resize(out, resize_prob2, resize_range2, output_scale=(1/sr_upsample_scale))
out = add_noise(out, gray_noise_prob2, gaussian_noise_prob2,
noise_range, poisson_scale_range)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes.
# We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize)
# will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(orig_h // sr_upsample_scale,
orig_w // sr_upsample_scale), mode=mode)
out = blur(out, sinc_kernel)
out = jpeg_compression_simulation(out, jpeg_compression_range2, jpeg_simulator)
else:
out = jpeg_compression_simulation(out, jpeg_compression_range2, jpeg_simulator)
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(orig_h // sr_upsample_scale,
orig_w // sr_upsample_scale), mode=mode)
out = blur(out, sinc_kernel)
# clamp and round
lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
(gt, gt_usm), lq = paired_random_crop([gt, gt_usm], lq, random_crop_gt_size, sr_upsample_scale)
return gt, gt_usm, lq