jhj0517 commited on
Commit
6140e4d
·
1 Parent(s): b9a973f

Add RealESRGANer wrapper

Browse files
modules/image_restoration/real_esrgan/wrapper/real_esrganer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import queue
5
+ import threading
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class RealESRGANer():
11
+ """A helper class for upsampling images with RealESRGAN.
12
+
13
+ Args:
14
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
15
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
16
+ model (nn.Module): The defined network. Default: None.
17
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
18
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
19
+ 0 denotes for do not use tile. Default: 0.
20
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
21
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
22
+ half (float): Whether to use half precision during inference. Default: False.
23
+ """
24
+
25
+ def __init__(self,
26
+ scale,
27
+ model_path,
28
+ dni_weight=None,
29
+ model=None,
30
+ tile=0,
31
+ tile_pad=10,
32
+ pre_pad=10,
33
+ half=False,
34
+ device=None,
35
+ gpu_id=None):
36
+ self.scale = scale
37
+ self.tile_size = tile
38
+ self.tile_pad = tile_pad
39
+ self.pre_pad = pre_pad
40
+ self.mod_scale = None
41
+ self.half = half
42
+
43
+ # initialize model
44
+ if gpu_id:
45
+ self.device = torch.device(
46
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
47
+ else:
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
49
+
50
+ assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
51
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
52
+
53
+ # prefer to use params_ema
54
+ if 'params_ema' in loadnet:
55
+ keyname = 'params_ema'
56
+ else:
57
+ keyname = 'params'
58
+ model.load_state_dict(loadnet[keyname], strict=True)
59
+
60
+ model.eval()
61
+ self.model = model.to(self.device)
62
+ if self.half:
63
+ self.model = self.model.half()
64
+
65
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
66
+ """Deep network interpolation.
67
+
68
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
69
+ """
70
+ net_a = torch.load(net_a, map_location=torch.device(loc))
71
+ net_b = torch.load(net_b, map_location=torch.device(loc))
72
+ for k, v_a in net_a[key].items():
73
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
74
+ return net_a
75
+
76
+ def pre_process(self, img):
77
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
78
+ """
79
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
80
+ self.img = img.unsqueeze(0).to(self.device)
81
+ if self.half:
82
+ self.img = self.img.half()
83
+
84
+ # pre_pad
85
+ if self.pre_pad != 0:
86
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
87
+ # mod pad for divisible borders
88
+ if self.scale == 2:
89
+ self.mod_scale = 2
90
+ elif self.scale == 1:
91
+ self.mod_scale = 4
92
+ if self.mod_scale is not None:
93
+ self.mod_pad_h, self.mod_pad_w = 0, 0
94
+ _, _, h, w = self.img.size()
95
+ if (h % self.mod_scale != 0):
96
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
97
+ if (w % self.mod_scale != 0):
98
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
99
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
100
+
101
+ def process(self):
102
+ # model inference
103
+ self.output = self.model(self.img)
104
+
105
+ def tile_process(self):
106
+ """It will first crop input images to tiles, and then process each tile.
107
+ Finally, all the processed tiles are merged into one images.
108
+
109
+ Modified from: https://github.com/ata4/esrgan-launcher
110
+ """
111
+ batch, channel, height, width = self.img.shape
112
+ output_height = height * self.scale
113
+ output_width = width * self.scale
114
+ output_shape = (batch, channel, output_height, output_width)
115
+
116
+ # start with black image
117
+ self.output = self.img.new_zeros(output_shape)
118
+ tiles_x = math.ceil(width / self.tile_size)
119
+ tiles_y = math.ceil(height / self.tile_size)
120
+
121
+ # loop over all tiles
122
+ for y in range(tiles_y):
123
+ for x in range(tiles_x):
124
+ # extract tile from input image
125
+ ofs_x = x * self.tile_size
126
+ ofs_y = y * self.tile_size
127
+ # input tile area on total image
128
+ input_start_x = ofs_x
129
+ input_end_x = min(ofs_x + self.tile_size, width)
130
+ input_start_y = ofs_y
131
+ input_end_y = min(ofs_y + self.tile_size, height)
132
+
133
+ # input tile area on total image with padding
134
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
135
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
136
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
137
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
138
+
139
+ # input tile dimensions
140
+ input_tile_width = input_end_x - input_start_x
141
+ input_tile_height = input_end_y - input_start_y
142
+ tile_idx = y * tiles_x + x + 1
143
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
144
+
145
+ # upscale tile
146
+ try:
147
+ with torch.no_grad():
148
+ output_tile = self.model(input_tile)
149
+ except RuntimeError as error:
150
+ print('Error', error)
151
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
152
+
153
+ # output tile area on total image
154
+ output_start_x = input_start_x * self.scale
155
+ output_end_x = input_end_x * self.scale
156
+ output_start_y = input_start_y * self.scale
157
+ output_end_y = input_end_y * self.scale
158
+
159
+ # output tile area without padding
160
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
161
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
162
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
163
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
164
+
165
+ # put tile into output image
166
+ self.output[:, :, output_start_y:output_end_y,
167
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
168
+ output_start_x_tile:output_end_x_tile]
169
+
170
+ def post_process(self):
171
+ # remove extra pad
172
+ if self.mod_scale is not None:
173
+ _, _, h, w = self.output.size()
174
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
175
+ # remove prepad
176
+ if self.pre_pad != 0:
177
+ _, _, h, w = self.output.size()
178
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
179
+ return self.output
180
+
181
+ @torch.no_grad()
182
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
183
+ h_input, w_input = img.shape[0:2]
184
+ # img: numpy
185
+ img = img.astype(np.float32)
186
+ if np.max(img) > 256: # 16-bit image
187
+ max_range = 65535
188
+ print('\tInput is a 16-bit image')
189
+ else:
190
+ max_range = 255
191
+ img = img / max_range
192
+ if len(img.shape) == 2: # gray image
193
+ img_mode = 'L'
194
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
195
+ elif img.shape[2] == 4: # RGBA image with alpha channel
196
+ img_mode = 'RGBA'
197
+ alpha = img[:, :, 3]
198
+ img = img[:, :, 0:3]
199
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
200
+ if alpha_upsampler == 'realesrgan':
201
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
202
+ else:
203
+ img_mode = 'RGB'
204
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
205
+
206
+ # ------------------- process image (without the alpha channel) ------------------- #
207
+ self.pre_process(img)
208
+ if self.tile_size > 0:
209
+ self.tile_process()
210
+ else:
211
+ self.process()
212
+ output_img = self.post_process()
213
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
214
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
215
+ if img_mode == 'L':
216
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
217
+
218
+ # ------------------- process the alpha channel if necessary ------------------- #
219
+ if img_mode == 'RGBA':
220
+ if alpha_upsampler == 'realesrgan':
221
+ self.pre_process(alpha)
222
+ if self.tile_size > 0:
223
+ self.tile_process()
224
+ else:
225
+ self.process()
226
+ output_alpha = self.post_process()
227
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
228
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
229
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
230
+ else: # use the cv2 resize for alpha channel
231
+ h, w = alpha.shape[0:2]
232
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
233
+
234
+ # merge the alpha channel
235
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
236
+ output_img[:, :, 3] = output_alpha
237
+
238
+ # ------------------------------ return ------------------------------ #
239
+ if max_range == 65535: # 16-bit image
240
+ output = (output_img * 65535.0).round().astype(np.uint16)
241
+ else:
242
+ output = (output_img * 255.0).round().astype(np.uint8)
243
+
244
+ if outscale is not None and outscale != float(self.scale):
245
+ output = cv2.resize(
246
+ output, (
247
+ int(w_input * outscale),
248
+ int(h_input * outscale),
249
+ ), interpolation=cv2.INTER_LANCZOS4)
250
+
251
+ return output, img_mode
252
+
253
+
254
+ class PrefetchReader(threading.Thread):
255
+ """Prefetch images.
256
+
257
+ Args:
258
+ img_list (list[str]): A image list of image paths to be read.
259
+ num_prefetch_queue (int): Number of prefetch queue.
260
+ """
261
+
262
+ def __init__(self, img_list, num_prefetch_queue):
263
+ super().__init__()
264
+ self.que = queue.Queue(num_prefetch_queue)
265
+ self.img_list = img_list
266
+
267
+ def run(self):
268
+ for img_path in self.img_list:
269
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
270
+ self.que.put(img)
271
+
272
+ self.que.put(None)
273
+
274
+ def __next__(self):
275
+ next_item = self.que.get()
276
+ if next_item is None:
277
+ raise StopIteration
278
+ return next_item
279
+
280
+ def __iter__(self):
281
+ return self
282
+
283
+
284
+ class IOConsumer(threading.Thread):
285
+
286
+ def __init__(self, opt, que, qid):
287
+ super().__init__()
288
+ self._queue = que
289
+ self.qid = qid
290
+ self.opt = opt
291
+
292
+ def run(self):
293
+ while True:
294
+ msg = self._queue.get()
295
+ if isinstance(msg, str) and msg == 'quit':
296
+ break
297
+
298
+ output = msg['output']
299
+ save_path = msg['save_path']
300
+ cv2.imwrite(save_path, output)
301
+ print(f'IO worker {self.qid} is done.')