Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import os, sys | |
import torch | |
import numpy as np | |
from lib.utils import transforms | |
from smplx import SMPL as _SMPL | |
from smplx.utils import SMPLOutput as ModelOutput | |
from smplx.lbs import vertices2joints | |
from configs import constants as _C | |
class SMPL(_SMPL): | |
""" Extension of the official SMPL implementation to support more joints """ | |
def __init__(self, *args, **kwargs): | |
sys.stdout = open(os.devnull, 'w') | |
super(SMPL, self).__init__(*args, **kwargs) | |
sys.stdout = sys.__stdout__ | |
J_regressor_wham = np.load(_C.BMODEL.JOINTS_REGRESSOR_WHAM) | |
J_regressor_eval = np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M) | |
self.register_buffer('J_regressor_wham', torch.tensor( | |
J_regressor_wham, dtype=torch.float32)) | |
self.register_buffer('J_regressor_eval', torch.tensor( | |
J_regressor_eval, dtype=torch.float32)) | |
self.register_buffer('J_regressor_feet', torch.from_numpy( | |
np.load(_C.BMODEL.JOINTS_REGRESSOR_FEET) | |
).float()) | |
def get_local_pose_from_reduced_global_pose(self, reduced_pose): | |
full_pose = torch.eye( | |
3, device=reduced_pose.device | |
)[(None, ) * 2].repeat(reduced_pose.shape[0], 24, 1, 1) | |
full_pose[:, _C.BMODEL.MAIN_JOINTS] = reduced_pose | |
return full_pose | |
def forward(self, | |
pred_rot6d, | |
betas, | |
cam=None, | |
cam_intrinsics=None, | |
bbox=None, | |
res=None, | |
return_full_pose=False, | |
**kwargs): | |
rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6) | |
).reshape(-1, 24, 3, 3) | |
output = self.get_output(body_pose=rotmat[:, 1:], | |
global_orient=rotmat[:, :1], | |
betas=betas.view(-1, 10), | |
pose2rot=False, | |
return_full_pose=return_full_pose) | |
if cam is not None: | |
joints3d = output.joints.reshape(*cam.shape[:2], -1, 3) | |
# Weak perspective projection (for InstaVariety) | |
weak_cam = convert_weak_perspective_to_perspective(cam) | |
weak_joints2d = weak_perspective_projection( | |
joints3d, | |
rotation=torch.eye(3, device=cam.device).unsqueeze(0).unsqueeze(0).expand(*cam.shape[:2], -1, -1), | |
translation=weak_cam, | |
focal_length=5000., | |
camera_center=torch.zeros(*cam.shape[:2], 2, device=cam.device) | |
) | |
output.weak_joints2d = weak_joints2d | |
# Full perspective projection | |
full_cam = convert_pare_to_full_img_cam( | |
cam, | |
bbox[:, :, 2] * 200., | |
bbox[:, :, :2], | |
res[:, 0].unsqueeze(-1), | |
res[:, 1].unsqueeze(-1), | |
focal_length=cam_intrinsics[:, :, 0, 0] | |
) | |
full_joints2d = full_perspective_projection( | |
joints3d, | |
translation=full_cam, | |
cam_intrinsics=cam_intrinsics, | |
) | |
output.full_joints2d = full_joints2d | |
output.full_cam = full_cam.reshape(-1, 3) | |
return output | |
def forward_nd(self, | |
pred_rot6d, | |
root, | |
betas, | |
return_full_pose=False): | |
rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6) | |
).reshape(-1, 24, 3, 3) | |
output = self.get_output(body_pose=rotmat[:, 1:], | |
global_orient=root.reshape(-1, 1, 3, 3), | |
betas=betas.view(-1, 10), | |
pose2rot=False, | |
return_full_pose=return_full_pose) | |
return output | |
def get_output(self, *args, **kwargs): | |
kwargs['get_skin'] = True | |
smpl_output = super(SMPL, self).forward(*args, **kwargs) | |
joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices) | |
feet = vertices2joints(self.J_regressor_feet, smpl_output.vertices) | |
offset = joints[..., [11, 12], :].mean(-2) | |
if 'transl' in kwargs: | |
offset = offset - kwargs['transl'] | |
vertices = smpl_output.vertices - offset.unsqueeze(-2) | |
joints = joints - offset.unsqueeze(-2) | |
feet = feet - offset.unsqueeze(-2) | |
output = ModelOutput(vertices=vertices, | |
global_orient=smpl_output.global_orient, | |
body_pose=smpl_output.body_pose, | |
joints=joints, | |
betas=smpl_output.betas, | |
full_pose=smpl_output.full_pose) | |
output.feet = feet | |
output.offset = offset | |
return output | |
def get_offset(self, *args, **kwargs): | |
kwargs['get_skin'] = True | |
smpl_output = super(SMPL, self).forward(*args, **kwargs) | |
joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices) | |
offset = joints[..., [11, 12], :].mean(-2) | |
return offset | |
def get_faces(self): | |
return np.array(self.faces) | |
def convert_weak_perspective_to_perspective( | |
weak_perspective_camera, | |
focal_length=5000., | |
img_res=224, | |
): | |
perspective_camera = torch.stack( | |
[ | |
weak_perspective_camera[..., 1], | |
weak_perspective_camera[..., 2], | |
2 * focal_length / (img_res * weak_perspective_camera[..., 0] + 1e-9) | |
], | |
dim=-1 | |
) | |
return perspective_camera | |
def weak_perspective_projection( | |
points, | |
rotation, | |
translation, | |
focal_length, | |
camera_center, | |
img_res=224, | |
normalize_joints2d=True, | |
): | |
""" | |
This function computes the perspective projection of a set of points. | |
Input: | |
points (b, f, N, 3): 3D points | |
rotation (b, f, 3, 3): Camera rotation | |
translation (b, f, 3): Camera translation | |
focal_length (b, f,) or scalar: Focal length | |
camera_center (b, f, 2): Camera center | |
""" | |
K = torch.zeros([*points.shape[:2], 3, 3], device=points.device) | |
K[:,:,0,0] = focal_length | |
K[:,:,1,1] = focal_length | |
K[:,:,2,2] = 1. | |
K[:,:,:-1, -1] = camera_center | |
# Transform points | |
points = torch.einsum('bfij,bfkj->bfki', rotation, points) | |
points = points + translation.unsqueeze(-2) | |
# Apply perspective distortion | |
projected_points = points / points[...,-1].unsqueeze(-1) | |
# Apply camera intrinsics | |
projected_points = torch.einsum('bfij,bfkj->bfki', K, projected_points) | |
if normalize_joints2d: | |
projected_points = projected_points / (img_res / 2.) | |
return projected_points[..., :-1] | |
def full_perspective_projection( | |
points, | |
cam_intrinsics, | |
rotation=None, | |
translation=None, | |
): | |
K = cam_intrinsics | |
if rotation is not None: | |
points = (rotation @ points.transpose(-1, -2)).transpose(-1, -2) | |
if translation is not None: | |
points = points + translation.unsqueeze(-2) | |
projected_points = points / points[..., -1].unsqueeze(-1) | |
projected_points = (K @ projected_points.transpose(-1, -2)).transpose(-1, -2) | |
return projected_points[..., :-1] | |
def convert_pare_to_full_img_cam( | |
pare_cam, | |
bbox_height, | |
bbox_center, | |
img_w, | |
img_h, | |
focal_length, | |
crop_res=224 | |
): | |
s, tx, ty = pare_cam[..., 0], pare_cam[..., 1], pare_cam[..., 2] | |
res = crop_res | |
r = bbox_height / res | |
tz = 2 * focal_length / (r * res * s) | |
cx = 2 * (bbox_center[..., 0] - (img_w / 2.)) / (s * bbox_height) | |
cy = 2 * (bbox_center[..., 1] - (img_h / 2.)) / (s * bbox_height) | |
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) | |
return cam_t | |
def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length): | |
""" | |
convert the camera parameters from the crop camera to the full camera | |
:param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty) | |
:param center: shape=(N, 2) bbox coordinates (c_x, c_y) | |
:param scale: shape=(N) square bbox resolution (b / 200) | |
:param full_img_shape: shape=(N, 2) original image height and width | |
:param focal_length: shape=(N,) | |
:return: | |
""" | |
img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1] | |
cx, cy, b = center[:, 0], center[:, 1], scale * 200 | |
w_2, h_2 = img_w / 2., img_h / 2. | |
bs = b * crop_cam[:, 0] + 1e-9 | |
tz = 2 * focal_length / bs | |
tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1] | |
ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2] | |
full_cam = torch.stack([tx, ty, tz], dim=-1) | |
return full_cam |