Fediory commited on
Commit
3e648fb
·
1 Parent(s): 96448a9

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from net.CIDNet import CIDNet
6
+ import torchvision.transforms as transforms
7
+ import torch.nn.functional as F
8
+ import os
9
+ import imquality.brisque as brisque
10
+ from loss.niqe_utils import *
11
+
12
+ eval_net = CIDNet()
13
+ eval_net.trans.gated = True
14
+ eval_net.trans.gated2 = True
15
+
16
+ def process_image(input_img,score,model_path,gamma,alpha_s=1.0,alpha_i=1.0):
17
+ torch.set_grad_enabled(False)
18
+ eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
19
+ eval_net.eval()
20
+
21
+ pil2tensor = transforms.Compose([transforms.ToTensor()])
22
+ input = pil2tensor(input_img)
23
+ factor = 8
24
+ h, w = input.shape[1], input.shape[2]
25
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
26
+ padh = H - h if h % factor != 0 else 0
27
+ padw = W - w if w % factor != 0 else 0
28
+ input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
29
+ with torch.no_grad():
30
+ eval_net.trans.alpha_s = alpha_s
31
+ eval_net.trans.alpha = alpha_i
32
+ output = eval_net(input**gamma)
33
+ output = torch.clamp(output,0,1)
34
+ output = output[:, :, :h, :w]
35
+ enhanced_img = transforms.ToPILImage()(output.squeeze(0))
36
+ if score == 'Yes':
37
+ im1 = enhanced_img.convert('RGB')
38
+ score_brisque = brisque.score(im1)
39
+ im1 = np.array(im1)
40
+ score_niqe = calculate_niqe(im1)
41
+ return enhanced_img,score_niqe,score_brisque
42
+ else:
43
+ return enhanced_img,0,0
44
+
45
+ def find_pth_files(directory):
46
+ pth_files = []
47
+ for root, dirs, files in os.walk(directory):
48
+ if 'train' in root.split(os.sep):
49
+ continue
50
+ for file in files:
51
+ if file.endswith('.pth'):
52
+ pth_files.append(os.path.join(root, file))
53
+ return pth_files
54
+
55
+ def remove_weights_prefix(paths):
56
+ cleaned_paths = [path.replace('.\\weights\\', '') for path in paths]
57
+ return cleaned_paths
58
+
59
+ directory = ".\weights"
60
+ pth_files = find_pth_files(directory)
61
+ pth_files2 = remove_weights_prefix(pth_files)
62
+
63
+ interface = gr.Interface(
64
+ fn=process_image,
65
+ inputs=[
66
+ gr.Image(label="Low-light Image", type="pil"),
67
+ gr.Radio(choices=['Yes','No'],label="Image Score"),
68
+ gr.Radio(choices=pth_files2,label="Model Path"),
69
+ gr.Slider(0.1,10,label="gamma curve",step=0.01,value=1.0),
70
+ gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0),
71
+ gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0)
72
+ ],
73
+ outputs=[
74
+ gr.Image(label="Result", type="pil"),
75
+ gr.Textbox(label="NIQE"),
76
+ gr.Textbox(label="BRISQUE")
77
+ ],
78
+ title="HVI-CIDNet (Low-Light Image Enhancement)",
79
+ allow_flagging="never"
80
+ )
81
+
82
+ interface.launch(server_port=7862)
loss/niqe_pris_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
3
+ size 11850
loss/niqe_utils.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ from scipy.ndimage import convolve
6
+ from scipy.special import gamma
7
+ import torch
8
+
9
+ def cubic(x):
10
+ """cubic function used for calculate_weights_indices."""
11
+ absx = torch.abs(x)
12
+ absx2 = absx**2
13
+ absx3 = absx**3
14
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
15
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
16
+ (absx <= 2)).type_as(absx))
17
+
18
+
19
+
20
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
21
+ """Calculate weights and indices, used for imresize function.
22
+ Args:
23
+ in_length (int): Input length.
24
+ out_length (int): Output length.
25
+ scale (float): Scale factor.
26
+ kernel_width (int): Kernel width.
27
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
28
+ """
29
+
30
+ if (scale < 1) and antialiasing:
31
+ # Use a modified kernel (larger kernel width) to simultaneously
32
+ # interpolate and antialias
33
+ kernel_width = kernel_width / scale
34
+
35
+ # Output-space coordinates
36
+ x = torch.linspace(1, out_length, out_length)
37
+
38
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
39
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
40
+ # space maps to 1.5 in input space.
41
+ u = x / scale + 0.5 * (1 - 1 / scale)
42
+
43
+ # What is the left-most pixel that can be involved in the computation?
44
+ left = torch.floor(u - kernel_width / 2)
45
+
46
+ # What is the maximum number of pixels that can be involved in the
47
+ # computation? Note: it's OK to use an extra pixel here; if the
48
+ # corresponding weights are all zero, it will be eliminated at the end
49
+ # of this function.
50
+ p = math.ceil(kernel_width) + 2
51
+
52
+ # The indices of the input pixels involved in computing the k-th output
53
+ # pixel are in row k of the indices matrix.
54
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
55
+ out_length, p)
56
+
57
+ # The weights used to compute the k-th output pixel are in row k of the
58
+ # weights matrix.
59
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
60
+
61
+ # apply cubic kernel
62
+ if (scale < 1) and antialiasing:
63
+ weights = scale * cubic(distance_to_center * scale)
64
+ else:
65
+ weights = cubic(distance_to_center)
66
+
67
+ # Normalize the weights matrix so that each row sums to 1.
68
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
69
+ weights = weights / weights_sum.expand(out_length, p)
70
+
71
+ # If a column in weights is all zero, get rid of it. only consider the
72
+ # first and last column.
73
+ weights_zero_tmp = torch.sum((weights == 0), 0)
74
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
75
+ indices = indices.narrow(1, 1, p - 2)
76
+ weights = weights.narrow(1, 1, p - 2)
77
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
78
+ indices = indices.narrow(1, 0, p - 2)
79
+ weights = weights.narrow(1, 0, p - 2)
80
+ weights = weights.contiguous()
81
+ indices = indices.contiguous()
82
+ sym_len_s = -indices.min() + 1
83
+ sym_len_e = indices.max() - in_length
84
+ indices = indices + sym_len_s - 1
85
+ return weights, indices, int(sym_len_s), int(sym_len_e)
86
+
87
+ def imresize(img, scale, antialiasing=True):
88
+ """imresize function same as MATLAB.
89
+ It now only supports bicubic.
90
+ The same scale applies for both height and width.
91
+ Args:
92
+ img (Tensor | Numpy array):
93
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
94
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
95
+ scale (float): Scale factor. The same scale applies for both height
96
+ and width.
97
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
98
+ Default: True.
99
+ Returns:
100
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
101
+ """
102
+ squeeze_flag = False
103
+ if type(img).__module__ == np.__name__: # numpy type
104
+ numpy_type = True
105
+ if img.ndim == 2:
106
+ img = img[:, :, None]
107
+ squeeze_flag = True
108
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
109
+ else:
110
+ numpy_type = False
111
+ if img.ndim == 2:
112
+ img = img.unsqueeze(0)
113
+ squeeze_flag = True
114
+
115
+ in_c, in_h, in_w = img.size()
116
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
117
+ kernel_width = 4
118
+ kernel = 'cubic'
119
+
120
+ # get weights and indices
121
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
122
+ antialiasing)
123
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
124
+ antialiasing)
125
+ # process H dimension
126
+ # symmetric copying
127
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
128
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
129
+
130
+ sym_patch = img[:, :sym_len_hs, :]
131
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
132
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
133
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
134
+
135
+ sym_patch = img[:, -sym_len_he:, :]
136
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
137
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
138
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
139
+
140
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
141
+ kernel_width = weights_h.size(1)
142
+ for i in range(out_h):
143
+ idx = int(indices_h[i][0])
144
+ for j in range(in_c):
145
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
146
+
147
+ # process W dimension
148
+ # symmetric copying
149
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
150
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
151
+
152
+ sym_patch = out_1[:, :, :sym_len_ws]
153
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
154
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
155
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
156
+
157
+ sym_patch = out_1[:, :, -sym_len_we:]
158
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
159
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
160
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
161
+
162
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
163
+ kernel_width = weights_w.size(1)
164
+ for i in range(out_w):
165
+ idx = int(indices_w[i][0])
166
+ for j in range(in_c):
167
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
168
+
169
+ if squeeze_flag:
170
+ out_2 = out_2.squeeze(0)
171
+ if numpy_type:
172
+ out_2 = out_2.numpy()
173
+ if not squeeze_flag:
174
+ out_2 = out_2.transpose(1, 2, 0)
175
+
176
+ return out_2
177
+
178
+
179
+ def _convert_input_type_range(img):
180
+ """Convert the type and range of the input image.
181
+ It converts the input image to np.float32 type and range of [0, 1].
182
+ It is mainly used for pre-processing the input image in colorspace
183
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
184
+ Args:
185
+ img (ndarray): The input image. It accepts:
186
+ 1. np.uint8 type with range [0, 255];
187
+ 2. np.float32 type with range [0, 1].
188
+ Returns:
189
+ (ndarray): The converted image with type of np.float32 and range of
190
+ [0, 1].
191
+ """
192
+ img_type = img.dtype
193
+ img = img.astype(np.float32)
194
+ if img_type == np.float32:
195
+ pass
196
+ elif img_type == np.uint8:
197
+ img /= 255.
198
+ else:
199
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
200
+ return img
201
+
202
+
203
+ def _convert_output_type_range(img, dst_type):
204
+ """Convert the type and range of the image according to dst_type.
205
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
206
+ images will be converted to np.uint8 type with range [0, 255]. If
207
+ `dst_type` is np.float32, it converts the image to np.float32 type with
208
+ range [0, 1].
209
+ It is mainly used for post-processing images in colorspace conversion
210
+ functions such as rgb2ycbcr and ycbcr2rgb.
211
+ Args:
212
+ img (ndarray): The image to be converted with np.float32 type and
213
+ range [0, 255].
214
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
215
+ converts the image to np.uint8 type with range [0, 255]. If
216
+ dst_type is np.float32, it converts the image to np.float32 type
217
+ with range [0, 1].
218
+ Returns:
219
+ (ndarray): The converted image with desired type and range.
220
+ """
221
+ if dst_type not in (np.uint8, np.float32):
222
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
223
+ if dst_type == np.uint8:
224
+ img = img.round()
225
+ else:
226
+ img /= 255.
227
+ return img.astype(dst_type)
228
+
229
+
230
+
231
+ def rgb2ycbcr(img, y_only=False):
232
+ """Convert a RGB image to YCbCr image.
233
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
234
+ It implements the ITU-R BT.601 conversion for standard-definition
235
+ television. See more details in
236
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
237
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
238
+ In OpenCV, it implements a JPEG conversion. See more details in
239
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
240
+ Args:
241
+ img (ndarray): The input image. It accepts:
242
+ 1. np.uint8 type with range [0, 255];
243
+ 2. np.float32 type with range [0, 1].
244
+ y_only (bool): Whether to only return Y channel. Default: False.
245
+ Returns:
246
+ ndarray: The converted YCbCr image. The output image has the same type
247
+ and range as input image.
248
+ """
249
+ img_type = img.dtype
250
+ img = _convert_input_type_range(img)
251
+ if y_only:
252
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
253
+ else:
254
+ out_img = np.matmul(
255
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
256
+ out_img = _convert_output_type_range(out_img, img_type)
257
+ return out_img
258
+
259
+
260
+ def bgr2ycbcr(img, y_only=False):
261
+ """Convert a BGR image to YCbCr image.
262
+ The bgr version of rgb2ycbcr.
263
+ It implements the ITU-R BT.601 conversion for standard-definition
264
+ television. See more details in
265
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
266
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
267
+ In OpenCV, it implements a JPEG conversion. See more details in
268
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
269
+ Args:
270
+ img (ndarray): The input image. It accepts:
271
+ 1. np.uint8 type with range [0, 255];
272
+ 2. np.float32 type with range [0, 1].
273
+ y_only (bool): Whether to only return Y channel. Default: False.
274
+ Returns:
275
+ ndarray: The converted YCbCr image. The output image has the same type
276
+ and range as input image.
277
+ """
278
+ img_type = img.dtype
279
+ img = _convert_input_type_range(img)
280
+ if y_only:
281
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
282
+ else:
283
+ out_img = np.matmul(
284
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
285
+ out_img = _convert_output_type_range(out_img, img_type)
286
+ return out_img
287
+
288
+ def ycbcr2rgb(img):
289
+ """Convert a YCbCr image to RGB image.
290
+ This function produces the same results as Matlab's ycbcr2rgb function.
291
+ It implements the ITU-R BT.601 conversion for standard-definition
292
+ television. See more details in
293
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
294
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
295
+ In OpenCV, it implements a JPEG conversion. See more details in
296
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
297
+ Args:
298
+ img (ndarray): The input image. It accepts:
299
+ 1. np.uint8 type with range [0, 255];
300
+ 2. np.float32 type with range [0, 1].
301
+ Returns:
302
+ ndarray: The converted RGB image. The output image has the same type
303
+ and range as input image.
304
+ """
305
+ img_type = img.dtype
306
+ img = _convert_input_type_range(img) * 255
307
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
308
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
309
+ out_img = _convert_output_type_range(out_img, img_type)
310
+ return out_img
311
+
312
+
313
+ def to_y_channel(img):
314
+ """Change to Y channel of YCbCr.
315
+ Args:
316
+ img (ndarray): Images with range [0, 255].
317
+ Returns:
318
+ (ndarray): Images with range [0, 255] (float type) without round.
319
+ """
320
+ img = img.astype(np.float32) / 255.
321
+ if img.ndim == 3 and img.shape[2] == 3:
322
+ img = bgr2ycbcr(img, y_only=True)
323
+ img = img[..., None]
324
+ return img * 255.
325
+
326
+
327
+ def reorder_image(img, input_order='HWC'):
328
+ """Reorder images to 'HWC' order.
329
+ If the input_order is (h, w), return (h, w, 1);
330
+ If the input_order is (c, h, w), return (h, w, c);
331
+ If the input_order is (h, w, c), return as it is.
332
+ Args:
333
+ img (ndarray): Input image.
334
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
335
+ If the input image shape is (h, w), input_order will not have
336
+ effects. Default: 'HWC'.
337
+ Returns:
338
+ ndarray: reordered image.
339
+ """
340
+
341
+ if input_order not in ['HWC', 'CHW']:
342
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
343
+ if len(img.shape) == 2:
344
+ img = img[..., None]
345
+ if input_order == 'CHW':
346
+ img = img.transpose(1, 2, 0)
347
+ return img
348
+
349
+ def rgb2ycbcr_pt(img, y_only=False):
350
+ """Convert RGB images to YCbCr images (PyTorch version).
351
+ It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
352
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
353
+ Args:
354
+ img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
355
+ y_only (bool): Whether to only return Y channel. Default: False.
356
+ Returns:
357
+ (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
358
+ """
359
+ if y_only:
360
+ weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
361
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
362
+ else:
363
+ weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
364
+ bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
365
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
366
+
367
+ out_img = out_img / 255.
368
+ return
369
+
370
+ def tensor2img(tensor):
371
+ im = (255. * tensor).data.cpu().numpy()
372
+ # clamp
373
+ im[im > 255] = 255
374
+ im[im < 0] = 0
375
+ im = im.astype(np.uint8)
376
+ return im
377
+
378
+ def img2tensor(img):
379
+ img = (img / 255.).astype('float32')
380
+ if img.ndim ==2:
381
+ img = np.expand_dims(np.expand_dims(img, axis = 0),axis=0)
382
+ else:
383
+ img = np.transpose(img, (2, 0, 1)) # C, H, W
384
+ img = np.expand_dims(img, axis=0)
385
+ img = np.ascontiguousarray(img, dtype=np.float32)
386
+ tensor = torch.from_numpy(img)
387
+ return tensor
388
+
389
+ def estimate_aggd_param(block):
390
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
391
+ Args:
392
+ block (ndarray): 2D Image block.
393
+ Returns:
394
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
395
+ distribution (Estimating the parames in Equation 7 in the paper).
396
+ """
397
+ block = block.flatten()
398
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
399
+ gam_reciprocal = np.reciprocal(gam)
400
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
401
+
402
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
403
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
404
+ gammahat = left_std / right_std
405
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
406
+ rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
407
+ array_position = np.argmin((r_gam - rhatnorm)**2)
408
+
409
+ alpha = gam[array_position]
410
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
411
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
412
+ return (alpha, beta_l, beta_r)
413
+
414
+
415
+ def compute_feature(block):
416
+ """Compute features.
417
+ Args:
418
+ block (ndarray): 2D Image block.
419
+ Returns:
420
+ list: Features with length of 18.
421
+ """
422
+ feat = []
423
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
424
+ feat.extend([alpha, (beta_l + beta_r) / 2])
425
+
426
+ # distortions disturb the fairly regular structure of natural images.
427
+ # This deviation can be captured by analyzing the sample distribution of
428
+ # the products of pairs of adjacent coefficients computed along
429
+ # horizontal, vertical and diagonal orientations.
430
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
431
+ for i in range(len(shifts)):
432
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
433
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
434
+ # Eq. 8
435
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
436
+ feat.extend([alpha, mean, beta_l, beta_r])
437
+ return feat
438
+
439
+
440
+ def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
441
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
442
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
443
+ This implementation could produce almost the same results as the official
444
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
445
+ Note that we do not include block overlap height and width, since they are
446
+ always 0 in the official implementation.
447
+ For good performance, it is advisable by the official implementation to
448
+ divide the distorted image in to the same size patched as used for the
449
+ construction of multivariate Gaussian model.
450
+ Args:
451
+ img (ndarray): Input image whose quality needs to be computed. The
452
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
453
+ Range [0, 255] with float type.
454
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
455
+ model calculated on the pristine dataset.
456
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
457
+ Gaussian model calculated on the pristine dataset.
458
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
459
+ image.
460
+ block_size_h (int): Height of the blocks in to which image is divided.
461
+ Default: 96 (the official recommended value).
462
+ block_size_w (int): Width of the blocks in to which image is divided.
463
+ Default: 96 (the official recommended value).
464
+ """
465
+ assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
466
+ # crop image
467
+ h, w = img.shape
468
+ num_block_h = math.floor(h / block_size_h)
469
+ num_block_w = math.floor(w / block_size_w)
470
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
471
+
472
+ distparam = [] # dist param is actually the multiscale features
473
+ for scale in (1, 2): # perform on two scales (1, 2)
474
+ mu = convolve(img, gaussian_window, mode='nearest')
475
+ sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
476
+ # normalize, as in Eq. 1 in the paper
477
+ img_nomalized = (img - mu) / (sigma + 1)
478
+
479
+ feat = []
480
+ for idx_w in range(num_block_w):
481
+ for idx_h in range(num_block_h):
482
+ # process ecah block
483
+ block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
484
+ idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
485
+ feat.append(compute_feature(block))
486
+
487
+ distparam.append(np.array(feat))
488
+
489
+ if scale == 1:
490
+ img = imresize(img / 255., scale=0.5, antialiasing=True)
491
+ img = img * 255.
492
+
493
+ distparam = np.concatenate(distparam, axis=1)
494
+
495
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
496
+ mu_distparam = np.nanmean(distparam, axis=0)
497
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
498
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
499
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
500
+
501
+ # compute niqe quality, Eq. 10 in the paper
502
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
503
+ quality = np.matmul(
504
+ np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
505
+
506
+ quality = np.sqrt(quality)
507
+ quality = float(np.squeeze(quality))
508
+ return quality
509
+
510
+
511
+ def calculate_niqe(img, crop_border=0,input_order='HWC', convert_to='y', **kwargs):
512
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
513
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
514
+ This implementation could produce almost the same results as the official
515
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
516
+ > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
517
+ > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
518
+ We use the official params estimated from the pristine dataset.
519
+ We use the recommended block size (96, 96) without overlaps.
520
+ Args:
521
+ img (ndarray): Input image whose quality needs to be computed.
522
+ The input image must be in range [0, 255] with float/int type.
523
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
524
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
525
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
526
+ crop_border (int): Cropped pixels in each edge of an image. These
527
+ pixels are not involved in the metric calculation.
528
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
529
+ Default: 'HWC'.
530
+ convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
531
+ Default: 'y'.
532
+ Returns:
533
+ float: NIQE result.
534
+ """
535
+ # ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
536
+ # we use the official params estimated from the pristine dataset.
537
+ niqe_pris_params = np.load('./loss/niqe_pris_params.npz')
538
+ mu_pris_param = niqe_pris_params['mu_pris_param']
539
+ cov_pris_param = niqe_pris_params['cov_pris_param']
540
+ gaussian_window = niqe_pris_params['gaussian_window']
541
+
542
+ img = img.astype(np.float32)
543
+ if input_order != 'HW':
544
+ img = reorder_image(img, input_order=input_order)
545
+ if convert_to == 'y':
546
+ img = to_y_channel(img)
547
+ elif convert_to == 'gray':
548
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
549
+ img = np.squeeze(img)
550
+
551
+ if crop_border != 0:
552
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
553
+
554
+ # round is necessary for being consistent with MATLAB's result
555
+ img = img.round()
556
+
557
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
558
+
559
+ return niqe_result
net/CIDNet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from net.HVI_transform import RGB_HVI
6
+ from net.transformer_utils import *
7
+ from net.LCA import *
8
+
9
+ class CIDNet(nn.Module):
10
+ def __init__(self,
11
+ channels=[36, 36, 72, 144],
12
+ heads=[1, 2, 4, 8],
13
+ norm=False
14
+ ):
15
+ super(CIDNet, self).__init__()
16
+
17
+
18
+ [ch1, ch2, ch3, ch4] = channels
19
+ [head1, head2, head3, head4] = heads
20
+
21
+ # HV_ways
22
+ self.HVE_block0 = nn.Sequential(
23
+ nn.ReplicationPad2d(1),
24
+ nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
25
+ )
26
+ self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
27
+ self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
28
+ self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
29
+
30
+ self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
31
+ self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
32
+ self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
33
+ self.HVD_block0 = nn.Sequential(
34
+ nn.ReplicationPad2d(1),
35
+ nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
36
+ )
37
+
38
+
39
+ # I_ways
40
+ self.IE_block0 = nn.Sequential(
41
+ nn.ReplicationPad2d(1),
42
+ nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
43
+ )
44
+ self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
45
+ self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
46
+ self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
47
+
48
+ self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
49
+ self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
50
+ self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
51
+ self.ID_block0 = nn.Sequential(
52
+ nn.ReplicationPad2d(1),
53
+ nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
54
+ )
55
+
56
+ self.HV_LCA1 = HV_LCA(ch2, head2)
57
+ self.HV_LCA2 = HV_LCA(ch3, head3)
58
+ self.HV_LCA3 = HV_LCA(ch4, head4)
59
+ self.HV_LCA4 = HV_LCA(ch4, head4)
60
+ self.HV_LCA5 = HV_LCA(ch3, head3)
61
+ self.HV_LCA6 = HV_LCA(ch2, head2)
62
+
63
+ self.I_LCA1 = I_LCA(ch2, head2)
64
+ self.I_LCA2 = I_LCA(ch3, head3)
65
+ self.I_LCA3 = I_LCA(ch4, head4)
66
+ self.I_LCA4 = I_LCA(ch4, head4)
67
+ self.I_LCA5 = I_LCA(ch3, head3)
68
+ self.I_LCA6 = I_LCA(ch2, head2)
69
+
70
+ self.trans = RGB_HVI().cuda()
71
+
72
+ def forward(self, x):
73
+ dtypes = x.dtype
74
+ hvi = self.trans.HVIT(x)
75
+ i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
76
+ # low
77
+ i_enc0 = self.IE_block0(i)
78
+ i_enc1 = self.IE_block1(i_enc0)
79
+ hv_0 = self.HVE_block0(hvi)
80
+ hv_1 = self.HVE_block1(hv_0)
81
+ i_jump0 = i_enc0
82
+ hv_jump0 = hv_0
83
+
84
+ i_enc2 = self.I_LCA1(i_enc1, hv_1)
85
+ hv_2 = self.HV_LCA1(hv_1, i_enc1)
86
+ v_jump1 = i_enc2
87
+ hv_jump1 = hv_2
88
+ i_enc2 = self.IE_block2(i_enc2)
89
+ hv_2 = self.HVE_block2(hv_2)
90
+
91
+ i_enc3 = self.I_LCA2(i_enc2, hv_2)
92
+ hv_3 = self.HV_LCA2(hv_2, i_enc2)
93
+ v_jump2 = i_enc3
94
+ hv_jump2 = hv_3
95
+ i_enc3 = self.IE_block3(i_enc2)
96
+ hv_3 = self.HVE_block3(hv_2)
97
+
98
+ i_enc4 = self.I_LCA3(i_enc3, hv_3)
99
+ hv_4 = self.HV_LCA3(hv_3, i_enc3)
100
+
101
+ i_dec4 = self.I_LCA4(i_enc4,hv_4)
102
+ hv_4 = self.HV_LCA4(hv_4, i_enc4)
103
+
104
+ hv_3 = self.HVD_block3(hv_4, hv_jump2)
105
+ i_dec3 = self.ID_block3(i_dec4, v_jump2)
106
+ i_dec2 = self.I_LCA5(i_dec3, hv_3)
107
+ hv_2 = self.HV_LCA5(hv_3, i_dec3)
108
+
109
+ hv_2 = self.HVD_block2(hv_2, hv_jump1)
110
+ i_dec2 = self.ID_block2(i_dec3, v_jump1)
111
+
112
+ i_dec1 = self.I_LCA6(i_dec2, hv_2)
113
+ hv_1 = self.HV_LCA6(hv_2, i_dec2)
114
+
115
+ i_dec1 = self.ID_block1(i_dec1, i_jump0)
116
+ i_dec0 = self.ID_block0(i_dec1)
117
+ hv_1 = self.HVD_block1(hv_1, hv_jump0)
118
+ hv_0 = self.HVD_block0(hv_1)
119
+
120
+ output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
121
+ output_rgb = self.trans.PHVIT(output_hvi)
122
+
123
+ return output_rgb
124
+
125
+ def HVIT(self,x):
126
+ hvi = self.trans.HVIT(x)
127
+ return hvi
128
+
129
+
130
+
net/HVI_transform.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ pi = 3.141592653589793
5
+
6
+ class RGB_HVI(nn.Module):
7
+ def __init__(self):
8
+ super(RGB_HVI, self).__init__()
9
+ self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
10
+ self.gated = False
11
+ self.gated2= False
12
+ self.alpha = 1.0
13
+ self.alpha_s = 1.3
14
+ self.this_k = 0
15
+
16
+ def HVIT(self, img):
17
+ eps = 1e-8
18
+ device = img.device
19
+ dtypes = img.dtype
20
+ hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
21
+ value = img.max(1)[0].to(dtypes)
22
+ img_min = img.min(1)[0].to(dtypes)
23
+ hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
24
+ hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
25
+ hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
26
+
27
+ hue[img.min(1)[0]==value] = 0.0
28
+ hue = hue/6.0
29
+
30
+ saturation = (value - img_min ) / (value + eps )
31
+ saturation[value==0] = 0
32
+
33
+ hue = hue.unsqueeze(1)
34
+ saturation = saturation.unsqueeze(1)
35
+ value = value.unsqueeze(1)
36
+
37
+ k = self.density_k
38
+ self.this_k = k.item()
39
+
40
+ color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
41
+ ch = (2.0 * pi * hue).cos()
42
+ cv = (2.0 * pi * hue).sin()
43
+ H = color_sensitive * saturation * ch
44
+ V = color_sensitive * saturation * cv
45
+ I = value
46
+ xyz = torch.cat([H, V, I],dim=1)
47
+ return xyz
48
+
49
+ def PHVIT(self, img):
50
+ eps = 1e-8
51
+ H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
52
+
53
+ # clip
54
+ H = torch.clamp(H,-1,1)
55
+ V = torch.clamp(V,-1,1)
56
+ I = torch.clamp(I,0,1)
57
+
58
+ v = I
59
+ k = self.this_k
60
+ color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
61
+ H = (H) / (color_sensitive + eps)
62
+ V = (V) / (color_sensitive + eps)
63
+ H = torch.clamp(H,-1,1)
64
+ V = torch.clamp(V,-1,1)
65
+ h = torch.atan2(V + eps,H + eps) / (2*pi)
66
+ h = h%1
67
+ s = torch.sqrt(H**2 + V**2 + eps)
68
+
69
+ if self.gated:
70
+ s = s * self.alpha_s
71
+
72
+ s = torch.clamp(s,0,1)
73
+ v = torch.clamp(v,0,1)
74
+
75
+ r = torch.zeros_like(h)
76
+ g = torch.zeros_like(h)
77
+ b = torch.zeros_like(h)
78
+
79
+ hi = torch.floor(h * 6.0)
80
+ f = h * 6.0 - hi
81
+ p = v * (1. - s)
82
+ q = v * (1. - (f * s))
83
+ t = v * (1. - ((1. - f) * s))
84
+
85
+ hi0 = hi==0
86
+ hi1 = hi==1
87
+ hi2 = hi==2
88
+ hi3 = hi==3
89
+ hi4 = hi==4
90
+ hi5 = hi==5
91
+
92
+ r[hi0] = v[hi0]
93
+ g[hi0] = t[hi0]
94
+ b[hi0] = p[hi0]
95
+
96
+ r[hi1] = q[hi1]
97
+ g[hi1] = v[hi1]
98
+ b[hi1] = p[hi1]
99
+
100
+ r[hi2] = p[hi2]
101
+ g[hi2] = v[hi2]
102
+ b[hi2] = t[hi2]
103
+
104
+ r[hi3] = p[hi3]
105
+ g[hi3] = q[hi3]
106
+ b[hi3] = v[hi3]
107
+
108
+ r[hi4] = t[hi4]
109
+ g[hi4] = p[hi4]
110
+ b[hi4] = v[hi4]
111
+
112
+ r[hi5] = v[hi5]
113
+ g[hi5] = p[hi5]
114
+ b[hi5] = q[hi5]
115
+
116
+ r = r.unsqueeze(1)
117
+ g = g.unsqueeze(1)
118
+ b = b.unsqueeze(1)
119
+ rgb = torch.cat([r, g, b], dim=1)
120
+ if self.gated2:
121
+ rgb = rgb * self.alpha
122
+ return rgb
net/LCA.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from net.transformer_utils import *
5
+
6
+ # Cross Attention Block
7
+ class CAB(nn.Module):
8
+ def __init__(self, dim, num_heads, bias):
9
+ super(CAB, self).__init__()
10
+ self.num_heads = num_heads
11
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
12
+
13
+ self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
14
+ self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
15
+ self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
16
+ self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
17
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
18
+
19
+ def forward(self, x, y):
20
+ b, c, h, w = x.shape
21
+
22
+ q = self.q_dwconv(self.q(x))
23
+ kv = self.kv_dwconv(self.kv(y))
24
+ k, v = kv.chunk(2, dim=1)
25
+
26
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
27
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
28
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
29
+
30
+ q = torch.nn.functional.normalize(q, dim=-1)
31
+ k = torch.nn.functional.normalize(k, dim=-1)
32
+
33
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
34
+ attn = nn.functional.softmax(attn,dim=-1)
35
+
36
+ out = (attn @ v)
37
+
38
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
39
+
40
+ out = self.project_out(out)
41
+ return out
42
+
43
+
44
+ # Intensity Enhancement Layer
45
+ class IEL(nn.Module):
46
+ def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
47
+ super(IEL, self).__init__()
48
+
49
+ hidden_features = int(dim*ffn_expansion_factor)
50
+
51
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
52
+
53
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
54
+ self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
55
+ self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
56
+
57
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
58
+
59
+ self.Tanh = nn.Tanh()
60
+ def forward(self, x):
61
+ x = self.project_in(x)
62
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
63
+ x1 = self.Tanh(self.dwconv1(x1)) + x1
64
+ x2 = self.Tanh(self.dwconv2(x2)) + x2
65
+ x = x1 * x2
66
+ x = self.project_out(x)
67
+ return x
68
+
69
+
70
+ # Lightweight Cross Attention
71
+ class HV_LCA(nn.Module):
72
+ def __init__(self, dim,num_heads, bias=False):
73
+ super(HV_LCA, self).__init__()
74
+ self.gdfn = IEL(dim) # IEL and CDL have same structure
75
+ self.norm = LayerNorm(dim)
76
+ self.ffn = CAB(dim, num_heads, bias)
77
+
78
+ def forward(self, x, y):
79
+ x = x + self.ffn(self.norm(x),self.norm(y))
80
+ x = self.gdfn(self.norm(x))
81
+ return x
82
+
83
+ class I_LCA(nn.Module):
84
+ def __init__(self, dim,num_heads, bias=False):
85
+ super(I_LCA, self).__init__()
86
+ self.norm = LayerNorm(dim)
87
+ self.gdfn = IEL(dim)
88
+ self.ffn = CAB(dim, num_heads, bias=bias)
89
+
90
+ def forward(self, x, y):
91
+ x = x + self.ffn(self.norm(x),self.norm(y))
92
+ x = x + self.gdfn(self.norm(x))
93
+ return x
net/transformer_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ class LayerNorm(nn.Module):
7
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
8
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
9
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
10
+ with shape (batch_size, channels, height, width).
11
+ """
12
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
13
+ super().__init__()
14
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
15
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
16
+ self.eps = eps
17
+ self.data_format = data_format
18
+ if self.data_format not in ["channels_last", "channels_first"]:
19
+ raise NotImplementedError
20
+ self.normalized_shape = (normalized_shape, )
21
+
22
+ def forward(self, x):
23
+ if self.data_format == "channels_last":
24
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
25
+ elif self.data_format == "channels_first":
26
+ u = x.mean(1, keepdim=True)
27
+ s = (x - u).pow(2).mean(1, keepdim=True)
28
+ x = (x - u) / torch.sqrt(s + self.eps)
29
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
30
+ return x
31
+
32
+ class NormDownsample(nn.Module):
33
+ def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
34
+ super(NormDownsample, self).__init__()
35
+ self.use_norm=use_norm
36
+ if self.use_norm:
37
+ self.norm=LayerNorm(out_ch)
38
+ self.prelu = nn.PReLU()
39
+ self.down = nn.Sequential(
40
+ nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
41
+ nn.UpsamplingBilinear2d(scale_factor=scale))
42
+ def forward(self, x):
43
+ x = self.down(x)
44
+ x = self.prelu(x)
45
+ if self.use_norm:
46
+ x = self.norm(x)
47
+ return x
48
+ else:
49
+ return x
50
+
51
+ class NormUpsample(nn.Module):
52
+ def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
53
+ super(NormUpsample, self).__init__()
54
+ self.use_norm=use_norm
55
+ if self.use_norm:
56
+ self.norm=LayerNorm(out_ch)
57
+ self.prelu = nn.PReLU()
58
+ self.up_scale = nn.Sequential(
59
+ nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
60
+ nn.UpsamplingBilinear2d(scale_factor=scale))
61
+ self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
62
+
63
+ def forward(self, x,y):
64
+ x = self.up_scale(x)
65
+ x = torch.cat([x, y],dim=1)
66
+ x = self.up(x)
67
+ x = self.prelu(x)
68
+ if self.use_norm:
69
+ return self.norm(x)
70
+ else:
71
+ return x
72
+
weights/LOL-Blur.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eba53573097b4138be26ffc219826cf87eaeafbc0a6ae90cf4b35330070b5494
3
+ size 7971076
weights/LOLv1/BestLPIPS_0.0868.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:533bec252359efc8f05720ce5a73c3a8c891384e62ba62b47b457b4a4c87247e
3
+ size 7971706
weights/LOLv1/BestSSIM_0.8631.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:267ed28a7f65af87dd7bd6f486d88cfa25e654234badb974df7bc85dbf267ee3
3
+ size 7971706
weights/LOLv1/PSNR_24.74.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb268bdf320d7236ba2e70c2b95e8c116d9766a3f3dbf4ccbbab1387228dfbf7
3
+ size 7971706
weights/LOLv1/w_perc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9206385b514014b60bf9396151a011cabf68abb380c2ad87ffde3d53e8227926
3
+ size 7968002
weights/LOLv1/wo_perc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5afb426fee910dd9806b4e6214f05914fb5988b7982a4624560815aefce2b5
3
+ size 7970627
weights/LOLv2_real/best_PSNR.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44a7f92cf02a1da322c1a34b3aa940341dc93fd53840b7f19b49f461200b00a6
3
+ size 7971269
weights/LOLv2_real/best_SSIM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d7ad3e93d65effaeb937cf833b1f2b3fb78c7529a4db95c589ff0ca569347c3
3
+ size 7971269
weights/LOLv2_real/w_perc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e0affc8ce89cceeb283dfa3ce7108c767cefc6f59826da37ec4c39482839db8
3
+ size 7967682
weights/LOLv2_syn/w_perc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:460a9956018935cab61ddef3cea14f1dbbc02a65e12ee4d5b5ab45d704d2575b
3
+ size 7967682
weights/LOLv2_syn/wo_perc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:519cc5fc55491cf36b47b8295c0204dde29ba887feeed06a77e716072469e55e
3
+ size 7974353
weights/SICE.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b0824026124427fe0419465fc06a6be0129356c949794d1d1f52c4173091607
3
+ size 7964800
weights/SID.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49678f69495610a2691bb5b2665d2766b57a6a79201a236f7afa84f67c5fec5f
3
+ size 7964543
weights/generalization.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6395370762ec1fa0f6fe92a38c62382b1ebe21edc012f597023bfcec8b95cd27
3
+ size 7971462