avans06 commited on
Commit
086ab89
·
1 Parent(s): c59704a

Improve RRDB and ESRGAN model loading methods.

Browse files

Set torch.backends.cudnn options to avoid get black image for RTX16xx card

Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +86 -45
  3. requirements.txt +1 -2
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.14.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
  import gc
 
3
  import cv2
4
  import numpy as np
5
  import gradio as gr
6
  import torch
7
  import traceback
 
8
  from facexlib.utils.misc import download_from_url
9
- from realesrgan.utils import RealESRGANer
10
 
11
 
12
  # Define URLs and their corresponding local storage paths
@@ -111,7 +113,6 @@ I am releasing the Series 3 from my 4xLSDIRCompact models. In general my suggest
111
  "https://github.com/Phhofm/models/releases/tag/1xExposureCorrection_compact",
112
  """This model is meant as an experiment to see if compact can be used to train on overexposed images to exposure correct those using the pixel, perceptual, color, color and ldl losses. There is no brightness loss. Still it seems to kinda work."""],
113
 
114
-
115
  # RRDBNet
116
  "RealESRGAN_x4plus_anime_6B.pth": ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
117
  "https://github.com/xinntao/Real-ESRGAN/releases/tag/v0.2.2.4",
@@ -140,6 +141,31 @@ I am releasing the Series 3 from my 4xLSDIRCompact models. In general my suggest
140
  Model for color images including manga covers and color illustrations, digital art, visual novel art, artbooks, and more.
