Spaces:
Sleeping
Sleeping
File size: 8,813 Bytes
c87d1bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
from torch import nn
import numpy as np
from configs import constants as _C
from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator,
rollout_global_motion, reset_root_velocity, compute_camera_motion)
from lib.utils.transforms import axis_angle_to_matrix
class Network(nn.Module):
def __init__(self,
smpl,
pose_dr=0.1,
d_embed=512,
n_layers=3,
d_feat=2048,
rnn_type='LSTM',
**kwargs
):
super().__init__()
n_joints = _C.KEYPOINTS.NUM_JOINTS
self.smpl = smpl
in_dim = n_joints * 2 + 3
d_context = d_embed + n_joints * 3
self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2))
# Module 1. Motion Encoder
self.motion_encoder = MotionEncoder(in_dim=in_dim,
d_embed=d_embed,
pose_dr=pose_dr,
rnn_type=rnn_type,
n_layers=n_layers,
n_joints=n_joints)
self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context,
rnn_type=rnn_type,
n_layers=n_layers)
# Module 3. Feature Integrator
self.integrator = Integrator(in_channel=d_feat + d_context,
out_channel=d_context)
# Module 4. Motion Decoder
self.motion_decoder = MotionDecoder(d_embed=d_context,
rnn_type=rnn_type,
n_layers=n_layers)
# Module 5. Trajectory Refiner
self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context,
d_hidden=d_embed,
rnn_type=rnn_type,
n_layers=2)
def compute_global_feet(self, root_world, trans):
# # Compute world-coordinate motion
cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam)
feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3)
feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT
return feet_world, cam_R
def forward_smpl(self, **kwargs):
self.output = self.smpl(self.pred_pose,
self.pred_shape,
cam=self.pred_cam,
return_full_pose=not self.training,
**kwargs,
)
from loguru import logger
logger.info(f"Output Joints: {self.output.joints}")
logger.info(f"Output Vertices: {self.output.vertices}")
# Save joints and vertices as .npy arrays
np.save('joints.npy', self.output.joints.cpu().numpy())
np.save('vertices.npy', self.output.vertices.cpu().numpy())
# Feet location in global coordinate
root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel)
feet_world, cam_R = self.compute_global_feet(root_world, trans)
# Return output
output = {'feet': feet_world,
'contact': self.pred_contact,
'pose': self.pred_pose,
'betas': self.pred_shape,
'cam': self.pred_cam,
'poses_root_cam': self.output.global_orient,
'poses_root_r6d': self.pred_root,
'vel_root': self.pred_vel,
'pose_root': self.pred_root,
'verts_cam': self.output.vertices}
if self.training:
output.update({
'kp3d': self.output.joints,
'kp3d_nn': self.pred_kp3d,
'full_kp2d': self.output.full_joints2d,
'weak_kp2d': self.output.weak_joints2d,
'R': cam_R,
})
else:
output.update({
'poses_root_r6d': self.pred_root,
'trans_cam': self.output.full_cam,
'poses_body': self.output.body_pose})
return output
def preprocess(self, x, mask):
self.b, self.f = x.shape[:2]
# Treat masked keypoints
mask_embedding = mask.unsqueeze(-1) * self.mask_embedding
_mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1)
_mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1)
_mask_embedding = mask_embedding.reshape(self.b, self.f, -1)
_mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1)
x[_mask] = 0.0
x = x + _mask_embedding
return x
def rollout(self, output, pred_root, pred_vel, return_y_up):
root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
if return_y_up:
yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
root_world = yup2ydown.mT @ root_world
trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
output.update({
'poses_root_world': root_world,
'trans_world': trans_world,
})
return output
def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs):
# --------- Refine trajectory --------- #
update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5)
output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up)
# --------- #
# Do rollout
output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up)
# --------- Compute refined feet --------- #
if self.training:
feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world'])
output.update({'feet_refined': feet_world})
return output
def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None,
cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs):
x = self.preprocess(x, mask)
init_kp, init_smpl = inits
# --------- Inference --------- #
# Stage 1. Encode motion
pred_kp3d, motion_context = self.motion_encoder(x, init_kp)
self.old_motion_context = motion_context.detach().clone()
# Stage 2. Decode global trajectory
pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel)
# Stage 3. Integrate features
if img_features is not None and self.integrator is not None:
motion_context = self.integrator(motion_context, img_features)
# Stage 4. Decode SMPL motion
pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl)
# --------- #
# --------- Register predictions --------- #
self.pred_kp3d = pred_kp3d
self.pred_root = pred_root
self.pred_vel = pred_vel
self.pred_pose = pred_pose
self.pred_shape = pred_shape
self.pred_cam = pred_cam
self.pred_contact = pred_contact
# --------- #
# --------- Build SMPL --------- #
output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res)
# --------- #
# --------- Refine trajectory --------- #
if refine_traj:
output = self.refine_trajectory(output, cam_angvel, return_y_up)
else:
output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up)
# --------- #
return output |