Update the calling method of the HAT Module.
Browse files
app.py
CHANGED
@@ -324,7 +324,7 @@ class Upscale:
|
|
324 |
self.img_name = os.path.basename(str(img))
|
325 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
326 |
|
327 |
-
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
|
328 |
|
329 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
330 |
if len(img.shape) == 2: # for gray inputs
|
@@ -381,51 +381,7 @@ class Upscale:
|
|
381 |
# print(f"{param}: {value}")
|
382 |
elif upscale_type == "HAT":
|
383 |
half = False
|
384 |
-
import torch.nn.functional as F
|
385 |
from basicsr.archs.hat_arch import HAT
|
386 |
-
class HATWithAutoPadding(HAT):
|
387 |
-
def pad_to_multiple(self, img, multiple):
|
388 |
-
"""
|
389 |
-
Fill the image to multiples of both width and height as integers.
|
390 |
-
"""
|
391 |
-
_, _, h, w = img.shape
|
392 |
-
pad_h = (multiple - h % multiple) % multiple
|
393 |
-
pad_w = (multiple - w % multiple) % multiple
|
394 |
-
|
395 |
-
# Padding on the top, bottom, left, and right.
|
396 |
-
pad_top = pad_h // 2
|
397 |
-
pad_bottom = pad_h - pad_top
|
398 |
-
pad_left = pad_w // 2
|
399 |
-
pad_right = pad_w - pad_left
|
400 |
-
|
401 |
-
img_padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="reflect")
|
402 |
-
return img_padded, (pad_top, pad_bottom, pad_left, pad_right)
|
403 |
-
|
404 |
-
def remove_padding(self, img, pad_info):
|
405 |
-
"""
|
406 |
-
Remove padding and restore to the original size, considering upscaling.
|
407 |
-
"""
|
408 |
-
pad_top, pad_bottom, pad_left, pad_right = pad_info
|
409 |
-
|
410 |
-
# Adjust padding based on upscaling factor
|
411 |
-
pad_top = int(pad_top * self.upscale)
|
412 |
-
pad_bottom = int(pad_bottom * self.upscale)
|
413 |
-
pad_left = int(pad_left * self.upscale)
|
414 |
-
pad_right = int(pad_right * self.upscale)
|
415 |
-
|
416 |
-
return img[:, :, pad_top:-pad_bottom if pad_bottom > 0 else None, pad_left:-pad_right if pad_right > 0 else None]
|
417 |
-
|
418 |
-
def forward(self, x):
|
419 |
-
# Step 1: Auto padding
|
420 |
-
x_padded, pad_info = self.pad_to_multiple(x, self.window_size)
|
421 |
-
|
422 |
-
# Step 2: Normal model processing
|
423 |
-
x_processed = super().forward(x_padded)
|
424 |
-
|
425 |
-
# Step 3: Remove padding
|
426 |
-
x_cropped = self.remove_padding(x_processed, pad_info)
|
427 |
-
return x_cropped
|
428 |
-
|
429 |
# The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
|
430 |
# https://github.com/XPixelGroup/HAT/tree/main/options/test
|
431 |
if "hat-l" in upscale_model.lower():
|
@@ -446,7 +402,7 @@ class Upscale:
|
|
446 |
num_heads = [6, 6, 6, 6, 6, 6]
|
447 |
mlp_ratio = 2
|
448 |
upsampler = "pixelshuffle"
|
449 |
-
model =
|
450 |
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
|
451 |
elif "RealPLKSR" in upscale_type:
|
452 |
from basicsr.archs.realplksr_arch import realplksr
|
@@ -493,18 +449,19 @@ class Upscale:
|
|
493 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
494 |
return new_image
|
495 |
|
496 |
-
def enhance(
|
497 |
# img: numpy
|
498 |
h_input, w_input = img.shape[0:2]
|
499 |
pil_img = self.cv2pil(img)
|
500 |
-
pil_img =
|
501 |
cv_image = self.pil2cv(pil_img)
|
502 |
if outscale is not None and outscale != float(self.netscale):
|
|
|
503 |
cv_image = cv2.resize(
|
504 |
cv_image, (
|
505 |
int(w_input * outscale),
|
506 |
int(h_input * outscale),
|
507 |
-
), interpolation=
|
508 |
return cv_image, None
|
509 |
|
510 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
324 |
self.img_name = os.path.basename(str(img))
|
325 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
326 |
|
327 |
+
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
|
328 |
|
329 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
330 |
if len(img.shape) == 2: # for gray inputs
|
|
|
381 |
# print(f"{param}: {value}")
|
382 |
elif upscale_type == "HAT":
|
383 |
half = False
|
|
|
384 |
from basicsr.archs.hat_arch import HAT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
# The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
|
386 |
# https://github.com/XPixelGroup/HAT/tree/main/options/test
|
387 |
if "hat-l" in upscale_model.lower():
|
|
|
402 |
num_heads = [6, 6, 6, 6, 6, 6]
|
403 |
mlp_ratio = 2
|
404 |
upsampler = "pixelshuffle"
|
405 |
+
model = HAT(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
406 |
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
|
407 |
elif "RealPLKSR" in upscale_type:
|
408 |
from basicsr.archs.realplksr_arch import realplksr
|
|
|
449 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
450 |
return new_image
|
451 |
|
452 |
+
def enhance(self_, img, outscale=None):
|
453 |
# img: numpy
|
454 |
h_input, w_input = img.shape[0:2]
|
455 |
pil_img = self.cv2pil(img)
|
456 |
+
pil_img = self_.__call__(pil_img)
|
457 |
cv_image = self.pil2cv(pil_img)
|
458 |
if outscale is not None and outscale != float(self.netscale):
|
459 |
+
interpolation = cv2.INTER_AREA if outscale < float(self.netscale) else cv2.INTER_LANCZOS4
|
460 |
cv_image = cv2.resize(
|
461 |
cv_image, (
|
462 |
int(w_input * outscale),
|
463 |
int(h_input * outscale),
|
464 |
+
), interpolation=interpolation)
|
465 |
return cv_image, None
|
466 |
|
467 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|