Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
from torch import nn | |
import numpy as np | |
from configs import constants as _C | |
from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator, | |
rollout_global_motion, reset_root_velocity, compute_camera_motion) | |
from lib.utils.transforms import axis_angle_to_matrix | |
class Network(nn.Module): | |
def __init__(self, | |
smpl, | |
pose_dr=0.1, | |
d_embed=512, | |
n_layers=3, | |
d_feat=2048, | |
rnn_type='LSTM', | |
**kwargs | |
): | |
super().__init__() | |
n_joints = _C.KEYPOINTS.NUM_JOINTS | |
self.smpl = smpl | |
in_dim = n_joints * 2 + 3 | |
d_context = d_embed + n_joints * 3 | |
self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2)) | |
# Module 1. Motion Encoder | |
self.motion_encoder = MotionEncoder(in_dim=in_dim, | |
d_embed=d_embed, | |
pose_dr=pose_dr, | |
rnn_type=rnn_type, | |
n_layers=n_layers, | |
n_joints=n_joints) | |
self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context, | |
rnn_type=rnn_type, | |
n_layers=n_layers) | |
# Module 3. Feature Integrator | |
self.integrator = Integrator(in_channel=d_feat + d_context, | |
out_channel=d_context) | |
# Module 4. Motion Decoder | |
self.motion_decoder = MotionDecoder(d_embed=d_context, | |
rnn_type=rnn_type, | |
n_layers=n_layers) | |
# Module 5. Trajectory Refiner | |
self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context, | |
d_hidden=d_embed, | |
rnn_type=rnn_type, | |
n_layers=2) | |
def compute_global_feet(self, root_world, trans): | |
# # Compute world-coordinate motion | |
cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam) | |
feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3) | |
feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT | |
return feet_world, cam_R | |
def forward_smpl(self, **kwargs): | |
self.output = self.smpl(self.pred_pose, | |
self.pred_shape, | |
cam=self.pred_cam, | |
return_full_pose=not self.training, | |
**kwargs, | |
) | |
from loguru import logger | |
logger.info(f"Output Joints: {self.output.joints}") | |
logger.info(f"Output Vertices: {self.output.vertices}") | |
# Save joints and vertices as .npy arrays | |
np.save('joints.npy', self.output.joints.cpu().numpy()) | |
np.save('vertices.npy', self.output.vertices.cpu().numpy()) | |
# Feet location in global coordinate | |
root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel) | |
feet_world, cam_R = self.compute_global_feet(root_world, trans) | |
# Return output | |
output = {'feet': feet_world, | |
'contact': self.pred_contact, | |
'pose': self.pred_pose, | |
'betas': self.pred_shape, | |
'cam': self.pred_cam, | |
'poses_root_cam': self.output.global_orient, | |
'poses_root_r6d': self.pred_root, | |
'vel_root': self.pred_vel, | |
'pose_root': self.pred_root, | |
'verts_cam': self.output.vertices} | |
if self.training: | |
output.update({ | |
'kp3d': self.output.joints, | |
'kp3d_nn': self.pred_kp3d, | |
'full_kp2d': self.output.full_joints2d, | |
'weak_kp2d': self.output.weak_joints2d, | |
'R': cam_R, | |
}) | |
else: | |
output.update({ | |
'poses_root_r6d': self.pred_root, | |
'trans_cam': self.output.full_cam, | |
'poses_body': self.output.body_pose}) | |
return output | |
def preprocess(self, x, mask): | |
self.b, self.f = x.shape[:2] | |
# Treat masked keypoints | |
mask_embedding = mask.unsqueeze(-1) * self.mask_embedding | |
_mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1) | |
_mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1) | |
_mask_embedding = mask_embedding.reshape(self.b, self.f, -1) | |
_mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1) | |
x[_mask] = 0.0 | |
x = x + _mask_embedding | |
return x | |
def rollout(self, output, pred_root, pred_vel, return_y_up): | |
root_world, trans_world = rollout_global_motion(pred_root, pred_vel) | |
if return_y_up: | |
yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device) | |
root_world = yup2ydown.mT @ root_world | |
trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1) | |
output.update({ | |
'poses_root_world': root_world, | |
'trans_world': trans_world, | |
}) | |
return output | |
def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs): | |
# --------- Refine trajectory --------- # | |
update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5) | |
output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up) | |
# --------- # | |
# Do rollout | |
output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up) | |
# --------- Compute refined feet --------- # | |
if self.training: | |
feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world']) | |
output.update({'feet_refined': feet_world}) | |
return output | |
def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None, | |
cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs): | |
x = self.preprocess(x, mask) | |
init_kp, init_smpl = inits | |
# --------- Inference --------- # | |
# Stage 1. Encode motion | |
pred_kp3d, motion_context = self.motion_encoder(x, init_kp) | |
self.old_motion_context = motion_context.detach().clone() | |
# Stage 2. Decode global trajectory | |
pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel) | |
# Stage 3. Integrate features | |
if img_features is not None and self.integrator is not None: | |
motion_context = self.integrator(motion_context, img_features) | |
# Stage 4. Decode SMPL motion | |
pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl) | |
# --------- # | |
# --------- Register predictions --------- # | |
self.pred_kp3d = pred_kp3d | |
self.pred_root = pred_root | |
self.pred_vel = pred_vel | |
self.pred_pose = pred_pose | |
self.pred_shape = pred_shape | |
self.pred_cam = pred_cam | |
self.pred_contact = pred_contact | |
# --------- # | |
# --------- Build SMPL --------- # | |
output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res) | |
# --------- # | |
# --------- Refine trajectory --------- # | |
if refine_traj: | |
output = self.refine_trajectory(output, cam_angvel, return_y_up) | |
else: | |
output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up) | |
# --------- # | |
return output |