avans06 commited on
Commit
308a135
·
1 Parent(s): ac599ac

Added support for parsing DRCT models from the DRCT project by author ming053l.

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -270,10 +270,10 @@ Optimized primarily for PAL resolution (NTSC might work good as well)."""],
270
  "https://openmodeldb.info/models/4x-NomosWebPhoto-RealPLKSR",
271
  """4x RealPLKSR model for photography, trained with realistic noise, lens blur, jpg and webp re-compression."""],
272
 
273
- # "4xNomos2_hq_drct-l.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos2_hq_drct-l/4xNomos2_hq_drct-l.pth",
274
- # "https://github.com/Phhofm/models/releases/tag/4xNomos2_hq_drct-l",
275
- # """An drct-l 4x upscaling model, similiar to the 4xNomos2_hq_atd, 4xNomos2_hq_dat2 and 4xNomos2_hq_mosr models, trained and for usage on non-degraded input to give good quality output.
276
- # """],
277
 
278
  # "4xNomos2_hq_atd.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos2_hq_atd/4xNomos2_hq_atd.pth",
279
  # "https://github.com/Phhofm/models/releases/tag/4xNomos2_hq_atd",
@@ -306,8 +306,8 @@ def get_model_type(model_name):
306
  model_type = "RealPLKSR_dysample"
307
  elif any(value in model_name.lower() for value in ("realplksr", "rplksr", "realplskr")):
308
  model_type = "RealPLKSR"
309
- elif "drct-l" in model_name.lower():
310
- model_type = "DRCT-L"
311
  elif "atd" in model_name.lower():
312
  model_type = "ATD"
313
  return f"{model_type}, {model_name}"
@@ -343,7 +343,6 @@ class Upscale:
343
  modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
344
 
345
  self.netscale = 1 if any(sub in upscale_model for sub in ("x1", "1x")) else (2 if any(sub in upscale_model for sub in ("x2", "2x")) else 4)
346
- loadnet = None
347
  model = None
348
  is_auto_split_upscale = True
349
  half = True if torch.cuda.is_available() else False
@@ -411,11 +410,24 @@ class Upscale:
411
  elif upscale_type == "RealPLKSR":
412
  half = False if "RealPLSKR" in upscale_model else half
413
  model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  self.upsampler = None
416
- if loadnet:
417
- self.upsampler = RealESRGANer(scale=self.netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
418
- elif model:
419
  self.upsampler = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
420
  elif upscale_model:
421
  self.upsampler = None
@@ -684,7 +696,7 @@ def main():
684
  for key, _ in typed_upscale_models.items():
685
  upscale_type, upscale_model = key.split(", ", 1)
686
  if tmptype and tmptype != upscale_type:#RRDB ESRGAN
687
- speed = "Fast" if tmptype == "SRVGG" else ("Slow" if any(value == tmptype for value in ("DAT", "HAT")) else "Normal")
688
  upscale_model_header = f"| Upscale Model | Info, Type: {tmptype}, Model execution speed: {speed} | Download URL |\n|------------|------|--------------|"
689
  upscale_model_tables.append(upscale_model_header + "\n" + "\n".join(rows))
690
  rows.clear()
 
270
  "https://openmodeldb.info/models/4x-NomosWebPhoto-RealPLKSR",
271
  """4x RealPLKSR model for photography, trained with realistic noise, lens blur, jpg and webp re-compression."""],
272
 
273
+ "4xNomos2_hq_drct-l.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos2_hq_drct-l/4xNomos2_hq_drct-l.pth",
274
+ "https://github.com/Phhofm/models/releases/tag/4xNomos2_hq_drct-l",
275
+ """An drct-l 4x upscaling model, similiar to the 4xNomos2_hq_atd, 4xNomos2_hq_dat2 and 4xNomos2_hq_mosr models, trained and for usage on non-degraded input to give good quality output.
276
+ """],
277
 
278
  # "4xNomos2_hq_atd.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos2_hq_atd/4xNomos2_hq_atd.pth",
279
  # "https://github.com/Phhofm/models/releases/tag/4xNomos2_hq_atd",
 
306
  model_type = "RealPLKSR_dysample"
307
  elif any(value in model_name.lower() for value in ("realplksr", "rplksr", "realplskr")):
308
  model_type = "RealPLKSR"
309
+ elif "drct" in model_name.lower():
310
+ model_type = "DRCT"
311
  elif "atd" in model_name.lower():
312
  model_type = "ATD"
313
  return f"{model_type}, {model_name}"
 
343
  modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
344
 
345
  self.netscale = 1 if any(sub in upscale_model for sub in ("x1", "1x")) else (2 if any(sub in upscale_model for sub in ("x2", "2x")) else 4)
 
346
  model = None
347
  is_auto_split_upscale = True
348
  half = True if torch.cuda.is_available() else False
 
410
  elif upscale_type == "RealPLKSR":
411
  half = False if "RealPLSKR" in upscale_model else half
412
  model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale)
413
+ elif upscale_type == "DRCT":
414
+ half = False
415
+ from basicsr.archs.DRCT_arch import DRCT
416
+ window_size = 16
417
+ compress_ratio = 3
418
+ squeeze_factor = 30
419
+ depths = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
420
+ embed_dim = 180
421
+ num_heads = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
422
+ mlp_ratio = 2
423
+ upsampler = "pixelshuffle"
424
+ model = DRCT(upscale=self.netscale, in_chans=3, img_size= 64, window_size=window_size, compress_ratio=compress_ratio,squeeze_factor=squeeze_factor,
425
+ conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths=depths,
426
+ embed_dim=embed_dim, num_heads=num_heads, gc= 32,
427
+ mlp_ratio=mlp_ratio, upsampler=upsampler, resi_connection= '1conv')
428
 
429
  self.upsampler = None
430
+ if model:
 
 
431
  self.upsampler = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
432
  elif upscale_model:
433
  self.upsampler = None
 
696
  for key, _ in typed_upscale_models.items():
697
  upscale_type, upscale_model = key.split(", ", 1)
698
  if tmptype and tmptype != upscale_type:#RRDB ESRGAN
699
+ speed = "Fast" if tmptype == "SRVGG" else ("Slow" if any(value == tmptype for value in ("DAT", "HAT", "DRCT")) else "Normal")
700
  upscale_model_header = f"| Upscale Model | Info, Type: {tmptype}, Model execution speed: {speed} | Download URL |\n|------------|------|--------------|"
701
  upscale_model_tables.append(upscale_model_header + "\n" + "\n".join(rows))
702
  rows.clear()