Spaces:
Running
on
Zero
Running
on
Zero
Add application file
Browse files- app.py +82 -0
- loss/niqe_pris_params.npz +3 -0
- loss/niqe_utils.py +559 -0
- net/CIDNet.py +130 -0
- net/HVI_transform.py +122 -0
- net/LCA.py +93 -0
- net/transformer_utils.py +72 -0
- weights/LOL-Blur.pth +3 -0
- weights/LOLv1/BestLPIPS_0.0868.pth +3 -0
- weights/LOLv1/BestSSIM_0.8631.pth +3 -0
- weights/LOLv1/PSNR_24.74.pth +3 -0
- weights/LOLv1/w_perc.pth +3 -0
- weights/LOLv1/wo_perc.pth +3 -0
- weights/LOLv2_real/best_PSNR.pth +3 -0
- weights/LOLv2_real/best_SSIM.pth +3 -0
- weights/LOLv2_real/w_perc.pth +3 -0
- weights/LOLv2_syn/w_perc.pth +3 -0
- weights/LOLv2_syn/wo_perc.pth +3 -0
- weights/SICE.pth +3 -0
- weights/SID.pth +3 -0
- weights/generalization.pth +3 -0
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from net.CIDNet import CIDNet
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import os
|
9 |
+
import imquality.brisque as brisque
|
10 |
+
from loss.niqe_utils import *
|
11 |
+
|
12 |
+
eval_net = CIDNet()
|
13 |
+
eval_net.trans.gated = True
|
14 |
+
eval_net.trans.gated2 = True
|
15 |
+
|
16 |
+
def process_image(input_img,score,model_path,gamma,alpha_s=1.0,alpha_i=1.0):
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
|
19 |
+
eval_net.eval()
|
20 |
+
|
21 |
+
pil2tensor = transforms.Compose([transforms.ToTensor()])
|
22 |
+
input = pil2tensor(input_img)
|
23 |
+
factor = 8
|
24 |
+
h, w = input.shape[1], input.shape[2]
|
25 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
26 |
+
padh = H - h if h % factor != 0 else 0
|
27 |
+
padw = W - w if w % factor != 0 else 0
|
28 |
+
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
|
29 |
+
with torch.no_grad():
|
30 |
+
eval_net.trans.alpha_s = alpha_s
|
31 |
+
eval_net.trans.alpha = alpha_i
|
32 |
+
output = eval_net(input**gamma)
|
33 |
+
output = torch.clamp(output,0,1)
|
34 |
+
output = output[:, :, :h, :w]
|
35 |
+
enhanced_img = transforms.ToPILImage()(output.squeeze(0))
|
36 |
+
if score == 'Yes':
|
37 |
+
im1 = enhanced_img.convert('RGB')
|
38 |
+
score_brisque = brisque.score(im1)
|
39 |
+
im1 = np.array(im1)
|
40 |
+
score_niqe = calculate_niqe(im1)
|
41 |
+
return enhanced_img,score_niqe,score_brisque
|
42 |
+
else:
|
43 |
+
return enhanced_img,0,0
|
44 |
+
|
45 |
+
def find_pth_files(directory):
|
46 |
+
pth_files = []
|
47 |
+
for root, dirs, files in os.walk(directory):
|
48 |
+
if 'train' in root.split(os.sep):
|
49 |
+
continue
|
50 |
+
for file in files:
|
51 |
+
if file.endswith('.pth'):
|
52 |
+
pth_files.append(os.path.join(root, file))
|
53 |
+
return pth_files
|
54 |
+
|
55 |
+
def remove_weights_prefix(paths):
|
56 |
+
cleaned_paths = [path.replace('.\\weights\\', '') for path in paths]
|
57 |
+
return cleaned_paths
|
58 |
+
|
59 |
+
directory = ".\weights"
|
60 |
+
pth_files = find_pth_files(directory)
|
61 |
+
pth_files2 = remove_weights_prefix(pth_files)
|
62 |
+
|
63 |
+
interface = gr.Interface(
|
64 |
+
fn=process_image,
|
65 |
+
inputs=[
|
66 |
+
gr.Image(label="Low-light Image", type="pil"),
|
67 |
+
gr.Radio(choices=['Yes','No'],label="Image Score"),
|
68 |
+
gr.Radio(choices=pth_files2,label="Model Path"),
|
69 |
+
gr.Slider(0.1,10,label="gamma curve",step=0.01,value=1.0),
|
70 |
+
gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0),
|
71 |
+
gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0)
|
72 |
+
],
|
73 |
+
outputs=[
|
74 |
+
gr.Image(label="Result", type="pil"),
|
75 |
+
gr.Textbox(label="NIQE"),
|
76 |
+
gr.Textbox(label="BRISQUE")
|
77 |
+
],
|
78 |
+
title="HVI-CIDNet (Low-Light Image Enhancement)",
|
79 |
+
allow_flagging="never"
|
80 |
+
)
|
81 |
+
|
82 |
+
interface.launch(server_port=7862)
|
loss/niqe_pris_params.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
|
3 |
+
size 11850
|
loss/niqe_utils.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
from scipy.ndimage import convolve
|
6 |
+
from scipy.special import gamma
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def cubic(x):
|
10 |
+
"""cubic function used for calculate_weights_indices."""
|
11 |
+
absx = torch.abs(x)
|
12 |
+
absx2 = absx**2
|
13 |
+
absx3 = absx**3
|
14 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
15 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
16 |
+
(absx <= 2)).type_as(absx))
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
21 |
+
"""Calculate weights and indices, used for imresize function.
|
22 |
+
Args:
|
23 |
+
in_length (int): Input length.
|
24 |
+
out_length (int): Output length.
|
25 |
+
scale (float): Scale factor.
|
26 |
+
kernel_width (int): Kernel width.
|
27 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
28 |
+
"""
|
29 |
+
|
30 |
+
if (scale < 1) and antialiasing:
|
31 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
32 |
+
# interpolate and antialias
|
33 |
+
kernel_width = kernel_width / scale
|
34 |
+
|
35 |
+
# Output-space coordinates
|
36 |
+
x = torch.linspace(1, out_length, out_length)
|
37 |
+
|
38 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
39 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
40 |
+
# space maps to 1.5 in input space.
|
41 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
42 |
+
|
43 |
+
# What is the left-most pixel that can be involved in the computation?
|
44 |
+
left = torch.floor(u - kernel_width / 2)
|
45 |
+
|
46 |
+
# What is the maximum number of pixels that can be involved in the
|
47 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
48 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
49 |
+
# of this function.
|
50 |
+
p = math.ceil(kernel_width) + 2
|
51 |
+
|
52 |
+
# The indices of the input pixels involved in computing the k-th output
|
53 |
+
# pixel are in row k of the indices matrix.
|
54 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
55 |
+
out_length, p)
|
56 |
+
|
57 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
58 |
+
# weights matrix.
|
59 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
60 |
+
|
61 |
+
# apply cubic kernel
|
62 |
+
if (scale < 1) and antialiasing:
|
63 |
+
weights = scale * cubic(distance_to_center * scale)
|
64 |
+
else:
|
65 |
+
weights = cubic(distance_to_center)
|
66 |
+
|
67 |
+
# Normalize the weights matrix so that each row sums to 1.
|
68 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
69 |
+
weights = weights / weights_sum.expand(out_length, p)
|
70 |
+
|
71 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
72 |
+
# first and last column.
|
73 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
74 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
75 |
+
indices = indices.narrow(1, 1, p - 2)
|
76 |
+
weights = weights.narrow(1, 1, p - 2)
|
77 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
78 |
+
indices = indices.narrow(1, 0, p - 2)
|
79 |
+
weights = weights.narrow(1, 0, p - 2)
|
80 |
+
weights = weights.contiguous()
|
81 |
+
indices = indices.contiguous()
|
82 |
+
sym_len_s = -indices.min() + 1
|
83 |
+
sym_len_e = indices.max() - in_length
|
84 |
+
indices = indices + sym_len_s - 1
|
85 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
86 |
+
|
87 |
+
def imresize(img, scale, antialiasing=True):
|
88 |
+
"""imresize function same as MATLAB.
|
89 |
+
It now only supports bicubic.
|
90 |
+
The same scale applies for both height and width.
|
91 |
+
Args:
|
92 |
+
img (Tensor | Numpy array):
|
93 |
+
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
94 |
+
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
95 |
+
scale (float): Scale factor. The same scale applies for both height
|
96 |
+
and width.
|
97 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
98 |
+
Default: True.
|
99 |
+
Returns:
|
100 |
+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
101 |
+
"""
|
102 |
+
squeeze_flag = False
|
103 |
+
if type(img).__module__ == np.__name__: # numpy type
|
104 |
+
numpy_type = True
|
105 |
+
if img.ndim == 2:
|
106 |
+
img = img[:, :, None]
|
107 |
+
squeeze_flag = True
|
108 |
+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
109 |
+
else:
|
110 |
+
numpy_type = False
|
111 |
+
if img.ndim == 2:
|
112 |
+
img = img.unsqueeze(0)
|
113 |
+
squeeze_flag = True
|
114 |
+
|
115 |
+
in_c, in_h, in_w = img.size()
|
116 |
+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
117 |
+
kernel_width = 4
|
118 |
+
kernel = 'cubic'
|
119 |
+
|
120 |
+
# get weights and indices
|
121 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
122 |
+
antialiasing)
|
123 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
124 |
+
antialiasing)
|
125 |
+
# process H dimension
|
126 |
+
# symmetric copying
|
127 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
128 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
129 |
+
|
130 |
+
sym_patch = img[:, :sym_len_hs, :]
|
131 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
132 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
133 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
134 |
+
|
135 |
+
sym_patch = img[:, -sym_len_he:, :]
|
136 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
137 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
138 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
139 |
+
|
140 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
141 |
+
kernel_width = weights_h.size(1)
|
142 |
+
for i in range(out_h):
|
143 |
+
idx = int(indices_h[i][0])
|
144 |
+
for j in range(in_c):
|
145 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
146 |
+
|
147 |
+
# process W dimension
|
148 |
+
# symmetric copying
|
149 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
150 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
151 |
+
|
152 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
153 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
154 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
155 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
156 |
+
|
157 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
158 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
159 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
160 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
161 |
+
|
162 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
163 |
+
kernel_width = weights_w.size(1)
|
164 |
+
for i in range(out_w):
|
165 |
+
idx = int(indices_w[i][0])
|
166 |
+
for j in range(in_c):
|
167 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
168 |
+
|
169 |
+
if squeeze_flag:
|
170 |
+
out_2 = out_2.squeeze(0)
|
171 |
+
if numpy_type:
|
172 |
+
out_2 = out_2.numpy()
|
173 |
+
if not squeeze_flag:
|
174 |
+
out_2 = out_2.transpose(1, 2, 0)
|
175 |
+
|
176 |
+
return out_2
|
177 |
+
|
178 |
+
|
179 |
+
def _convert_input_type_range(img):
|
180 |
+
"""Convert the type and range of the input image.
|
181 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
182 |
+
It is mainly used for pre-processing the input image in colorspace
|
183 |
+
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
184 |
+
Args:
|
185 |
+
img (ndarray): The input image. It accepts:
|
186 |
+
1. np.uint8 type with range [0, 255];
|
187 |
+
2. np.float32 type with range [0, 1].
|
188 |
+
Returns:
|
189 |
+
(ndarray): The converted image with type of np.float32 and range of
|
190 |
+
[0, 1].
|
191 |
+
"""
|
192 |
+
img_type = img.dtype
|
193 |
+
img = img.astype(np.float32)
|
194 |
+
if img_type == np.float32:
|
195 |
+
pass
|
196 |
+
elif img_type == np.uint8:
|
197 |
+
img /= 255.
|
198 |
+
else:
|
199 |
+
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
200 |
+
return img
|
201 |
+
|
202 |
+
|
203 |
+
def _convert_output_type_range(img, dst_type):
|
204 |
+
"""Convert the type and range of the image according to dst_type.
|
205 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
206 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
207 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
208 |
+
range [0, 1].
|
209 |
+
It is mainly used for post-processing images in colorspace conversion
|
210 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
211 |
+
Args:
|
212 |
+
img (ndarray): The image to be converted with np.float32 type and
|
213 |
+
range [0, 255].
|
214 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
215 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
216 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
217 |
+
with range [0, 1].
|
218 |
+
Returns:
|
219 |
+
(ndarray): The converted image with desired type and range.
|
220 |
+
"""
|
221 |
+
if dst_type not in (np.uint8, np.float32):
|
222 |
+
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
223 |
+
if dst_type == np.uint8:
|
224 |
+
img = img.round()
|
225 |
+
else:
|
226 |
+
img /= 255.
|
227 |
+
return img.astype(dst_type)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
def rgb2ycbcr(img, y_only=False):
|
232 |
+
"""Convert a RGB image to YCbCr image.
|
233 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
234 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
235 |
+
television. See more details in
|
236 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
237 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
238 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
239 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
240 |
+
Args:
|
241 |
+
img (ndarray): The input image. It accepts:
|
242 |
+
1. np.uint8 type with range [0, 255];
|
243 |
+
2. np.float32 type with range [0, 1].
|
244 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
245 |
+
Returns:
|
246 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
247 |
+
and range as input image.
|
248 |
+
"""
|
249 |
+
img_type = img.dtype
|
250 |
+
img = _convert_input_type_range(img)
|
251 |
+
if y_only:
|
252 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
253 |
+
else:
|
254 |
+
out_img = np.matmul(
|
255 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
256 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
257 |
+
return out_img
|
258 |
+
|
259 |
+
|
260 |
+
def bgr2ycbcr(img, y_only=False):
|
261 |
+
"""Convert a BGR image to YCbCr image.
|
262 |
+
The bgr version of rgb2ycbcr.
|
263 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
264 |
+
television. See more details in
|
265 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
266 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
267 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
268 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
269 |
+
Args:
|
270 |
+
img (ndarray): The input image. It accepts:
|
271 |
+
1. np.uint8 type with range [0, 255];
|
272 |
+
2. np.float32 type with range [0, 1].
|
273 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
274 |
+
Returns:
|
275 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
276 |
+
and range as input image.
|
277 |
+
"""
|
278 |
+
img_type = img.dtype
|
279 |
+
img = _convert_input_type_range(img)
|
280 |
+
if y_only:
|
281 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
282 |
+
else:
|
283 |
+
out_img = np.matmul(
|
284 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
285 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
286 |
+
return out_img
|
287 |
+
|
288 |
+
def ycbcr2rgb(img):
|
289 |
+
"""Convert a YCbCr image to RGB image.
|
290 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
291 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
292 |
+
television. See more details in
|
293 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
294 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
295 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
296 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
297 |
+
Args:
|
298 |
+
img (ndarray): The input image. It accepts:
|
299 |
+
1. np.uint8 type with range [0, 255];
|
300 |
+
2. np.float32 type with range [0, 1].
|
301 |
+
Returns:
|
302 |
+
ndarray: The converted RGB image. The output image has the same type
|
303 |
+
and range as input image.
|
304 |
+
"""
|
305 |
+
img_type = img.dtype
|
306 |
+
img = _convert_input_type_range(img) * 255
|
307 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
308 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
309 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
310 |
+
return out_img
|
311 |
+
|
312 |
+
|
313 |
+
def to_y_channel(img):
|
314 |
+
"""Change to Y channel of YCbCr.
|
315 |
+
Args:
|
316 |
+
img (ndarray): Images with range [0, 255].
|
317 |
+
Returns:
|
318 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
319 |
+
"""
|
320 |
+
img = img.astype(np.float32) / 255.
|
321 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
322 |
+
img = bgr2ycbcr(img, y_only=True)
|
323 |
+
img = img[..., None]
|
324 |
+
return img * 255.
|
325 |
+
|
326 |
+
|
327 |
+
def reorder_image(img, input_order='HWC'):
|
328 |
+
"""Reorder images to 'HWC' order.
|
329 |
+
If the input_order is (h, w), return (h, w, 1);
|
330 |
+
If the input_order is (c, h, w), return (h, w, c);
|
331 |
+
If the input_order is (h, w, c), return as it is.
|
332 |
+
Args:
|
333 |
+
img (ndarray): Input image.
|
334 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
335 |
+
If the input image shape is (h, w), input_order will not have
|
336 |
+
effects. Default: 'HWC'.
|
337 |
+
Returns:
|
338 |
+
ndarray: reordered image.
|
339 |
+
"""
|
340 |
+
|
341 |
+
if input_order not in ['HWC', 'CHW']:
|
342 |
+
raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
|
343 |
+
if len(img.shape) == 2:
|
344 |
+
img = img[..., None]
|
345 |
+
if input_order == 'CHW':
|
346 |
+
img = img.transpose(1, 2, 0)
|
347 |
+
return img
|
348 |
+
|
349 |
+
def rgb2ycbcr_pt(img, y_only=False):
|
350 |
+
"""Convert RGB images to YCbCr images (PyTorch version).
|
351 |
+
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
352 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
353 |
+
Args:
|
354 |
+
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
355 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
356 |
+
Returns:
|
357 |
+
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
358 |
+
"""
|
359 |
+
if y_only:
|
360 |
+
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
361 |
+
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
362 |
+
else:
|
363 |
+
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
364 |
+
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
365 |
+
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
366 |
+
|
367 |
+
out_img = out_img / 255.
|
368 |
+
return
|
369 |
+
|
370 |
+
def tensor2img(tensor):
|
371 |
+
im = (255. * tensor).data.cpu().numpy()
|
372 |
+
# clamp
|
373 |
+
im[im > 255] = 255
|
374 |
+
im[im < 0] = 0
|
375 |
+
im = im.astype(np.uint8)
|
376 |
+
return im
|
377 |
+
|
378 |
+
def img2tensor(img):
|
379 |
+
img = (img / 255.).astype('float32')
|
380 |
+
if img.ndim ==2:
|
381 |
+
img = np.expand_dims(np.expand_dims(img, axis = 0),axis=0)
|
382 |
+
else:
|
383 |
+
img = np.transpose(img, (2, 0, 1)) # C, H, W
|
384 |
+
img = np.expand_dims(img, axis=0)
|
385 |
+
img = np.ascontiguousarray(img, dtype=np.float32)
|
386 |
+
tensor = torch.from_numpy(img)
|
387 |
+
return tensor
|
388 |
+
|
389 |
+
def estimate_aggd_param(block):
|
390 |
+
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
|
391 |
+
Args:
|
392 |
+
block (ndarray): 2D Image block.
|
393 |
+
Returns:
|
394 |
+
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
|
395 |
+
distribution (Estimating the parames in Equation 7 in the paper).
|
396 |
+
"""
|
397 |
+
block = block.flatten()
|
398 |
+
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
|
399 |
+
gam_reciprocal = np.reciprocal(gam)
|
400 |
+
r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
401 |
+
|
402 |
+
left_std = np.sqrt(np.mean(block[block < 0]**2))
|
403 |
+
right_std = np.sqrt(np.mean(block[block > 0]**2))
|
404 |
+
gammahat = left_std / right_std
|
405 |
+
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
|
406 |
+
rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
|
407 |
+
array_position = np.argmin((r_gam - rhatnorm)**2)
|
408 |
+
|
409 |
+
alpha = gam[array_position]
|
410 |
+
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
411 |
+
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
412 |
+
return (alpha, beta_l, beta_r)
|
413 |
+
|
414 |
+
|
415 |
+
def compute_feature(block):
|
416 |
+
"""Compute features.
|
417 |
+
Args:
|
418 |
+
block (ndarray): 2D Image block.
|
419 |
+
Returns:
|
420 |
+
list: Features with length of 18.
|
421 |
+
"""
|
422 |
+
feat = []
|
423 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block)
|
424 |
+
feat.extend([alpha, (beta_l + beta_r) / 2])
|
425 |
+
|
426 |
+
# distortions disturb the fairly regular structure of natural images.
|
427 |
+
# This deviation can be captured by analyzing the sample distribution of
|
428 |
+
# the products of pairs of adjacent coefficients computed along
|
429 |
+
# horizontal, vertical and diagonal orientations.
|
430 |
+
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
|
431 |
+
for i in range(len(shifts)):
|
432 |
+
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
|
433 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
|
434 |
+
# Eq. 8
|
435 |
+
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
|
436 |
+
feat.extend([alpha, mean, beta_l, beta_r])
|
437 |
+
return feat
|
438 |
+
|
439 |
+
|
440 |
+
def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
|
441 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
442 |
+
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
443 |
+
This implementation could produce almost the same results as the official
|
444 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
445 |
+
Note that we do not include block overlap height and width, since they are
|
446 |
+
always 0 in the official implementation.
|
447 |
+
For good performance, it is advisable by the official implementation to
|
448 |
+
divide the distorted image in to the same size patched as used for the
|
449 |
+
construction of multivariate Gaussian model.
|
450 |
+
Args:
|
451 |
+
img (ndarray): Input image whose quality needs to be computed. The
|
452 |
+
image must be a gray or Y (of YCbCr) image with shape (h, w).
|
453 |
+
Range [0, 255] with float type.
|
454 |
+
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
|
455 |
+
model calculated on the pristine dataset.
|
456 |
+
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
|
457 |
+
Gaussian model calculated on the pristine dataset.
|
458 |
+
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
|
459 |
+
image.
|
460 |
+
block_size_h (int): Height of the blocks in to which image is divided.
|
461 |
+
Default: 96 (the official recommended value).
|
462 |
+
block_size_w (int): Width of the blocks in to which image is divided.
|
463 |
+
Default: 96 (the official recommended value).
|
464 |
+
"""
|
465 |
+
assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
466 |
+
# crop image
|
467 |
+
h, w = img.shape
|
468 |
+
num_block_h = math.floor(h / block_size_h)
|
469 |
+
num_block_w = math.floor(w / block_size_w)
|
470 |
+
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
|
471 |
+
|
472 |
+
distparam = [] # dist param is actually the multiscale features
|
473 |
+
for scale in (1, 2): # perform on two scales (1, 2)
|
474 |
+
mu = convolve(img, gaussian_window, mode='nearest')
|
475 |
+
sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
|
476 |
+
# normalize, as in Eq. 1 in the paper
|
477 |
+
img_nomalized = (img - mu) / (sigma + 1)
|
478 |
+
|
479 |
+
feat = []
|
480 |
+
for idx_w in range(num_block_w):
|
481 |
+
for idx_h in range(num_block_h):
|
482 |
+
# process ecah block
|
483 |
+
block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
|
484 |
+
idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
|
485 |
+
feat.append(compute_feature(block))
|
486 |
+
|
487 |
+
distparam.append(np.array(feat))
|
488 |
+
|
489 |
+
if scale == 1:
|
490 |
+
img = imresize(img / 255., scale=0.5, antialiasing=True)
|
491 |
+
img = img * 255.
|
492 |
+
|
493 |
+
distparam = np.concatenate(distparam, axis=1)
|
494 |
+
|
495 |
+
# fit a MVG (multivariate Gaussian) model to distorted patch features
|
496 |
+
mu_distparam = np.nanmean(distparam, axis=0)
|
497 |
+
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
|
498 |
+
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
|
499 |
+
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
|
500 |
+
|
501 |
+
# compute niqe quality, Eq. 10 in the paper
|
502 |
+
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
|
503 |
+
quality = np.matmul(
|
504 |
+
np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
|
505 |
+
|
506 |
+
quality = np.sqrt(quality)
|
507 |
+
quality = float(np.squeeze(quality))
|
508 |
+
return quality
|
509 |
+
|
510 |
+
|
511 |
+
def calculate_niqe(img, crop_border=0,input_order='HWC', convert_to='y', **kwargs):
|
512 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
513 |
+
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
514 |
+
This implementation could produce almost the same results as the official
|
515 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
516 |
+
> MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
|
517 |
+
> Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
|
518 |
+
We use the official params estimated from the pristine dataset.
|
519 |
+
We use the recommended block size (96, 96) without overlaps.
|
520 |
+
Args:
|
521 |
+
img (ndarray): Input image whose quality needs to be computed.
|
522 |
+
The input image must be in range [0, 255] with float/int type.
|
523 |
+
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
|
524 |
+
If the input order is 'HWC' or 'CHW', it will be converted to gray
|
525 |
+
or Y (of YCbCr) image according to the ``convert_to`` argument.
|
526 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
527 |
+
pixels are not involved in the metric calculation.
|
528 |
+
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
|
529 |
+
Default: 'HWC'.
|
530 |
+
convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
|
531 |
+
Default: 'y'.
|
532 |
+
Returns:
|
533 |
+
float: NIQE result.
|
534 |
+
"""
|
535 |
+
# ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
536 |
+
# we use the official params estimated from the pristine dataset.
|
537 |
+
niqe_pris_params = np.load('./loss/niqe_pris_params.npz')
|
538 |
+
mu_pris_param = niqe_pris_params['mu_pris_param']
|
539 |
+
cov_pris_param = niqe_pris_params['cov_pris_param']
|
540 |
+
gaussian_window = niqe_pris_params['gaussian_window']
|
541 |
+
|
542 |
+
img = img.astype(np.float32)
|
543 |
+
if input_order != 'HW':
|
544 |
+
img = reorder_image(img, input_order=input_order)
|
545 |
+
if convert_to == 'y':
|
546 |
+
img = to_y_channel(img)
|
547 |
+
elif convert_to == 'gray':
|
548 |
+
img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
|
549 |
+
img = np.squeeze(img)
|
550 |
+
|
551 |
+
if crop_border != 0:
|
552 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border]
|
553 |
+
|
554 |
+
# round is necessary for being consistent with MATLAB's result
|
555 |
+
img = img.round()
|
556 |
+
|
557 |
+
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
|
558 |
+
|
559 |
+
return niqe_result
|
net/CIDNet.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
from net.HVI_transform import RGB_HVI
|
6 |
+
from net.transformer_utils import *
|
7 |
+
from net.LCA import *
|
8 |
+
|
9 |
+
class CIDNet(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
channels=[36, 36, 72, 144],
|
12 |
+
heads=[1, 2, 4, 8],
|
13 |
+
norm=False
|
14 |
+
):
|
15 |
+
super(CIDNet, self).__init__()
|
16 |
+
|
17 |
+
|
18 |
+
[ch1, ch2, ch3, ch4] = channels
|
19 |
+
[head1, head2, head3, head4] = heads
|
20 |
+
|
21 |
+
# HV_ways
|
22 |
+
self.HVE_block0 = nn.Sequential(
|
23 |
+
nn.ReplicationPad2d(1),
|
24 |
+
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
|
25 |
+
)
|
26 |
+
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
|
27 |
+
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
|
28 |
+
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
|
29 |
+
|
30 |
+
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
|
31 |
+
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
|
32 |
+
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
|
33 |
+
self.HVD_block0 = nn.Sequential(
|
34 |
+
nn.ReplicationPad2d(1),
|
35 |
+
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
# I_ways
|
40 |
+
self.IE_block0 = nn.Sequential(
|
41 |
+
nn.ReplicationPad2d(1),
|
42 |
+
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
|
43 |
+
)
|
44 |
+
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
|
45 |
+
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
|
46 |
+
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
|
47 |
+
|
48 |
+
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
|
49 |
+
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
|
50 |
+
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
|
51 |
+
self.ID_block0 = nn.Sequential(
|
52 |
+
nn.ReplicationPad2d(1),
|
53 |
+
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
|
54 |
+
)
|
55 |
+
|
56 |
+
self.HV_LCA1 = HV_LCA(ch2, head2)
|
57 |
+
self.HV_LCA2 = HV_LCA(ch3, head3)
|
58 |
+
self.HV_LCA3 = HV_LCA(ch4, head4)
|
59 |
+
self.HV_LCA4 = HV_LCA(ch4, head4)
|
60 |
+
self.HV_LCA5 = HV_LCA(ch3, head3)
|
61 |
+
self.HV_LCA6 = HV_LCA(ch2, head2)
|
62 |
+
|
63 |
+
self.I_LCA1 = I_LCA(ch2, head2)
|
64 |
+
self.I_LCA2 = I_LCA(ch3, head3)
|
65 |
+
self.I_LCA3 = I_LCA(ch4, head4)
|
66 |
+
self.I_LCA4 = I_LCA(ch4, head4)
|
67 |
+
self.I_LCA5 = I_LCA(ch3, head3)
|
68 |
+
self.I_LCA6 = I_LCA(ch2, head2)
|
69 |
+
|
70 |
+
self.trans = RGB_HVI().cuda()
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
dtypes = x.dtype
|
74 |
+
hvi = self.trans.HVIT(x)
|
75 |
+
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
|
76 |
+
# low
|
77 |
+
i_enc0 = self.IE_block0(i)
|
78 |
+
i_enc1 = self.IE_block1(i_enc0)
|
79 |
+
hv_0 = self.HVE_block0(hvi)
|
80 |
+
hv_1 = self.HVE_block1(hv_0)
|
81 |
+
i_jump0 = i_enc0
|
82 |
+
hv_jump0 = hv_0
|
83 |
+
|
84 |
+
i_enc2 = self.I_LCA1(i_enc1, hv_1)
|
85 |
+
hv_2 = self.HV_LCA1(hv_1, i_enc1)
|
86 |
+
v_jump1 = i_enc2
|
87 |
+
hv_jump1 = hv_2
|
88 |
+
i_enc2 = self.IE_block2(i_enc2)
|
89 |
+
hv_2 = self.HVE_block2(hv_2)
|
90 |
+
|
91 |
+
i_enc3 = self.I_LCA2(i_enc2, hv_2)
|
92 |
+
hv_3 = self.HV_LCA2(hv_2, i_enc2)
|
93 |
+
v_jump2 = i_enc3
|
94 |
+
hv_jump2 = hv_3
|
95 |
+
i_enc3 = self.IE_block3(i_enc2)
|
96 |
+
hv_3 = self.HVE_block3(hv_2)
|
97 |
+
|
98 |
+
i_enc4 = self.I_LCA3(i_enc3, hv_3)
|
99 |
+
hv_4 = self.HV_LCA3(hv_3, i_enc3)
|
100 |
+
|
101 |
+
i_dec4 = self.I_LCA4(i_enc4,hv_4)
|
102 |
+
hv_4 = self.HV_LCA4(hv_4, i_enc4)
|
103 |
+
|
104 |
+
hv_3 = self.HVD_block3(hv_4, hv_jump2)
|
105 |
+
i_dec3 = self.ID_block3(i_dec4, v_jump2)
|
106 |
+
i_dec2 = self.I_LCA5(i_dec3, hv_3)
|
107 |
+
hv_2 = self.HV_LCA5(hv_3, i_dec3)
|
108 |
+
|
109 |
+
hv_2 = self.HVD_block2(hv_2, hv_jump1)
|
110 |
+
i_dec2 = self.ID_block2(i_dec3, v_jump1)
|
111 |
+
|
112 |
+
i_dec1 = self.I_LCA6(i_dec2, hv_2)
|
113 |
+
hv_1 = self.HV_LCA6(hv_2, i_dec2)
|
114 |
+
|
115 |
+
i_dec1 = self.ID_block1(i_dec1, i_jump0)
|
116 |
+
i_dec0 = self.ID_block0(i_dec1)
|
117 |
+
hv_1 = self.HVD_block1(hv_1, hv_jump0)
|
118 |
+
hv_0 = self.HVD_block0(hv_1)
|
119 |
+
|
120 |
+
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
|
121 |
+
output_rgb = self.trans.PHVIT(output_hvi)
|
122 |
+
|
123 |
+
return output_rgb
|
124 |
+
|
125 |
+
def HVIT(self,x):
|
126 |
+
hvi = self.trans.HVIT(x)
|
127 |
+
return hvi
|
128 |
+
|
129 |
+
|
130 |
+
|
net/HVI_transform.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
pi = 3.141592653589793
|
5 |
+
|
6 |
+
class RGB_HVI(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(RGB_HVI, self).__init__()
|
9 |
+
self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
|
10 |
+
self.gated = False
|
11 |
+
self.gated2= False
|
12 |
+
self.alpha = 1.0
|
13 |
+
self.alpha_s = 1.3
|
14 |
+
self.this_k = 0
|
15 |
+
|
16 |
+
def HVIT(self, img):
|
17 |
+
eps = 1e-8
|
18 |
+
device = img.device
|
19 |
+
dtypes = img.dtype
|
20 |
+
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
|
21 |
+
value = img.max(1)[0].to(dtypes)
|
22 |
+
img_min = img.min(1)[0].to(dtypes)
|
23 |
+
hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
|
24 |
+
hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
|
25 |
+
hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
|
26 |
+
|
27 |
+
hue[img.min(1)[0]==value] = 0.0
|
28 |
+
hue = hue/6.0
|
29 |
+
|
30 |
+
saturation = (value - img_min ) / (value + eps )
|
31 |
+
saturation[value==0] = 0
|
32 |
+
|
33 |
+
hue = hue.unsqueeze(1)
|
34 |
+
saturation = saturation.unsqueeze(1)
|
35 |
+
value = value.unsqueeze(1)
|
36 |
+
|
37 |
+
k = self.density_k
|
38 |
+
self.this_k = k.item()
|
39 |
+
|
40 |
+
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
|
41 |
+
ch = (2.0 * pi * hue).cos()
|
42 |
+
cv = (2.0 * pi * hue).sin()
|
43 |
+
H = color_sensitive * saturation * ch
|
44 |
+
V = color_sensitive * saturation * cv
|
45 |
+
I = value
|
46 |
+
xyz = torch.cat([H, V, I],dim=1)
|
47 |
+
return xyz
|
48 |
+
|
49 |
+
def PHVIT(self, img):
|
50 |
+
eps = 1e-8
|
51 |
+
H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
|
52 |
+
|
53 |
+
# clip
|
54 |
+
H = torch.clamp(H,-1,1)
|
55 |
+
V = torch.clamp(V,-1,1)
|
56 |
+
I = torch.clamp(I,0,1)
|
57 |
+
|
58 |
+
v = I
|
59 |
+
k = self.this_k
|
60 |
+
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
|
61 |
+
H = (H) / (color_sensitive + eps)
|
62 |
+
V = (V) / (color_sensitive + eps)
|
63 |
+
H = torch.clamp(H,-1,1)
|
64 |
+
V = torch.clamp(V,-1,1)
|
65 |
+
h = torch.atan2(V + eps,H + eps) / (2*pi)
|
66 |
+
h = h%1
|
67 |
+
s = torch.sqrt(H**2 + V**2 + eps)
|
68 |
+
|
69 |
+
if self.gated:
|
70 |
+
s = s * self.alpha_s
|
71 |
+
|
72 |
+
s = torch.clamp(s,0,1)
|
73 |
+
v = torch.clamp(v,0,1)
|
74 |
+
|
75 |
+
r = torch.zeros_like(h)
|
76 |
+
g = torch.zeros_like(h)
|
77 |
+
b = torch.zeros_like(h)
|
78 |
+
|
79 |
+
hi = torch.floor(h * 6.0)
|
80 |
+
f = h * 6.0 - hi
|
81 |
+
p = v * (1. - s)
|
82 |
+
q = v * (1. - (f * s))
|
83 |
+
t = v * (1. - ((1. - f) * s))
|
84 |
+
|
85 |
+
hi0 = hi==0
|
86 |
+
hi1 = hi==1
|
87 |
+
hi2 = hi==2
|
88 |
+
hi3 = hi==3
|
89 |
+
hi4 = hi==4
|
90 |
+
hi5 = hi==5
|
91 |
+
|
92 |
+
r[hi0] = v[hi0]
|
93 |
+
g[hi0] = t[hi0]
|
94 |
+
b[hi0] = p[hi0]
|
95 |
+
|
96 |
+
r[hi1] = q[hi1]
|
97 |
+
g[hi1] = v[hi1]
|
98 |
+
b[hi1] = p[hi1]
|
99 |
+
|
100 |
+
r[hi2] = p[hi2]
|
101 |
+
g[hi2] = v[hi2]
|
102 |
+
b[hi2] = t[hi2]
|
103 |
+
|
104 |
+
r[hi3] = p[hi3]
|
105 |
+
g[hi3] = q[hi3]
|
106 |
+
b[hi3] = v[hi3]
|
107 |
+
|
108 |
+
r[hi4] = t[hi4]
|
109 |
+
g[hi4] = p[hi4]
|
110 |
+
b[hi4] = v[hi4]
|
111 |
+
|
112 |
+
r[hi5] = v[hi5]
|
113 |
+
g[hi5] = p[hi5]
|
114 |
+
b[hi5] = q[hi5]
|
115 |
+
|
116 |
+
r = r.unsqueeze(1)
|
117 |
+
g = g.unsqueeze(1)
|
118 |
+
b = b.unsqueeze(1)
|
119 |
+
rgb = torch.cat([r, g, b], dim=1)
|
120 |
+
if self.gated2:
|
121 |
+
rgb = rgb * self.alpha
|
122 |
+
return rgb
|
net/LCA.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
from net.transformer_utils import *
|
5 |
+
|
6 |
+
# Cross Attention Block
|
7 |
+
class CAB(nn.Module):
|
8 |
+
def __init__(self, dim, num_heads, bias):
|
9 |
+
super(CAB, self).__init__()
|
10 |
+
self.num_heads = num_heads
|
11 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
12 |
+
|
13 |
+
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
14 |
+
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
|
15 |
+
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
|
16 |
+
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
|
17 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
18 |
+
|
19 |
+
def forward(self, x, y):
|
20 |
+
b, c, h, w = x.shape
|
21 |
+
|
22 |
+
q = self.q_dwconv(self.q(x))
|
23 |
+
kv = self.kv_dwconv(self.kv(y))
|
24 |
+
k, v = kv.chunk(2, dim=1)
|
25 |
+
|
26 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
27 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
28 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
29 |
+
|
30 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
31 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
32 |
+
|
33 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
34 |
+
attn = nn.functional.softmax(attn,dim=-1)
|
35 |
+
|
36 |
+
out = (attn @ v)
|
37 |
+
|
38 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
39 |
+
|
40 |
+
out = self.project_out(out)
|
41 |
+
return out
|
42 |
+
|
43 |
+
|
44 |
+
# Intensity Enhancement Layer
|
45 |
+
class IEL(nn.Module):
|
46 |
+
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
|
47 |
+
super(IEL, self).__init__()
|
48 |
+
|
49 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
50 |
+
|
51 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
52 |
+
|
53 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
54 |
+
self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
|
55 |
+
self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
|
56 |
+
|
57 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
58 |
+
|
59 |
+
self.Tanh = nn.Tanh()
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.project_in(x)
|
62 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
63 |
+
x1 = self.Tanh(self.dwconv1(x1)) + x1
|
64 |
+
x2 = self.Tanh(self.dwconv2(x2)) + x2
|
65 |
+
x = x1 * x2
|
66 |
+
x = self.project_out(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
# Lightweight Cross Attention
|
71 |
+
class HV_LCA(nn.Module):
|
72 |
+
def __init__(self, dim,num_heads, bias=False):
|
73 |
+
super(HV_LCA, self).__init__()
|
74 |
+
self.gdfn = IEL(dim) # IEL and CDL have same structure
|
75 |
+
self.norm = LayerNorm(dim)
|
76 |
+
self.ffn = CAB(dim, num_heads, bias)
|
77 |
+
|
78 |
+
def forward(self, x, y):
|
79 |
+
x = x + self.ffn(self.norm(x),self.norm(y))
|
80 |
+
x = self.gdfn(self.norm(x))
|
81 |
+
return x
|
82 |
+
|
83 |
+
class I_LCA(nn.Module):
|
84 |
+
def __init__(self, dim,num_heads, bias=False):
|
85 |
+
super(I_LCA, self).__init__()
|
86 |
+
self.norm = LayerNorm(dim)
|
87 |
+
self.gdfn = IEL(dim)
|
88 |
+
self.ffn = CAB(dim, num_heads, bias=bias)
|
89 |
+
|
90 |
+
def forward(self, x, y):
|
91 |
+
x = x + self.ffn(self.norm(x),self.norm(y))
|
92 |
+
x = x + self.gdfn(self.norm(x))
|
93 |
+
return x
|
net/transformer_utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
class LayerNorm(nn.Module):
|
7 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
8 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
9 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
10 |
+
with shape (batch_size, channels, height, width).
|
11 |
+
"""
|
12 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
13 |
+
super().__init__()
|
14 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
15 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
16 |
+
self.eps = eps
|
17 |
+
self.data_format = data_format
|
18 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
19 |
+
raise NotImplementedError
|
20 |
+
self.normalized_shape = (normalized_shape, )
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
if self.data_format == "channels_last":
|
24 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
25 |
+
elif self.data_format == "channels_first":
|
26 |
+
u = x.mean(1, keepdim=True)
|
27 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
28 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
29 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
30 |
+
return x
|
31 |
+
|
32 |
+
class NormDownsample(nn.Module):
|
33 |
+
def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
|
34 |
+
super(NormDownsample, self).__init__()
|
35 |
+
self.use_norm=use_norm
|
36 |
+
if self.use_norm:
|
37 |
+
self.norm=LayerNorm(out_ch)
|
38 |
+
self.prelu = nn.PReLU()
|
39 |
+
self.down = nn.Sequential(
|
40 |
+
nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
|
41 |
+
nn.UpsamplingBilinear2d(scale_factor=scale))
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.down(x)
|
44 |
+
x = self.prelu(x)
|
45 |
+
if self.use_norm:
|
46 |
+
x = self.norm(x)
|
47 |
+
return x
|
48 |
+
else:
|
49 |
+
return x
|
50 |
+
|
51 |
+
class NormUpsample(nn.Module):
|
52 |
+
def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
|
53 |
+
super(NormUpsample, self).__init__()
|
54 |
+
self.use_norm=use_norm
|
55 |
+
if self.use_norm:
|
56 |
+
self.norm=LayerNorm(out_ch)
|
57 |
+
self.prelu = nn.PReLU()
|
58 |
+
self.up_scale = nn.Sequential(
|
59 |
+
nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
|
60 |
+
nn.UpsamplingBilinear2d(scale_factor=scale))
|
61 |
+
self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
|
62 |
+
|
63 |
+
def forward(self, x,y):
|
64 |
+
x = self.up_scale(x)
|
65 |
+
x = torch.cat([x, y],dim=1)
|
66 |
+
x = self.up(x)
|
67 |
+
x = self.prelu(x)
|
68 |
+
if self.use_norm:
|
69 |
+
return self.norm(x)
|
70 |
+
else:
|
71 |
+
return x
|
72 |
+
|
weights/LOL-Blur.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eba53573097b4138be26ffc219826cf87eaeafbc0a6ae90cf4b35330070b5494
|
3 |
+
size 7971076
|
weights/LOLv1/BestLPIPS_0.0868.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:533bec252359efc8f05720ce5a73c3a8c891384e62ba62b47b457b4a4c87247e
|
3 |
+
size 7971706
|
weights/LOLv1/BestSSIM_0.8631.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:267ed28a7f65af87dd7bd6f486d88cfa25e654234badb974df7bc85dbf267ee3
|
3 |
+
size 7971706
|
weights/LOLv1/PSNR_24.74.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb268bdf320d7236ba2e70c2b95e8c116d9766a3f3dbf4ccbbab1387228dfbf7
|
3 |
+
size 7971706
|
weights/LOLv1/w_perc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9206385b514014b60bf9396151a011cabf68abb380c2ad87ffde3d53e8227926
|
3 |
+
size 7968002
|
weights/LOLv1/wo_perc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c5afb426fee910dd9806b4e6214f05914fb5988b7982a4624560815aefce2b5
|
3 |
+
size 7970627
|
weights/LOLv2_real/best_PSNR.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44a7f92cf02a1da322c1a34b3aa940341dc93fd53840b7f19b49f461200b00a6
|
3 |
+
size 7971269
|
weights/LOLv2_real/best_SSIM.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d7ad3e93d65effaeb937cf833b1f2b3fb78c7529a4db95c589ff0ca569347c3
|
3 |
+
size 7971269
|
weights/LOLv2_real/w_perc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e0affc8ce89cceeb283dfa3ce7108c767cefc6f59826da37ec4c39482839db8
|
3 |
+
size 7967682
|
weights/LOLv2_syn/w_perc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:460a9956018935cab61ddef3cea14f1dbbc02a65e12ee4d5b5ab45d704d2575b
|
3 |
+
size 7967682
|
weights/LOLv2_syn/wo_perc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:519cc5fc55491cf36b47b8295c0204dde29ba887feeed06a77e716072469e55e
|
3 |
+
size 7974353
|
weights/SICE.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b0824026124427fe0419465fc06a6be0129356c949794d1d1f52c4173091607
|
3 |
+
size 7964800
|
weights/SID.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49678f69495610a2691bb5b2665d2766b57a6a79201a236f7afa84f67c5fec5f
|
3 |
+
size 7964543
|
weights/generalization.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6395370762ec1fa0f6fe92a38c62382b1ebe21edc012f597023bfcec8b95cd27
|
3 |
+
size 7971462
|