from __future__ import absolute_import from __future__ import print_function from __future__ import division import os, sys import torch import numpy as np from lib.utils import transforms from smplx import SMPL as _SMPL from smplx.utils import SMPLOutput as ModelOutput from smplx.lbs import vertices2joints from configs import constants as _C class SMPL(_SMPL): """ Extension of the official SMPL implementation to support more joints """ def __init__(self, *args, **kwargs): sys.stdout = open(os.devnull, 'w') super(SMPL, self).__init__(*args, **kwargs) sys.stdout = sys.__stdout__ J_regressor_wham = np.load(_C.BMODEL.JOINTS_REGRESSOR_WHAM) J_regressor_eval = np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M) self.register_buffer('J_regressor_wham', torch.tensor( J_regressor_wham, dtype=torch.float32)) self.register_buffer('J_regressor_eval', torch.tensor( J_regressor_eval, dtype=torch.float32)) self.register_buffer('J_regressor_feet', torch.from_numpy( np.load(_C.BMODEL.JOINTS_REGRESSOR_FEET) ).float()) def get_local_pose_from_reduced_global_pose(self, reduced_pose): full_pose = torch.eye( 3, device=reduced_pose.device )[(None, ) * 2].repeat(reduced_pose.shape[0], 24, 1, 1) full_pose[:, _C.BMODEL.MAIN_JOINTS] = reduced_pose return full_pose def forward(self, pred_rot6d, betas, cam=None, cam_intrinsics=None, bbox=None, res=None, return_full_pose=False, **kwargs): rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6) ).reshape(-1, 24, 3, 3) output = self.get_output(body_pose=rotmat[:, 1:], global_orient=rotmat[:, :1], betas=betas.view(-1, 10), pose2rot=False, return_full_pose=return_full_pose) if cam is not None: joints3d = output.joints.reshape(*cam.shape[:2], -1, 3) # Weak perspective projection (for InstaVariety) weak_cam = convert_weak_perspective_to_perspective(cam) weak_joints2d = weak_perspective_projection( joints3d, rotation=torch.eye(3, device=cam.device).unsqueeze(0).unsqueeze(0).expand(*cam.shape[:2], -1, -1), translation=weak_cam, focal_length=5000., camera_center=torch.zeros(*cam.shape[:2], 2, device=cam.device) ) output.weak_joints2d = weak_joints2d # Full perspective projection full_cam = convert_pare_to_full_img_cam( cam, bbox[:, :, 2] * 200., bbox[:, :, :2], res[:, 0].unsqueeze(-1), res[:, 1].unsqueeze(-1), focal_length=cam_intrinsics[:, :, 0, 0] ) full_joints2d = full_perspective_projection( joints3d, translation=full_cam, cam_intrinsics=cam_intrinsics, ) output.full_joints2d = full_joints2d output.full_cam = full_cam.reshape(-1, 3) return output def forward_nd(self, pred_rot6d, root, betas, return_full_pose=False): rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6) ).reshape(-1, 24, 3, 3) output = self.get_output(body_pose=rotmat[:, 1:], global_orient=root.reshape(-1, 1, 3, 3), betas=betas.view(-1, 10), pose2rot=False, return_full_pose=return_full_pose) return output def get_output(self, *args, **kwargs): kwargs['get_skin'] = True smpl_output = super(SMPL, self).forward(*args, **kwargs) joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices) feet = vertices2joints(self.J_regressor_feet, smpl_output.vertices) offset = joints[..., [11, 12], :].mean(-2) if 'transl' in kwargs: offset = offset - kwargs['transl'] vertices = smpl_output.vertices - offset.unsqueeze(-2) joints = joints - offset.unsqueeze(-2) feet = feet - offset.unsqueeze(-2) output = ModelOutput(vertices=vertices, global_orient=smpl_output.global_orient, body_pose=smpl_output.body_pose, joints=joints, betas=smpl_output.betas, full_pose=smpl_output.full_pose) output.feet = feet output.offset = offset return output def get_offset(self, *args, **kwargs): kwargs['get_skin'] = True smpl_output = super(SMPL, self).forward(*args, **kwargs) joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices) offset = joints[..., [11, 12], :].mean(-2) return offset def get_faces(self): return np.array(self.faces) def convert_weak_perspective_to_perspective( weak_perspective_camera, focal_length=5000., img_res=224, ): perspective_camera = torch.stack( [ weak_perspective_camera[..., 1], weak_perspective_camera[..., 2], 2 * focal_length / (img_res * weak_perspective_camera[..., 0] + 1e-9) ], dim=-1 ) return perspective_camera def weak_perspective_projection( points, rotation, translation, focal_length, camera_center, img_res=224, normalize_joints2d=True, ): """ This function computes the perspective projection of a set of points. Input: points (b, f, N, 3): 3D points rotation (b, f, 3, 3): Camera rotation translation (b, f, 3): Camera translation focal_length (b, f,) or scalar: Focal length camera_center (b, f, 2): Camera center """ K = torch.zeros([*points.shape[:2], 3, 3], device=points.device) K[:,:,0,0] = focal_length K[:,:,1,1] = focal_length K[:,:,2,2] = 1. K[:,:,:-1, -1] = camera_center # Transform points points = torch.einsum('bfij,bfkj->bfki', rotation, points) points = points + translation.unsqueeze(-2) # Apply perspective distortion projected_points = points / points[...,-1].unsqueeze(-1) # Apply camera intrinsics projected_points = torch.einsum('bfij,bfkj->bfki', K, projected_points) if normalize_joints2d: projected_points = projected_points / (img_res / 2.) return projected_points[..., :-1] def full_perspective_projection( points, cam_intrinsics, rotation=None, translation=None, ): K = cam_intrinsics if rotation is not None: points = (rotation @ points.transpose(-1, -2)).transpose(-1, -2) if translation is not None: points = points + translation.unsqueeze(-2) projected_points = points / points[..., -1].unsqueeze(-1) projected_points = (K @ projected_points.transpose(-1, -2)).transpose(-1, -2) return projected_points[..., :-1] def convert_pare_to_full_img_cam( pare_cam, bbox_height, bbox_center, img_w, img_h, focal_length, crop_res=224 ): s, tx, ty = pare_cam[..., 0], pare_cam[..., 1], pare_cam[..., 2] res = crop_res r = bbox_height / res tz = 2 * focal_length / (r * res * s) cx = 2 * (bbox_center[..., 0] - (img_w / 2.)) / (s * bbox_height) cy = 2 * (bbox_center[..., 1] - (img_h / 2.)) / (s * bbox_height) cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) return cam_t def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length): """ convert the camera parameters from the crop camera to the full camera :param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty) :param center: shape=(N, 2) bbox coordinates (c_x, c_y) :param scale: shape=(N) square bbox resolution (b / 200) :param full_img_shape: shape=(N, 2) original image height and width :param focal_length: shape=(N,) :return: """ img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1] cx, cy, b = center[:, 0], center[:, 1], scale * 200 w_2, h_2 = img_w / 2., img_h / 2. bs = b * crop_cam[:, 0] + 1e-9 tz = 2 * focal_length / bs tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1] ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2] full_cam = torch.stack([tx, ty, tz], dim=-1) return full_cam