# FLAME face model, modified from https://github.com/radekd91/emoca import torch import torch.nn as nn import numpy as np import pickle import torch.nn.functional as F from .lbs import lbs, batch_rodrigues, vertices2landmarks def to_tensor(array, dtype=torch.float32): if 'torch.tensor' not in str(type(array)): return torch.tensor(array, dtype=dtype) def to_np(array, dtype=np.float32): if 'scipy.sparse' in str(type(array)): array = array.todense() return np.array(array, dtype=dtype) class Struct(object): def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) def rot_mat_to_euler(rot_mats): # Calculates rotation matrix to euler angles # Careful for extreme cases of eular angles like [0.0, pi, 0.0] sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) class FLAME(nn.Module): """ Given flame parameters this class generates a differentiable FLAME function which outputs the a mesh and 2D/3D facial landmarks """ def __init__(self, config, flame_full=False): super(FLAME, self).__init__() print("creating the FLAME Decoder") with open(config.flame_model_path, 'rb') as f: # flame_model = Struct(**pickle.load(f, encoding='latin1')) ss = pickle.load(f, encoding='latin1') flame_model = Struct(**ss) self.dtype = torch.float32 self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) # The vertices of the template model self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) # The shape components and expression shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) if not flame_full: shapedirs = torch.cat([shapedirs[:, :, :config.n_shape], shapedirs[:, :, 300:300 + config.n_exp]], 2) else: shapedirs = torch.cat([shapedirs[:, :, :300], shapedirs[:, :, 300:400]], 2) self.register_buffer('shapedirs', shapedirs) # The pose components num_pose_basis = flame_model.posedirs.shape[-1] posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) # self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) parents = to_tensor(to_np(flame_model.kintree_table[0])).long() parents[0] = -1 self.register_buffer('parents', parents) self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) # Fixing Eyeball and neck rotation default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, requires_grad=False)) default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, requires_grad=False)) # Static and Dynamic Landmark embeddings for FLAME lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') lmk_embeddings = lmk_embeddings[()] self.register_buffer('lmk_faces_idx', torch.tensor(lmk_embeddings['static_lmk_faces_idx'], dtype=torch.long)) self.register_buffer('lmk_bary_coords', torch.tensor(lmk_embeddings['static_lmk_bary_coords'], dtype=self.dtype)) self.register_buffer('dynamic_lmk_faces_idx', torch.tensor(lmk_embeddings['dynamic_lmk_faces_idx'], dtype=torch.long)) self.register_buffer('dynamic_lmk_bary_coords', torch.tensor(lmk_embeddings['dynamic_lmk_bary_coords'], dtype=self.dtype)) self.register_buffer('full_lmk_faces_idx', torch.tensor(lmk_embeddings['full_lmk_faces_idx'], dtype=torch.long)) self.register_buffer('full_lmk_bary_coords', torch.tensor(lmk_embeddings['full_lmk_bary_coords'], dtype=self.dtype)) neck_kin_chain = [] NECK_IDX = 1 curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) while curr_idx != -1: neck_kin_chain.append(curr_idx) curr_idx = self.parents[curr_idx] self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) #---------------------------------- # clip rotation angles within a narrow range if needed eye_limits=((-50, 50), (-50, 50), (-0.1, 0.1)) neck_limits=((-90, 90), (-60, 60), (-80, 80)) jaw_limits=((-5, 60), (-0.1, 0.1), (-0.1, 0.1)) global_limits=((-20, 20), (-90, 90), (-20, 20)) global_limits = torch.tensor(global_limits).float() / 180 * np.pi self.register_buffer('global_limits', global_limits) neck_limits = torch.tensor(neck_limits).float() / 180 * np.pi self.register_buffer('neck_limits', neck_limits) jaw_limits = torch.tensor(jaw_limits).float() / 180 * np.pi self.register_buffer('jaw_limits', jaw_limits) eye_limits = torch.tensor(eye_limits).float() / 180 * np.pi self.register_buffer('eye_limits', eye_limits) def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, dynamic_lmk_b_coords, neck_kin_chain, dtype=torch.float32): """ Selects the face contour depending on the reletive position of the head Input: vertices: N X num_of_vertices X 3 pose: N X full pose dynamic_lmk_faces_idx: The list of contour face indexes dynamic_lmk_b_coords: The list of contour barycentric weights neck_kin_chain: The tree to consider for the relative rotation dtype: Data type return: The contour face indexes and the corresponding barycentric weights """ batch_size = pose.shape[0] aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, neck_kin_chain) rot_mats = batch_rodrigues( aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) rel_rot_mat = torch.eye(3, device=pose.device, dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) for idx in range(len(neck_kin_chain)): rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) y_rot_angle = torch.round( torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long) neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) mask = y_rot_angle.lt(-39).to(dtype=torch.long) neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) return dyn_lmk_faces_idx, dyn_lmk_b_coords def _apply_rotation_limit(self, rotation, limit): r_min, r_max = limit[:, 0].view(1, 3), limit[:, 1].view(1, 3) diff = r_max - r_min return r_min + (torch.tanh(rotation) + 1) / 2 * diff def apply_rotation_limits(self, neck=None, jaw=None): """ method to call for applying rotation limits. Don't use _apply_rotation_limit() in other methods as this might cause some bugs if we change which poses are affected by rotation limits. For this reason, in this method, all affected poses are limited within one function so that if we add more restricted poses, they can just be updated here :param neck: :param jaw: :return: """ neck = self._apply_rotation_limit(neck, self.neck_limits) if neck is not None else None jaw = self._apply_rotation_limit(jaw, self.jaw_limits) if jaw is not None else None ret = [i for i in [neck, jaw] if i is not None] return ret[0] if len(ret) == 1 else ret def _revert_rotation_limit(self, rotation, limit): """ inverse function of _apply_rotation_limit() from rotation angle vector (rodriguez) -> scalars from -inf ... inf :param rotation: tensor of shape N x 3 :param limit: tensor of shape 3 x 2 (min, max) :return: """ r_min, r_max = limit[:, 0].view(1, 3), limit[:, 1].view(1, 3) diff = r_max - r_min rotation = rotation.clone() for i in range(3): rotation[:, i] = torch.clip(rotation[:, i], min=r_min[0, i] + diff[0, i] * .01, max=r_max[0, i] - diff[0, i] * .01) return torch.atanh((rotation - r_min) / diff * 2 - 1) def revert_rotation_limits(self, neck, jaw): """ inverse function of apply_rotation_limits() from rotation angle vector (rodriguez) -> scalars from -inf ... inf :param rotation: :param limit: :return: """ neck = self._revert_rotation_limit(neck, self.neck_limits) jaw = self._revert_rotation_limit(jaw, self.jaw_limits) return neck, jaw def get_neutral_joint_rotations(self): res = {} for name, limit in zip(['neck', 'jaw', 'global', 'eyes'], [self.neck_limits, self.jaw_limits, self.global_limits, self.eye_limits]): r_min, r_max = limit[:, 0], limit[:, 1] diff = r_max - r_min res[name] = torch.atanh(-2 * r_min / diff - 1) # assert (r_min + (torch.tanh(res[name]) + 1) / 2 * diff) < 1e-7 return res def _pose2rot(self, pose): rot_mats = batch_rodrigues( pose.view(-1, 3), dtype=pose.dtype).view([pose.shape[0], 3, 3]) return rot_mats def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): """ Calculates landmarks by barycentric interpolation Input: vertices: torch.tensor NxVx3, dtype = torch.float32 The tensor of input vertices faces: torch.tensor (N*F)x3, dtype = torch.long The faces of the mesh lmk_faces_idx: torch.tensor N X L, dtype = torch.long The tensor with the indices of the faces used to calculate the landmarks. lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 The tensor of barycentric coordinates that are used to interpolate the landmarks Returns: landmarks: torch.tensor NxLx3, dtype = torch.float32 The coordinates of the landmarks for each mesh in the batch """ # Extract the indices of the vertices for each face # NxLx3 batch_size, num_verts = vertices.shape[:2] lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( device=vertices.device) * num_verts lmk_vertices = vertices.view(-1, 3)[lmk_faces] landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) return landmarks def seletec_3d68(self, vertices): landmarks3d = vertices2landmarks(vertices, self.faces_tensor, self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) return landmarks3d def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None, use_rotation_limits=False): """ Input: shape_params: N X number of shape parameters expression_params: N X number of expression parameters pose_params: N X number of pose parameters (6) return:d vertices: N X V X 3 landmarks: N X number of landmarks X 3 """ batch_size = shape_params.shape[0] if pose_params is None: pose_params = self.eye_pose.expand(batch_size, -1) if eye_pose_params is None: eye_pose_params = self.eye_pose.expand(batch_size, -1) betas = torch.cat([shape_params, expression_params], dim=1) if use_rotation_limits: neck_pose, jaw_pose = self.apply_rotation_limits(neck=pose_params[:, :3], jaw=pose_params[:, 3:]) pose_params = torch.cat([neck_pose,jaw_pose],dim=-1) eye_pose_params = torch.cat([self._apply_rotation_limit(eye_pose_params[:, :3], self.eye_limits), self._apply_rotation_limit(eye_pose_params[:, 3:], self.eye_limits)], dim=1) # set global rotation to zero full_pose = torch.cat( [torch.zeros_like(pose_params[:, :3]), pose_params[:, :3], pose_params[:, 3:], eye_pose_params], dim=1) # full_pose = torch.cat( # [pose_params[:, :3], torch.zeros_like(pose_params[:, :3]), pose_params[:, 3:], eye_pose_params], dim=1) template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) vertices, J_transformed = lbs(betas, full_pose, template_vertices, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, dtype=self.dtype) lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, dtype=self.dtype) lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) landmarks2d = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) bz = vertices.shape[0] landmarks3d = vertices2landmarks(vertices, self.faces_tensor, self.full_lmk_faces_idx.repeat(bz, 1), self.full_lmk_bary_coords.repeat(bz, 1, 1)) return vertices, landmarks2d, landmarks3d, J_transformed class FLAMETex(nn.Module): """ current FLAME texture: https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 tex_path: '/ps/scratch/yfeng/Data/FLAME/texture/albedoModel2020_FLAME_albedoPart.npz' ## adapted from BFM tex_path: '/ps/scratch/yfeng/Data/FLAME/texture/FLAME_albedo_from_BFM.npz' """ def __init__(self, config): super(FLAMETex, self).__init__() if config.tex_type == 'BFM': mu_key = 'MU' pc_key = 'PC' n_pc = 199 tex_path = config.tex_path tex_space = np.load(tex_path) texture_mean = tex_space[mu_key].reshape(1, -1) texture_basis = tex_space[pc_key].reshape(-1, n_pc) elif config.tex_type == 'FLAME': mu_key = 'mean' pc_key = 'tex_dir' n_pc = 200 tex_path = config.flame_tex_path tex_space = np.load(tex_path) texture_mean = tex_space[mu_key].reshape(1, -1) / 255. texture_basis = tex_space[pc_key].reshape(-1, n_pc) / 255. else: print('texture type ', config.tex_type, 'not exist!') exit() n_tex = config.n_tex num_components = texture_basis.shape[1] texture_mean = torch.from_numpy(texture_mean).float()[None, ...] texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...] self.register_buffer('texture_mean', texture_mean) self.register_buffer('texture_basis', texture_basis) def forward(self, texcode): texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1) texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2) texture = F.interpolate(texture, [256, 256]) texture = texture[:, [2, 1, 0], :, :] return texture class FLAMETex_trainable(nn.Module): """ current FLAME texture: https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 tex_path: '/ps/scratch/yfeng/Data/FLAME/texture/albedoModel2020_FLAME_albedoPart.npz' ## adapted from BFM tex_path: '/ps/scratch/yfeng/Data/FLAME/texture/FLAME_albedo_from_BFM.npz' """ def __init__(self, config): super(FLAMETex_trainable, self).__init__() tex_params = config.tex_params texture_model = np.load(config.tex_path) num_tex_pc = texture_model['PC'].shape[-1] tex_shape = texture_model['MU'].shape MU = torch.from_numpy(np.reshape(texture_model['MU'], (1, -1))).float()[None, ...] PC = torch.from_numpy(np.reshape(texture_model['PC'], (-1, num_tex_pc))[:, :tex_params]).float()[None, ...] self.register_buffer('MU', MU) self.register_buffer('PC', PC) if 'specMU' in texture_model.files: specMU = torch.from_numpy(np.reshape(texture_model['specMU'], (1, -1))).float()[None, ...] specPC = torch.from_numpy(np.reshape(texture_model['specPC'], (-1, num_tex_pc)))[:, :tex_params].float()[ None, ...] self.register_buffer('specMU', specMU) self.register_buffer('specPC', specPC) self.isspec = True else: self.isspec = False self.register_parameter('PC_correction', nn.Parameter(torch.zeros_like(PC))) def forward(self, texcode): diff_albedo = self.MU + (self.PC * texcode[:, None, :]).sum(-1) + ( self.PC_correction * texcode[:, None, :]).sum(-1) if self.isspec: spec_albedo = self.specMU + (self.specPC * texcode[:, None, :]).sum(-1) texture = (diff_albedo + spec_albedo) # torch.pow(0.6*(diff_albedo + spec_albedo), 1.0/2.2) else: texture = diff_albedo texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2) texture = F.interpolate(texture, [256, 256]) texture = texture[:, [2, 1, 0], :, :] return texture