141
  DAT2 version is the highest quality version but also the slowest. See the ESRGAN version for faster performance."""],
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # DATNet
144
  "4xNomos8kDAT.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth",
145
  "https://openmodeldb.info/models/4x-Nomos8kDAT",
@@ -264,11 +290,13 @@ example_list = ["images/a01.jpg", "images/a02.jpg", "images/a03.jpg", "images/a0
264
  def get_model_type(model_name):
265
  # Define model type mappings based on key parts of the model names
266
  model_type = "other"
267
- if any(value in model_name.lower() for value in ("realesrgan", "realesrnet")):
 
 
268
  model_type = "RRDB"
269
  elif any(value in model_name.lower() for value in ("realesr", "exposurecorrection", "parimgcompact", "lsdircompact")):
270
  model_type = "SRVGG"
271
- elif "esrgan" in model_name.lower() or "4x-AnimeSharp.pth" == model_name:
272
  model_type = "ESRGAN"
273
  elif "dat" in model_name.lower():
274
  model_type = "DAT"
@@ -296,13 +324,13 @@ class Upscale:
296
  self.img_name = os.path.basename(str(img))
297
  self.basename, self.extension = os.path.splitext(self.img_name)
298
 
299
- img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # cv2.imread(img, cv2.IMREAD_UNCHANGED)
300
 
301
  self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
302
  if len(img.shape) == 2: # for gray inputs
303
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
304
 
305
- h, w = img.shape[0:2]
306
 
307
  if face_restoration:
308
  download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
@@ -314,48 +342,45 @@ class Upscale:
314
  download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
315
  modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
316
 
317
- netscale = 4
318
  loadnet = None
319
  model = None
320
  is_auto_split_upscale = True
321
  half = True if torch.cuda.is_available() else False
322
  if upscale_type:
323
  from basicsr.archs.rrdbnet_arch import RRDBNet
324
- from basicsr.archs.realplksr_arch import realplksr
325
  # background enhancer with upscale model
326
- if upscale_type == "RRDB":
327
- netscale = 2 if "x2" in upscale_model else 4
328
- num_block = 6 if "6B" in upscale_model else 23
329
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=netscale)
330
- elif upscale_type == "SRVGG":
331
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
332
- netscale = 1 if "1x" in upscale_model else (2 if "2x" in upscale_model else 4)
333
- num_conv = 16 if any(value in upscale_model for value in ("animevideov3", "ExposureCorrection", "ParimgCompact", "LSDIRCompact")) else 32
334
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=netscale, act_type='prelu')
335
- elif upscale_type == "ESRGAN":
336
- netscale = 4
337
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
338
- loadnet = {}
339
  loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
340
- for key, value in loadnet_origin.items():
341
- new_key = key.replace("model.0", "conv_first").replace("model.1.sub.23.", "conv_body.").replace("model.1.sub", "body") \
342
- .replace(".0.weight", ".weight").replace(".0.bias", ".bias").replace(".RDB1.", ".rdb1.").replace(".RDB2.", ".rdb2.").replace(".RDB3.", ".rdb3.") \
343
- .replace("model.3.", "conv_up1.").replace("model.6.", "conv_up2.").replace("model.8.", "conv_hr.").replace("model.10.", "conv_last.")
344
- loadnet[new_key] = value
 
 
 
 
 
 
 
 
 
 
 
 
345
  elif upscale_type == "DAT":
346
  from basicsr.archs.dat_arch import DAT
347
  half = False
348
- netscale = 4
349
  expansion_factor = 2. if "dat2" in upscale_model.lower() else 4.
350
- model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=netscale)
351
  # # Speculate on the parameters.
352
  # loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
353
- # inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, netscale)
354
  # for param, value in inferred_params.items():
355
  # print(f"{param}: {value}")
356
  elif upscale_type == "HAT":
357
  half = False
358
- netscale = 4
359
  import torch.nn.functional as F
360
  from basicsr.archs.hat_arch import HAT
361
  class HATWithAutoPadding(HAT):
@@ -422,21 +447,20 @@ class Upscale:
422
  mlp_ratio = 2
423
  upsampler = "pixelshuffle"
424
  model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
425
- squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=netscale,)
426
- elif upscale_type == "RealPLKSR_dysample":
427
- netscale = 4
428
- model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=netscale, dysample=True)
429
- elif upscale_type == "RealPLKSR":
430
- half = False if "RealPLSKR" in upscale_model else half
431
- netscale = 2 if upscale_model.startswith("2x") else 4
432
- model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=netscale)
433
 
434
-
435
  self.upsampler = None
436
  if loadnet:
437
- self.upsampler = RealESRGANer(scale=netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
438
  elif model:
439
- self.upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
440
  elif upscale_model:
441
  self.upsampler = None
442
  import PIL
@@ -475,7 +499,7 @@ class Upscale:
475
  pil_img = self.cv2pil(img)
476
  pil_img = self.__call__(pil_img)
477
  cv_image = self.pil2cv(pil_img)
478
- if outscale is not None and outscale != float(netscale):
479
  cv_image = cv2.resize(
480
  cv_image, (
481
  int(w_input * outscale),
@@ -514,7 +538,7 @@ class Upscale:
514
  arch = "GPEN-2048"
515
  resolution = 2048
516
 
517
- self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, bg_upsampler=self.upsampler, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
518
 
519
  files = []
520
  if not outputWithModelName:
@@ -522,10 +546,10 @@ class Upscale:
522
 
523
  try:
524
  bg_upsample_img = None
525
- if self.upsampler and self.upsampler.enhance:
526
  from utils.dataops import auto_split_upscale
527
  bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
528
-
529
  if self.face_enhancer:
530
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
531
  # save faces
@@ -571,6 +595,19 @@ class Upscale:
571
  print("global exception", error)
572
  return None, None
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
576
  if "params" in state_dict:
@@ -659,6 +696,10 @@ class Upscale:
659
  def main():
660
  if torch.cuda.is_available():
661
  torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
 
 
 
 
662
  # Ensure the target directory exists
663
  os.makedirs('output', exist_ok=True)
664
 
 
1
  import os
2
  import gc
3
+ import re
4
  import cv2
5
  import numpy as np
6
  import gradio as gr
7
  import torch
8
  import traceback
9
+ from collections import defaultdict
10
  from facexlib.utils.misc import download_from_url
11
+ from basicsr.utils.realesrganer import RealESRGANer
12
 
13
 
14
  # Define URLs and their corresponding local storage paths
 
113
  "https://github.com/Phhofm/models/releases/tag/1xExposureCorrection_compact",
114
  """This model is meant as an experiment to see if compact can be used to train on overexposed images to exposure correct those using the pixel, perceptual, color, color and ldl losses. There is no brightness loss. Still it seems to kinda work."""],
115
 
 
116
  # RRDBNet
117
  "RealESRGAN_x4plus_anime_6B.pth": ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
118
  "https://github.com/xinntao/Real-ESRGAN/releases/tag/v0.2.2.4",
 
141
  Model for color images including manga covers and color illustrations, digital art, visual novel art, artbooks, and more.
142
  DAT2 version is the highest quality version but also the slowest. See the ESRGAN version for faster performance."""],
143
 
