Techt3o's picture
ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe
c87d1bc verified
raw
history blame
8.81 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
from torch import nn
import numpy as np
from configs import constants as _C
from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator,
rollout_global_motion, reset_root_velocity, compute_camera_motion)
from lib.utils.transforms import axis_angle_to_matrix
class Network(nn.Module):
def __init__(self,
smpl,
pose_dr=0.1,
d_embed=512,
n_layers=3,
d_feat=2048,
rnn_type='LSTM',
**kwargs
):
super().__init__()
n_joints = _C.KEYPOINTS.NUM_JOINTS
self.smpl = smpl
in_dim = n_joints * 2 + 3
d_context = d_embed + n_joints * 3
self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2))
# Module 1. Motion Encoder
self.motion_encoder = MotionEncoder(in_dim=in_dim,
d_embed=d_embed,
pose_dr=pose_dr,
rnn_type=rnn_type,
n_layers=n_layers,
n_joints=n_joints)
self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context,
rnn_type=rnn_type,
n_layers=n_layers)
# Module 3. Feature Integrator
self.integrator = Integrator(in_channel=d_feat + d_context,
out_channel=d_context)
# Module 4. Motion Decoder
self.motion_decoder = MotionDecoder(d_embed=d_context,
rnn_type=rnn_type,
n_layers=n_layers)
# Module 5. Trajectory Refiner
self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context,
d_hidden=d_embed,
rnn_type=rnn_type,
n_layers=2)
def compute_global_feet(self, root_world, trans):
# # Compute world-coordinate motion
cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam)
feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3)
feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT
return feet_world, cam_R
def forward_smpl(self, **kwargs):
self.output = self.smpl(self.pred_pose,
self.pred_shape,
cam=self.pred_cam,
return_full_pose=not self.training,
**kwargs,
)
from loguru import logger
logger.info(f"Output Joints: {self.output.joints}")
logger.info(f"Output Vertices: {self.output.vertices}")
# Save joints and vertices as .npy arrays
np.save('joints.npy', self.output.joints.cpu().numpy())
np.save('vertices.npy', self.output.vertices.cpu().numpy())
# Feet location in global coordinate
root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel)
feet_world, cam_R = self.compute_global_feet(root_world, trans)
# Return output
output = {'feet': feet_world,
'contact': self.pred_contact,
'pose': self.pred_pose,
'betas': self.pred_shape,
'cam': self.pred_cam,
'poses_root_cam': self.output.global_orient,
'poses_root_r6d': self.pred_root,
'vel_root': self.pred_vel,
'pose_root': self.pred_root,
'verts_cam': self.output.vertices}
if self.training:
output.update({
'kp3d': self.output.joints,
'kp3d_nn': self.pred_kp3d,
'full_kp2d': self.output.full_joints2d,
'weak_kp2d': self.output.weak_joints2d,
'R': cam_R,
})
else:
output.update({
'poses_root_r6d': self.pred_root,
'trans_cam': self.output.full_cam,
'poses_body': self.output.body_pose})
return output
def preprocess(self, x, mask):
self.b, self.f = x.shape[:2]
# Treat masked keypoints
mask_embedding = mask.unsqueeze(-1) * self.mask_embedding
_mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1)
_mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1)
_mask_embedding = mask_embedding.reshape(self.b, self.f, -1)
_mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1)
x[_mask] = 0.0
x = x + _mask_embedding
return x
def rollout(self, output, pred_root, pred_vel, return_y_up):
root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
if return_y_up:
yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
root_world = yup2ydown.mT @ root_world
trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
output.update({
'poses_root_world': root_world,
'trans_world': trans_world,
})
return output
def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs):
# --------- Refine trajectory --------- #
update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5)
output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up)
# --------- #
# Do rollout
output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up)
# --------- Compute refined feet --------- #
if self.training:
feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world'])
output.update({'feet_refined': feet_world})
return output
def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None,
cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs):
x = self.preprocess(x, mask)
init_kp, init_smpl = inits
# --------- Inference --------- #
# Stage 1. Encode motion
pred_kp3d, motion_context = self.motion_encoder(x, init_kp)
self.old_motion_context = motion_context.detach().clone()
# Stage 2. Decode global trajectory
pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel)
# Stage 3. Integrate features
if img_features is not None and self.integrator is not None:
motion_context = self.integrator(motion_context, img_features)
# Stage 4. Decode SMPL motion
pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl)
# --------- #
# --------- Register predictions --------- #
self.pred_kp3d = pred_kp3d
self.pred_root = pred_root
self.pred_vel = pred_vel
self.pred_pose = pred_pose
self.pred_shape = pred_shape
self.pred_cam = pred_cam
self.pred_contact = pred_contact
# --------- #
# --------- Build SMPL --------- #
output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res)
# --------- #
# --------- Refine trajectory --------- #
if refine_traj:
output = self.refine_trajectory(output, cam_angvel, return_y_up)
else:
output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up)
# --------- #
return output