Spaces:
Running
on
Zero
Running
on
Zero
# from metaface_fitting 20221122 | |
import torch | |
from torch import nn | |
import numpy as np | |
from pytorch3d.structures import Meshes | |
# from pytorch3d.renderer import TexturesVertex | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
PerspectiveCameras, | |
PointLights, | |
RasterizationSettings, | |
MeshRenderer, | |
MeshRasterizer, | |
SoftPhongShader, | |
TexturesVertex, | |
blending | |
) | |
from pytorch3d.loss import ( | |
# mesh_edge_loss, | |
mesh_laplacian_smoothing, | |
# mesh_normal_consistency, | |
) | |
class FaceVerseModel(nn.Module): | |
def __init__(self, model_dict, batch_size=1, device='cuda:0', expr_52=True, **kargs): | |
super(FaceVerseModel, self).__init__() | |
self.batch_size = batch_size | |
self.device = torch.device(device) | |
self.rotXYZ = torch.eye(3).view(1, 3, 3).repeat(3, 1, 1).view(3, 1, 3, 3).to(self.device) | |
self.renderer = ModelRenderer(device, **kargs) | |
self.kp_inds = torch.tensor(model_dict['mediapipe_keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device) | |
self.ver_inds = model_dict['ver_inds'] | |
self.tri_inds = model_dict['tri_inds'] | |
meanshape = torch.tensor(model_dict['meanshape'].reshape(-1, 3), dtype=torch.float32, requires_grad=False, device=self.device) | |
meanshape[:, [1, 2]] *= -1 | |
meanshape = meanshape * 0.1 | |
meanshape[:, 1] += 1 | |
self.meanshape = meanshape.reshape(1, -1) | |
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device) | |
idBase = torch.tensor(model_dict['idBase'].reshape(-1, 3, 150), dtype=torch.float32, requires_grad=False, device=self.device) | |
idBase[:, [1, 2]] *= -1 | |
self.idBase = (idBase * 0.1).reshape(-1, 150) | |
self.expr_52 = expr_52 | |
if expr_52: | |
expBase = torch.tensor(np.load('metamodel/v3/exBase_52.npy').reshape(-1, 3, 52), dtype=torch.float32, requires_grad=False, device=self.device) | |
else: | |
expBase = torch.tensor(model_dict['exBase'].reshape(-1, 3, 171), dtype=torch.float32, requires_grad=False, device=self.device) | |
expBase[:, [1, 2]] *= -1 | |
self.expBase = (expBase * 0.1).reshape(-1, 171) | |
self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device) | |
self.l_eyescale = model_dict['left_eye_exp'] | |
self.r_eyescale = model_dict['right_eye_exp'] | |
self.uv = torch.tensor(model_dict['uv'], dtype=torch.float32, requires_grad=False, device=self.device) | |
self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device) | |
self.tri_uv = torch.tensor(model_dict['tri_uv'], dtype=torch.int64, requires_grad=False, device=self.device) | |
self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device) | |
self.num_vertex = self.meanshape.shape[1] // 3 | |
self.id_dims = self.idBase.shape[1] | |
self.tex_dims = self.texBase.shape[1] | |
self.exp_dims = self.expBase.shape[1] | |
self.all_dims = self.id_dims + self.tex_dims + self.exp_dims | |
self.init_coeff_tensors() | |
# for tracking by landmarks | |
self.kp_inds_view = torch.cat([self.kp_inds[:, None] * 3, self.kp_inds[:, None] * 3 + 1, self.kp_inds[:, None] * 3 + 2], dim=1).flatten() | |
self.idBase_view = self.idBase[self.kp_inds_view, :].detach().clone() | |
self.expBase_view = self.expBase[self.kp_inds_view, :].detach().clone() | |
self.meanshape_view = self.meanshape[:, self.kp_inds_view].detach().clone() | |
# zxc | |
self.identity = torch.eye(3, dtype=torch.float32, device=self.device) | |
self.point_shift = torch.nn.Parameter(torch.zeros(self.num_vertex, 3, dtype=torch.float32, device=self.device)) # [N, 3] | |
def set_renderer(self, intr=None, img_size=256, cam_dist=10., render_depth=False, rasterize_blur_radius=0.): | |
self.renderer = ModelRenderer(self.device, intr, img_size, cam_dist, render_depth, rasterize_blur_radius) | |
def init_coeff_tensors(self, id_coeff=None, tex_coeff=None, exp_coeff=None, gamma_coeff=None, trans_coeff=None, rot_coeff=None, scale_coeff=None, eye_coeff=None): | |
if id_coeff is None: | |
self.id_tensor = torch.zeros((1, self.id_dims), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
assert id_coeff.shape == (1, self.id_dims) | |
self.id_tensor = id_coeff.clone().detach().requires_grad_(True) | |
if tex_coeff is None: | |
self.tex_tensor = torch.zeros((1, self.tex_dims), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
assert tex_coeff.shape == (1, self.tex_dims) | |
self.tex_tensor = tex_coeff.clone().detach().requires_grad_(True) | |
if exp_coeff is None: | |
self.exp_tensor = torch.zeros((self.batch_size, self.exp_dims), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
assert exp_coeff.shape == (1, self.exp_dims) | |
self.exp_tensor = exp_coeff.clone().detach().requires_grad_(True) | |
if gamma_coeff is None: | |
self.gamma_tensor = torch.zeros((self.batch_size, 27), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
self.gamma_tensor = gamma_coeff.clone().detach().requires_grad_(True) | |
if trans_coeff is None: | |
self.trans_tensor = torch.zeros((self.batch_size, 3), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
self.trans_tensor = trans_coeff.clone().detach().requires_grad_(True) | |
if scale_coeff is None: | |
self.scale_tensor = 1.0 * torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device) | |
self.scale_tensor.requires_grad_(True) | |
else: | |
self.scale_tensor = scale_coeff.clone().detach().requires_grad_(True) | |
if rot_coeff is None: | |
self.rot_tensor = torch.zeros((self.batch_size, 3), dtype=torch.float32, requires_grad=True, device=self.device) | |
else: | |
self.rot_tensor = rot_coeff.clone().detach().requires_grad_(True) | |
if eye_coeff is None: | |
self.eye_tensor = torch.zeros( | |
(self.batch_size, 4), dtype=torch.float32, | |
requires_grad=True, device=self.device) | |
else: | |
self.eye_tensor = eye_coeff.clone().detach().requires_grad_(True) | |
def get_lms(self, vs): | |
lms = vs[:, self.kp_inds, :] | |
return lms | |
def split_coeffs(self, coeffs): | |
id_coeff = coeffs[:, :self.id_dims] # identity(shape) coeff | |
exp_coeff = coeffs[:, self.id_dims:self.id_dims + self.exp_dims] # expression coeff | |
tex_coeff = coeffs[:, self.id_dims + self.exp_dims:self.all_dims] # texture(albedo) coeff | |
angles = coeffs[:, self.all_dims:self.all_dims + 3] # ruler angles(x,y,z) for rotation of dim 3 | |
gamma = coeffs[:, self.all_dims + 3:self.all_dims + 30] # lighting coeff for 3 channel SH function of dim 27 | |
translation = coeffs[:, self.all_dims + 30:self.all_dims+33] # translation coeff of dim 3 | |
if coeffs.shape[1] == self.all_dims + 36: # 包含scale | |
eye_coeff = coeffs[:, self.all_dims + 33:] # eye coeff of dim 3 | |
scale = torch.ones_like(coeffs[:, -1:]) | |
else: # 不包含scale | |
eye_coeff = coeffs[:, self.all_dims + 33:-1] # eye coeff of dim 3 | |
scale = coeffs[:, -1:] | |
return id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye_coeff, scale | |
def merge_coeffs(self, id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye, scale): | |
coeffs = torch.cat([id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye, scale], dim=1) | |
return coeffs | |
def get_packed_tensors(self): | |
return self.merge_coeffs(self.id_tensor, | |
self.exp_tensor, | |
self.tex_tensor, | |
self.rot_tensor, self.gamma_tensor, | |
self.trans_tensor, self.eye_tensor, self.scale_tensor) | |
# def get_pytorch3d_mesh(self, coeffs, enable_pts_shift=False): | |
# id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs) | |
# rotation = self.compute_rotation_matrix(angles) | |
# | |
# vs = self.get_vs(id_coeff, exp_coeff) | |
# if enable_pts_shift: | |
# vs = vs + self.point_shift.unsqueeze(0).expand_as(vs) | |
# vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale)) | |
# | |
# face_texture = self.get_color(tex_coeff) | |
# face_norm = self.compute_norm(vs, self.tri, self.point_buf) | |
# face_norm_r = face_norm.bmm(rotation) | |
# face_color = self.add_illumination(face_texture, face_norm_r, gamma) | |
# | |
# face_color_tv = TexturesVertex(face_color) | |
# mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv) | |
# | |
# return mesh | |
def cal_laplacian_regularization(self, enable_pts_shift): | |
current_mesh = self.get_pytorch3d_mesh(self.get_packed_tensors(), enable_pts_shift=enable_pts_shift) | |
disp_reg_loss = mesh_laplacian_smoothing(current_mesh, method="uniform") | |
return disp_reg_loss | |
def forward(self, coeffs, render=True, camT=None, enable_pts_shift=False): | |
id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, eye_coeff, scale = self.split_coeffs(coeffs) | |
rotation = self.compute_rotation_matrix(angles) | |
if camT is not None: | |
rotation2 = camT[:3, :3].permute(1, 0).reshape(1, 3, 3) | |
translation2 = camT[:3, 3:].permute(1, 0).reshape(1, 1, 3) | |
if torch.allclose(rotation2, self.identity): | |
translation = translation + translation2 | |
else: | |
rotation = torch.matmul(rotation, rotation2) | |
translation = torch.matmul(translation, rotation2) + translation2 | |
l_eye_mat = self.compute_eye_rotation_matrix(eye_coeff[:, :2]) | |
r_eye_mat = self.compute_eye_rotation_matrix(eye_coeff[:, 2:]) | |
l_eye_mean = self.get_l_eye_center(id_coeff) | |
r_eye_mean = self.get_r_eye_center(id_coeff) | |
if render: | |
vs = self.get_vs(id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean) | |
if enable_pts_shift: | |
vs = vs + self.point_shift.unsqueeze(0).expand_as(vs) | |
vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale)) | |
lms_t = self.get_lms(vs_t) | |
lms_proj = self.renderer.project_vs(lms_t) | |
face_texture = self.get_color(tex_coeff) | |
face_norm = self.compute_norm(vs, self.tri, self.point_buf) | |
face_norm_r = face_norm.bmm(rotation) | |
face_color = self.add_illumination(face_texture, face_norm_r, gamma) | |
face_color_tv = TexturesVertex(face_color) | |
mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv) | |
rendered_img = self.renderer.renderer(mesh) | |
return {'rendered_img': rendered_img, | |
'lms_proj': lms_proj, | |
'face_texture': face_texture, | |
'vs': vs_t, | |
'tri': self.tri, | |
'color': face_color, 'lms_t': lms_t} | |
else: | |
lms = self.get_vs_lms(id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean) | |
lms_t = self.rigid_transform(lms, rotation, translation, torch.abs(scale)) | |
lms_proj = self.renderer.project_vs(lms_t) | |
return {'lms_proj': lms_proj, 'lms_t': lms_t} | |
def get_vs(self, id_coeff, exp_coeff, l_eye_mat=None, r_eye_mat=None, l_eye_mean=None, r_eye_mean=None): | |
face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \ | |
torch.einsum('ij,aj->ai', self.expBase, exp_coeff) + self.meanshape | |
face_shape = face_shape.view(self.batch_size, -1, 3) | |
if l_eye_mat is not None: | |
face_shape[:, self.ver_inds[0]:self.ver_inds[1]] = torch.matmul(face_shape[:, self.ver_inds[0]:self.ver_inds[1]] - l_eye_mean, l_eye_mat) + l_eye_mean | |
face_shape[:, self.ver_inds[1]:self.ver_inds[2]] = torch.matmul(face_shape[:, self.ver_inds[1]:self.ver_inds[2]] - r_eye_mean, r_eye_mat) + r_eye_mean | |
return face_shape | |
def get_vs_lms(self, id_coeff, exp_coeff, l_eye_mat, r_eye_mat, l_eye_mean, r_eye_mean): | |
face_shape = torch.einsum('ij,aj->ai', self.idBase_view, id_coeff) + \ | |
torch.einsum('ij,aj->ai', self.expBase_view, exp_coeff) + self.meanshape_view | |
face_shape = face_shape.view(self.batch_size, -1, 3) | |
face_shape[:, 473:478] = torch.matmul(face_shape[:, 473:478] - l_eye_mean, l_eye_mat) + l_eye_mean | |
face_shape[:, 468:473] = torch.matmul(face_shape[:, 468:473] - r_eye_mean, r_eye_mat) + r_eye_mean | |
return face_shape | |
def get_l_eye_center(self, id_coeff): | |
eye_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + self.meanshape | |
eye_shape = eye_shape.view(self.batch_size, -1, 3)[:, self.ver_inds[0]:self.ver_inds[1]] | |
eye_shape[:, :, 2] += 0.005 | |
return torch.mean(eye_shape, dim=1, keepdim=True) | |
def get_r_eye_center(self, id_coeff): | |
eye_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + self.meanshape | |
eye_shape = eye_shape.view(self.batch_size, -1, 3)[:, self.ver_inds[1]:self.ver_inds[2]] | |
eye_shape[:, :, 2] += 0.005 | |
return torch.mean(eye_shape, dim=1, keepdim=True) | |
def get_color(self, tex_coeff): | |
face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex | |
face_texture = face_texture.view(self.batch_size, -1, 3) | |
return face_texture | |
def compute_norm(self, vs, tri, point_buf): | |
face_id = tri | |
point_id = point_buf | |
v1 = vs[:, face_id[:, 0], :] | |
v2 = vs[:, face_id[:, 1], :] | |
v3 = vs[:, face_id[:, 2], :] | |
e1 = v1 - v2 | |
e2 = v2 - v3 | |
face_norm = e1.cross(e2) | |
v_norm = face_norm[:, point_id, :].sum(2) | |
v_norm = v_norm / (v_norm.norm(dim=2).unsqueeze(2) + 1e-9) | |
return v_norm | |
def project_vs(self, vs): | |
vs = torch.matmul(vs, self.reverse_z.repeat((self.batch_size, 1, 1))) + self.camera_pos | |
aug_projection = torch.matmul(vs, self.p_mat.repeat((self.batch_size, 1, 1)).permute((0, 2, 1))) | |
face_projection = aug_projection[:, :, :2] / torch.reshape(aug_projection[:, :, 2], [self.batch_size, -1, 1]) | |
return face_projection | |
def make_rotMat(self, coeffes=None, angle=None, translation=None, scale=None, no_scale=False):# P * rot * scale + trans -> P * T | |
if coeffes is not None: | |
_, _, _, angle, _, translation, scale = self.split_coeffs(coeffes) | |
rotation = self.compute_rotation_matrix(angle) | |
cam_T = torch.eye(4, dtype=torch.float32).to(angle.device) | |
cam_T[:3, :3] = rotation[0] if no_scale else torch.abs(scale[0]) * rotation[0] | |
cam_T[-1, :3] = translation[0] | |
return cam_T | |
def compute_eye_rotation_matrix(self, eye): | |
# 0 left_eye + down - up | |
# 1 left_eye + right - left | |
# 2 right_eye + down - up | |
# 3 right_eye + right - left | |
sinx = torch.sin(eye[:, 0]) | |
siny = torch.sin(eye[:, 1]) | |
cosx = torch.cos(eye[:, 0]) | |
cosy = torch.cos(eye[:, 1]) | |
if self.batch_size != 1: | |
rotXYZ = self.rotXYZ.repeat(1, self.batch_size, 1, 1).detach().clone() | |
else: | |
rotXYZ = self.rotXYZ.detach().clone() | |
rotXYZ[0, :, 1, 1] = cosx | |
rotXYZ[0, :, 1, 2] = -sinx | |
rotXYZ[0, :, 2, 1] = sinx | |
rotXYZ[0, :, 2, 2] = cosx | |
rotXYZ[1, :, 0, 0] = cosy | |
rotXYZ[1, :, 0, 2] = siny | |
rotXYZ[1, :, 2, 0] = -siny | |
rotXYZ[1, :, 2, 2] = cosy | |
rotation = rotXYZ[1].bmm(rotXYZ[0]) | |
return rotation.permute(0, 2, 1) | |
def compute_rotation_matrix(self, angles): | |
sinx = torch.sin(angles[:, 0]) | |
siny = torch.sin(angles[:, 1]) | |
sinz = torch.sin(angles[:, 2]) | |
cosx = torch.cos(angles[:, 0]) | |
cosy = torch.cos(angles[:, 1]) | |
cosz = torch.cos(angles[:, 2]) | |
if self.batch_size != 1: | |
rotXYZ = self.rotXYZ.repeat(1, self.batch_size, 1, 1) | |
else: | |
rotXYZ = self.rotXYZ.detach().clone() | |
rotXYZ[0, :, 1, 1] = cosx | |
rotXYZ[0, :, 1, 2] = -sinx | |
rotXYZ[0, :, 2, 1] = sinx | |
rotXYZ[0, :, 2, 2] = cosx | |
rotXYZ[1, :, 0, 0] = cosy | |
rotXYZ[1, :, 0, 2] = siny | |
rotXYZ[1, :, 2, 0] = -siny | |
rotXYZ[1, :, 2, 2] = cosy | |
rotXYZ[2, :, 0, 0] = cosz | |
rotXYZ[2, :, 0, 1] = -sinz | |
rotXYZ[2, :, 1, 0] = sinz | |
rotXYZ[2, :, 1, 1] = cosz | |
rotation = rotXYZ[2].bmm(rotXYZ[1]).bmm(rotXYZ[0]) | |
return rotation.permute(0, 2, 1) | |
def add_illumination(self, face_texture, norm, gamma): | |
gamma = gamma.view(-1, 3, 9).clone() | |
gamma[:, :, 0] += 0.8 | |
gamma = gamma.permute(0, 2, 1) | |
a0 = np.pi | |
a1 = 2 * np.pi / np.sqrt(3.0) | |
a2 = 2 * np.pi / np.sqrt(8.0) | |
c0 = 1 / np.sqrt(4 * np.pi) | |
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) | |
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) | |
d0 = 0.5 / np.sqrt(3.0) | |
norm = norm.view(-1, 3) | |
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] | |
arrH = [] | |
arrH.append(a0 * c0 * (nx * 0 + 1)) | |
arrH.append(-a1 * c1 * ny) | |
arrH.append(a1 * c1 * nz) | |
arrH.append(-a1 * c1 * nx) | |
arrH.append(a2 * c2 * nx * ny) | |
arrH.append(-a2 * c2 * ny * nz) | |
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) | |
arrH.append(-a2 * c2 * nx * nz) | |
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) | |
H = torch.stack(arrH, 1) | |
Y = H.view(self.batch_size, face_texture.shape[1], 9) | |
lighting = Y.bmm(gamma) | |
face_color = face_texture * lighting | |
return face_color | |
def rigid_transform(self, vs, rot, trans, scale): | |
vs_r = torch.matmul(vs * scale, rot) | |
vs_t = vs_r + trans.view(-1, 1, 3) | |
return vs_t | |
def get_rot_tensor(self): | |
return self.rot_tensor | |
def get_trans_tensor(self): | |
return self.trans_tensor | |
def get_exp_tensor(self): | |
return self.exp_tensor | |
def get_tex_tensor(self): | |
return self.tex_tensor | |
def get_id_tensor(self): | |
return self.id_tensor | |
def get_gamma_tensor(self): | |
return self.gamma_tensor | |
def get_scale_tensor(self): | |
return self.scale_tensor | |
class ModelRenderer(nn.Module): | |
def __init__(self, device='cuda:0', intr=None, img_size=256, cam_dist=10., render_depth=False, rasterize_blur_radius=0.): | |
super(ModelRenderer, self).__init__() | |
self.render_depth = render_depth | |
self.img_size = img_size | |
self.device = torch.device(device) | |
self.cam_dist = cam_dist | |
if intr is None: | |
intr = np.eye(3, dtype=np.float32) | |
intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] = 1315, 1315, img_size // 2, img_size // 2 | |
self.fx, self.fy, self.cx, self.cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] | |
self.renderer = self._get_renderer(self.device, cam_dist, torch.from_numpy(intr), render_depth=render_depth, rasterize_blur_radius=rasterize_blur_radius) | |
self.p_mat = self._get_p_mat(device) | |
self.reverse_xz = self._get_reverse_xz(device) | |
self.camera_pos = self._get_camera_pose(device, cam_dist) | |
def _get_renderer(self, device, cam_dist=10., K=None, render_depth=False, rasterize_blur_radius=0.): | |
R, T = look_at_view_transform(cam_dist, 0, 0) # camera's position | |
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] | |
fx = -fx * 2.0 / (self.img_size - 1) | |
fy = -fy * 2.0 / (self.img_size - 1) | |
cx = - (cx - (self.img_size - 1) / 2.0) * 2.0 / (self.img_size - 1) | |
cy = - (cy - (self.img_size - 1) / 2.0) * 2.0 / (self.img_size - 1) | |
cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=torch.tensor([[fx, fy]], device=device, dtype=torch.float32), | |
principal_point=((cx, cy),), | |
in_ndc=True) | |
lights = PointLights(device=device, location=[[0.0, 0.0, 1e5]], | |
ambient_color=[[1, 1, 1]], | |
specular_color=[[0., 0., 0.]], diffuse_color=[[0., 0., 0.]]) | |
raster_settings = RasterizationSettings( | |
image_size=self.img_size, | |
blur_radius=rasterize_blur_radius if render_depth else 0., | |
faces_per_pixel=1, | |
) | |
blend_params = blending.BlendParams(background_color=[0, 0, 0]) | |
renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=raster_settings | |
), | |
shader=SoftPhongShader( | |
device=device, | |
cameras=cameras, | |
lights=lights, | |
blend_params=blend_params | |
) | |
) if not render_depth else \ | |
MeshRendererWithDepth( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=raster_settings | |
), | |
shader=SoftPhongShader( | |
device=device, | |
cameras=cameras, | |
lights=lights, | |
blend_params=blend_params | |
) | |
) | |
return renderer | |
def _get_camera_pose(self, device, cam_dist=10.): | |
camera_pos = torch.tensor([0.0, 0.0, cam_dist], device=device).reshape(1, 1, 3) | |
return camera_pos | |
def _get_p_mat(self, device): | |
# half_image_width = self.img_size // 2 | |
p_matrix = np.array([self.fx, 0.0, self.cx, | |
0.0, self.fy, self.cy, | |
0.0, 0.0, 1.0], dtype=np.float32).reshape(1, 3, 3) | |
return torch.tensor(p_matrix, device=device) | |
def _get_reverse_xz(self, device): | |
reverse_z = np.reshape( | |
np.array([-1.0, 0, 0, 0, 1, 0, 0, 0, -1.0], dtype=np.float32), [1, 3, 3]) | |
return torch.tensor(reverse_z, device=device) | |
def project_vs(self, vs): | |
batchsize = vs.shape[0] | |
vs = torch.matmul(vs, self.reverse_xz.repeat((batchsize, 1, 1))) + self.camera_pos | |
aug_projection = torch.matmul( | |
vs, self.p_mat.repeat((batchsize, 1, 1)).permute((0, 2, 1))) | |
face_projection = aug_projection[:, :, :2] / torch.reshape(aug_projection[:, :, 2], [batchsize, -1, 1]) | |
return face_projection | |
class MeshRendererWithDepth(MeshRenderer): | |
def __init__(self, rasterizer, shader): | |
super().__init__(rasterizer, shader) | |
def forward(self, meshes_world, **kwargs) -> torch.Tensor: | |
fragments = self.rasterizer(meshes_world, **kwargs) | |
images = self.shader(fragments, meshes_world, **kwargs) | |
return images, fragments.zbuf |