Spaces:
Running
Running
| ''' | |
| @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) | |
| @author: yangxy ([email protected]) | |
| https://github.com/yangxy/GPEN | |
| @inproceedings{Yang2021GPEN, | |
| title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, | |
| author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang}, | |
| booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, | |
| year={2021} | |
| } | |
| © Alibaba, 2021. For academic and non-commercial use only. | |
| ================================================== | |
| slightly modified by Kai Zhang (2021-06-03) | |
| https://github.com/cszn/KAIR | |
| How to run: | |
| step 1: Download <RetinaFace-R50.pth> model and <GPEN-512.pth> model and put them into `model_zoo`. | |
| RetinaFace-R50.pth: https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth | |
| GPEN-512.pth: https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-512.pth | |
| step 2: Install ninja by `pip install ninja`; set <inputdir> for your own testing images | |
| step 3: `python main_test_face_enhancement.py` | |
| ================================================== | |
| ''' | |
| import os | |
| import cv2 | |
| import glob | |
| import numpy as np | |
| import torch | |
| from utils.utils_alignfaces import warp_and_crop_face, get_reference_facial_points | |
| from utils import utils_image as util | |
| from retinaface.retinaface_detection import RetinaFaceDetection | |
| from models.network_faceenhancer import FullGenerator as enhancer_net | |
| class faceenhancer(object): | |
| def __init__(self, model_path='model_zoo/GPEN-512.pth', size=512, channel_multiplier=2): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model_path = model_path | |
| self.size = size | |
| self.model = enhancer_net(self.size, 512, 8, channel_multiplier).to(self.device) | |
| self.model.load_state_dict(torch.load(self.model_path)) | |
| self.model.eval() | |
| def process(self, img): | |
| ''' | |
| img: uint8 RGB image, (W, H, 3) | |
| out: uint8 RGB image, (W, H, 3) | |
| ''' | |
| img = cv2.resize(img, (self.size, self.size)) | |
| img = util.uint2tensor4(img) | |
| img = (img - 0.5) / 0.5 | |
| img = img.to(self.device) | |
| with torch.no_grad(): | |
| out, __ = self.model(img) | |
| out = util.tensor2uint(out * 0.5 + 0.5) | |
| return out | |
| class faceenhancer_with_detection_alignment(object): | |
| def __init__(self, model_path, size=512, channel_multiplier=2): | |
| self.facedetector = RetinaFaceDetection('model_zoo/RetinaFace-R50.pth') | |
| self.faceenhancer = faceenhancer(model_path, size, channel_multiplier) | |
| self.size = size | |
| self.threshold = 0.9 | |
| self.mask = np.zeros((512, 512), np.float32) | |
| cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) | |
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) | |
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) | |
| self.kernel = np.array(( | |
| [0.0625, 0.125, 0.0625], | |
| [0.125, 0.25, 0.125], | |
| [0.0625, 0.125, 0.0625]), dtype="float32") | |
| # get the reference 5 landmarks position in the crop settings | |
| default_square = True | |
| inner_padding_factor = 0.25 | |
| outer_padding = (0, 0) | |
| self.reference_5pts = get_reference_facial_points( | |
| (self.size, self.size), inner_padding_factor, outer_padding, default_square) | |
| def process(self, img): | |
| ''' | |
| img: uint8 RGB image, (W, H, 3) | |
| img, orig_faces, enhanced_faces: uint8 RGB image / cropped face images | |
| ''' | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| facebs, landms = self.facedetector.detect(img) | |
| orig_faces, enhanced_faces = [], [] | |
| height, width = img.shape[:2] | |
| full_mask = np.zeros((height, width), dtype=np.float32) | |
| full_img = np.zeros(img.shape, dtype=np.uint8) | |
| for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): | |
| if faceb[4]<self.threshold: continue | |
| fh, fw = (faceb[3]-faceb[1]), (faceb[2]-faceb[0]) | |
| facial5points = np.reshape(facial5points, (2, 5)) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size)) | |
| # Enhance the face image! | |
| ef = self.faceenhancer.process(of) | |
| orig_faces.append(of) | |
| enhanced_faces.append(ef) | |
| tmp_mask = self.mask | |
| tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) | |
| tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3) | |
| if min(fh, fw) < 100: # Gaussian filter for small face | |
| ef = cv2.filter2D(ef, -1, self.kernel) | |
| tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) | |
| mask = tmp_mask - full_mask | |
| full_mask[np.where(mask>0)] = tmp_mask[np.where(mask>0)] | |
| full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)] | |
| full_mask = full_mask[:, :, np.newaxis] | |
| img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask) | |
| return img, orig_faces, enhanced_faces | |
| if __name__=='__main__': | |
| inputdir = os.path.join('testsets', 'real_faces') | |
| outdir = os.path.join('testsets', 'real_faces_results') | |
| os.makedirs(outdir, exist_ok=True) | |
| # whether use the face detection&alignment or not | |
| need_face_detection = True | |
| if need_face_detection: | |
| enhancer = faceenhancer_with_detection_alignment(model_path=os.path.join('model_zoo','GPEN-512.pth'), size=512, channel_multiplier=2) | |
| else: | |
| enhancer = faceenhancer(model_path=os.path.join('model_zoo','GPEN-512.pth'), size=512, channel_multiplier=2) | |
| for idx, img_file in enumerate(util.get_image_paths(inputdir)): | |
| img_name, ext = os.path.splitext(os.path.basename(img_file)) | |
| img_L = util.imread_uint(img_file, n_channels=3) | |
| print('{:->4d} --> {:<s}'.format(idx+1, img_name+ext)) | |
| img_L = cv2.resize(img_L, (0,0), fx=2, fy=2) | |
| if need_face_detection: | |
| # do the enhancement | |
| img_H, orig_faces, enhanced_faces = enhancer.process(img_L) | |
| util.imsave(np.hstack((img_L, img_H)), os.path.join(outdir, img_name+'_comparison.png')) | |
| util.imsave(img_H, os.path.join(outdir, img_name+'_enhanced.png')) | |
| for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)): | |
| of = cv2.resize(of, ef.shape[:2]) | |
| util.imsave(np.hstack((of, ef)), os.path.join(outdir, img_name+'_face%02d'%m+'.png')) | |
| else: | |
| # do the enhancement | |
| img_H = enhancer.process(img_L) | |
| util.imsave(img_H, os.path.join(outdir, img_name+'_enhanced_without_detection.png')) | |