"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch """ import numpy as np import torch from .base_model import BaseModel from . import networks from .cropping import align_img from .bfm import ParametricFaceModel from .util.pytorch3d import MeshRenderer import trimesh class FaceReconModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): """ Configures options specific for CUT model """ # net structure and parameters parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth') opt, _ = parser.parse_known_args() parser.set_defaults( focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. ) return parser def __init__(self, opt): """Initialize this model class. Parameters: opt -- training/test options A few things can be done here. - (required) call the initialization function of BaseModel - define loss function, visualization images, model names, and optimizers """ BaseModel.__init__(self, opt) # call the initialization method of BaseModel self.model_names = ['net_recon'] self.parallel_names = self.model_names + ['renderer'] self.net_recon = networks.define_net_recon( net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path ) self.facemodel = ParametricFaceModel( bfm_folder=opt.bfm_folder, is_train=self.isTrain) fov = 2 * np.arctan(112 / 1015) * 180 / np.pi self.renderer = MeshRenderer( rasterize_fov=fov, znear=0.1, zfar=50, rasterize_size=int(2 * 112) ) # Our program will automatically call to define schedulers, load networks, and print networks def preproces_img(self, im, lm, to_tensor=True): # to RGB stand_index = np.array([96, 97, 54, 76, 82]) W,H = im.size lm = lm.reshape([-1, 2]) lm = lm[stand_index,:] lm[:, -1] = H - 1 - lm[:, -1] trans_params, im, lm, _ = align_img(im, lm) if to_tensor: im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) lm = torch.tensor(lm).unsqueeze(0) return im, lm, trans_params def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ input_img = input['imgs'] lmdks = input['lms'] align_img, align_lmdks, trans_params = self.preproces_img(input_img, lmdks) align_img = align_img.to(self.device) return align_img, trans_params def split_coeff(self, coeffs): """ Return: coeffs_dict -- a dict of torch.tensors Parameters: coeffs -- torch.tensor, size (B, 256) """ id_coeffs = coeffs[:, :80] exp_coeffs = coeffs[:, 80: 144] tex_coeffs = coeffs[:, 144: 224] angles = coeffs[:, 224: 227] gammas = coeffs[:, 227: 254] translations = coeffs[:, 254:] return { 'id': id_coeffs, 'exp': exp_coeffs, 'tex': tex_coeffs, 'angle': angles, 'gamma': gammas, 'trans': translations } def optimize_parameters(self): return None def forward(self, input): self.input_img, trans_params = self.set_input(input) output_coeff = self.net_recon(self.input_img) pred_coeffs_dict = self.split_coeff(output_coeff) pred_coeffs = {key:pred_coeffs_dict[key].cpu().numpy() for key in pred_coeffs_dict} self.facemodel.to(self.device) self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ self.facemodel.compute_for_render(output_coeff) self.pred_mask, _, self.pred_face = self.renderer( self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) return pred_coeffs, trans_params def save_mesh(self, name): recon_shape = self.pred_vertex # get reconstructed shape recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space recon_shape = recon_shape.cpu().numpy()[0] recon_color = self.pred_color recon_color = recon_color.cpu().numpy()[0] tri = self.facemodel.face_buf.cpu().numpy() mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) # mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri) mesh.export(name) def compute_visuals(self): with torch.no_grad(): input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw), axis=-2) output_vis_numpy = np.clip(output_vis_numpy, 0, 255) # self.output_vis = torch.tensor( # output_vis_numpy / 255., dtype=torch.float32 # ).permute(0, 3, 1, 2).to(self.device) self.output_vis = output_vis_numpy