jhj0517
Update default model
c311090
raw
history blame
2.24 kB
import os.path
import gradio as gr
import torch
from PIL import Image
import numpy as np
from typing import Optional
from RealESRGAN import RealESRGAN
from modules.utils.paths import *
from .model_downloader import *
class RealESRGANInferencer:
def __init__(self,
model_dir: str = MODELS_REAL_ESRGAN_DIR,
output_dir: str = OUTPUTS_DIR):
self.model_dir = model_dir
self.output_dir = output_dir
self.device = self.get_device()
self.model = None
self.available_models = list(MODELS_REALESRGAN_URL.keys())
self.default_model = self.available_models[0]
def load_model(self,
model_name: Optional[str] = None,
scale: int = 1,
progress: gr.Progress = gr.Progress()):
if model_name is None:
model_name = self.default_model
if not model_name.endswith(".pth"):
model_name += ".pth"
model_path = os.path.join(self.model_dir, model_name)
if not os.path.exists(model_path):
progress(0, f"Downloading RealESRGAN model to : {model_path}")
name, ext = os.path.splitext(model_name)
download_resrgan_model(model_path, MODELS_REALESRGAN_URL[name])
if self.model is None:
self.model = RealESRGAN(self.device, scale=scale)
self.model.load_weights(model_path=model_path, download=False)
def restore_image(self,
img_path: str,
overwrite: bool = True):
if self.model is None:
self.load_model()
try:
img = Image.open(img_path).convert('RGB')
sr_img = self.model.predict(img)
if overwrite:
output_path = img_path
else:
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
sr_img.save(output_path)
return output_path
except Exception as e:
raise
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"