import os import cv2 import torch from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact import gradio as gr # 模型配置 MODEL_OPTIONS = { "RealESRGAN_x4plus": { "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), "netscale": 4, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" }, "RealESRNet_x4plus": { "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), "netscale": 4, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth" }, "RealESRGAN_x4plus_anime_6B": { "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4), "netscale": 4, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" }, "RealESRGAN_x2plus": { "model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2), "netscale": 2, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" }, "realesr-animevideov3": { "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'), "netscale": 4, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth" }, "realesr-general-x4v3": { "model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'), "netscale": 4, "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" } } def load_model(model_name): """加载并初始化 Real-ESRGAN 模型""" model_config = MODEL_OPTIONS[model_name] model = model_config["model"]() netscale = model_config["netscale"] file_url = model_config["file_url"] # 下载模型权重 model_path = os.path.join("weights", f"{model_name}.pth") if not os.path.isfile(model_path): os.makedirs("weights", exist_ok=True) model_path = load_file_from_url(url=file_url, model_dir="weights", progress=True, file_name=None) # 初始化 RealESRGANer upsampler = RealESRGANer( scale=netscale, model_path=model_path, model=model, tile=0, # 默认无分块 tile_pad=10, pre_pad=0, half=True # 默认使用 fp16 ) return upsampler def enhance_image(input_image, model_name, outscale, face_enhance): """执行图像超分辨率增强""" # 将 Gradio 上传的图像转换为 OpenCV 格式 img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) # 确定图像模式 if len(img.shape) == 3 and img.shape[2] == 4: img_mode = 'RGBA' else: img_mode = None # 加载模型 upsampler = load_model(model_name) # 是否使用人脸增强 if face_enhance: from gfpgan import GFPGANer face_enhancer = GFPGANer( model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler ) _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) else: output, _ = upsampler.enhance(img, outscale=outscale) # 将结果转换回 RGB 格式以供 Gradio 显示 output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) return output # Gradio 界面 with gr.Blocks(title="Real-ESRGAN 图像超分辨率") as app: gr.Markdown("## Real-ESRGAN 图像超分辨率系统") gr.Markdown("上传图像,选择模型和参数,生成高清图像!") with gr.Row(): with gr.Column(): input_image = gr.Image(label="输入图像", type="numpy") model_dropdown = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), label="选择模型", value="RealESRGAN_x4plus" ) outscale = gr.Slider(minimum=1, maximum=4, step=0.5, value=4, label="放大倍数") face_enhance = gr.Checkbox(label="启用人脸增强", value=False) enhance_btn = gr.Button("开始增强", variant="primary") with gr.Column(): output_image = gr.Image(label="增强结果", type="numpy") enhance_btn.click( fn=enhance_image, inputs=[input_image, model_dropdown, outscale, face_enhance], outputs=[output_image], api_name="enhance" ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)