File size: 5,235 Bytes
9584322
54a84be
6896ca1
 
 
 
54a84be
9584322
 
54a84be
 
 
9584322
54a84be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9584322
54a84be
 
 
 
 
 
9584322
54a84be
 
9584322
54a84be
 
9584322
54a84be
9584322
 
 
 
54a84be
 
 
 
 
 
 
 
 
 
 
9584322
54a84be
 
 
 
 
 
 
 
 
 
 
9584322
 
 
54a84be
9584322
 
54a84be
 
 
 
 
9584322
54a84be
 
9584322
 
54a84be
 
 
 
9584322
 
 
54a84be
 
 
 
 
 
 
 
 
9584322
54a84be
 
9584322
54a84be
 
 
 
 
 
9584322
 
54a84be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import cv2
import subprocess

subprocess.run(["sed", "-i", "9s/.*/import torchvision.transforms.functional as F_t/", "/usr/local/lib/python3.11/dist-packages/pytorchvideo/transforms/augmentations.py"], check=True)

import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import gradio as gr

# 模型配置
MODEL_OPTIONS = {
    "RealESRGAN_x4plus": {
        "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
        "netscale": 4,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
    },
    "RealESRNet_x4plus": {
        "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
        "netscale": 4,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
    },
    "RealESRGAN_x4plus_anime_6B": {
        "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4),
        "netscale": 4,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
    },
    "RealESRGAN_x2plus": {
        "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
        "netscale": 2,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
    },
    "realesr-animevideov3": {
        "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'),
        "netscale": 4,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
    },
    "realesr-general-x4v3": {
        "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'),
        "netscale": 4,
        "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
    }
}

def load_model(model_name):
    """加载并初始化 Real-ESRGAN 模型"""
    model_config = MODEL_OPTIONS[model_name]
    model = model_config["model"]()
    netscale = model_config["netscale"]
    file_url = model_config["file_url"]

    # 下载模型权重
    model_path = os.path.join("weights", f"{model_name}.pth")
    if not os.path.isfile(model_path):
        os.makedirs("weights", exist_ok=True)
        model_path = load_file_from_url(url=file_url, model_dir="weights", progress=True, file_name=None)

    # 初始化 RealESRGANer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        model=model,
        tile=0,  # 默认无分块
        tile_pad=10,
        pre_pad=0,
        half=True  # 默认使用 fp16
    )
    return upsampler

def enhance_image(input_image, model_name, outscale, face_enhance):
    """执行图像超分辨率增强"""
    # 将 Gradio 上传的图像转换为 OpenCV 格式
    img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)

    # 确定图像模式
    if len(img.shape) == 3 and img.shape[2] == 4:
        img_mode = 'RGBA'
    else:
        img_mode = None

    # 加载模型
    upsampler = load_model(model_name)

    # 是否使用人脸增强
    if face_enhance:
        from gfpgan import GFPGANer
        face_enhancer = GFPGANer(
            model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
            upscale=outscale,
            arch='clean',
            channel_multiplier=2,
            bg_upsampler=upsampler
        )
        _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
    else:
        output, _ = upsampler.enhance(img, outscale=outscale)

    # 将结果转换回 RGB 格式以供 Gradio 显示
    output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
    return output

# Gradio 界面
with gr.Blocks(title="Real-ESRGAN 图像超分辨率") as app:
    gr.Markdown("## Real-ESRGAN 图像超分辨率系统")
    gr.Markdown("上传图像,选择模型和参数,生成高清图像!")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="输入图像", type="numpy")
            model_dropdown = gr.Dropdown(
                choices=list(MODEL_OPTIONS.keys()),
                label="选择模型",
                value="RealESRGAN_x4plus"
            )
            outscale = gr.Slider(minimum=1, maximum=4, step=0.5, value=4, label="放大倍数")
            face_enhance = gr.Checkbox(label="启用人脸增强", value=False)
            enhance_btn = gr.Button("开始增强", variant="primary")

        with gr.Column():
            output_image = gr.Image(label="增强结果", type="numpy")

    enhance_btn.click(
        fn=enhance_image,
        inputs=[input_image, model_dropdown, outscale, face_enhance],
        outputs=[output_image],
        api_name="enhance"
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)