import logging |
import os |
import os.path as osp |
import pickle |
from collections import namedtuple |
from typing import Dict, Optional, Union |
import numpy as np |
import torch |
import torch.nn as nn |
logging.getLogger("smplx").setLevel(logging.ERROR) |
from .lbs import find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks |
from .utils import ( |
Array, |
FLAMEOutput, |
MANOOutput, |
SMPLHOutput, |
SMPLOutput, |
SMPLXOutput, |
Struct, |
Tensor, |
find_joint_kin_chain, |
to_np, |
to_tensor, |
) |
from .vertex_ids import vertex_ids as VERTEX_IDS |
from .vertex_joint_selector import VertexJointSelector |
ModelOutput = namedtuple( |
"ModelOutput", |
[ |
"vertices", |
"joints", |
"full_pose", |
"betas", |
"global_orient", |
"body_pose", |
"expression", |
"left_hand_pose", |
"right_hand_pose", |
"jaw_pose", |
], |
) |
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) |
class SMPL(nn.Module): |
def __init__( |
self, |
model_path: str, |
kid_template_path: str = "", |
data_struct: Optional[Struct] = None, |
create_betas: bool = True, |
betas: Optional[Tensor] = None, |
num_betas: int = 10, |
create_global_orient: bool = True, |
global_orient: Optional[Tensor] = None, |
create_body_pose: bool = True, |
body_pose: Optional[Tensor] = None, |
create_transl: bool = True, |
transl: Optional[Tensor] = None, |
dtype=torch.float32, |
batch_size: int = 1, |
joint_mapper=None, |
gender: str = "neutral", |
age: str = "adult", |
vertex_ids: Dict[str, int] = None, |
v_template: Optional[Union[Tensor, Array]] = None, |
v_personal: Optional[Union[Tensor, Array]] = None, |
**kwargs, |
) -> None: |
"""SMPL model constructor |
Parameters |
---------- |
model_path: str |
The path to the folder or to the file where the model |
parameters are stored |
data_struct: Strct |
A struct object. If given, then the parameters of the model are |
read from the object. Otherwise, the model tries to read the |
parameters from the given `model_path`. (default = None) |
create_global_orient: bool, optional |
Flag for creating a member variable for the global orientation |
of the body. (default = True) |
global_orient: torch.tensor, optional, Bx3 |
The default value for the global orientation variable. |
(default = None) |
create_body_pose: bool, optional |
Flag for creating a member variable for the pose of the body. |
(default = True) |
body_pose: torch.tensor, optional, Bx(Body Joints * 3) |
The default value for the body pose variable. |
(default = None) |
num_betas: int, optional |
Number of shape components to use |
(default = 10). |
create_betas: bool, optional |
Flag for creating a member variable for the shape space |
(default = True). |
betas: torch.tensor, optional, Bx10 |
The default value for the shape member variable. |
(default = None) |
create_transl: bool, optional |
Flag for creating a member variable for the translation |
of the body. (default = True) |
transl: torch.tensor, optional, Bx3 |
The default value for the transl variable. |
(default = None) |
dtype: torch.dtype, optional |
The data type for the created variables |
batch_size: int, optional |
The batch size used for creating the member variables |
joint_mapper: object, optional |
An object that re-maps the joints. Useful if one wants to |
re-order the SMPL joints to some other convention (e.g. MSCOCO) |
(default = None) |
gender: str, optional |
Which gender to load |
vertex_ids: dict, optional |
A dictionary containing the indices of the extra vertices that |
will be selected |
""" |
self.gender = gender |
self.age = age |
if data_struct is None: |
if osp.isdir(model_path): |
model_fn = "SMPL_{}.{ext}".format(gender.upper(), ext="pkl") |
smpl_path = os.path.join(model_path, model_fn) |
else: |
smpl_path = model_path |
assert osp.exists(smpl_path), "Path {} does not exist!".format(smpl_path) |
with open(smpl_path, "rb") as smpl_file: |
data_struct = Struct(**pickle.load(smpl_file, encoding="latin1")) |
super(SMPL, self).__init__() |
self.batch_size = batch_size |
shapedirs = data_struct.shapedirs |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM: |
num_betas = min(num_betas, 10) |
else: |
num_betas = min(num_betas, self.SHAPE_SPACE_DIM) |
if self.age == "kid": |
v_template_smil = np.load(kid_template_path) |
v_template_smil -= np.mean(v_template_smil, axis=0) |
v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2) |
shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2) |
num_betas = num_betas + 1 |
self._num_betas = num_betas |
shapedirs = shapedirs[:, :, :num_betas] |
self.register_buffer("shapedirs", to_tensor(to_np(shapedirs), dtype=dtype)) |
if vertex_ids is None: |
vertex_ids = VERTEX_IDS["smplh"] |
self.dtype = dtype |
self.joint_mapper = joint_mapper |
self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, **kwargs) |
self.faces = data_struct.f |
self.register_buffer( |
"faces_tensor", |
to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long), |
) |
if create_betas: |
if betas is None: |
default_betas = torch.zeros([batch_size, self.num_betas], dtype=dtype) |
else: |
if torch.is_tensor(betas): |
default_betas = betas.clone().detach() |
else: |
default_betas = torch.tensor(betas, dtype=dtype) |
self.register_parameter("betas", nn.Parameter(default_betas, requires_grad=True)) |
if create_global_orient: |
if global_orient is None: |
default_global_orient = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
if torch.is_tensor(global_orient): |
default_global_orient = global_orient.clone().detach() |
else: |
default_global_orient = torch.tensor(global_orient, dtype=dtype) |
global_orient = nn.Parameter(default_global_orient, requires_grad=True) |
self.register_parameter("global_orient", global_orient) |
if create_body_pose: |
if body_pose is None: |
default_body_pose = torch.zeros([batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) |
else: |
if torch.is_tensor(body_pose): |
default_body_pose = body_pose.clone().detach() |
else: |
default_body_pose = torch.tensor(body_pose, dtype=dtype) |
self.register_parameter( |
"body_pose", nn.Parameter(default_body_pose, requires_grad=True) |
) |
if create_transl: |
if transl is None: |
default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True) |
else: |
default_transl = torch.tensor(transl, dtype=dtype) |
self.register_parameter("transl", nn.Parameter(default_transl, requires_grad=True)) |
if v_template is None: |
v_template = data_struct.v_template |
if not torch.is_tensor(v_template): |
v_template = to_tensor(to_np(v_template), dtype=dtype) |
if v_personal is not None: |
v_personal = to_tensor(to_np(v_personal), dtype=dtype) |
v_template += v_personal |
self.register_buffer("v_template", v_template) |
j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) |
self.register_buffer("J_regressor", j_regressor) |
num_pose_basis = data_struct.posedirs.shape[-1] |
posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T |
self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=dtype)) |
parents = to_tensor(to_np(data_struct.kintree_table[0])).long() |
parents[0] = -1 |
self.register_buffer("parents", parents) |
self.register_buffer("lbs_weights", to_tensor(to_np(data_struct.weights), dtype=dtype)) |
@property |
def num_betas(self): |
return self._num_betas |
@property |
def num_expression_coeffs(self): |
return 0 |
def create_mean_pose(self, data_struct) -> Tensor: |
pass |
def name(self) -> str: |
return "SMPL" |
@torch.no_grad() |
def reset_params(self, **params_dict) -> None: |
for param_name, param in self.named_parameters(): |
if param_name in params_dict: |
param[:] = torch.tensor(params_dict[param_name]) |
else: |
param.fill_(0) |
def get_num_verts(self) -> int: |
return self.v_template.shape[0] |
def get_num_faces(self) -> int: |
return self.faces.shape[0] |
def extra_repr(self) -> str: |
msg = [ |
f"Gender: {self.gender.upper()}", |
f"Number of joints: {self.J_regressor.shape[0]}", |
f"Betas: {self.num_betas}", |
] |
return "\n".join(msg) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts=True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> SMPLOutput: |
"""Forward pass for the SMPL model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable and use it as the global |
rotation of the body. Useful if someone wishes to predicts this |
with an external model. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
If given, ignore the member variable `betas` and use it |
instead. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
body_pose: torch.tensor, optional, shape Bx(J*3) |
If given, ignore the member variable `body_pose` and use it |
instead. For example, it can used if someone predicts the |
pose of the body joints are predicted from some external model. |
It should be a tensor that contains joint rotations in |
axis-angle format. (default=None) |
transl: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `transl` and use it |
instead. For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
""" |
global_orient = (global_orient if global_orient is not None else self.global_orient) |
body_pose = body_pose if body_pose is not None else self.body_pose |
betas = betas if betas is not None else self.betas |
apply_trans = transl is not None or hasattr(self, "transl") |
if transl is None and hasattr(self, "transl"): |
transl = self.transl |
full_pose = torch.cat([global_orient, body_pose], dim=1) |
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) |
if betas.shape[0] != batch_size: |
num_repeats = int(batch_size / betas.shape[0]) |
betas = betas.expand(num_repeats, -1) |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=pose2rot, |
) |
joints = self.vertex_joint_selector(vertices, joints) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if apply_trans: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLOutput( |
vertices=vertices if return_verts else None, |
global_orient=global_orient, |
body_pose=body_pose, |
joints=joints, |
betas=betas, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class SMPLLayer(SMPL): |
def __init__(self, *args, **kwargs) -> None: |
super(SMPLLayer, self).__init__( |
create_body_pose=False, |
create_betas=False, |
create_global_orient=False, |
create_transl=False, |
*args, |
**kwargs, |
) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts=True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> SMPLOutput: |
"""Forward pass for the SMPL model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3x3 |
Global rotation of the body. Useful if someone wishes to |
predicts this with an external model. It is expected to be in |
rotation matrix format. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
Shape parameters. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
body_pose: torch.tensor, optional, shape BxJx3x3 |
Body pose. For example, it can used if someone predicts the |
pose of the body joints are predicted from some external model. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
transl: torch.tensor, optional, shape Bx3 |
Translation vector of the body. |
For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
""" |
model_vars = [betas, global_orient, body_pose, transl] |
batch_size = 1 |
for var in model_vars: |
if var is None: |
continue |
batch_size = max(batch_size, len(var)) |
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
if global_orient is None: |
global_orient = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if body_pose is None: |
body_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, |
3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
-1).contiguous() |
) |
if betas is None: |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) |
if transl is None: |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
full_pose = torch.cat( |
[ |
global_orient.reshape(-1, 1, 3, 3), |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
], |
dim=1, |
) |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=False, |
) |
joints = self.vertex_joint_selector(vertices, joints) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if transl is not None: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLOutput( |
vertices=vertices if return_verts else None, |
global_orient=global_orient, |
body_pose=body_pose, |
joints=joints, |
betas=betas, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class SMPLH(SMPL): |
def __init__( |
self, |
model_path, |
kid_template_path: str = "", |
data_struct: Optional[Struct] = None, |
create_left_hand_pose: bool = True, |
left_hand_pose: Optional[Tensor] = None, |
create_right_hand_pose: bool = True, |
right_hand_pose: Optional[Tensor] = None, |
use_pca: bool = True, |
num_pca_comps: int = 6, |
flat_hand_mean: bool = False, |
batch_size: int = 1, |
gender: str = "neutral", |
age: str = "adult", |
dtype=torch.float32, |
vertex_ids=None, |
use_compressed: bool = True, |
ext: str = "pkl", |
**kwargs, |
) -> None: |
"""SMPLH model constructor |
Parameters |
---------- |
model_path: str |
The path to the folder or to the file where the model |
parameters are stored |
data_struct: Strct |
A struct object. If given, then the parameters of the model are |
read from the object. Otherwise, the model tries to read the |
parameters from the given `model_path`. (default = None) |
create_left_hand_pose: bool, optional |
Flag for creating a member variable for the pose of the left |
hand. (default = True) |
left_hand_pose: torch.tensor, optional, BxP |
The default value for the left hand pose member variable. |
(default = None) |
create_right_hand_pose: bool, optional |
Flag for creating a member variable for the pose of the right |
hand. (default = True) |
right_hand_pose: torch.tensor, optional, BxP |
The default value for the right hand pose member variable. |
(default = None) |
num_pca_comps: int, optional |
The number of PCA components to use for each hand. |
(default = 6) |
flat_hand_mean: bool, optional |
If False, then the pose of the hand is initialized to False. |
batch_size: int, optional |
The batch size used for creating the member variables |
gender: str, optional |
Which gender to load |
dtype: torch.dtype, optional |
The data type for the created variables |
vertex_ids: dict, optional |
A dictionary containing the indices of the extra vertices that |
will be selected |
""" |
self.num_pca_comps = num_pca_comps |
if data_struct is None: |
if osp.isdir(model_path): |
model_fn = "SMPLH_{}.{ext}".format(gender.upper(), ext=ext) |
smplh_path = os.path.join(model_path, model_fn) |
else: |
smplh_path = model_path |
assert osp.exists(smplh_path), "Path {} does not exist!".format(smplh_path) |
if ext == "pkl": |
with open(smplh_path, "rb") as smplh_file: |
model_data = pickle.load(smplh_file, encoding="latin1") |
elif ext == "npz": |
model_data = np.load(smplh_path, allow_pickle=True) |
else: |
raise ValueError("Unknown extension: {}".format(ext)) |
data_struct = Struct(**model_data) |
if vertex_ids is None: |
vertex_ids = VERTEX_IDS["smplh"] |
super(SMPLH, self).__init__( |
model_path=model_path, |
kid_template_path=kid_template_path, |
data_struct=data_struct, |
batch_size=batch_size, |
vertex_ids=vertex_ids, |
gender=gender, |
age=age, |
use_compressed=use_compressed, |
dtype=dtype, |
ext=ext, |
**kwargs, |
) |
self.use_pca = use_pca |
self.num_pca_comps = num_pca_comps |
self.flat_hand_mean = flat_hand_mean |
left_hand_components = data_struct.hands_componentsl[:num_pca_comps] |
right_hand_components = data_struct.hands_componentsr[:num_pca_comps] |
self.np_left_hand_components = left_hand_components |
self.np_right_hand_components = right_hand_components |
if self.use_pca: |
self.register_buffer( |
"left_hand_components", torch.tensor(left_hand_components, dtype=dtype) |
) |
self.register_buffer( |
"right_hand_components", |
torch.tensor(right_hand_components, dtype=dtype), |
) |
if self.flat_hand_mean: |
left_hand_mean = np.zeros_like(data_struct.hands_meanl) |
else: |
left_hand_mean = data_struct.hands_meanl |
if self.flat_hand_mean: |
right_hand_mean = np.zeros_like(data_struct.hands_meanr) |
else: |
right_hand_mean = data_struct.hands_meanr |
self.register_buffer("left_hand_mean", to_tensor(left_hand_mean, dtype=self.dtype)) |
self.register_buffer("right_hand_mean", to_tensor(right_hand_mean, dtype=self.dtype)) |
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
if create_left_hand_pose: |
if left_hand_pose is None: |
default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) |
else: |
default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) |
left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True) |
self.register_parameter("left_hand_pose", left_hand_pose_param) |
if create_right_hand_pose: |
if right_hand_pose is None: |
default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) |
else: |
default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) |
right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True) |
self.register_parameter("right_hand_pose", right_hand_pose_param) |
pose_mean_tensor = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) |
if not torch.is_tensor(pose_mean_tensor): |
pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) |
self.register_buffer("pose_mean", pose_mean_tensor) |
def create_mean_pose(self, data_struct, flat_hand_mean=False): |
global_orient_mean = torch.zeros([3], dtype=self.dtype) |
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) |
pose_mean = torch.cat( |
[ |
global_orient_mean, |
body_pose_mean, |
self.left_hand_mean, |
self.right_hand_mean, |
], |
dim=0, |
) |
return pose_mean |
def name(self) -> str: |
return "SMPL+H" |
def extra_repr(self): |
msg = super(SMPLH, self).extra_repr() |
msg = [msg] |
if self.use_pca: |
msg.append(f"Number of PCA components: {self.num_pca_comps}") |
msg.append(f"Flat hand mean: {self.flat_hand_mean}") |
return "\n".join(msg) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
left_hand_pose: Optional[Tensor] = None, |
right_hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> SMPLHOutput: |
"""""" |
global_orient = (global_orient if global_orient is not None else self.global_orient) |
body_pose = body_pose if body_pose is not None else self.body_pose |
betas = betas if betas is not None else self.betas |
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) |
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) |
apply_trans = transl is not None or hasattr(self, "transl") |
if transl is None: |
if hasattr(self, "transl"): |
transl = self.transl |
if self.use_pca: |
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) |
right_hand_pose = torch.einsum( |
"bi,ij->bj", [right_hand_pose, self.right_hand_components] |
) |
full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) |
full_pose += self.pose_mean |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=pose2rot, |
) |
joints = self.vertex_joint_selector(vertices, joints) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if apply_trans: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLHOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
global_orient=global_orient, |
body_pose=body_pose, |
left_hand_pose=left_hand_pose, |
right_hand_pose=right_hand_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class SMPLHLayer(SMPLH): |
def __init__(self, *args, **kwargs) -> None: |
"""SMPL+H as a layer model constructor""" |
super(SMPLHLayer, self).__init__( |
create_global_orient=False, |
create_body_pose=False, |
create_left_hand_pose=False, |
create_right_hand_pose=False, |
create_betas=False, |
create_transl=False, |
*args, |
**kwargs, |
) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
left_hand_pose: Optional[Tensor] = None, |
right_hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> SMPLHOutput: |
"""Forward pass for the SMPL+H model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3x3 |
Global rotation of the body. Useful if someone wishes to |
predicts this with an external model. It is expected to be in |
rotation matrix format. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
Shape parameters. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
body_pose: torch.tensor, optional, shape BxJx3x3 |
If given, ignore the member variable `body_pose` and use it |
instead. For example, it can used if someone predicts the |
pose of the body joints are predicted from some external model. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
If given, contains the pose of the left hand. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
If given, contains the pose of the right hand. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
transl: torch.tensor, optional, shape Bx3 |
Translation vector of the body. |
For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
""" |
model_vars = [ |
betas, |
global_orient, |
body_pose, |
transl, |
left_hand_pose, |
right_hand_pose, |
] |
batch_size = 1 |
for var in model_vars: |
if var is None: |
continue |
batch_size = max(batch_size, len(var)) |
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
if global_orient is None: |
global_orient = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if body_pose is None: |
body_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() |
) |
if left_hand_pose is None: |
left_hand_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
) |
if right_hand_pose is None: |
right_hand_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
) |
if betas is None: |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) |
if transl is None: |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
full_pose = torch.cat( |
[ |
global_orient.reshape(-1, 1, 3, 3), |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
], |
dim=1, |
) |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=False, |
) |
joints = self.vertex_joint_selector(vertices, joints) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if transl is not None: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLHOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
global_orient=global_orient, |
body_pose=body_pose, |
left_hand_pose=left_hand_pose, |
right_hand_pose=right_hand_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class SMPLX(SMPLH): |
""" |
SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters |
trained jointly for the face, hands and body. |
SMPL-X uses standard vertex based linear blend skinning with learned |
corrective blend shapes, has N=10475 vertices and K=54 joints, |
which includes joints for the neck, jaw, eyeballs and fingers. |
""" |
NECK_IDX = 12 |
def __init__( |
self, |
model_path: str, |
kid_template_path: str = "", |
num_expression_coeffs: int = 10, |
create_expression: bool = True, |
expression: Optional[Tensor] = None, |
create_jaw_pose: bool = True, |
jaw_pose: Optional[Tensor] = None, |
create_leye_pose: bool = True, |
leye_pose: Optional[Tensor] = None, |
create_reye_pose=True, |
reye_pose: Optional[Tensor] = None, |
use_face_contour: bool = False, |
batch_size: int = 1, |
gender: str = "neutral", |
age: str = "adult", |
dtype=torch.float32, |
ext: str = "npz", |
**kwargs, |
) -> None: |
"""SMPLX model constructor |
Parameters |
---------- |
model_path: str |
The path to the folder or to the file where the model |
parameters are stored |
num_expression_coeffs: int, optional |
Number of expression components to use |
(default = 10). |
create_expression: bool, optional |
Flag for creating a member variable for the expression space |
(default = True). |
expression: torch.tensor, optional, Bx10 |
The default value for the expression member variable. |
(default = None) |
create_jaw_pose: bool, optional |
Flag for creating a member variable for the jaw pose. |
(default = False) |
jaw_pose: torch.tensor, optional, Bx3 |
The default value for the jaw pose variable. |
(default = None) |
create_leye_pose: bool, optional |
Flag for creating a member variable for the left eye pose. |
(default = False) |
leye_pose: torch.tensor, optional, Bx10 |
The default value for the left eye pose variable. |
(default = None) |
create_reye_pose: bool, optional |
Flag for creating a member variable for the right eye pose. |
(default = False) |
reye_pose: torch.tensor, optional, Bx10 |
The default value for the right eye pose variable. |
(default = None) |
use_face_contour: bool, optional |
Whether to compute the keypoints that form the facial contour |
batch_size: int, optional |
The batch size used for creating the member variables |
gender: str, optional |
Which gender to load |
dtype: torch.dtype |
The data type for the created variables |
""" |
from huggingface_hub import hf_hub_download |
model_fn = "SMPLX_{}.{ext}".format(gender.upper(), ext=ext) |
smplx_path = hf_hub_download( |
repo_id=model_path, use_auth_token=os.environ["ICON"], filename=f"models/{model_fn}" |
) |
if ext == "pkl": |
with open(smplx_path, "rb") as smplx_file: |
model_data = pickle.load(smplx_file, encoding="latin1") |
elif ext == "npz": |
model_data = np.load(smplx_path, allow_pickle=True) |
else: |
raise ValueError("Unknown extension: {}".format(ext)) |
data_struct = Struct(**model_data) |
super(SMPLX, self).__init__( |
model_path=model_path, |
kid_template_path=kid_template_path, |
data_struct=data_struct, |
dtype=dtype, |
batch_size=batch_size, |
vertex_ids=VERTEX_IDS["smplx"], |
gender=gender, |
age=age, |
ext=ext, |
**kwargs, |
) |
lmk_faces_idx = data_struct.lmk_faces_idx |
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) |
lmk_bary_coords = data_struct.lmk_bary_coords |
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) |
self.use_face_contour = use_face_contour |
if self.use_face_contour: |
dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx |
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) |
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) |
dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords |
dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, dtype=dtype) |
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_bary_coords) |
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) |
if create_jaw_pose: |
if jaw_pose is None: |
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
self.register_parameter("jaw_pose", jaw_pose_param) |
if create_leye_pose: |
if leye_pose is None: |
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) |
self.register_parameter("leye_pose", leye_pose_param) |
if create_reye_pose: |
if reye_pose is None: |
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) |
self.register_parameter("reye_pose", reye_pose_param) |
shapedirs = data_struct.shapedirs |
if len(shapedirs.shape) < 3: |
shapedirs = shapedirs[:, :, None] |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: |
expr_start_idx = 10 |
expr_end_idx = 20 |
num_expression_coeffs = min(num_expression_coeffs, 10) |
else: |
expr_start_idx = self.SHAPE_SPACE_DIM |
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) |
self._num_expression_coeffs = num_expression_coeffs |
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) |
if create_expression: |
if expression is None: |
default_expression = torch.zeros([batch_size, self.num_expression_coeffs], |
dtype=dtype) |
else: |
default_expression = torch.tensor(expression, dtype=dtype) |
expression_param = nn.Parameter(default_expression, requires_grad=True) |
self.register_parameter("expression", expression_param) |
def name(self) -> str: |
return "SMPL-X" |
@property |
def num_expression_coeffs(self): |
return self._num_expression_coeffs |
def create_mean_pose(self, data_struct, flat_hand_mean=False): |
global_orient_mean = torch.zeros([3], dtype=self.dtype) |
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) |
jaw_pose_mean = torch.zeros([3], dtype=self.dtype) |
leye_pose_mean = torch.zeros([3], dtype=self.dtype) |
reye_pose_mean = torch.zeros([3], dtype=self.dtype) |
pose_mean = np.concatenate( |
[ |
global_orient_mean, |
body_pose_mean, |
jaw_pose_mean, |
leye_pose_mean, |
reye_pose_mean, |
self.left_hand_mean, |
self.right_hand_mean, |
], |
axis=0, |
) |
return pose_mean |
def extra_repr(self): |
msg = super(SMPLX, self).extra_repr() |
msg = [msg, f"Number of Expression Coefficients: {self.num_expression_coeffs}"] |
return "\n".join(msg) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
left_hand_pose: Optional[Tensor] = None, |
right_hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
expression: Optional[Tensor] = None, |
jaw_pose: Optional[Tensor] = None, |
leye_pose: Optional[Tensor] = None, |
reye_pose: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
return_joint_transformation: bool = False, |
return_vertex_transformation: bool = False, |
pose_type: str = 'posed', |
**kwargs, |
) -> SMPLXOutput: |
""" |
Forward pass for the SMPLX model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable and use it as the global |
rotation of the body. Useful if someone wishes to predicts this |
with an external model. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
If given, ignore the member variable `betas` and use it |
instead. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
expression: torch.tensor, optional, shape BxN_e |
If given, ignore the member variable `expression` and use it |
instead. For example, it can used if expression parameters |
`expression` are predicted from some external model. |
body_pose: torch.tensor, optional, shape Bx(J*3) |
If given, ignore the member variable `body_pose` and use it |
instead. For example, it can used if someone predicts the |
pose of the body joints are predicted from some external model. |
It should be a tensor that contains joint rotations in |
axis-angle format. (default=None) |
left_hand_pose: torch.tensor, optional, shape BxP |
If given, ignore the member variable `left_hand_pose` and |
use this instead. It should either contain PCA coefficients or |
joint rotations in axis-angle format. |
right_hand_pose: torch.tensor, optional, shape BxP |
If given, ignore the member variable `right_hand_pose` and |
use this instead. It should either contain PCA coefficients or |
joint rotations in axis-angle format. |
jaw_pose: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `jaw_pose` and |
use this instead. It should either joint rotations in |
axis-angle format. |
transl: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `transl` and use it |
instead. For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
output: ModelOutput |
A named tuple of type `ModelOutput` |
""" |
global_orient = (global_orient if global_orient is not None else self.global_orient) |
body_pose = body_pose if body_pose is not None else self.body_pose |
betas = betas if betas is not None else self.betas |
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) |
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) |
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
expression = expression if expression is not None else self.expression |
apply_trans = transl is not None or hasattr(self, "transl") |
if transl is None: |
if hasattr(self, "transl"): |
transl = self.transl |
if self.use_pca: |
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) |
right_hand_pose = torch.einsum( |
"bi,ij->bj", [right_hand_pose, self.right_hand_components] |
) |
full_pose = torch.cat( |
[ |
global_orient, |
body_pose, |
jaw_pose, |
leye_pose, |
reye_pose, |
left_hand_pose, |
right_hand_pose, |
], |
dim=1, |
) |
if pose_type == "t-pose": |
full_pose *= 0.0 |
elif pose_type == "a-pose": |
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3) |
body_pose[:, 15] = torch.tensor([0., 0., -45 * np.pi / 180.]) |
body_pose[:, 16] = torch.tensor([0., 0., 45 * np.pi / 180.]) |
body_pose = body_pose.view(body_pose.shape[0], -1) |
full_pose = torch.cat( |
[ |
global_orient * 0., |
body_pose, |
jaw_pose * 0., |
leye_pose * 0., |
reye_pose * 0., |
left_hand_pose * 0., |
right_hand_pose * 0., |
], |
dim=1, |
) |
elif pose_type == "da-pose": |
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3) |
body_pose[:, 0] = torch.tensor([0., 0., 30 * np.pi / 180.]) |
body_pose[:, 1] = torch.tensor([0., 0., -30 * np.pi / 180.]) |
body_pose = body_pose.view(body_pose.shape[0], -1) |
full_pose = torch.cat( |
[ |
global_orient * 0., |
body_pose, |
jaw_pose * 0., |
leye_pose * 0., |
reye_pose * 0., |
left_hand_pose * 0., |
right_hand_pose * 0., |
], |
dim=1, |
) |
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) |
scale = int(batch_size / betas.shape[0]) |
if scale > 1: |
betas = betas.expand(scale, -1) |
shape_components = torch.cat([betas, expression], dim=-1) |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
if return_joint_transformation or return_vertex_transformation: |
vertices, joints, joint_transformation, vertex_transformation = lbs( |
shape_components, |
full_pose, |
self.v_template, |
shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=pose2rot, |
return_transformation=True, |
) |
else: |
vertices, joints = lbs( |
shape_components, |
full_pose, |
self.v_template, |
shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=pose2rot, |
) |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) |
if self.use_face_contour: |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
vertices, |
full_pose, |
self.dynamic_lmk_faces_idx, |
self.dynamic_lmk_bary_coords, |
self.neck_kin_chain, |
pose2rot=True, |
) |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
lmk_bary_coords = torch.cat([ |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
], 1) |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) |
joints = self.vertex_joint_selector(vertices, joints) |
joints = torch.cat([joints, landmarks], dim=1) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints=joints, vertices=vertices) |
if apply_trans: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLXOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
expression=expression, |
global_orient=global_orient, |
body_pose=body_pose, |
left_hand_pose=left_hand_pose, |
right_hand_pose=right_hand_pose, |
jaw_pose=jaw_pose, |
full_pose=full_pose if return_full_pose else None, |
joint_transformation=joint_transformation if return_joint_transformation else None, |
vertex_transformation=vertex_transformation if return_vertex_transformation else None, |
) |
return output |
class SMPLXLayer(SMPLX): |
def __init__(self, *args, **kwargs) -> None: |
super(SMPLXLayer, self).__init__( |
create_global_orient=False, |
create_body_pose=False, |
create_left_hand_pose=False, |
create_right_hand_pose=False, |
create_jaw_pose=False, |
create_leye_pose=False, |
create_reye_pose=False, |
create_betas=False, |
create_expression=False, |
create_transl=False, |
*args, |
**kwargs, |
) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
body_pose: Optional[Tensor] = None, |
left_hand_pose: Optional[Tensor] = None, |
right_hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
expression: Optional[Tensor] = None, |
jaw_pose: Optional[Tensor] = None, |
leye_pose: Optional[Tensor] = None, |
reye_pose: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
**kwargs, |
) -> SMPLXOutput: |
""" |
Forward pass for the SMPLX model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3x3 |
If given, ignore the member variable and use it as the global |
rotation of the body. Useful if someone wishes to predicts this |
with an external model. It is expected to be in rotation matrix |
format. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
If given, ignore the member variable `betas` and use it |
instead. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
expression: torch.tensor, optional, shape BxN_e |
Expression coefficients. |
For example, it can used if expression parameters |
`expression` are predicted from some external model. |
body_pose: torch.tensor, optional, shape BxJx3x3 |
If given, ignore the member variable `body_pose` and use it |
instead. For example, it can used if someone predicts the |
pose of the body joints are predicted from some external model. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
If given, contains the pose of the left hand. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
If given, contains the pose of the right hand. |
It should be a tensor that contains joint rotations in |
rotation matrix format. (default=None) |
jaw_pose: torch.tensor, optional, shape Bx3x3 |
Jaw pose. It should either joint rotations in |
rotation matrix format. |
transl: torch.tensor, optional, shape Bx3 |
Translation vector of the body. |
For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full pose vector (default=False) |
Returns |
------- |
output: ModelOutput |
A data class that contains the posed vertices and joints |
""" |
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
model_vars = [ |
betas, |
global_orient, |
body_pose, |
transl, |
expression, |
left_hand_pose, |
right_hand_pose, |
jaw_pose, |
] |
batch_size = 1 |
for var in model_vars: |
if var is None: |
continue |
batch_size = max(batch_size, len(var)) |
if global_orient is None: |
global_orient = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if body_pose is None: |
body_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, |
3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
-1).contiguous() |
) |
if left_hand_pose is None: |
left_hand_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
) |
if right_hand_pose is None: |
right_hand_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
) |
if jaw_pose is None: |
jaw_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if leye_pose is None: |
leye_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if reye_pose is None: |
reye_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if expression is None: |
expression = torch.zeros([batch_size, self.num_expression_coeffs], |
dtype=dtype, |
device=device) |
if betas is None: |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) |
if transl is None: |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
full_pose = torch.cat( |
[ |
global_orient.reshape(-1, 1, 3, 3), |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
jaw_pose.reshape(-1, 1, 3, 3), |
leye_pose.reshape(-1, 1, 3, 3), |
reye_pose.reshape(-1, 1, 3, 3), |
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
], |
dim=1, |
) |
shape_components = torch.cat([betas, expression], dim=-1) |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
vertices, joints = lbs( |
shape_components, |
full_pose, |
self.v_template, |
shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=False, |
) |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(batch_size, 1, 1) |
if self.use_face_contour: |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
vertices, |
full_pose, |
self.dynamic_lmk_faces_idx, |
self.dynamic_lmk_bary_coords, |
self.neck_kin_chain, |
pose2rot=False, |
) |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
lmk_bary_coords = torch.cat([ |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
], 1) |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) |
joints = self.vertex_joint_selector(vertices, joints) |
joints = torch.cat([joints, landmarks], dim=1) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints=joints, vertices=vertices) |
if transl is not None: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = SMPLXOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
expression=expression, |
global_orient=global_orient, |
body_pose=body_pose, |
left_hand_pose=left_hand_pose, |
right_hand_pose=right_hand_pose, |
jaw_pose=jaw_pose, |
transl=transl, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class MANO(SMPL): |
def __init__( |
self, |
model_path: str, |
is_rhand: bool = True, |
data_struct: Optional[Struct] = None, |
create_hand_pose: bool = True, |
hand_pose: Optional[Tensor] = None, |
use_pca: bool = True, |
num_pca_comps: int = 6, |
flat_hand_mean: bool = False, |
batch_size: int = 1, |
dtype=torch.float32, |
vertex_ids=None, |
use_compressed: bool = True, |
ext: str = "pkl", |
**kwargs, |
) -> None: |
"""MANO model constructor |
Parameters |
---------- |
model_path: str |
The path to the folder or to the file where the model |
parameters are stored |
data_struct: Strct |
A struct object. If given, then the parameters of the model are |
read from the object. Otherwise, the model tries to read the |
parameters from the given `model_path`. (default = None) |
create_hand_pose: bool, optional |
Flag for creating a member variable for the pose of the right |
hand. (default = True) |
hand_pose: torch.tensor, optional, BxP |
The default value for the right hand pose member variable. |
(default = None) |
num_pca_comps: int, optional |
The number of PCA components to use for each hand. |
(default = 6) |
flat_hand_mean: bool, optional |
If False, then the pose of the hand is initialized to False. |
batch_size: int, optional |
The batch size used for creating the member variables |
dtype: torch.dtype, optional |
The data type for the created variables |
vertex_ids: dict, optional |
A dictionary containing the indices of the extra vertices that |
will be selected |
""" |
self.num_pca_comps = num_pca_comps |
self.is_rhand = is_rhand |
if data_struct is None: |
if osp.isdir(model_path): |
model_fn = "MANO_{}.{ext}".format("RIGHT" if is_rhand else "LEFT", ext=ext) |
mano_path = os.path.join(model_path, model_fn) |
else: |
mano_path = model_path |
self.is_rhand = (True if "RIGHT" in os.path.basename(model_path) else False) |
assert osp.exists(mano_path), "Path {} does not exist!".format(mano_path) |
if ext == "pkl": |
with open(mano_path, "rb") as mano_file: |
model_data = pickle.load(mano_file, encoding="latin1") |
elif ext == "npz": |
model_data = np.load(mano_path, allow_pickle=True) |
else: |
raise ValueError("Unknown extension: {}".format(ext)) |
data_struct = Struct(**model_data) |
if vertex_ids is None: |
vertex_ids = VERTEX_IDS["smplh"] |
super(MANO, self).__init__( |
model_path=model_path, |
data_struct=data_struct, |
batch_size=batch_size, |
vertex_ids=vertex_ids, |
use_compressed=use_compressed, |
dtype=dtype, |
ext=ext, |
**kwargs, |
) |
self.vertex_joint_selector.extra_joints_idxs = to_tensor( |
list(VERTEX_IDS["mano"].values()), dtype=torch.long |
) |
self.use_pca = use_pca |
self.num_pca_comps = num_pca_comps |
if self.num_pca_comps == 45: |
self.use_pca = False |
self.flat_hand_mean = flat_hand_mean |
hand_components = data_struct.hands_components[:num_pca_comps] |
self.np_hand_components = hand_components |
if self.use_pca: |
self.register_buffer("hand_components", torch.tensor(hand_components, dtype=dtype)) |
if self.flat_hand_mean: |
hand_mean = np.zeros_like(data_struct.hands_mean) |
else: |
hand_mean = data_struct.hands_mean |
self.register_buffer("hand_mean", to_tensor(hand_mean, dtype=self.dtype)) |
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
if create_hand_pose: |
if hand_pose is None: |
default_hand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) |
else: |
default_hand_pose = torch.tensor(hand_pose, dtype=dtype) |
hand_pose_param = nn.Parameter(default_hand_pose, requires_grad=True) |
self.register_parameter("hand_pose", hand_pose_param) |
pose_mean = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) |
pose_mean_tensor = pose_mean.clone().to(dtype) |
self.register_buffer("pose_mean", pose_mean_tensor) |
def name(self) -> str: |
return "MANO" |
def create_mean_pose(self, data_struct, flat_hand_mean=False): |
global_orient_mean = torch.zeros([3], dtype=self.dtype) |
pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) |
return pose_mean |
def extra_repr(self): |
msg = [super(MANO, self).extra_repr()] |
if self.use_pca: |
msg.append(f"Number of PCA components: {self.num_pca_comps}") |
msg.append(f"Flat hand mean: {self.flat_hand_mean}") |
return "\n".join(msg) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
**kwargs, |
) -> MANOOutput: |
"""Forward pass for the MANO model""" |
global_orient = (global_orient if global_orient is not None else self.global_orient) |
betas = betas if betas is not None else self.betas |
hand_pose = hand_pose if hand_pose is not None else self.hand_pose |
apply_trans = transl is not None or hasattr(self, "transl") |
if transl is None: |
if hasattr(self, "transl"): |
transl = self.transl |
if self.use_pca: |
hand_pose = torch.einsum("bi,ij->bj", [hand_pose, self.hand_components]) |
full_pose = torch.cat([global_orient, hand_pose], dim=1) |
full_pose += self.pose_mean |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=True, |
) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if apply_trans: |
joints = joints + transl.unsqueeze(dim=1) |
vertices = vertices + transl.unsqueeze(dim=1) |
output = MANOOutput( |
vertices=vertices if return_verts else None, |
joints=joints if return_verts else None, |
betas=betas, |
global_orient=global_orient, |
hand_pose=hand_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class MANOLayer(MANO): |
def __init__(self, *args, **kwargs) -> None: |
"""MANO as a layer model constructor""" |
super(MANOLayer, self).__init__( |
create_global_orient=False, |
create_hand_pose=False, |
create_betas=False, |
create_transl=False, |
*args, |
**kwargs, |
) |
def name(self) -> str: |
return "MANO" |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
hand_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
**kwargs, |
) -> MANOOutput: |
"""Forward pass for the MANO model""" |
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
if global_orient is None: |
batch_size = 1 |
global_orient = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
else: |
batch_size = global_orient.shape[0] |
if hand_pose is None: |
hand_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
) |
if betas is None: |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) |
if transl is None: |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
full_pose = torch.cat([global_orient, hand_pose], dim=1) |
vertices, joints = lbs( |
betas, |
full_pose, |
self.v_template, |
self.shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=False, |
) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints) |
if transl is not None: |
joints = joints + transl.unsqueeze(dim=1) |
vertices = vertices + transl.unsqueeze(dim=1) |
output = MANOOutput( |
vertices=vertices if return_verts else None, |
joints=joints if return_verts else None, |
betas=betas, |
global_orient=global_orient, |
hand_pose=hand_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class FLAME(SMPL): |
NECK_IDX = 0 |
def __init__( |
self, |
model_path: str, |
data_struct=None, |
num_expression_coeffs=10, |
create_expression: bool = True, |
expression: Optional[Tensor] = None, |
create_neck_pose: bool = True, |
neck_pose: Optional[Tensor] = None, |
create_jaw_pose: bool = True, |
jaw_pose: Optional[Tensor] = None, |
create_leye_pose: bool = True, |
leye_pose: Optional[Tensor] = None, |
create_reye_pose=True, |
reye_pose: Optional[Tensor] = None, |
use_face_contour=False, |
batch_size: int = 1, |
gender: str = "neutral", |
dtype: torch.dtype = torch.float32, |
ext="pkl", |
**kwargs, |
) -> None: |
"""FLAME model constructor |
Parameters |
---------- |
model_path: str |
The path to the folder or to the file where the model |
parameters are stored |
num_expression_coeffs: int, optional |
Number of expression components to use |
(default = 10). |
create_expression: bool, optional |
Flag for creating a member variable for the expression space |
(default = True). |
expression: torch.tensor, optional, Bx10 |
The default value for the expression member variable. |
(default = None) |
create_neck_pose: bool, optional |
Flag for creating a member variable for the neck pose. |
(default = False) |
neck_pose: torch.tensor, optional, Bx3 |
The default value for the neck pose variable. |
(default = None) |
create_jaw_pose: bool, optional |
Flag for creating a member variable for the jaw pose. |
(default = False) |
jaw_pose: torch.tensor, optional, Bx3 |
The default value for the jaw pose variable. |
(default = None) |
create_leye_pose: bool, optional |
Flag for creating a member variable for the left eye pose. |
(default = False) |
leye_pose: torch.tensor, optional, Bx10 |
The default value for the left eye pose variable. |
(default = None) |
create_reye_pose: bool, optional |
Flag for creating a member variable for the right eye pose. |
(default = False) |
reye_pose: torch.tensor, optional, Bx10 |
The default value for the right eye pose variable. |
(default = None) |
use_face_contour: bool, optional |
Whether to compute the keypoints that form the facial contour |
batch_size: int, optional |
The batch size used for creating the member variables |
gender: str, optional |
Which gender to load |
dtype: torch.dtype |
The data type for the created variables |
""" |
model_fn = f"FLAME_{gender.upper()}.{ext}" |
flame_path = os.path.join(model_path, model_fn) |
assert osp.exists(flame_path), "Path {} does not exist!".format(flame_path) |
if ext == "npz": |
file_data = np.load(flame_path, allow_pickle=True) |
elif ext == "pkl": |
with open(flame_path, "rb") as smpl_file: |
file_data = pickle.load(smpl_file, encoding="latin1") |
else: |
raise ValueError("Unknown extension: {}".format(ext)) |
data_struct = Struct(**file_data) |
super(FLAME, self).__init__( |
model_path=model_path, |
data_struct=data_struct, |
dtype=dtype, |
batch_size=batch_size, |
gender=gender, |
ext=ext, |
**kwargs, |
) |
self.use_face_contour = use_face_contour |
self.vertex_joint_selector.extra_joints_idxs = to_tensor([], dtype=torch.long) |
if create_neck_pose: |
if neck_pose is None: |
default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_neck_pose = torch.tensor(neck_pose, dtype=dtype) |
neck_pose_param = nn.Parameter(default_neck_pose, requires_grad=True) |
self.register_parameter("neck_pose", neck_pose_param) |
if create_jaw_pose: |
if jaw_pose is None: |
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
self.register_parameter("jaw_pose", jaw_pose_param) |
if create_leye_pose: |
if leye_pose is None: |
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) |
self.register_parameter("leye_pose", leye_pose_param) |
if create_reye_pose: |
if reye_pose is None: |
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
else: |
default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) |
self.register_parameter("reye_pose", reye_pose_param) |
shapedirs = data_struct.shapedirs |
if len(shapedirs.shape) < 3: |
shapedirs = shapedirs[:, :, None] |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: |
expr_start_idx = 10 |
expr_end_idx = 20 |
num_expression_coeffs = min(num_expression_coeffs, 10) |
else: |
expr_start_idx = self.SHAPE_SPACE_DIM |
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) |
self._num_expression_coeffs = num_expression_coeffs |
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) |
if create_expression: |
if expression is None: |
default_expression = torch.zeros([batch_size, self.num_expression_coeffs], |
dtype=dtype) |
else: |
default_expression = torch.tensor(expression, dtype=dtype) |
expression_param = nn.Parameter(default_expression, requires_grad=True) |
self.register_parameter("expression", expression_param) |
landmark_bcoord_filename = osp.join(model_path, "flame_static_embedding.pkl") |
with open(landmark_bcoord_filename, "rb") as fp: |
landmarks_data = pickle.load(fp, encoding="latin1") |
lmk_faces_idx = landmarks_data["lmk_face_idx"].astype(np.int64) |
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) |
lmk_bary_coords = landmarks_data["lmk_b_coords"] |
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) |
if self.use_face_contour: |
face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy") |
contour_embeddings = np.load(face_contour_path, allow_pickle=True, |
encoding="latin1")[()] |
dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64) |
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) |
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) |
dynamic_lmk_b_coords = torch.tensor(contour_embeddings["lmk_b_coords"], dtype=dtype) |
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_b_coords) |
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) |
@property |
def num_expression_coeffs(self): |
return self._num_expression_coeffs |
def name(self) -> str: |
return "FLAME" |
def extra_repr(self): |
msg = [ |
super(FLAME, self).extra_repr(), |
f"Number of Expression Coefficients: {self.num_expression_coeffs}", |
f"Use face contour: {self.use_face_contour}", |
] |
return "\n".join(msg) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
neck_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
expression: Optional[Tensor] = None, |
jaw_pose: Optional[Tensor] = None, |
leye_pose: Optional[Tensor] = None, |
reye_pose: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> FLAMEOutput: |
""" |
Forward pass for the SMPLX model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable and use it as the global |
rotation of the body. Useful if someone wishes to predicts this |
with an external model. (default=None) |
betas: torch.tensor, optional, shape Bx10 |
If given, ignore the member variable `betas` and use it |
instead. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
expression: torch.tensor, optional, shape Bx10 |
If given, ignore the member variable `expression` and use it |
instead. For example, it can used if expression parameters |
`expression` are predicted from some external model. |
jaw_pose: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `jaw_pose` and |
use this instead. It should either joint rotations in |
axis-angle format. |
jaw_pose: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `jaw_pose` and |
use this instead. It should either joint rotations in |
axis-angle format. |
transl: torch.tensor, optional, shape Bx3 |
If given, ignore the member variable `transl` and use it |
instead. For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
output: ModelOutput |
A named tuple of type `ModelOutput` |
""" |
global_orient = (global_orient if global_orient is not None else self.global_orient) |
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
neck_pose = neck_pose if neck_pose is not None else self.neck_pose |
leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
betas = betas if betas is not None else self.betas |
expression = expression if expression is not None else self.expression |
apply_trans = transl is not None or hasattr(self, "transl") |
if transl is None: |
if hasattr(self, "transl"): |
transl = self.transl |
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
batch_size = max(betas.shape[0], global_orient.shape[0], jaw_pose.shape[0]) |
scale = int(batch_size / betas.shape[0]) |
if scale > 1: |
betas = betas.expand(scale, -1) |
shape_components = torch.cat([betas, expression], dim=-1) |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
vertices, joints = lbs( |
shape_components, |
full_pose, |
self.v_template, |
shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=pose2rot, |
) |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) |
if self.use_face_contour: |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
vertices, |
full_pose, |
self.dynamic_lmk_faces_idx, |
self.dynamic_lmk_bary_coords, |
self.neck_kin_chain, |
pose2rot=True, |
) |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
lmk_bary_coords = torch.cat([ |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
], 1) |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) |
joints = self.vertex_joint_selector(vertices, joints) |
joints = torch.cat([joints, landmarks], dim=1) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints=joints, vertices=vertices) |
if apply_trans: |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = FLAMEOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
expression=expression, |
global_orient=global_orient, |
neck_pose=neck_pose, |
jaw_pose=jaw_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
class FLAMELayer(FLAME): |
def __init__(self, *args, **kwargs) -> None: |
""" FLAME as a layer model constructor """ |
super(FLAMELayer, self).__init__( |
create_betas=False, |
create_expression=False, |
create_global_orient=False, |
create_neck_pose=False, |
create_jaw_pose=False, |
create_leye_pose=False, |
create_reye_pose=False, |
*args, |
**kwargs, |
) |
def forward( |
self, |
betas: Optional[Tensor] = None, |
global_orient: Optional[Tensor] = None, |
neck_pose: Optional[Tensor] = None, |
transl: Optional[Tensor] = None, |
expression: Optional[Tensor] = None, |
jaw_pose: Optional[Tensor] = None, |
leye_pose: Optional[Tensor] = None, |
reye_pose: Optional[Tensor] = None, |
return_verts: bool = True, |
return_full_pose: bool = False, |
pose2rot: bool = True, |
**kwargs, |
) -> FLAMEOutput: |
""" |
Forward pass for the SMPLX model |
Parameters |
---------- |
global_orient: torch.tensor, optional, shape Bx3x3 |
Global rotation of the body. Useful if someone wishes to |
predicts this with an external model. It is expected to be in |
rotation matrix format. (default=None) |
betas: torch.tensor, optional, shape BxN_b |
Shape parameters. For example, it can used if shape parameters |
`betas` are predicted from some external model. |
(default=None) |
expression: torch.tensor, optional, shape BxN_e |
If given, ignore the member variable `expression` and use it |
instead. For example, it can used if expression parameters |
`expression` are predicted from some external model. |
jaw_pose: torch.tensor, optional, shape Bx3x3 |
Jaw pose. It should either joint rotations in |
rotation matrix format. |
transl: torch.tensor, optional, shape Bx3 |
Translation vector of the body. |
For example, it can used if the translation |
`transl` is predicted from some external model. |
(default=None) |
return_verts: bool, optional |
Return the vertices. (default=True) |
return_full_pose: bool, optional |
Returns the full axis-angle pose vector (default=False) |
Returns |
------- |
output: ModelOutput |
A named tuple of type `ModelOutput` |
""" |
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
if global_orient is None: |
batch_size = 1 |
global_orient = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
else: |
batch_size = global_orient.shape[0] |
if neck_pose is None: |
neck_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() |
) |
if jaw_pose is None: |
jaw_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if leye_pose is None: |
leye_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if reye_pose is None: |
reye_pose = ( |
torch.eye(3, device=device, |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
) |
if betas is None: |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) |
if expression is None: |
expression = torch.zeros([batch_size, self.num_expression_coeffs], |
dtype=dtype, |
device=device) |
if transl is None: |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
shape_components = torch.cat([betas, expression], dim=-1) |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
vertices, joints = lbs( |
shape_components, |
full_pose, |
self.v_template, |
shapedirs, |
self.posedirs, |
self.J_regressor, |
self.parents, |
self.lbs_weights, |
pose2rot=False, |
) |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) |
if self.use_face_contour: |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
vertices, |
full_pose, |
self.dynamic_lmk_faces_idx, |
self.dynamic_lmk_bary_coords, |
self.neck_kin_chain, |
pose2rot=False, |
) |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
lmk_bary_coords = torch.cat([ |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
], 1) |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) |
joints = self.vertex_joint_selector(vertices, joints) |
joints = torch.cat([joints, landmarks], dim=1) |
if self.joint_mapper is not None: |
joints = self.joint_mapper(joints=joints, vertices=vertices) |
joints += transl.unsqueeze(dim=1) |
vertices += transl.unsqueeze(dim=1) |
output = FLAMEOutput( |
vertices=vertices if return_verts else None, |
joints=joints, |
betas=betas, |
expression=expression, |
global_orient=global_orient, |
neck_pose=neck_pose, |
jaw_pose=jaw_pose, |
full_pose=full_pose if return_full_pose else None, |
) |
return output |
def build_layer(model_path: str, |
model_type: str = "smpl", |
**kwargs) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: |
"""Method for creating a model from a path and a model type |
Parameters |
---------- |
model_path: str |
Either the path to the model you wish to load or a folder, |
where each subfolder contains the differents types, i.e.: |
model_path: |
| |
|-- smpl |
|-- smplh |
|-- smplx |
|-- mano |
|-- flame |
model_type: str, optional |
When model_path is a folder, then this parameter specifies the |
type of model to be loaded |
**kwargs: dict |
Keyword arguments |
Returns |
------- |
body_model: nn.Module |
The PyTorch module that implements the corresponding body model |
Raises |
------ |
ValueError: In case the model type is not one of SMPL, SMPLH, |
""" |
if osp.isdir(model_path): |
model_path = os.path.join(model_path, model_type) |
else: |
model_type = osp.basename(model_path).split("_")[0].lower() |
if model_type.lower() == "smpl": |
return SMPLLayer(model_path, **kwargs) |
elif model_type.lower() == "smplh": |
return SMPLHLayer(model_path, **kwargs) |
elif model_type.lower() == "smplx": |
return SMPLXLayer(model_path, **kwargs) |
elif "mano" in model_type.lower(): |
return MANOLayer(model_path, **kwargs) |
elif "flame" in model_type.lower(): |
return FLAMELayer(model_path, **kwargs) |
else: |
raise ValueError(f"Unknown model type {model_type}, exiting!") |
def create(model_path: str, |
model_type: str = "smpl", |
**kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: |
"""Method for creating a model from a path and a model type |
Parameters |
---------- |
model_path: str |
Either the path to the model you wish to load or a folder, |
where each subfolder contains the differents types, i.e.: |
model_path: |
| |
|-- smpl |
|-- smplh |
|-- smplx |
|-- mano |
model_type: str, optional |
When model_path is a folder, then this parameter specifies the |
type of model to be loaded |
**kwargs: dict |
Keyword arguments |
Returns |
------- |
body_model: nn.Module |
The PyTorch module that implements the corresponding body model |
Raises |
------ |
ValueError: In case the model type is not one of SMPL, SMPLH, |
""" |
if osp.isdir(model_path): |
model_path = os.path.join(model_path, model_type) |
else: |
model_type = osp.basename(model_path).split("_")[0].lower() |
if model_type.lower() == "smpl": |
return SMPL(model_path, **kwargs) |
elif model_type.lower() == "smplh": |
return SMPLH(model_path, **kwargs) |
elif model_type.lower() == "smplx": |
return SMPLX(model_path, **kwargs) |
elif "mano" in model_type.lower(): |
return MANO(model_path, **kwargs) |
elif "flame" in model_type.lower(): |
return FLAME(model_path, **kwargs) |
else: |
raise ValueError(f"Unknown model type {model_type}, exiting!") |