avans06 commited on
Commit
ac599ac
·
1 Parent(s): bb6a756

Update the calling method of the HAT Module.

Browse files
Files changed (1) hide show
  1. app.py +6 -49
app.py CHANGED
@@ -324,7 +324,7 @@ class Upscale:
324
  self.img_name = os.path.basename(str(img))
325
  self.basename, self.extension = os.path.splitext(self.img_name)
326
 
327
- img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
328
 
329
  self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
330
  if len(img.shape) == 2: # for gray inputs
@@ -381,51 +381,7 @@ class Upscale:
381
  # print(f"{param}: {value}")
382
  elif upscale_type == "HAT":
383
  half = False
384
- import torch.nn.functional as F
385
  from basicsr.archs.hat_arch import HAT
386
- class HATWithAutoPadding(HAT):
387
- def pad_to_multiple(self, img, multiple):
388
- """
389
- Fill the image to multiples of both width and height as integers.
390
- """
391
- _, _, h, w = img.shape
392
- pad_h = (multiple - h % multiple) % multiple
393
- pad_w = (multiple - w % multiple) % multiple
394
-
395
- # Padding on the top, bottom, left, and right.
396
- pad_top = pad_h // 2
397
- pad_bottom = pad_h - pad_top
398
- pad_left = pad_w // 2
399
- pad_right = pad_w - pad_left
400
-
401
- img_padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="reflect")
402
- return img_padded, (pad_top, pad_bottom, pad_left, pad_right)
403
-
404
- def remove_padding(self, img, pad_info):
405
- """
406
- Remove padding and restore to the original size, considering upscaling.
407
- """
408
- pad_top, pad_bottom, pad_left, pad_right = pad_info
409
-
410
- # Adjust padding based on upscaling factor
411
- pad_top = int(pad_top * self.upscale)
412
- pad_bottom = int(pad_bottom * self.upscale)
413
- pad_left = int(pad_left * self.upscale)
414
- pad_right = int(pad_right * self.upscale)
415
-
416
- return img[:, :, pad_top:-pad_bottom if pad_bottom > 0 else None, pad_left:-pad_right if pad_right > 0 else None]
417
-
418
- def forward(self, x):
419
- # Step 1: Auto padding
420
- x_padded, pad_info = self.pad_to_multiple(x, self.window_size)
421
-
422
- # Step 2: Normal model processing
423
- x_processed = super().forward(x_padded)
424
-
425
- # Step 3: Remove padding
426
- x_cropped = self.remove_padding(x_processed, pad_info)
427
- return x_cropped
428
-
429
  # The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
430
  # https://github.com/XPixelGroup/HAT/tree/main/options/test
431
  if "hat-l" in upscale_model.lower():
@@ -446,7 +402,7 @@ class Upscale:
446
  num_heads = [6, 6, 6, 6, 6, 6]
447
  mlp_ratio = 2
448
  upsampler = "pixelshuffle"
449
- model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
450
  squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
451
  elif "RealPLKSR" in upscale_type:
452
  from basicsr.archs.realplksr_arch import realplksr
@@ -493,18 +449,19 @@ class Upscale:
493
  new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
494
  return new_image
495
 
496
- def enhance(self, img, outscale=None):
497
  # img: numpy
498
  h_input, w_input = img.shape[0:2]
499
  pil_img = self.cv2pil(img)
500
- pil_img = self.__call__(pil_img)
501
  cv_image = self.pil2cv(pil_img)
502
  if outscale is not None and outscale != float(self.netscale):
 
503
  cv_image = cv2.resize(
504
  cv_image, (
505
  int(w_input * outscale),
506
  int(h_input * outscale),
507
- ), interpolation=cv2.INTER_LANCZOS4)
508
  return cv_image, None
509
 
510
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
324
  self.img_name = os.path.basename(str(img))
325
  self.basename, self.extension = os.path.splitext(self.img_name)
326
 
327
+ img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
328
 
329
  self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
330
  if len(img.shape) == 2: # for gray inputs
 
381
  # print(f"{param}: {value}")
382
  elif upscale_type == "HAT":
383
  half = False
 
384
  from basicsr.archs.hat_arch import HAT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  # The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
386
  # https://github.com/XPixelGroup/HAT/tree/main/options/test
387
  if "hat-l" in upscale_model.lower():
 
402
  num_heads = [6, 6, 6, 6, 6, 6]
403
  mlp_ratio = 2
404
  upsampler = "pixelshuffle"
405
+ model = HAT(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
406
  squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
407
  elif "RealPLKSR" in upscale_type:
408
  from basicsr.archs.realplksr_arch import realplksr
 
449
  new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
450
  return new_image
451
 
452
+ def enhance(self_, img, outscale=None):
453
  # img: numpy
454
  h_input, w_input = img.shape[0:2]
455
  pil_img = self.cv2pil(img)
456
+ pil_img = self_.__call__(pil_img)
457
  cv_image = self.pil2cv(pil_img)
458
  if outscale is not None and outscale != float(self.netscale):
459
+ interpolation = cv2.INTER_AREA if outscale < float(self.netscale) else cv2.INTER_LANCZOS4
460
  cv_image = cv2.resize(
461
  cv_image, (
462
  int(w_input * outscale),
463
  int(h_input * outscale),
464
+ ), interpolation=interpolation)
465
  return cv_image, None
466
 
467
  device = "cuda" if torch.cuda.is_available() else "cpu"