QZFantasies's picture
add wheels
c614b0f
# Multi-HMR
# Copyright (c) 2024-present NAVER Corp.
# CC BY-NC-SA 4.0 license
import torch
from torch import nn
from torch import nn
import smplx
import torch
import numpy as np
import pose_utils
from pose_utils import inverse_perspective_projection, perspective_projection
import roma
import pickle
import os
from pose_utils.constants_service import SMPLX_DIR
from pose_utils.rot6d import rotation_6d_to_matrix
from smplx.lbs import vertices2joints
class SMPL_Layer(nn.Module):
"""
Extension of the SMPL Layer with information about the camera for (inverse) projection the camera plane.
"""
def __init__(
self,
smpl_dir,
type="smplx",
gender="neutral",
num_betas=10,
kid=False,
person_center=None,
*args,
**kwargs,
):
super().__init__()
# Args
assert type == "smplx"
self.type = type
self.kid = kid
self.num_betas = num_betas
self.bm_x = smplx.create(
smpl_dir, "smplx", gender=gender, use_pca=False, flat_hand_mean=True, num_betas=num_betas
)
# Primary keypoint - root
self.joint_names = eval(f"pose_utils.get_{self.type}_joint_names")()
self.person_center = person_center
self.person_center_idx = None
if self.person_center is not None:
self.person_center_idx = self.joint_names.index(self.person_center)
def forward(
self,
pose,
shape,
loc,
dist,
transl,
K,
expression=None, # facial expression
rot6d=False,
j_regressor=None,
):
"""
Args:
- pose: pose of the person in axis-angle - torch.Tensor [bs,24,3]
- shape: torch.Tensor [bs,10]
- loc: 2D location of the pelvis in pixel space - torch.Tensor [bs,2]
- dist: distance of the pelvis from the camera in m - torch.Tensor [bs,1]
Return:
- dict containing a bunch of useful information about each person
"""
if loc is not None and dist is not None:
assert pose.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0]
POSE_TYPE_LENGTH = 6 if rot6d else 3
if self.type == "smpl":
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [24, POSE_TYPE_LENGTH]
elif self.type == "smplx":
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [
53,
POSE_TYPE_LENGTH,
] # taking root_orient, body_pose, lhand, rhan and jaw for the moment
else:
raise NameError
assert len(shape.shape) == 2 and (
list(shape.shape[1:]) == [self.num_betas] or list(shape.shape[1:]) == [self.num_betas + 1]
)
if loc is not None and dist is not None:
assert len(loc.shape) == 2 and list(loc.shape[1:]) == [2]
assert len(dist.shape) == 2 and list(dist.shape[1:]) == [1]
bs = pose.shape[0]
out = {}
# No humans
if bs == 0:
return {}
# Low dimensional parameters
kwargs_pose = {
"betas": shape,
}
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(bs, 1)
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1)
kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1)
kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1)
kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1)
if expression is not None:
kwargs_pose["expression"] = expression.flatten(1) # [bs,10]
else:
kwargs_pose["expression"] = self.bm_x.expression.repeat(bs, 1)
# default - to be generalized
kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(bs, 1)
kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(bs, 1)
# kwargs_pose['pose2rot'] = not rot6d
# Forward using the parametric 3d model SMPL-X layer
output = self.bm_x(pose2rot=not rot6d, **kwargs_pose)
verts = output.vertices
j3d = output.joints # 45 joints
if rot6d:
R = rotation_6d_to_matrix(pose[:, 0])
else:
R = roma.rotvec_to_rotmat(pose[:, 0])
# Apply global orientation on 3D points
pelvis = j3d[:, [0]]
j3d = (R.unsqueeze(1) @ (j3d - pelvis).unsqueeze(-1)).squeeze(-1)
# Apply global orientation on 3D points - bis
verts = (R.unsqueeze(1) @ (verts - pelvis).unsqueeze(-1)).squeeze(-1)
# Location of the person in 3D
if transl is None:
if K.dtype == torch.float16:
# because of torch.inverse - not working with float16 at the moment
transl = inverse_perspective_projection(
loc.unsqueeze(1).float(), K.float(), dist.unsqueeze(1).float()
)[:, 0]
transl = transl.half()
else:
transl = inverse_perspective_projection(loc.unsqueeze(1), K, dist.unsqueeze(1))[:, 0]
# Updating transl if we choose a certain person center
transl_up = transl.clone()
# Definition of the translation depend on the args: 1) vanilla SMPL - 2) computed from a given joint
if self.person_center_idx is None:
# Add pelvis to transl - standard way for SMPLX layer
transl_up = transl_up + pelvis[:, 0]
else:
# Center around the joint because teh translation is computed from this joint
person_center = j3d[:, [self.person_center_idx]]
verts = verts - person_center
j3d = j3d - person_center
# Moving into the camera coordinate system
j3d_cam = j3d + transl_up.unsqueeze(1)
verts_cam = verts + transl_up.unsqueeze(1)
# Projection in camera plane
if j_regressor is not None:
# for smplify
j3d_cam = vertices2joints(j_regressor, verts_cam)
j2d = perspective_projection(j3d_cam, K)
v2d = perspective_projection(verts_cam, K)
out.update(
{
"v3d": verts_cam, # in 3d camera space
"j3d": j3d_cam, # in 3d camera space
"j2d": j2d,
"v2d": v2d,
"transl": transl, # translation of the primary keypoint
"transl_pelvis": transl.unsqueeze(1) - person_center - pelvis, # root=pelvis
"j3d_world": output.joints,
}
)
return out
def forward_local(self, pose, shape):
N, J, L = pose.shape
if N < 1:
return None
kwargs_pose = {
"betas": shape,
}
if J == 53:
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1)
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1)
kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1)
kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1)
kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1)
elif J==55:
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1)
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1)
kwargs_pose["left_hand_pose"] = pose[:, 25:40].flatten(1)
kwargs_pose["right_hand_pose"] = pose[:, 40:55].flatten(1)
kwargs_pose["jaw_pose"] = pose[:, 22:23].flatten(1)
else:
raise ValueError(f"pose dim error, should be 53 or 55, but got {J}")
kwargs_pose["expression"] = self.bm_x.expression.repeat(N, 1)
# default - to be generalized
kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(N, 1)
kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(N, 1)
output = self.bm_x(**kwargs_pose)
return output
def convert_standard_pose(self, poses):
# pose: N, J, 3
n = poses.shape[0]
poses = torch.cat(
[
poses[:, :22],
poses[:, 52:53],
self.bm_x.leye_pose.repeat(n, 1, 1),
self.bm_x.reye_pose.repeat(n, 1, 1),
poses[:, 22:52],
],
dim=1,
)
return poses