# from metaface_fitting 20221122 import torch from torch import nn import numpy as np from pytorch3d.structures import Meshes # from pytorch3d.renderer import TexturesVertex from pytorch3d.renderer import ( look_at_view_transform, PerspectiveCameras, PointLights, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, TexturesVertex, blending ) from pytorch3d.loss import ( # mesh_edge_loss, mesh_laplacian_smoothing, # mesh_normal_consistency, ) class FaceVerseModel(nn.Module): def __init__(self, model_dict, batch_size=1, device='cuda:0', expr_52=True, **kargs): super(FaceVerseModel, self).__init__() self.batch_size = batch_size self.device = torch.device(device) self.rotXYZ = torch.eye(3).view(1, 3, 3).repeat(3, 1, 1).view(3, 1, 3, 3).to(self.device) self.renderer = ModelRenderer(device, **kargs) self.kp_inds = torch.tensor(model_dict['mediapipe_keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device) self.ver_inds = model_dict['ver_inds'] self.tri_inds = model_dict['tri_inds'] meanshape = torch.tensor(model_dict['meanshape'].reshape(-1, 3), dtype=torch.float32, requires_grad=False, device=self.device) meanshape[:, [1, 2]] *= -1 meanshape = meanshape * 0.1 meanshape[:, 1] += 1 self.meanshape = meanshape.reshape(1, -1) self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device) idBase = torch.tensor(model_dict['idBase'].reshape(-1, 3, 150), dtype=torch.float32, requires_grad=False, device=self.device) idBase[:, [1, 2]] *= -1 self.idBase = (idBase * 0.1).reshape(-1, 150) self.expr_52 = expr_52 if expr_52: expBase = torch.tensor(np.load('metamodel/v3/exBase_52.npy').reshape(-1, 3, 52), dtype=torch.float32, requires_grad=False, device=self.device) else: expBase = torch.tensor(model_dict['exBase'].reshape(-1, 3, 171), dtype=torch.float32, requires_grad=False, device=self.device) expBase[:, [1, 2]] *= -1 self.expBase = (expBase * 0.1).reshape(-1, 171) self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device) self.l_eyescale = model_dict['left_eye_exp'] self.r_eyescale = model_dict['right_eye_exp'] self.uv = torch.tensor(model_dict['uv'], dtype=torch.float32, requires_grad=False, device=self.device) self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device) self.tri_uv = torch.tensor(model_dict['tri_uv'], dtype=torch.int64, requires_grad=False, device=self.device) self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device) self.num_vertex = self.meanshape.shape[1] // 3 self.id_dims = self.idBase.shape[1] self.tex_dims = self.texBase.shape[1] self.exp_dims = self.expBase.shape[1] self.all_dims = self.id_dims + self.tex_dims + self.exp_dims self.init_coeff_tensors() # for tracking by landmarks self.kp_inds_view = torch.cat([self.kp_inds[:, None] * 3, self.kp_inds[:, None] * 3 + 1, self.kp_inds[:, None] * 3 + 2], dim=1).flatten() self.idBase_view = self.idBase[self.kp_inds_view, :].detach().clone() self.expBase_view = self.expBase[self.kp_inds_view, :].detach().clone() self.meanshape_view = self.meanshape[:, self.kp_inds_view].detach().clone() # zxc self.identity = torch.eye(3, dtype=torch.float32, device=self.device) self.point_shift = torch.nn.Parameter(torch.zeros(self.num_vertex, 3, dtype=torch.float32, device=self.device)) # [N, 3] def set_renderer(self, intr=None, img_size=256, cam_dist=10., render_depth=False, rasterize_blur_radius=0.): self.renderer = ModelRenderer(self.device, intr, img_size, cam_dist, render_depth, rasterize_blur_radius) def init_coeff_tensors(self, id_coeff=None, tex_coeff=None, exp_coeff=None, gamma_coeff=None, trans_coeff=None, rot_coeff=None, scale_coeff=None, eye_coeff=None): if id_coeff is None: self.id_tensor = torch.zeros((1, self.id_dims), dtype=torch.float32, requires_grad=True, device=self.device) else: assert id_coeff.shape == (1, self.id_dims) self.id_tensor = id_coeff.clone().detach().requires_grad_(True) if tex_coeff is None: self.tex_tensor = torch.zeros((1, self.tex_dims), dtype=torch.float32, requires_grad=True, device=self.device) else: assert tex_coeff.shape == (1, self.tex_dims) self.tex_tensor = tex_coeff.clone().detach().requires_grad_(True) if exp_coeff is None: self.exp_tensor = torch.zeros((self.batch_size, self.exp_dims), dtype=torch.float32, requires_grad=True, device=self.device) else: assert exp_coeff.shape == (1, self.exp_dims) self.exp_tensor = exp_coeff.clone().detach().requires_grad_(True) if gamma_coeff is None: self.gamma_tensor = torch.zeros((self.batch_size, 27), dtype=torch.float32, requires_grad=True, device=self.device) else: self.gamma_tensor = gamma_coeff.clone().detach().requires_grad_(True) if trans_coeff is None: self.trans_tensor = torch.zeros((self.batch_size, 3), dtype=torch.float32, requires_grad=True, device=self.device) else: self.trans_tensor = trans_coeff.clone().detach().requires_grad_(True) if scale_coeff is None: self.scale_tensor = 1.0 * torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device) self.scale_tensor.requires_grad_(True) else: self.scale_tensor = scale_coeff.clone().detach().requires_grad_(True) if rot_coeff is None: self.rot_tensor = torch.zeros((self.batch_size, 3), dtype=torch.float32, requires_grad=True, device=self.device) else: self.rot_tensor = rot_coeff.clone().detach().requires_grad_(True) if eye_coeff is None: self.eye_tensor = torch.zeros( (self.batch_size, 4), dtype=torch.float32, requires_grad=True, device=self.device) else: self.eye_tensor = eye_coeff.clone().detach().requires_grad_(True) def get_lms(self, vs): lms = vs[:, self.kp_inds, :] return lms def split_coeffs(self, coeffs): id_coeff = coeffs[:, :self.id_dims] # identity(shape) coeff exp_coeff = coeffs[:, self.id_dims:self.id_dims + self.exp_dims] # expression coeff tex_coeff = coeffs[:, self.id_dims + self.exp_dims:self.all_dims] # texture(albedo) coeff angles = coeffs[:, self.all_dims:self.all_dims + 3] # ruler angles(x,y,z) for rotation of dim 3 gamma = coeffs[:, self.all_dims + 3:self.all_dims + 30] # lighting coeff for 3 channel SH function of dim 27 translation = coeffs[:, self.all_dims + 30:self.all_dims+33] # translation coeff of dim 3 if coeffs.shape[1] == self.all_dims + 36: # 包含scale eye_coeff = coeffs[:, self.all_dims + 33:] # eye coeff of dim 3 scale = torch.ones_like(coeffs[:, -1:]) else: # 不包含scale eye_coeff = coeffs[:, self.all_dims + 33:-1] # eye coeff of dim 3 scale = coeffs[:, -1:] return id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye_coeff, scale def merge_coeffs(self, id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye, scale): coeffs = torch.cat([id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye, scale], dim=1) return coeffs def get_packed_tensors(self): return self.merge_coeffs(self.id_tensor, self.exp_tensor, self.tex_tensor, self.rot_tensor, self.gamma_tensor, self.trans_tensor, self.eye_tensor, self.scale_tensor) # def get_pytorch3d_mesh(self, coeffs, enable_pts_shift=False): # id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs) # rotation = self.compute_rotation_matrix(angles) # # vs = self.get_vs(id_coeff, exp_coeff) # if enable_pts_shift: # vs = vs + self.point_shift.unsqueeze(0).expand_as(vs) # vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale)) # # face_texture = self.get_color(tex_coeff) # face_norm = self.compute_norm(vs, self.tri, self.point_buf) # face_norm_r = face_norm.bmm(rotation) # face_color = self.add_illumination(face_texture, face_norm_r, gamma) # # face_color_tv = TexturesVertex(face_color) # mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv) # # return mesh def cal_laplacian_regularization(self, enable_pts_shift): current_mesh = self.get_pytorch3d_mesh(self.get_packed_tensors(), enable_pts_shift=enable_pts_shift) disp_reg_loss = mesh_laplacian_smoothing(current_mesh, method="uniform") return disp_reg_loss def forward(self, coeffs, render=True, camT=None, enable_pts_shift=False): id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye_coeff, scale = self.split_coeffs(coeffs) rotation = self.compute_rotation_matrix(angles) if camT is not None: rotation2 = camT[:3, :3].permute(1, 0).reshape(1, 3, 3) translation2 = camT[:3, 3:].permute(1, 0).reshape(1, 1, 3) if torch.allclose(rotation2, self.identity): translation = translation + translation2 else: rotation = torch.matmul(rotation, rotation2) translation = torch.matmul(translation, rotation2) + translation2 l_eye_mat = self.compute_eye_rotation_matrix(eye_coeff[:, :2]) r_eye_mat = self.compute_eye_rotation_matrix(eye_coeff[:, 2:]) l_eye_mean = self.get_l_eye_center(id_coeff) r_eye_mean = self.get_r_eye_center(id_coeff) if render: vs = self.get_vs(id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean) if enable_pts_shift: vs = vs + self.point_shift.unsqueeze(0).expand_as(vs) vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale)) lms_t = self.get_lms(vs_t) lms_proj = self.renderer.project_vs(lms_t) face_texture = self.get_color(tex_coeff) face_norm = self.compute_norm(vs, self.tri, self.point_buf) face_norm_r = face_norm.bmm(rotation) face_color = self.add_illumination(face_texture, face_norm_r, gamma) face_color_tv = TexturesVertex(face_color) mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv) rendered_img = self.renderer.renderer(mesh) return {'rendered_img': rendered_img, 'lms_proj': lms_proj, 'face_texture': face_texture, 'vs': vs_t, 'tri': self.tri, 'color': face_color, 'lms_t': lms_t} else: lms = self.get_vs_lms(id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean) lms_t = self.rigid_transform(lms, rotation, translation, torch.abs(scale)) lms_proj = self.renderer.project_vs(lms_t) return {'lms_proj': lms_proj, 'lms_t': lms_t} def get_vs(self, id_coeff, exp_coeff, l_eye_mat=None, r_eye_mat=None, l_eye_mean=None, r_eye_mean=None): face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \ torch.einsum('ij,aj->ai', self.expBase, exp_coeff) + self.meanshape face_shape = face_shape.view(self.batch_size, -1, 3) if l_eye_mat is not None: face_shape[:, self.ver_inds[0]:self.ver_inds[1]] = torch.matmul(face_shape[:, self.ver_inds[0]:self.ver_inds[1]] - l_eye_mean, l_eye_mat) + l_eye_mean face_shape[:, self.ver_inds[1]:self.ver_inds[2]] = torch.matmul(face_shape[:, self.ver_inds[1]:self.ver_inds[2]] - r_eye_mean, r_eye_mat) + r_eye_mean return face_shape def get_vs_lms(self, id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean): face_shape = torch.einsum('ij,aj->ai', self.idBase_view, id_coeff) + \ torch.einsum('ij,aj->ai', self.expBase_view, exp_coeff) + self.meanshape_view face_shape = face_shape.view(self.batch_size, -1, 3) face_shape[:, 473:478] = torch.matmul(face_shape[:, 473:478] - l_eye_mean, l_eye_mat) + l_eye_mean face_shape[:, 468:473] = torch.matmul(face_shape[:, 468:473] - r_eye_mean, r_eye_mat) + r_eye_mean return face_shape def get_l_eye_center(self, id_coeff): eye_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + self.meanshape eye_shape = eye_shape.view(self.batch_size, -1, 3)[:, self.ver_inds[0]:self.ver_inds[1]] eye_shape[:, :, 2] += 0.005 return torch.mean(eye_shape, dim=1, keepdim=True) def get_r_eye_center(self, id_coeff): eye_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + self.meanshape eye_shape = eye_shape.view(self.batch_size, -1, 3)[:, self.ver_inds[1]:self.ver_inds[2]] eye_shape[:, :, 2] += 0.005 return torch.mean(eye_shape, dim=1, keepdim=True) def get_color(self, tex_coeff): face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex face_texture = face_texture.view(self.batch_size, -1, 3) return face_texture def compute_norm(self, vs, tri, point_buf): face_id = tri point_id = point_buf v1 = vs[:, face_id[:, 0], :] v2 = vs[:, face_id[:, 1], :] v3 = vs[:, face_id[:, 2], :] e1 = v1 - v2 e2 = v2 - v3 face_norm = e1.cross(e2) v_norm = face_norm[:, point_id, :].sum(2) v_norm = v_norm / (v_norm.norm(dim=2).unsqueeze(2) + 1e-9) return v_norm def project_vs(self, vs): vs = torch.matmul(vs, self.reverse_z.repeat((self.batch_size, 1, 1))) + self.camera_pos aug_projection = torch.matmul(vs, self.p_mat.repeat((self.batch_size, 1, 1)).permute((0, 2, 1))) face_projection = aug_projection[:, :, :2] / torch.reshape(aug_projection[:, :, 2], [self.batch_size, -1, 1]) return face_projection def make_rotMat(self, coeffes=None, angle=None, translation=None, scale=None, no_scale=False):# P * rot * scale + trans -> P * T if coeffes is not None: _, _, _, angle, _, translation, scale = self.split_coeffs(coeffes) rotation = self.compute_rotation_matrix(angle) cam_T = torch.eye(4, dtype=torch.float32).to(angle.device) cam_T[:3, :3] = rotation[0] if no_scale else torch.abs(scale[0]) * rotation[0] cam_T[-1, :3] = translation[0] return cam_T def compute_eye_rotation_matrix(self, eye): # 0 left_eye + down - up # 1 left_eye + right - left # 2 right_eye + down - up # 3 right_eye + right - left sinx = torch.sin(eye[:, 0]) siny = torch.sin(eye[:, 1]) cosx = torch.cos(eye[:, 0]) cosy = torch.cos(eye[:, 1]) if self.batch_size != 1: rotXYZ = self.rotXYZ.repeat(1, self.batch_size, 1, 1).detach().clone() else: rotXYZ = self.rotXYZ.detach().clone() rotXYZ[0, :, 1, 1] = cosx rotXYZ[0, :, 1, 2] = -sinx rotXYZ[0, :, 2, 1] = sinx rotXYZ[0, :, 2, 2] = cosx rotXYZ[1, :, 0, 0] = cosy rotXYZ[1, :, 0, 2] = siny rotXYZ[1, :, 2, 0] = -siny rotXYZ[1, :, 2, 2] = cosy rotation = rotXYZ[1].bmm(rotXYZ[0]) return rotation.permute(0, 2, 1) def compute_rotation_matrix(self, angles): sinx = torch.sin(angles[:, 0]) siny = torch.sin(angles[:, 1]) sinz = torch.sin(angles[:, 2]) cosx = torch.cos(angles[:, 0]) cosy = torch.cos(angles[:, 1]) cosz = torch.cos(angles[:, 2]) if self.batch_size != 1: rotXYZ = self.rotXYZ.repeat(1, self.batch_size, 1, 1) else: rotXYZ = self.rotXYZ.detach().clone() rotXYZ[0, :, 1, 1] = cosx rotXYZ[0, :, 1, 2] = -sinx rotXYZ[0, :, 2, 1] = sinx rotXYZ[0, :, 2, 2] = cosx rotXYZ[1, :, 0, 0] = cosy rotXYZ[1, :, 0, 2] = siny rotXYZ[1, :, 2, 0] = -siny rotXYZ[1, :, 2, 2] = cosy rotXYZ[2, :, 0, 0] = cosz rotXYZ[2, :, 0, 1] = -sinz rotXYZ[2, :, 1, 0] = sinz rotXYZ[2, :, 1, 1] = cosz rotation = rotXYZ[2].bmm(rotXYZ[1]).bmm(rotXYZ[0]) return rotation.permute(0, 2, 1) def add_illumination(self, face_texture, norm, gamma): gamma = gamma.view(-1, 3, 9).clone() gamma[:, :, 0] += 0.8 gamma = gamma.permute(0, 2, 1) a0 = np.pi a1 = 2 * np.pi / np.sqrt(3.0) a2 = 2 * np.pi / np.sqrt(8.0) c0 = 1 / np.sqrt(4 * np.pi) c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) d0 = 0.5 / np.sqrt(3.0) norm = norm.view(-1, 3) nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] arrH = [] arrH.append(a0 * c0 * (nx * 0 + 1)) arrH.append(-a1 * c1 * ny) arrH.append(a1 * c1 * nz) arrH.append(-a1 * c1 * nx) arrH.append(a2 * c2 * nx * ny) arrH.append(-a2 * c2 * ny * nz) arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) arrH.append(-a2 * c2 * nx * nz) arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) H = torch.stack(arrH, 1) Y = H.view(self.batch_size, face_texture.shape[1], 9) lighting = Y.bmm(gamma) face_color = face_texture * lighting return face_color def rigid_transform(self, vs, rot, trans, scale): vs_r = torch.matmul(vs * scale, rot) vs_t = vs_r + trans.view(-1, 1, 3) return vs_t def get_rot_tensor(self): return self.rot_tensor def get_trans_tensor(self): return self.trans_tensor def get_exp_tensor(self): return self.exp_tensor def get_tex_tensor(self): return self.tex_tensor def get_id_tensor(self): return self.id_tensor def get_gamma_tensor(self): return self.gamma_tensor def get_scale_tensor(self): return self.scale_tensor class ModelRenderer(nn.Module): def __init__(self, device='cuda:0', intr=None, img_size=256, cam_dist=10., render_depth=False, rasterize_blur_radius=0.): super(ModelRenderer, self).__init__() self.render_depth = render_depth self.img_size = img_size self.device = torch.device(device) self.cam_dist = cam_dist if intr is None: intr = np.eye(3, dtype=np.float32) intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] = 1315, 1315, img_size // 2, img_size // 2 self.fx, self.fy, self.cx, self.cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] self.renderer = self._get_renderer(self.device, cam_dist, torch.from_numpy(intr), render_depth=render_depth, rasterize_blur_radius=rasterize_blur_radius) self.p_mat = self._get_p_mat(device) self.reverse_xz = self._get_reverse_xz(device) self.camera_pos = self._get_camera_pose(device, cam_dist) def _get_renderer(self, device, cam_dist=10., K=None, render_depth=False, rasterize_blur_radius=0.): R, T = look_at_view_transform(cam_dist, 0, 0) # camera's position fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] fx = -fx * 2.0 / (self.img_size - 1) fy = -fy * 2.0 / (self.img_size - 1) cx = - (cx - (self.img_size - 1) / 2.0) * 2.0 / (self.img_size - 1) cy = - (cy - (self.img_size - 1) / 2.0) * 2.0 / (self.img_size - 1) cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=torch.tensor([[fx, fy]], device=device, dtype=torch.float32), principal_point=((cx, cy),), in_ndc=True) lights = PointLights(device=device, location=[[0.0, 0.0, 1e5]], ambient_color=[[1, 1, 1]], specular_color=[[0., 0., 0.]], diffuse_color=[[0., 0., 0.]]) raster_settings = RasterizationSettings( image_size=self.img_size, blur_radius=rasterize_blur_radius if render_depth else 0., faces_per_pixel=1, ) blend_params = blending.BlendParams(background_color=[0, 0, 0]) renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SoftPhongShader( device=device, cameras=cameras, lights=lights, blend_params=blend_params ) ) if not render_depth else \ MeshRendererWithDepth( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SoftPhongShader( device=device, cameras=cameras, lights=lights, blend_params=blend_params ) ) return renderer def _get_camera_pose(self, device, cam_dist=10.): camera_pos = torch.tensor([0.0, 0.0, cam_dist], device=device).reshape(1, 1, 3) return camera_pos def _get_p_mat(self, device): # half_image_width = self.img_size // 2 p_matrix = np.array([self.fx, 0.0, self.cx, 0.0, self.fy, self.cy, 0.0, 0.0, 1.0], dtype=np.float32).reshape(1, 3, 3) return torch.tensor(p_matrix, device=device) def _get_reverse_xz(self, device): reverse_z = np.reshape( np.array([-1.0, 0, 0, 0, 1, 0, 0, 0, -1.0], dtype=np.float32), [1, 3, 3]) return torch.tensor(reverse_z, device=device) def project_vs(self, vs): batchsize = vs.shape[0] vs = torch.matmul(vs, self.reverse_xz.repeat((batchsize, 1, 1))) + self.camera_pos aug_projection = torch.matmul( vs, self.p_mat.repeat((batchsize, 1, 1)).permute((0, 2, 1))) face_projection = aug_projection[:, :, :2] / torch.reshape(aug_projection[:, :, 2], [batchsize, -1, 1]) return face_projection class MeshRendererWithDepth(MeshRenderer): def __init__(self, rasterizer, shader): super().__init__(rasterizer, shader) def forward(self, meshes_world, **kwargs) -> torch.Tensor: fragments = self.rasterizer(meshes_world, **kwargs) images = self.shader(fragments, meshes_world, **kwargs) return images, fragments.zbuf