144
+ "2x-sudo-RealESRGAN.pth": ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/2x-sudo-RealESRGAN.pth",
145
+ "https://openmodeldb.info/models/2x-sudo-RealESRGAN",
146
+ """Pretrained: Pretrained_Model_G: RealESRGAN_x4plus_anime_6B.pth / RealESRGAN_x4plus_anime_6B.pth (sudo_RealESRGAN2x_3.332.758_G.pth)
147
+ Tried to make the best 2x model there is for drawings. I think i archived that.
148
+ And yes, it is nearly 3.8 million iterations (probably a record nobody will beat here), took me nearly half a year to train.
149
+ It can happen that in one edge is a noisy pattern in edges. You can use padding/crop for that.
150
+ I aimed for perceptual quality without zooming in like 400%. Since RealESRGAN is 4x, I downscaled these images with bicubic."""],
151
+
152
+ "2x-sudo-RealESRGAN-Dropout.pth": ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/2x-sudo-RealESRGAN-Dropout.pth",
153
+ "https://openmodeldb.info/models/2x-sudo-RealESRGAN-Dropout",
154
+ """Pretrained: Pretrained_Model_G: RealESRGAN_x4plus_anime_6B.pth / RealESRGAN_x4plus_anime_6B.pth (sudo_RealESRGAN2x_3.332.758_G.pth)
155
+ Tried to make the best 2x model there is for drawings. I think i archived that.
156
+ And yes, it is nearly 3.8 million iterations (probably a record nobody will beat here), took me nearly half a year to train.
157
+ It can happen that in one edge is a noisy pattern in edges. You can use padding/crop for that.
158
+ I aimed for perceptual quality without zooming in like 400%. Since RealESRGAN is 4x, I downscaled these images with bicubic."""],
159
+
160
+ "4xNomos2_otf_esrgan.pth": ["https://github.com/Phhofm/models/releases/download/4xNomos2_otf_esrgan/4xNomos2_otf_esrgan.pth",
161
+ "https://github.com/Phhofm/models/releases/tag/4xNomos2_otf_esrgan",
162
+ """Purpose: Restoration, 4x ESRGAN model for photography, trained using the Real-ESRGAN otf degradation pipeline."""],
163
+
164
+ "4xNomosWebPhoto_esrgan.pth": ["https://github.com/Phhofm/models/releases/download/4xNomosWebPhoto_esrgan/4xNomosWebPhoto_esrgan.pth",
165
+ "https://github.com/Phhofm/models/releases/tag/4xNomosWebPhoto_esrgan",
166
+ """Purpose: Restoration, 4x ESRGAN model for photography, trained with realistic noise, lens blur, jpg and webp re-compression.
167
+ ESRGAN version of 4xNomosWebPhoto_RealPLKSR, trained on the same dataset and in the same way."""],
168
+
169
  # DATNet
170
  "4xNomos8kDAT.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth",
171
  "https://openmodeldb.info/models/4x-Nomos8kDAT",
 
290
  def get_model_type(model_name):
291
  # Define model type mappings based on key parts of the model names
292
  model_type = "other"
293
+ if any(value in model_name.lower() for value in ("4x-animesharp.pth", "sudo-realesrgan")):
294
+ model_type = "ESRGAN"
295
+ elif any(value in model_name.lower() for value in ("realesrgan", "realesrnet")):
296
  model_type = "RRDB"
297
  elif any(value in model_name.lower() for value in ("realesr", "exposurecorrection", "parimgcompact", "lsdircompact")):
298
  model_type = "SRVGG"
299
+ elif "esrgan" in model_name.lower():
300
  model_type = "ESRGAN"
301
  elif "dat" in model_name.lower():
302
  model_type = "DAT"
 
324
  self.img_name = os.path.basename(str(img))
325
  self.basename, self.extension = os.path.splitext(self.img_name)
326
 
327
+ img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
328
 
329
  self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
330
  if len(img.shape) == 2: # for gray inputs
331
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
332
 
333
+ self.h_input, self.w_input = img.shape[0:2]
334
 
335
  if face_restoration:
336
  download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
 
342
  download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
343
  modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
344
 
345
+ self.netscale = 1 if any(sub in upscale_model for sub in ("x1", "1x")) else (2 if any(sub in upscale_model for sub in ("x2", "2x")) else 4)
346
  loadnet = None
347
  model = None
348
  is_auto_split_upscale = True
349
  half = True if torch.cuda.is_available() else False
350
  if upscale_type:
351
  from basicsr.archs.rrdbnet_arch import RRDBNet
 
352
  # background enhancer with upscale model
353
+ if any(value == upscale_type for value in ("SRVGG", "RRDB", "ESRGAN")):
 
 
 
 
 
 
 
 
 
 
 
 
354
  loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
