from __future__ import absolute_import, division, print_function |
from typing import List, Optional, Tuple |
import numpy as np |
import torch |
import torch.nn.functional as F |
from .utils import Tensor, rot_mat_to_euler |
def find_dynamic_lmk_idx_and_bcoords( |
vertices: Tensor, |
pose: Tensor, |
dynamic_lmk_faces_idx: Tensor, |
dynamic_lmk_b_coords: Tensor, |
neck_kin_chain: List[int], |
pose2rot: bool = True, |
) -> Tuple[Tensor, Tensor]: |
"""Compute the faces, barycentric coordinates for the dynamic landmarks |
To do so, we first compute the rotation of the neck around the y-axis |
and then use a pre-computed look-up table to find the faces and the |
barycentric coordinates that will be used. |
Special thanks to Soubhik Sanyal ([email protected]) |
for providing the original TensorFlow implementation and for the LUT. |
Parameters |
---------- |
vertices: torch.tensor BxVx3, dtype = torch.float32 |
The tensor of input vertices |
pose: torch.tensor Bx(Jx3), dtype = torch.float32 |
The current pose of the body model |
dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long |
The look-up table from neck rotation to faces |
dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 |
The look-up table from neck rotation to barycentric coordinates |
neck_kin_chain: list |
A python list that contains the indices of the joints that form the |
kinematic chain of the neck. |
dtype: torch.dtype, optional |
Returns |
------- |
dyn_lmk_faces_idx: torch.tensor, dtype = torch.long |
A tensor of size BxL that contains the indices of the faces that |
will be used to compute the current dynamic landmarks. |
dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 |
A tensor of size BxL that contains the indices of the faces that |
will be used to compute the current dynamic landmarks. |
""" |
dtype = vertices.dtype |
batch_size = vertices.shape[0] |
if pose2rot: |
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, neck_kin_chain) |
rot_mats = batch_rodrigues(aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3) |
else: |
rot_mats = torch.index_select(pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain) |
rel_rot_mat = ( |
torch.eye(3, device=vertices.device, |
dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) |
) |
for idx in range(len(neck_kin_chain)): |
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) |
y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, |
max=39)).to(dtype=torch.long) |
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) |
mask = y_rot_angle.lt(-39).to(dtype=torch.long) |
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) |
y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle |
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) |
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) |
return dyn_lmk_faces_idx, dyn_lmk_b_coords |
def vertices2landmarks( |
vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor |
) -> Tensor: |
"""Calculates landmarks by barycentric interpolation |
Parameters |
---------- |
vertices: torch.tensor BxVx3, dtype = torch.float32 |
The tensor of input vertices |
faces: torch.tensor Fx3, dtype = torch.long |
The faces of the mesh |
lmk_faces_idx: torch.tensor L, dtype = torch.long |
The tensor with the indices of the faces used to calculate the |
landmarks. |
lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 |
The tensor of barycentric coordinates that are used to interpolate |
the landmarks |
Returns |
------- |
landmarks: torch.tensor BxLx3, dtype = torch.float32 |
The coordinates of the landmarks for each mesh in the batch |
""" |
batch_size, num_verts = vertices.shape[:2] |
device = vertices.device |
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3) |
lmk_faces += ( |
torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts |
) |
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) |
landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords]) |
return landmarks |
def lbs( |
betas: Tensor, |
pose: Tensor, |
v_template: Tensor, |
shapedirs: Tensor, |
posedirs: Tensor, |
J_regressor: Tensor, |
parents: Tensor, |
lbs_weights: Tensor, |
pose2rot: bool = True, |
return_transformation: bool = False, |
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: |
"""Performs Linear Blend Skinning with the given shape and pose parameters |
Parameters |
---------- |
betas : torch.tensor BxNB |
The tensor of shape parameters |
pose : torch.tensor Bx(J + 1) * 3 |
The pose parameters in axis-angle format |
v_template torch.tensor BxVx3 |
The template mesh that will be deformed |
shapedirs : torch.tensor 1xNB |
The tensor of PCA shape displacements |
posedirs : torch.tensor Px(V * 3) |
The pose PCA coefficients |
J_regressor : torch.tensor JxV |
The regressor array that is used to calculate the joints from |
the position of the vertices |
parents: torch.tensor J |
The array that describes the kinematic tree for the model |
lbs_weights: torch.tensor N x V x (J + 1) |
The linear blend skinning weights that represent how much the |
rotation matrix of each part affects each vertex |
pose2rot: bool, optional |
Flag on whether to convert the input pose tensor to rotation |
matrices. The default value is True. If False, then the pose tensor |
should already contain rotation matrices and have a size of |
Bx(J + 1)x9 |
dtype: torch.dtype, optional |
Returns |
------- |
verts: torch.tensor BxVx3 |
The vertices of the mesh after applying the shape and pose |
displacements. |
joints: torch.tensor BxJx3 |
The joints of the model |
""" |
batch_size = max(betas.shape[0], pose.shape[0]) |
device, dtype = betas.device, betas.dtype |
v_shaped = v_template + blend_shapes(betas, shapedirs) |
J = vertices2joints(J_regressor, v_shaped) |
ident = torch.eye(3, dtype=dtype, device=device) |
if pose2rot: |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3]) |
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) |
pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) |
else: |
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident |
rot_mats = pose.view(batch_size, -1, 3, 3) |
pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), |
posedirs).view(batch_size, -1, 3) |
v_posed = pose_offsets + v_shaped |
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) |
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) |
num_joints = J_regressor.shape[0] |
T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) |
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device) |
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) |
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) |
verts = v_homo[:, :, :3, 0] |
if return_transformation: |
return verts, J_transformed, A, T |
return verts, J_transformed |
def general_lbs( |
pose: Tensor, |
v_template: Tensor, |
posedirs: Tensor, |
J_regressor: Tensor, |
parents: Tensor, |
lbs_weights: Tensor, |
pose2rot: bool = True, |
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: |
"""Performs Linear Blend Skinning with the given shape and pose parameters |
Parameters |
---------- |
pose : torch.tensor Bx(J + 1) * 3 |
The pose parameters in axis-angle format |
v_template torch.tensor BxVx3 |
The template mesh that will be deformed |
posedirs : torch.tensor Px(V * 3) |
The pose PCA coefficients |
J_regressor : torch.tensor JxV |
The regressor array that is used to calculate the joints from |
the position of the vertices |
parents: torch.tensor J |
The array that describes the kinematic tree for the model |
lbs_weights: torch.tensor N x V x (J + 1) |
The linear blend skinning weights that represent how much the |
rotation matrix of each part affects each vertex |
pose2rot: bool, optional |
Flag on whether to convert the input pose tensor to rotation |
matrices. The default value is True. If False, then the pose tensor |
should already contain rotation matrices and have a size of |
Bx(J + 1)x9 |
dtype: torch.dtype, optional |
Returns |
------- |
verts: torch.tensor BxVx3 |
The vertices of the mesh after applying the shape and pose |
displacements. |
joints: torch.tensor BxJx3 |
The joints of the model |
""" |
batch_size = pose.shape[0] |
device, dtype = pose.device, pose.dtype |
J = vertices2joints(J_regressor, v_template) |
ident = torch.eye(3, dtype=dtype, device=device) |
if pose2rot: |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3]) |
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) |
pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) |
else: |
rot_mats = pose.view(batch_size, -1, 3, 3) |
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident |
pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), |
posedirs).view(batch_size, -1, 3) |
v_posed = pose_offsets + v_template |
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) |
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) |
num_joints = J_regressor.shape[0] |
T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) |
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device) |
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) |
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) |
verts = v_homo[:, :, :3, 0] |
return verts, J_transformed |
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: |
"""Calculates the 3D joint locations from the vertices |
Parameters |
---------- |
J_regressor : torch.tensor JxV |
The regressor array that is used to calculate the joints from the |
position of the vertices |
vertices : torch.tensor BxVx3 |
The tensor of mesh vertices |
Returns |
------- |
torch.tensor BxJx3 |
The location of the joints |
""" |
return torch.einsum("bik,ji->bjk", [vertices, J_regressor]) |
def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor: |
"""Calculates the per vertex displacement due to the blend shapes |
Parameters |
---------- |
betas : torch.tensor Bx(num_betas) |
Blend shape coefficients |
shape_disps: torch.tensor Vx3x(num_betas) |
Blend shapes |
Returns |
------- |
torch.tensor BxVx3 |
The per-vertex displacement due to shape deformation |
""" |
blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps]) |
return blend_shape |
def batch_rodrigues( |
rot_vecs: Tensor, |
epsilon: float = 1e-8, |
) -> Tensor: |
"""Calculates the rotation matrices for a batch of rotation vectors |
Parameters |
---------- |
rot_vecs: torch.tensor Nx3 |
array of N axis-angle vectors |
Returns |
------- |
R: torch.tensor Nx3x3 |
The rotation matrices for the given axis-angle parameters |
""" |
batch_size = rot_vecs.shape[0] |
device, dtype = rot_vecs.device, rot_vecs.dtype |
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) |
rot_dir = rot_vecs / angle |
cos = torch.unsqueeze(torch.cos(angle), dim=1) |
sin = torch.unsqueeze(torch.sin(angle), dim=1) |
rx, ry, rz = torch.split(rot_dir, 1, dim=1) |
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) |
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) |
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) |
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) |
return rot_mat |
def transform_mat(R: Tensor, t: Tensor) -> Tensor: |
"""Creates a batch of transformation matrices |
Args: |
- R: Bx3x3 array of a batch of rotation matrices |
- t: Bx3x1 array of a batch of translation vectors |
Returns: |
- T: Bx4x4 Transformation matrix |
""" |
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) |
def batch_rigid_transform( |
rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32 |
) -> Tensor: |
""" |
Applies a batch of rigid transformations to the joints |
Parameters |
---------- |
rot_mats : torch.tensor BxNx3x3 |
Tensor of rotation matrices |
joints : torch.tensor BxNx3 |
Locations of joints |
parents : torch.tensor BxN |
The kinematic tree of each object |
dtype : torch.dtype, optional: |
The data type of the created tensors, the default is torch.float32 |
Returns |
------- |
posed_joints : torch.tensor BxNx3 |
The locations of the joints after applying the pose rotations |
rel_transforms : torch.tensor BxNx4x4 |
The relative (with respect to the root joint) rigid transformations |
for all the joints |
""" |
joints = torch.unsqueeze(joints, dim=-1) |
rel_joints = joints.clone() |
rel_joints[:, 1:] -= joints[:, parents[1:]] |
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) |
transform_chain = [transforms_mat[:, 0]] |
for i in range(1, parents.shape[0]): |
curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) |
transform_chain.append(curr_res) |
transforms = torch.stack(transform_chain, dim=1) |
posed_joints = transforms[:, :, :3, 3] |
joints_homogen = F.pad(joints, [0, 0, 0, 1]) |
rel_transforms = transforms - F.pad( |
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] |
) |
return posed_joints, rel_transforms |