Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
import numpy as np | |
from torch import nn | |
from configs import constants as _C | |
from .utils import rollout_global_motion | |
from lib.utils.transforms import axis_angle_to_matrix | |
class Regressor(nn.Module): | |
def __init__(self, in_dim, hid_dim, out_dims, init_dim, layer='LSTM', n_layers=2, n_iters=1): | |
super().__init__() | |
self.n_outs = len(out_dims) | |
self.rnn = getattr(nn, layer.upper())( | |
in_dim + init_dim, hid_dim, n_layers, | |
bidirectional=False, batch_first=True, dropout=0.3) | |
for i, out_dim in enumerate(out_dims): | |
setattr(self, 'declayer%d'%i, nn.Linear(hid_dim, out_dim)) | |
nn.init.xavier_uniform_(getattr(self, 'declayer%d'%i).weight, gain=0.01) | |
def forward(self, x, inits, h0): | |
xc = torch.cat([x, *inits], dim=-1) | |
xc, h0 = self.rnn(xc, h0) | |
preds = [] | |
for j in range(self.n_outs): | |
out = getattr(self, 'declayer%d'%j)(xc) | |
preds.append(out) | |
return preds, xc, h0 | |
class NeuralInitialization(nn.Module): | |
def __init__(self, in_dim, hid_dim, layer, n_layers): | |
super().__init__() | |
out_dim = hid_dim | |
self.n_layers = n_layers | |
self.num_inits = int(layer.upper() == 'LSTM') + 1 | |
out_dim *= self.num_inits * n_layers | |
self.linear1 = nn.Linear(in_dim, hid_dim) | |
self.linear2 = nn.Linear(hid_dim, hid_dim * self.n_layers) | |
self.linear3 = nn.Linear(hid_dim * self.n_layers, out_dim) | |
self.relu1 = nn.ReLU() | |
self.relu2 = nn.ReLU() | |
def forward(self, x): | |
b = x.shape[0] | |
out = self.linear3(self.relu2(self.linear2(self.relu1(self.linear1(x))))) | |
out = out.view(b, self.num_inits, self.n_layers, -1).permute(1, 2, 0, 3).contiguous() | |
if self.num_inits == 2: | |
return tuple([_ for _ in out]) | |
return out[0] | |
class Integrator(nn.Module): | |
def __init__(self, in_channel, out_channel, hid_channel=1024): | |
super().__init__() | |
self.layer1 = nn.Linear(in_channel, hid_channel) | |
self.relu1 = nn.ReLU() | |
self.dr1 = nn.Dropout(0.1) | |
self.layer2 = nn.Linear(hid_channel, hid_channel) | |
self.relu2 = nn.ReLU() | |
self.dr2 = nn.Dropout(0.1) | |
self.layer3 = nn.Linear(hid_channel, out_channel) | |
def forward(self, x, feat): | |
res = x | |
mask = (feat != 0).all(dim=-1).all(dim=-1) | |
out = torch.cat((x, feat), dim=-1) | |
out = self.layer1(out) | |
out = self.relu1(out) | |
out = self.dr1(out) | |
out = self.layer2(out) | |
out = self.relu2(out) | |
out = self.dr2(out) | |
out = self.layer3(out) | |
out[mask] = out[mask] + res[mask] | |
return out | |
class MotionEncoder(nn.Module): | |
def __init__(self, | |
in_dim, | |
d_embed, | |
pose_dr, | |
rnn_type, | |
n_layers, | |
n_joints): | |
super().__init__() | |
self.n_joints = n_joints | |
self.embed_layer = nn.Linear(in_dim, d_embed) | |
self.pos_drop = nn.Dropout(pose_dr) | |
# Keypoints initializer | |
self.neural_init = NeuralInitialization(n_joints * 3 + in_dim, d_embed, rnn_type, n_layers) | |
# 3d keypoints regressor | |
self.regressor = Regressor( | |
d_embed, d_embed, [n_joints * 3], n_joints * 3, rnn_type, n_layers) | |
def forward(self, x, init): | |
""" Forward pass of motion encoder. | |
""" | |
self.b, self.f = x.shape[:2] | |
x = self.embed_layer(x.reshape(self.b, self.f, -1)) | |
x = self.pos_drop(x) | |
h0 = self.neural_init(init) | |
pred_list = [init[..., :self.n_joints * 3]] | |
motion_context_list = [] | |
for i in range(self.f): | |
(pred_kp3d, ), motion_context, h0 = self.regressor(x[:, [i]], pred_list[-1:], h0) | |
motion_context_list.append(motion_context) | |
pred_list.append(pred_kp3d) | |
pred_kp3d = torch.cat(pred_list[1:], dim=1).view(self.b, self.f, -1, 3) | |
motion_context = torch.cat(motion_context_list, dim=1) | |
# Merge 3D keypoints with motion context | |
motion_context = torch.cat((motion_context, pred_kp3d.reshape(self.b, self.f, -1)), dim=-1) | |
return pred_kp3d, motion_context | |
class TrajectoryDecoder(nn.Module): | |
def __init__(self, | |
d_embed, | |
rnn_type, | |
n_layers): | |
super().__init__() | |
# Trajectory regressor | |
self.regressor = Regressor( | |
d_embed, d_embed, [3, 6], 12, rnn_type, n_layers, ) | |
def forward(self, x, root, cam_a, h0=None): | |
""" Forward pass of trajectory decoder. | |
""" | |
b, f = x.shape[:2] | |
pred_root_list, pred_vel_list = [root[:, :1]], [] | |
for i in range(f): | |
# Global coordinate estimation | |
(pred_rootv, pred_rootr), _, h0 = self.regressor( | |
x[:, [i]], [pred_root_list[-1], cam_a[:, [i]]], h0) | |
pred_root_list.append(pred_rootr) | |
pred_vel_list.append(pred_rootv) | |
pred_root = torch.cat(pred_root_list, dim=1).view(b, f + 1, -1) | |
pred_vel = torch.cat(pred_vel_list, dim=1).view(b, f, -1) | |
return pred_root, pred_vel | |
class MotionDecoder(nn.Module): | |
def __init__(self, | |
d_embed, | |
rnn_type, | |
n_layers): | |
super().__init__() | |
self.n_pose = 24 | |
# SMPL pose initialization | |
self.neural_init = NeuralInitialization(len(_C.BMODEL.MAIN_JOINTS) * 6, d_embed, rnn_type, n_layers) | |
# 3d keypoints regressor | |
self.regressor = Regressor( | |
d_embed, d_embed, [self.n_pose * 6, 10, 3, 4], self.n_pose * 6, rnn_type, n_layers) | |
def forward(self, x, init): | |
""" Forward pass of motion decoder. | |
""" | |
b, f = x.shape[:2] | |
h0 = self.neural_init(init[:, :, _C.BMODEL.MAIN_JOINTS].reshape(b, 1, -1)) | |
# Recursive prediction of SMPL parameters | |
pred_pose_list = [init.reshape(b, 1, -1)] | |
pred_shape_list, pred_cam_list, pred_contact_list = [], [], [] | |
for i in range(f): | |
# Camera coordinate estimation | |
(pred_pose, pred_shape, pred_cam, pred_contact), _, h0 = self.regressor(x[:, [i]], pred_pose_list[-1:], h0) | |
pred_pose_list.append(pred_pose) | |
pred_shape_list.append(pred_shape) | |
pred_cam_list.append(pred_cam) | |
pred_contact_list.append(pred_contact) | |
pred_pose = torch.cat(pred_pose_list[1:], dim=1).view(b, f, -1) | |
pred_shape = torch.cat(pred_shape_list, dim=1).view(b, f, -1) | |
pred_cam = torch.cat(pred_cam_list, dim=1).view(b, f, -1) | |
pred_contact = torch.cat(pred_contact_list, dim=1).view(b, f, -1) | |
return pred_pose, pred_shape, pred_cam, pred_contact | |
class TrajectoryRefiner(nn.Module): | |
def __init__(self, | |
d_embed, | |
d_hidden, | |
rnn_type, | |
n_layers): | |
super().__init__() | |
d_input = d_embed + 12 | |
self.refiner = Regressor( | |
d_input, d_hidden, [6, 3], 9, rnn_type, n_layers) | |
def forward(self, context, pred_vel, output, cam_angvel, return_y_up): | |
b, f = context.shape[:2] | |
# Register values | |
pred_root = output['poses_root_r6d'].clone().detach() | |
feet = output['feet'].clone().detach() | |
contact = output['contact'].clone().detach() | |
feet_vel = torch.cat((torch.zeros_like(feet[:, :1]), feet[:, 1:] - feet[:, :-1]), dim=1) * 30 # Normalize to 30 times | |
feet = (feet_vel * contact.unsqueeze(-1)).reshape(b, f, -1) # Velocity input | |
inpt_feat = torch.cat([context, feet], dim=-1) | |
(delta_root, delta_vel), _, _ = self.refiner(inpt_feat, [pred_root[:, 1:], pred_vel], h0=None) | |
pred_root[:, 1:] = pred_root[:, 1:] + delta_root | |
pred_vel = pred_vel + delta_vel | |
# root_world, trans_world = rollout_global_motion(pred_root, pred_vel) | |
# if return_y_up: | |
# yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device) | |
# root_world = yup2ydown.mT @ root_world | |
# trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1) | |
output.update({ | |
'poses_root_r6d_refined': pred_root, | |
'vel_root_refined': pred_vel, | |
# 'poses_root_world': root_world, | |
# 'trans_world': trans_world, | |
}) | |
return output |