355
+ if 'params_ema' in loadnet_origin or 'params' in loadnet_origin:
356
+ loadnet_origin = loadnet_origin['params_ema'] if 'params_ema' in loadnet_origin else loadnet_origin['params']
357
+ if upscale_type == "SRVGG":
358
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
359
+ body_max_num = self.find_max_numbers(loadnet_origin, "body")
360
+ num_feat = loadnet_origin["body.0.weight"].shape[0]
361
+ num_in_ch = loadnet_origin["body.0.weight"].shape[1]
362
+ num_conv = body_max_num // 2 - 1 #16 if any(value in upscale_model for value in ("animevideov3", "ExposureCorrection", "ParimgCompact", "LSDIRCompact")) else 32
363
+ model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=self.netscale, act_type='prelu')
364
+ elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
365
+ if upscale_type == "RRDB":
366
+ num_block = 1 + self.find_max_numbers(loadnet_origin, "body")
367
+ num_feat = loadnet_origin["conv_first.weight"].shape[0]
368
+ else:
369
+ num_block = self.find_max_numbers(loadnet_origin, "model.1.sub")
370
+ num_feat = loadnet_origin["model.0.weight"].shape[0]
371
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=self.netscale, is_real_esrgan=upscale_type == "RRDB")
372
  elif upscale_type == "DAT":
373
  from basicsr.archs.dat_arch import DAT
374
  half = False
 
375
  expansion_factor = 2. if "dat2" in upscale_model.lower() else 4.
376
+ model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=self.netscale)
377
  # # Speculate on the parameters.
378
  # loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
379
+ # inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, self.netscale)
380
  # for param, value in inferred_params.items():
381
  # print(f"{param}: {value}")
382
  elif upscale_type == "HAT":
383
  half = False
 
384
  import torch.nn.functional as F
385
  from basicsr.archs.hat_arch import HAT
386
  class HATWithAutoPadding(HAT):
 
447
  mlp_ratio = 2
448
  upsampler = "pixelshuffle"
449
  model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
450
+ squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
451
+ elif "RealPLKSR" in upscale_type:
452
+ from basicsr.archs.realplksr_arch import realplksr
453
+ if upscale_type == "RealPLKSR_dysample":
454
+ model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale, dysample=True)
455
+ elif upscale_type == "RealPLKSR":
456
+ half = False if "RealPLSKR" in upscale_model else half
457
+ model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale)
458
 
 
459
  self.upsampler = None
460
  if loadnet:
461
+ self.upsampler = RealESRGANer(scale=self.netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
462
  elif model:
463
+ self.upsampler = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
464
  elif upscale_model:
465
  self.upsampler = None
466
  import PIL
 
499
  pil_img = self.cv2pil(img)
500
  pil_img = self.__call__(pil_img)
501
  cv_image = self.pil2cv(pil_img)
502
+ if outscale is not None and outscale != float(self.netscale):
503
  cv_image = cv2.resize(
504
  cv_image, (
505
  int(w_input * outscale),
 
538
  arch = "GPEN-2048"
539
  resolution = 2048
540
 
541
+ self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
542
 
543
  files = []
544
  if not outputWithModelName:
 
546
 
547
  try:
548
  bg_upsample_img = None
549
+ if self.upsampler and hasattr(self.upsampler, "enhance"):
550
  from utils.dataops import auto_split_upscale
551
  bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
552
+
553
  if self.face_enhancer:
554
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
555
  # save faces
 
595
  print("global exception", error)
596
  return None, None
597
 
598
+ def find_max_numbers(self, state_dict, findkeys):
599
+ if isinstance(findkeys, str):
600
+ findkeys = [findkeys]
601
+ max_values = defaultdict(lambda: None)
602
+ patterns = {findkey: re.compile(rf"^{re.escape(findkey)}\.(\d+)\.") for findkey in findkeys}
603
+
604
+ for key in state_dict:
605
+ for findkey, pattern in patterns.items():
606
+ if match := pattern.match(key):
607
+ num = int(match.group(1))
608
+ max_values[findkey] = max(num, max_values[findkey] if max_values[findkey] is not None else num)
609
+
610
+ return tuple(max_values[findkey] for findkey in findkeys) if len(findkeys) > 1 else max_values[findkeys[0]]
611
 
612
  def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
613
  if "params" in state_dict:
 
696
  def main():
697
  if torch.cuda.is_available():
698
  torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
699
+ # set torch options to avoid get black image for RTX16xx card
700
+ # https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801
701
+ torch.backends.cudnn.enabled = True
702
+ torch.backends.cudnn.benchmark = True
703
  # Ensure the target directory exists
704
  os.makedirs('output', exist_ok=True)
705
 
requirements.txt CHANGED
@@ -1,11 +1,10 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu124
2
 
3
- gradio==5.14.0
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
7
  gfpgan @ git+https://github.com/avan06/GFPGAN
8
- realesrgan @ git+https://github.com/avan06/Real-ESRGAN
9
 
10
  numpy
11
  opencv-python
 
1
  --extra-index-url https://download.pytorch.org/whl/cu124
2
 
3
+ gradio==5.15.0
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
7
  gfpgan @ git+https://github.com/avan06/GFPGAN
 
8
 
9
  numpy
10
  opencv-python