Spaces:
Running
Running
File size: 5,398 Bytes
9584322 54a84be 6896ca1 a314672 0c9b170 6896ca1 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 976a1fb 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be 9584322 54a84be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import cv2
import subprocess
import subprocess
import subprocess
# 修复导入语句
subprocess.run([
"sed", "-i",
"8s/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/",
"/usr/local/lib/python3.10/site-packages/basicsr/data/degradations.py"
], check=True)
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import gradio as gr
# 模型配置
MODEL_OPTIONS = {
"RealESRGAN_x4plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
},
"RealESRNet_x4plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
},
"RealESRGAN_x4plus_anime_6B": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
},
"RealESRGAN_x2plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
"netscale": 2,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
},
"realesr-animevideov3": {
"model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
},
"realesr-general-x4v3": {
"model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
}
}
def load_model(model_name):
"""加载并初始化 Real-ESRGAN 模型"""
model_config = MODEL_OPTIONS[model_name]
model = model_config["model"]()
netscale = model_config["netscale"]
file_url = model_config["file_url"]
# 下载模型权重
model_path = os.path.join("weights", f"{model_name}.pth")
if not os.path.isfile(model_path):
os.makedirs("weights", exist_ok=True)
model_path = load_file_from_url(url=file_url, model_dir="weights", progress=True, file_name=None)
# 初始化 RealESRGANer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
tile=0, # 默认无分块
tile_pad=10,
pre_pad=0,
half=True # 默认使用 fp16
)
return upsampler
def enhance_image(input_image, model_name, outscale, face_enhance):
"""执行图像超分辨率增强"""
# 将 Gradio 上传的图像转换为 OpenCV 格式
img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# 确定图像模式
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
# 加载模型
upsampler = load_model(model_name)
# 是否使用人脸增强
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
# 将结果转换回 RGB 格式以供 Gradio 显示
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
return output
# Gradio 界面
with gr.Blocks(title="Real-ESRGAN 图像超分辨率",theme="NoCrypt/miku") as app:
gr.Markdown("## Real-ESRGAN 图像超分辨率系统")
gr.Markdown("上传图像,选择模型和参数,生成高清图像!")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="输入图像", type="numpy")
model_dropdown = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
label="选择模型",
value="RealESRGAN_x4plus"
)
outscale = gr.Slider(minimum=1, maximum=4, step=0.5, value=4, label="放大倍数")
face_enhance = gr.Checkbox(label="启用人脸增强", value=False)
enhance_btn = gr.Button("开始增强", variant="primary")
with gr.Column():
output_image = gr.Image(label="增强结果", type="numpy")
enhance_btn.click(
fn=enhance_image,
inputs=[input_image, model_dropdown, outscale, face_enhance],
outputs=[output_image],
api_name="enhance"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860) |