Spaces:
Running
on
Zero
Running
on
Zero
| """This script defines the face reconstruction model for Deep3DFaceRecon_pytorch | |
| """ | |
| import numpy as np | |
| import torch | |
| from .base_model import BaseModel | |
| from . import networks | |
| from .cropping import align_img | |
| from .bfm import ParametricFaceModel | |
| from .util.pytorch3d import MeshRenderer | |
| import trimesh | |
| class FaceReconModel(BaseModel): | |
| def modify_commandline_options(parser, is_train=True): | |
| """ Configures options specific for CUT model | |
| """ | |
| # net structure and parameters | |
| parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') | |
| parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth') | |
| opt, _ = parser.parse_known_args() | |
| parser.set_defaults( | |
| focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. | |
| ) | |
| return parser | |
| def __init__(self, opt): | |
| """Initialize this model class. | |
| Parameters: | |
| opt -- training/test options | |
| A few things can be done here. | |
| - (required) call the initialization function of BaseModel | |
| - define loss function, visualization images, model names, and optimizers | |
| """ | |
| BaseModel.__init__(self, opt) # call the initialization method of BaseModel | |
| self.model_names = ['net_recon'] | |
| self.parallel_names = self.model_names + ['renderer'] | |
| self.net_recon = networks.define_net_recon( | |
| net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path | |
| ) | |
| self.facemodel = ParametricFaceModel( | |
| bfm_folder=opt.bfm_folder, is_train=self.isTrain) | |
| fov = 2 * np.arctan(112 / 1015) * 180 / np.pi | |
| self.renderer = MeshRenderer( | |
| rasterize_fov=fov, znear=0.1, zfar=50, rasterize_size=int(2 * 112) | |
| ) | |
| # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks | |
| def preproces_img(self, im, lm, to_tensor=True): | |
| # to RGB | |
| stand_index = np.array([96, 97, 54, 76, 82]) | |
| W,H = im.size | |
| lm = lm.reshape([-1, 2]) | |
| lm = lm[stand_index,:] | |
| lm[:, -1] = H - 1 - lm[:, -1] | |
| trans_params, im, lm, _ = align_img(im, lm) | |
| if to_tensor: | |
| im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) | |
| lm = torch.tensor(lm).unsqueeze(0) | |
| return im, lm, trans_params | |
| def set_input(self, input): | |
| """Unpack input data from the dataloader and perform necessary pre-processing steps. | |
| Parameters: | |
| input: a dictionary that contains the data itself and its metadata information. | |
| """ | |
| input_img = input['imgs'] | |
| lmdks = input['lms'] | |
| align_img, align_lmdks, trans_params = self.preproces_img(input_img, lmdks) | |
| align_img = align_img.to(self.device) | |
| return align_img, trans_params | |
| def split_coeff(self, coeffs): | |
| """ | |
| Return: | |
| coeffs_dict -- a dict of torch.tensors | |
| Parameters: | |
| coeffs -- torch.tensor, size (B, 256) | |
| """ | |
| id_coeffs = coeffs[:, :80] | |
| exp_coeffs = coeffs[:, 80: 144] | |
| tex_coeffs = coeffs[:, 144: 224] | |
| angles = coeffs[:, 224: 227] | |
| gammas = coeffs[:, 227: 254] | |
| translations = coeffs[:, 254:] | |
| return { | |
| 'id': id_coeffs, | |
| 'exp': exp_coeffs, | |
| 'tex': tex_coeffs, | |
| 'angle': angles, | |
| 'gamma': gammas, | |
| 'trans': translations | |
| } | |
| def optimize_parameters(self): | |
| return None | |
| def forward(self, input): | |
| self.input_img, trans_params = self.set_input(input) | |
| output_coeff = self.net_recon(self.input_img) | |
| pred_coeffs_dict = self.split_coeff(output_coeff) | |
| pred_coeffs = {key:pred_coeffs_dict[key].cpu().numpy() for key in pred_coeffs_dict} | |
| self.facemodel.to(self.device) | |
| self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ | |
| self.facemodel.compute_for_render(output_coeff) | |
| self.pred_mask, _, self.pred_face = self.renderer( | |
| self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) | |
| return pred_coeffs, trans_params | |
| def save_mesh(self, name): | |
| recon_shape = self.pred_vertex # get reconstructed shape | |
| recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space | |
| recon_shape = recon_shape.cpu().numpy()[0] | |
| recon_color = self.pred_color | |
| recon_color = recon_color.cpu().numpy()[0] | |
| tri = self.facemodel.face_buf.cpu().numpy() | |
| mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) | |
| # mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri) | |
| mesh.export(name) | |
| def compute_visuals(self): | |
| with torch.no_grad(): | |
| input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img | |
| output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| output_vis_numpy = np.concatenate((input_img_numpy, | |
| output_vis_numpy_raw), axis=-2) | |
| output_vis_numpy = np.clip(output_vis_numpy, 0, 255) | |
| # self.output_vis = torch.tensor( | |
| # output_vis_numpy / 255., dtype=torch.float32 | |
| # ).permute(0, 3, 1, 2).to(self.device) | |
| self.output_vis = output_vis_numpy | |