guetLzy commited on
Commit
54a84be
·
verified ·
1 Parent(s): 5ec8f8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -121
app.py CHANGED
@@ -1,149 +1,129 @@
1
- import gradio as gr
2
- import argparse
3
- from realesrgan import RealESRGANer
4
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
5
  import os
 
 
6
  from basicsr.archs.rrdbnet_arch import RRDBNet
7
  from basicsr.utils.download_util import load_file_from_url
8
- def Generate(img, model_name):
9
-
10
- global output
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
13
- parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
14
- parser.add_argument(
15
- '-dn',
16
- '--denoise_strength',
17
- type=float,
18
- default=0.5,
19
- help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
20
- 'Only used for the realesr-general-x4v3 model'))
21
- parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
22
- parser.add_argument(
23
- '--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
24
- parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
25
- parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
26
- parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
27
- parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
28
- parser.add_argument('--face_enhance', action='store_true',help='Use GFPGAN to enhance face')
29
- parser.add_argument(
30
- '--fp32', action='store_true',default=True,help='Use fp32 precision during inference. Default: fp16 (half precision).')
31
- parser.add_argument(
32
- '--alpha_upsampler',
33
- type=str,
34
- default='realesrgan',
35
- help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
36
- parser.add_argument(
37
- '--ext',
38
- type=str,
39
- default='auto',
40
- help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
41
- parser.add_argument(
42
- '-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
43
 
44
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
47
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
48
- netscale = 4
49
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
50
- elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
51
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
52
- netscale = 4
53
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
54
- elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
55
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
56
- netscale = 4
57
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
58
- elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
59
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
60
- netscale = 2
61
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
62
- elif model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
63
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
64
- netscale = 4
65
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
66
- elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
67
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
68
- netscale = 4
69
- file_url = [
70
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
71
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
72
- ]
73
 
74
- model_path = os.path.join('weights', model_name + '.pth')
75
- print(model_path)
76
  if not os.path.isfile(model_path):
77
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
78
- for url in file_url:
79
- # model_path will be updated
80
- model_path = load_file_from_url(
81
- url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
82
- dni_weight = None
83
- if model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
84
- wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
85
- model_path = [model_path, wdn_model_path]
86
- dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
87
 
88
- # restorer
89
  upsampler = RealESRGANer(
90
  scale=netscale,
91
  model_path=model_path,
92
- dni_weight=dni_weight,
93
  model=model,
94
- tile=args.tile,
95
- tile_pad=args.tile_pad,
96
- pre_pad=args.pre_pad,
97
- half=not args.fp32,
98
- gpu_id=args.gpu_id)
 
 
 
 
 
 
99
 
100
- if args.face_enhance: # Use GFPGAN for face enhancement
 
 
 
 
 
 
 
 
 
 
101
  from gfpgan import GFPGANer
102
  face_enhancer = GFPGANer(
103
  model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
104
- upscale=args.outscale,
105
  arch='clean',
106
  channel_multiplier=2,
107
- bg_upsampler=upsampler)
108
- os.makedirs(args.output, exist_ok=True)
109
-
110
- try:
111
- if args.face_enhance:
112
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
113
- else:
114
- output, _ = upsampler.enhance(img, outscale=args.outscale)
115
- print("生成成功")
116
-
117
- except RuntimeError as error:
118
- print('Error', error)
119
- print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
120
- output = None
121
 
 
 
122
  return output
123
 
124
- with gr.Blocks() as demo:
125
-
126
- gr.Markdown(
127
- """
128
- # <center> Real-ESRGAN 在线体验程序
129
- """)
130
- gr.Markdown("""
131
- 1. **项目模型运行在CPU上,等待时间略长**
132
- 2. **原工程项目旨在对图片就行修复**
133
- 3. **项目源地址为:[Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)**
134
- """)
135
 
136
  with gr.Row():
137
  with gr.Column():
138
- img = gr.Image(type="numpy",label = "输入图片")
139
- model_name = gr.Dropdown(["RealESRGAN_x4plus","RealESRGAN_x4plus_anime_6B","RealESRGAN_x2plus",
140
- "realesr-animevideov3","realesr-general-x4v3"],info="选择模型")
141
- with gr.Column():
142
- img_out = gr.Image(type="numpy",label = "输出图片")
 
 
 
 
143
 
144
- btn = gr.Button("Generate")
 
145
 
146
- btn.click(Generate, inputs=[img,model_name], outputs=[img_out])
 
 
 
 
 
147
 
148
  if __name__ == "__main__":
149
- demo.launch()
 
 
 
 
 
1
  import os
2
+ import cv2
3
+ import torch
4
  from basicsr.archs.rrdbnet_arch import RRDBNet
5
  from basicsr.utils.download_util import load_file_from_url
6
+ from realesrgan import RealESRGANer
7
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
8
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # 模型配置
11
+ MODEL_OPTIONS = {
12
+ "RealESRGAN_x4plus": {
13
+ "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
14
+ "netscale": 4,
15
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
16
+ },
17
+ "RealESRNet_x4plus": {
18
+ "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
19
+ "netscale": 4,
20
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
21
+ },
22
+ "RealESRGAN_x4plus_anime_6B": {
23
+ "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4),
24
+ "netscale": 4,
25
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
26
+ },
27
+ "RealESRGAN_x2plus": {
28
+ "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
29
+ "netscale": 2,
30
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
31
+ },
32
+ "realesr-animevideov3": {
33
+ "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'),
34
+ "netscale": 4,
35
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
36
+ },
37
+ "realesr-general-x4v3": {
38
+ "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'),
39
+ "netscale": 4,
40
+ "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
41
+ }
42
+ }
43
 
44
+ def load_model(model_name):
45
+ """加载并初始化 Real-ESRGAN 模型"""
46
+ model_config = MODEL_OPTIONS[model_name]
47
+ model = model_config["model"]()
48
+ netscale = model_config["netscale"]
49
+ file_url = model_config["file_url"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # 下载模型权重
52
+ model_path = os.path.join("weights", f"{model_name}.pth")
53
  if not os.path.isfile(model_path):
54
+ os.makedirs("weights", exist_ok=True)
55
+ model_path = load_file_from_url(url=file_url, model_dir="weights", progress=True, file_name=None)
 
 
 
 
 
 
 
 
56
 
57
+ # 初始化 RealESRGANer
58
  upsampler = RealESRGANer(
59
  scale=netscale,
60
  model_path=model_path,
 
61
  model=model,
62
+ tile=0, # 默认无分块
63
+ tile_pad=10,
64
+ pre_pad=0,
65
+ half=True # 默认使用 fp16
66
+ )
67
+ return upsampler
68
+
69
+ def enhance_image(input_image, model_name, outscale, face_enhance):
70
+ """执行图像超分辨率增强"""
71
+ # 将 Gradio 上传的图像转换为 OpenCV 格式
72
+ img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
73
 
74
+ # 确定图像模式
75
+ if len(img.shape) == 3 and img.shape[2] == 4:
76
+ img_mode = 'RGBA'
77
+ else:
78
+ img_mode = None
79
+
80
+ # 加载模型
81
+ upsampler = load_model(model_name)
82
+
83
+ # 是否使用人脸增强
84
+ if face_enhance:
85
  from gfpgan import GFPGANer
86
  face_enhancer = GFPGANer(
87
  model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
88
+ upscale=outscale,
89
  arch='clean',
90
  channel_multiplier=2,
91
+ bg_upsampler=upsampler
92
+ )
93
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
94
+ else:
95
+ output, _ = upsampler.enhance(img, outscale=outscale)
 
 
 
 
 
 
 
 
 
96
 
97
+ # 将结果转换回 RGB 格式以供 Gradio 显示
98
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
99
  return output
100
 
101
+ # Gradio 界面
102
+ with gr.Blocks(title="Real-ESRGAN 图像超分辨率") as app:
103
+ gr.Markdown("## Real-ESRGAN 图像超分辨率系统")
104
+ gr.Markdown("上传图像,选择模型和参数,生成高清图像!")
 
 
 
 
 
 
 
105
 
106
  with gr.Row():
107
  with gr.Column():
108
+ input_image = gr.Image(label="输入图像", type="numpy")
109
+ model_dropdown = gr.Dropdown(
110
+ choices=list(MODEL_OPTIONS.keys()),
111
+ label="选择模型",
112
+ value="RealESRGAN_x4plus"
113
+ )
114
+ outscale = gr.Slider(minimum=1, maximum=4, step=0.5, value=4, label="放大倍数")
115
+ face_enhance = gr.Checkbox(label="启用人脸增强", value=False)
116
+ enhance_btn = gr.Button("开始增强", variant="primary")
117
 
118
+ with gr.Column():
119
+ output_image = gr.Image(label="增强结果", type="numpy")
120
 
121
+ enhance_btn.click(
122
+ fn=enhance_image,
123
+ inputs=[input_image, model_dropdown, outscale, face_enhance],
124
+ outputs=[output_image],
125
+ api_name="enhance"
126
+ )
127
 
128
  if __name__ == "__main__":
129
+ app.launch(server_name="0.0.0.0", server_port=7860)