gavinyuan
delete: GPEN example images
523fb10
"""
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
@author: yangxy ([email protected])
"""
import torch
import os
import cv2
import glob
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, utils
from gpen_model import FullGenerator, FullGenerator_SR
class FaceGAN(object):
def __init__(
self,
base_dir="./",
in_size=512,
out_size=512,
model=None,
channel_multiplier=2,
narrow=1,
key=None,
is_norm=True,
device="cuda",
):
self.mfile = os.path.join(base_dir, "weights", model + ".pth")
self.n_mlp = 8
self.device = device
self.is_norm = is_norm
self.in_resolution = in_size
self.out_resolution = out_size
self.key = key
self.load_model(channel_multiplier, narrow)
def load_model(self, channel_multiplier=2, narrow=1):
if self.in_resolution == self.out_resolution:
self.model = FullGenerator(
self.in_resolution,
512,
self.n_mlp,
channel_multiplier,
narrow=narrow,
device=self.device,
)
else:
self.model = FullGenerator_SR(
self.in_resolution,
self.out_resolution,
512,
self.n_mlp,
channel_multiplier,
narrow=narrow,
device=self.device,
)
pretrained_dict = torch.load(self.mfile, map_location=torch.device("cpu"))
if self.key is not None:
pretrained_dict = pretrained_dict[self.key]
self.model.load_state_dict(pretrained_dict)
self.model.to(self.device)
self.model.eval()
def process(self, img):
img = cv2.resize(img, (self.in_resolution, self.in_resolution))
img_t = self.img2tensor(img)
with torch.no_grad():
out, __ = self.model(img_t)
out = self.tensor2img(out)
return out
def img2tensor(self, img):
img_t = torch.from_numpy(img).to(self.device) / 255.0
if self.is_norm:
img_t = (img_t - 0.5) / 0.5
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
return img_t
def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
if self.is_norm:
img_t = img_t * 0.5 + 0.5
img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
return img_np.astype(imtype)