Improve RRDB and ESRGAN model loading methods.
Browse filesSet torch.backends.cudnn options to avoid get black image for RTX16xx card
- README.md +1 -1
- app.py +86 -45
- requirements.txt +1 -2
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
|
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.15.0
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import os
|
2 |
import gc
|
|
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
import traceback
|
|
|
8 |
from facexlib.utils.misc import download_from_url
|
9 |
-
from
|
10 |
|
11 |
|
12 |
# Define URLs and their corresponding local storage paths
|
@@ -111,7 +113,6 @@ I am releasing the Series 3 from my 4xLSDIRCompact models. In general my suggest
|
|
111 |
"https://github.com/Phhofm/models/releases/tag/1xExposureCorrection_compact",
|
112 |
"""This model is meant as an experiment to see if compact can be used to train on overexposed images to exposure correct those using the pixel, perceptual, color, color and ldl losses. There is no brightness loss. Still it seems to kinda work."""],
|
113 |
|
114 |
-
|
115 |
# RRDBNet
|
116 |
"RealESRGAN_x4plus_anime_6B.pth": ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
117 |
"https://github.com/xinntao/Real-ESRGAN/releases/tag/v0.2.2.4",
|
@@ -140,6 +141,31 @@ I am releasing the Series 3 from my 4xLSDIRCompact models. In general my suggest
|
|
140 |
Model for color images including manga covers and color illustrations, digital art, visual novel art, artbooks, and more.
|
141 |
DAT2 version is the highest quality version but also the slowest. See the ESRGAN version for faster performance."""],
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# DATNet
|
144 |
"4xNomos8kDAT.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth",
|
145 |
"https://openmodeldb.info/models/4x-Nomos8kDAT",
|
@@ -264,11 +290,13 @@ example_list = ["images/a01.jpg", "images/a02.jpg", "images/a03.jpg", "images/a0
|
|
264 |
def get_model_type(model_name):
|
265 |
# Define model type mappings based on key parts of the model names
|
266 |
model_type = "other"
|
267 |
-
if any(value in model_name.lower() for value in ("
|
|
|
|
|
268 |
model_type = "RRDB"
|
269 |
elif any(value in model_name.lower() for value in ("realesr", "exposurecorrection", "parimgcompact", "lsdircompact")):
|
270 |
model_type = "SRVGG"
|
271 |
-
elif "esrgan" in model_name.lower()
|
272 |
model_type = "ESRGAN"
|
273 |
elif "dat" in model_name.lower():
|
274 |
model_type = "DAT"
|
@@ -296,13 +324,13 @@ class Upscale:
|
|
296 |
self.img_name = os.path.basename(str(img))
|
297 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
298 |
|
299 |
-
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
|
300 |
|
301 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
302 |
if len(img.shape) == 2: # for gray inputs
|
303 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
304 |
|
305 |
-
|
306 |
|
307 |
if face_restoration:
|
308 |
download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
|
@@ -314,48 +342,45 @@ class Upscale:
|
|
314 |
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
|
315 |
modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
316 |
|
317 |
-
netscale = 4
|
318 |
loadnet = None
|
319 |
model = None
|
320 |
is_auto_split_upscale = True
|
321 |
half = True if torch.cuda.is_available() else False
|
322 |
if upscale_type:
|
323 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
324 |
-
from basicsr.archs.realplksr_arch import realplksr
|
325 |
# background enhancer with upscale model
|
326 |
-
if
|
327 |
-
netscale = 2 if "x2" in upscale_model else 4
|
328 |
-
num_block = 6 if "6B" in upscale_model else 23
|
329 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=netscale)
|
330 |
-
elif upscale_type == "SRVGG":
|
331 |
-
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
332 |
-
netscale = 1 if "1x" in upscale_model else (2 if "2x" in upscale_model else 4)
|
333 |
-
num_conv = 16 if any(value in upscale_model for value in ("animevideov3", "ExposureCorrection", "ParimgCompact", "LSDIRCompact")) else 32
|
334 |
-
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=netscale, act_type='prelu')
|
335 |
-
elif upscale_type == "ESRGAN":
|
336 |
-
netscale = 4
|
337 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
|
338 |
-
loadnet = {}
|
339 |
loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
elif upscale_type == "DAT":
|
346 |
from basicsr.archs.dat_arch import DAT
|
347 |
half = False
|
348 |
-
netscale = 4
|
349 |
expansion_factor = 2. if "dat2" in upscale_model.lower() else 4.
|
350 |
-
model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=netscale)
|
351 |
# # Speculate on the parameters.
|
352 |
# loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
353 |
-
# inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, netscale)
|
354 |
# for param, value in inferred_params.items():
|
355 |
# print(f"{param}: {value}")
|
356 |
elif upscale_type == "HAT":
|
357 |
half = False
|
358 |
-
netscale = 4
|
359 |
import torch.nn.functional as F
|
360 |
from basicsr.archs.hat_arch import HAT
|
361 |
class HATWithAutoPadding(HAT):
|
@@ -422,21 +447,20 @@ class Upscale:
|
|
422 |
mlp_ratio = 2
|
423 |
upsampler = "pixelshuffle"
|
424 |
model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
425 |
-
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=netscale,)
|
426 |
-
elif
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
|
434 |
-
|
435 |
self.upsampler = None
|
436 |
if loadnet:
|
437 |
-
self.upsampler = RealESRGANer(scale=netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
438 |
elif model:
|
439 |
-
self.upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
440 |
elif upscale_model:
|
441 |
self.upsampler = None
|
442 |
import PIL
|
@@ -475,7 +499,7 @@ class Upscale:
|
|
475 |
pil_img = self.cv2pil(img)
|
476 |
pil_img = self.__call__(pil_img)
|
477 |
cv_image = self.pil2cv(pil_img)
|
478 |
-
if outscale is not None and outscale != float(netscale):
|
479 |
cv_image = cv2.resize(
|
480 |
cv_image, (
|
481 |
int(w_input * outscale),
|
@@ -514,7 +538,7 @@ class Upscale:
|
|
514 |
arch = "GPEN-2048"
|
515 |
resolution = 2048
|
516 |
|
517 |
-
self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier,
|
518 |
|
519 |
files = []
|
520 |
if not outputWithModelName:
|
@@ -522,10 +546,10 @@ class Upscale:
|
|
522 |
|
523 |
try:
|
524 |
bg_upsample_img = None
|
525 |
-
if self.upsampler and self.upsampler
|
526 |
from utils.dataops import auto_split_upscale
|
527 |
bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
|
528 |
-
|
529 |
if self.face_enhancer:
|
530 |
cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
|
531 |
# save faces
|
@@ -571,6 +595,19 @@ class Upscale:
|
|
571 |
print("global exception", error)
|
572 |
return None, None
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
|
575 |
def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
|
576 |
if "params" in state_dict:
|
@@ -659,6 +696,10 @@ class Upscale:
|
|
659 |
def main():
|
660 |
if torch.cuda.is_available():
|
661 |
torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
|
|
|
|
|
|
|
|
|
662 |
# Ensure the target directory exists
|
663 |
os.makedirs('output', exist_ok=True)
|
664 |
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
+
import re
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
import traceback
|
9 |
+
from collections import defaultdict
|
10 |
from facexlib.utils.misc import download_from_url
|
11 |
+
from basicsr.utils.realesrganer import RealESRGANer
|
12 |
|
13 |
|
14 |
# Define URLs and their corresponding local storage paths
|
|
|
113 |
"https://github.com/Phhofm/models/releases/tag/1xExposureCorrection_compact",
|
114 |
"""This model is meant as an experiment to see if compact can be used to train on overexposed images to exposure correct those using the pixel, perceptual, color, color and ldl losses. There is no brightness loss. Still it seems to kinda work."""],
|
115 |
|
|
|
116 |
# RRDBNet
|
117 |
"RealESRGAN_x4plus_anime_6B.pth": ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
118 |
"https://github.com/xinntao/Real-ESRGAN/releases/tag/v0.2.2.4",
|
|
|
141 |
Model for color images including manga covers and color illustrations, digital art, visual novel art, artbooks, and more.
|
142 |
DAT2 version is the highest quality version but also the slowest. See the ESRGAN version for faster performance."""],
|
143 |
|
144 |
+
"2x-sudo-RealESRGAN.pth": ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/2x-sudo-RealESRGAN.pth",
|
145 |
+
"https://openmodeldb.info/models/2x-sudo-RealESRGAN",
|
146 |
+
"""Pretrained: Pretrained_Model_G: RealESRGAN_x4plus_anime_6B.pth / RealESRGAN_x4plus_anime_6B.pth (sudo_RealESRGAN2x_3.332.758_G.pth)
|
147 |
+
Tried to make the best 2x model there is for drawings. I think i archived that.
|
148 |
+
And yes, it is nearly 3.8 million iterations (probably a record nobody will beat here), took me nearly half a year to train.
|
149 |
+
It can happen that in one edge is a noisy pattern in edges. You can use padding/crop for that.
|
150 |
+
I aimed for perceptual quality without zooming in like 400%. Since RealESRGAN is 4x, I downscaled these images with bicubic."""],
|
151 |
+
|
152 |
+
"2x-sudo-RealESRGAN-Dropout.pth": ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/2x-sudo-RealESRGAN-Dropout.pth",
|
153 |
+
"https://openmodeldb.info/models/2x-sudo-RealESRGAN-Dropout",
|
154 |
+
"""Pretrained: Pretrained_Model_G: RealESRGAN_x4plus_anime_6B.pth / RealESRGAN_x4plus_anime_6B.pth (sudo_RealESRGAN2x_3.332.758_G.pth)
|
155 |
+
Tried to make the best 2x model there is for drawings. I think i archived that.
|
156 |
+
And yes, it is nearly 3.8 million iterations (probably a record nobody will beat here), took me nearly half a year to train.
|
157 |
+
It can happen that in one edge is a noisy pattern in edges. You can use padding/crop for that.
|
158 |
+
I aimed for perceptual quality without zooming in like 400%. Since RealESRGAN is 4x, I downscaled these images with bicubic."""],
|
159 |
+
|
160 |
+
"4xNomos2_otf_esrgan.pth": ["https://github.com/Phhofm/models/releases/download/4xNomos2_otf_esrgan/4xNomos2_otf_esrgan.pth",
|
161 |
+
"https://github.com/Phhofm/models/releases/tag/4xNomos2_otf_esrgan",
|
162 |
+
"""Purpose: Restoration, 4x ESRGAN model for photography, trained using the Real-ESRGAN otf degradation pipeline."""],
|
163 |
+
|
164 |
+
"4xNomosWebPhoto_esrgan.pth": ["https://github.com/Phhofm/models/releases/download/4xNomosWebPhoto_esrgan/4xNomosWebPhoto_esrgan.pth",
|
165 |
+
"https://github.com/Phhofm/models/releases/tag/4xNomosWebPhoto_esrgan",
|
166 |
+
"""Purpose: Restoration, 4x ESRGAN model for photography, trained with realistic noise, lens blur, jpg and webp re-compression.
|
167 |
+
ESRGAN version of 4xNomosWebPhoto_RealPLKSR, trained on the same dataset and in the same way."""],
|
168 |
+
|
169 |
# DATNet
|
170 |
"4xNomos8kDAT.pth" : ["https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth",
|
171 |
"https://openmodeldb.info/models/4x-Nomos8kDAT",
|
|
|
290 |
def get_model_type(model_name):
|
291 |
# Define model type mappings based on key parts of the model names
|
292 |
model_type = "other"
|
293 |
+
if any(value in model_name.lower() for value in ("4x-animesharp.pth", "sudo-realesrgan")):
|
294 |
+
model_type = "ESRGAN"
|
295 |
+
elif any(value in model_name.lower() for value in ("realesrgan", "realesrnet")):
|
296 |
model_type = "RRDB"
|
297 |
elif any(value in model_name.lower() for value in ("realesr", "exposurecorrection", "parimgcompact", "lsdircompact")):
|
298 |
model_type = "SRVGG"
|
299 |
+
elif "esrgan" in model_name.lower():
|
300 |
model_type = "ESRGAN"
|
301 |
elif "dat" in model_name.lower():
|
302 |
model_type = "DAT"
|
|
|
324 |
self.img_name = os.path.basename(str(img))
|
325 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
326 |
|
327 |
+
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
|
328 |
|
329 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
330 |
if len(img.shape) == 2: # for gray inputs
|
331 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
332 |
|
333 |
+
self.h_input, self.w_input = img.shape[0:2]
|
334 |
|
335 |
if face_restoration:
|
336 |
download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
|
|
|
342 |
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
|
343 |
modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
344 |
|
345 |
+
self.netscale = 1 if any(sub in upscale_model for sub in ("x1", "1x")) else (2 if any(sub in upscale_model for sub in ("x2", "2x")) else 4)
|
346 |
loadnet = None
|
347 |
model = None
|
348 |
is_auto_split_upscale = True
|
349 |
half = True if torch.cuda.is_available() else False
|
350 |
if upscale_type:
|
351 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
|
352 |
# background enhancer with upscale model
|
353 |
+
if any(value == upscale_type for value in ("SRVGG", "RRDB", "ESRGAN")):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
355 |
+
if 'params_ema' in loadnet_origin or 'params' in loadnet_origin:
|
356 |
+
loadnet_origin = loadnet_origin['params_ema'] if 'params_ema' in loadnet_origin else loadnet_origin['params']
|
357 |
+
if upscale_type == "SRVGG":
|
358 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
359 |
+
body_max_num = self.find_max_numbers(loadnet_origin, "body")
|
360 |
+
num_feat = loadnet_origin["body.0.weight"].shape[0]
|
361 |
+
num_in_ch = loadnet_origin["body.0.weight"].shape[1]
|
362 |
+
num_conv = body_max_num // 2 - 1 #16 if any(value in upscale_model for value in ("animevideov3", "ExposureCorrection", "ParimgCompact", "LSDIRCompact")) else 32
|
363 |
+
model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=self.netscale, act_type='prelu')
|
364 |
+
elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
|
365 |
+
if upscale_type == "RRDB":
|
366 |
+
num_block = 1 + self.find_max_numbers(loadnet_origin, "body")
|
367 |
+
num_feat = loadnet_origin["conv_first.weight"].shape[0]
|
368 |
+
else:
|
369 |
+
num_block = self.find_max_numbers(loadnet_origin, "model.1.sub")
|
370 |
+
num_feat = loadnet_origin["model.0.weight"].shape[0]
|
371 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=self.netscale, is_real_esrgan=upscale_type == "RRDB")
|
372 |
elif upscale_type == "DAT":
|
373 |
from basicsr.archs.dat_arch import DAT
|
374 |
half = False
|
|
|
375 |
expansion_factor = 2. if "dat2" in upscale_model.lower() else 4.
|
376 |
+
model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=self.netscale)
|
377 |
# # Speculate on the parameters.
|
378 |
# loadnet_origin = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
379 |
+
# inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, self.netscale)
|
380 |
# for param, value in inferred_params.items():
|
381 |
# print(f"{param}: {value}")
|
382 |
elif upscale_type == "HAT":
|
383 |
half = False
|
|
|
384 |
import torch.nn.functional as F
|
385 |
from basicsr.archs.hat_arch import HAT
|
386 |
class HATWithAutoPadding(HAT):
|
|
|
447 |
mlp_ratio = 2
|
448 |
upsampler = "pixelshuffle"
|
449 |
model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
450 |
+
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
|
451 |
+
elif "RealPLKSR" in upscale_type:
|
452 |
+
from basicsr.archs.realplksr_arch import realplksr
|
453 |
+
if upscale_type == "RealPLKSR_dysample":
|
454 |
+
model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale, dysample=True)
|
455 |
+
elif upscale_type == "RealPLKSR":
|
456 |
+
half = False if "RealPLSKR" in upscale_model else half
|
457 |
+
model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=self.netscale)
|
458 |
|
|
|
459 |
self.upsampler = None
|
460 |
if loadnet:
|
461 |
+
self.upsampler = RealESRGANer(scale=self.netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
462 |
elif model:
|
463 |
+
self.upsampler = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
464 |
elif upscale_model:
|
465 |
self.upsampler = None
|
466 |
import PIL
|
|
|
499 |
pil_img = self.cv2pil(img)
|
500 |
pil_img = self.__call__(pil_img)
|
501 |
cv_image = self.pil2cv(pil_img)
|
502 |
+
if outscale is not None and outscale != float(self.netscale):
|
503 |
cv_image = cv2.resize(
|
504 |
cv_image, (
|
505 |
int(w_input * outscale),
|
|
|
538 |
arch = "GPEN-2048"
|
539 |
resolution = 2048
|
540 |
|
541 |
+
self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
|
542 |
|
543 |
files = []
|
544 |
if not outputWithModelName:
|
|
|
546 |
|
547 |
try:
|
548 |
bg_upsample_img = None
|
549 |
+
if self.upsampler and hasattr(self.upsampler, "enhance"):
|
550 |
from utils.dataops import auto_split_upscale
|
551 |
bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
|
552 |
+
|
553 |
if self.face_enhancer:
|
554 |
cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
|
555 |
# save faces
|
|
|
595 |
print("global exception", error)
|
596 |
return None, None
|
597 |
|
598 |
+
def find_max_numbers(self, state_dict, findkeys):
|
599 |
+
if isinstance(findkeys, str):
|
600 |
+
findkeys = [findkeys]
|
601 |
+
max_values = defaultdict(lambda: None)
|
602 |
+
patterns = {findkey: re.compile(rf"^{re.escape(findkey)}\.(\d+)\.") for findkey in findkeys}
|
603 |
+
|
604 |
+
for key in state_dict:
|
605 |
+
for findkey, pattern in patterns.items():
|
606 |
+
if match := pattern.match(key):
|
607 |
+
num = int(match.group(1))
|
608 |
+
max_values[findkey] = max(num, max_values[findkey] if max_values[findkey] is not None else num)
|
609 |
+
|
610 |
+
return tuple(max_values[findkey] for findkey in findkeys) if len(findkeys) > 1 else max_values[findkeys[0]]
|
611 |
|
612 |
def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
|
613 |
if "params" in state_dict:
|
|
|
696 |
def main():
|
697 |
if torch.cuda.is_available():
|
698 |
torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
|
699 |
+
# set torch options to avoid get black image for RTX16xx card
|
700 |
+
# https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801
|
701 |
+
torch.backends.cudnn.enabled = True
|
702 |
+
torch.backends.cudnn.benchmark = True
|
703 |
# Ensure the target directory exists
|
704 |
os.makedirs('output', exist_ok=True)
|
705 |
|
requirements.txt
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
|
3 |
-
gradio==5.
|
4 |
|
5 |
basicsr @ git+https://github.com/avan06/BasicSR
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
7 |
gfpgan @ git+https://github.com/avan06/GFPGAN
|
8 |
-
realesrgan @ git+https://github.com/avan06/Real-ESRGAN
|
9 |
|
10 |
numpy
|
11 |
opencv-python
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
|
3 |
+
gradio==5.15.0
|
4 |
|
5 |
basicsr @ git+https://github.com/avan06/BasicSR
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
7 |
gfpgan @ git+https://github.com/avan06/GFPGAN
|
|
|
8 |
|
9 |
numpy
|
10 |
opencv-python
|