jhj0517
Migrate to official RealESRGAN version
f73fef2
raw
history blame
2.32 kB
import os.path
import gradio as gr
import torch
from PIL import Image
import numpy as np
from typing import Optional
from modules.utils.paths import *
from .model_downloader import download_resrgan_model, MODELS_REALESRGAN_URL
class RealESRGANInferencer:
def __init__(self,
model_dir: str = MODELS_REAL_ESRGAN_DIR,
output_dir: str = OUTPUTS_DIR):
self.model_dir = model_dir
self.output_dir = output_dir
self.device = self.get_device()
self.model = None
self.up_sampler = None
self.face_enhancer = None
self.available_models = list(MODELS_REALESRGAN_URL.keys())
self.default_model = self.available_models[0]
def load_model(self,
model_name: Optional[str] = None,
scale: int = 1,
progress: gr.Progress = gr.Progress()):
if model_name is None:
model_name = self.default_model
if not model_name.endswith(".pth"):
model_name += ".pth"
model_path = os.path.join(self.model_dir, model_name)
if not os.path.exists(model_path):
progress(0, f"Downloading RealESRGAN model to : {model_path}")
name, ext = os.path.splitext(model_name)
download_resrgan_model(model_path, MODELS_REALESRGAN_URL[name])
if self.model is None:
self.model = RealESRGAN(self.device, scale=scale)
self.model.load_weights(model_path=model_path, download=False)
def restore_image(self,
img_path: str,
overwrite: bool = True):
if self.model is None:
self.load_model()
try:
img = Image.open(img_path).convert('RGB')
sr_img = self.model.predict(img)
if overwrite:
output_path = img_path
else:
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
sr_img.save(output_path)
return output_path
except Exception as e:
raise
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"