Spaces:
Running
Running
from typing import Tuple | |
import numpy as np | |
import random | |
import torch | |
from numpy.typing import NDArray | |
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt | |
from basicsr.data.transforms import paired_random_crop | |
from basicsr.utils import DiffJPEG, USMSharp | |
from basicsr.utils.img_process_util import filter2D | |
from torch import Tensor | |
from torch.nn import functional as F | |
def blur(img: Tensor, kernel: NDArray) -> Tensor: | |
return filter2D(img, kernel) | |
def random_resize( | |
img: Tensor, | |
resize_prob: float, | |
resize_range: Tuple[int, int], | |
output_scale: float = 1 | |
) -> Tensor: | |
updown_type = random.choices(['up', 'down', 'keep'], resize_prob)[0] | |
if updown_type == 'up': | |
random_scale = np.random.uniform(1, resize_range[1]) | |
elif updown_type == 'down': | |
random_scale = np.random.uniform(resize_range[0], 1) | |
else: | |
random_scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(img, scale_factor=output_scale * random_scale, mode=mode) | |
return out | |
def add_noise( | |
img: Tensor, | |
gray_noise_prob: float, | |
gaussian_noise_prob: float, | |
noise_range: Tuple[float, float], | |
poisson_scale_range: Tuple[float, float] | |
) -> Tensor: | |
if np.random.uniform() < gaussian_noise_prob: | |
img = random_add_gaussian_noise_pt( | |
img, sigma_range=noise_range, clip=True, rounds=False, | |
gray_prob=gray_noise_prob) | |
else: | |
img = random_add_poisson_noise_pt( | |
img, scale_range=poisson_scale_range, | |
gray_prob=gray_noise_prob, clip=True, rounds=False) | |
return img | |
def jpeg_compression_simulation( | |
img: Tensor, | |
jpeg_range: Tuple[float, float], | |
jpeg_simulator: DiffJPEG | |
) -> Tensor: | |
jpeg_p = img.new_zeros(img.size(0)).uniform_(*jpeg_range) | |
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts | |
img = torch.clamp(img, 0, 1) | |
return jpeg_simulator(img, quality=jpeg_p) | |
def apply_real_esrgan_degradations( | |
gt: Tensor, | |
blur_kernel1: NDArray, | |
blur_kernel2: NDArray, | |
second_blur_prob: float, | |
sinc_kernel: NDArray, | |
resize_prob1: float, | |
resize_prob2: float, | |
resize_range1: Tuple[int, int], | |
resize_range2: Tuple[int, int], | |
gray_noise_prob1: float, | |
gray_noise_prob2: float, | |
gaussian_noise_prob1: float, | |
gaussian_noise_prob2: float, | |
noise_range: Tuple[float, float], | |
poisson_scale_range: Tuple[float, float], | |
jpeg_compression_range1: Tuple[float, float], | |
jpeg_compression_range2: Tuple[float, float], | |
jpeg_simulator: DiffJPEG, | |
random_crop_gt_size: 512, | |
sr_upsample_scale: float, | |
usm_sharpener: USMSharp | |
): | |
""" | |
Accept batch from batchloader, and then add two-order degradations | |
to obtain LQ images. | |
gt: Tensor of shape (B x C x H x W) | |
""" | |
gt_usm = usm_sharpener(gt) | |
# from PIL import Image | |
# Image.fromarray((gt_usm[0].permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)).save( | |
# "/home/cll/Desktop/GT_USM_orig.png") | |
orig_h, orig_w = gt.size()[2:4] | |
# ----------------------- The first degradation process ----------------------- # | |
out = blur(gt_usm, blur_kernel1) | |
out = random_resize(out, resize_prob1, resize_range1) | |
out = add_noise(out, gray_noise_prob1, gaussian_noise_prob1, noise_range, poisson_scale_range) | |
out = jpeg_compression_simulation(out, jpeg_compression_range1, jpeg_simulator) | |
# ----------------------- The second degradation process ----------------------- # | |
if np.random.uniform() < second_blur_prob: | |
out = blur(out, blur_kernel2) | |
out = random_resize(out, resize_prob2, resize_range2, output_scale=(1/sr_upsample_scale)) | |
out = add_noise(out, gray_noise_prob2, gaussian_noise_prob2, | |
noise_range, poisson_scale_range) | |
# JPEG compression + the final sinc filter | |
# We also need to resize images to desired sizes. | |
# We group [resize back + sinc filter] together | |
# as one operation. | |
# We consider two orders: | |
# 1. [resize back + sinc filter] + JPEG compression | |
# 2. JPEG compression + [resize back + sinc filter] | |
# Empirically, we find other combinations (sinc + JPEG + Resize) | |
# will introduce twisted lines. | |
if np.random.uniform() < 0.5: | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(orig_h // sr_upsample_scale, | |
orig_w // sr_upsample_scale), mode=mode) | |
out = blur(out, sinc_kernel) | |
out = jpeg_compression_simulation(out, jpeg_compression_range2, jpeg_simulator) | |
else: | |
out = jpeg_compression_simulation(out, jpeg_compression_range2, jpeg_simulator) | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(orig_h // sr_upsample_scale, | |
orig_w // sr_upsample_scale), mode=mode) | |
out = blur(out, sinc_kernel) | |
# clamp and round | |
lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. | |
(gt, gt_usm), lq = paired_random_crop([gt, gt_usm], lq, random_crop_gt_size, sr_upsample_scale) | |
return gt, gt_usm, lq | |