Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from configs import constants as _C | |
from lib.utils import transforms | |
from lib.utils.kp_utils import root_centering | |
class WHAMLoss(nn.Module): | |
def __init__( | |
self, | |
cfg=None, | |
device=None, | |
): | |
super(WHAMLoss, self).__init__() | |
self.cfg = cfg | |
self.n_joints = _C.KEYPOINTS.NUM_JOINTS | |
self.criterion = nn.MSELoss() | |
self.criterion_noreduce = nn.MSELoss(reduction='none') | |
self.pose_loss_weight = cfg.LOSS.POSE_LOSS_WEIGHT | |
self.shape_loss_weight = cfg.LOSS.SHAPE_LOSS_WEIGHT | |
self.keypoint_2d_loss_weight = cfg.LOSS.JOINT2D_LOSS_WEIGHT | |
self.keypoint_3d_loss_weight = cfg.LOSS.JOINT3D_LOSS_WEIGHT | |
self.cascaded_loss_weight = cfg.LOSS.CASCADED_LOSS_WEIGHT | |
self.vertices_loss_weight = cfg.LOSS.VERTS3D_LOSS_WEIGHT | |
self.contact_loss_weight = cfg.LOSS.CONTACT_LOSS_WEIGHT | |
self.root_vel_loss_weight = cfg.LOSS.ROOT_VEL_LOSS_WEIGHT | |
self.root_pose_loss_weight = cfg.LOSS.ROOT_POSE_LOSS_WEIGHT | |
self.sliding_loss_weight = cfg.LOSS.SLIDING_LOSS_WEIGHT | |
self.camera_loss_weight = cfg.LOSS.CAMERA_LOSS_WEIGHT | |
self.loss_weight = cfg.LOSS.LOSS_WEIGHT | |
kp_weights = [ | |
0.5, 0.5, 0.5, 0.5, 0.5, # Face | |
1.5, 1.5, 4, 4, 4, 4, # Arms | |
1.5, 1.5, 4, 4, 4, 4, # Legs | |
4, 4, 1.5, 1.5, 4, 4, # Legs | |
4, 4, 1.5, 1.5, 4, 4, # Arms | |
0.5, 0.5 # Head | |
] | |
theta_weights = [ | |
0.1, 1.0, 1.0, 1.0, 1.0, # pelvis, lhip, rhip, spine1, lknee | |
1.0, 1.0, 1.0, 1.0, 1.0, # rknn, spine2, lankle, rankle, spin3 | |
0.1, 0.1, # Foot | |
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # neck, lisldr, risldr, head, losldr, rosldr, | |
1.0, 1.0, 1.0, 1.0, # lelbow, relbow, lwrist, rwrist | |
0.1, 0.1, # Hand | |
] | |
self.theta_weights = torch.tensor([[theta_weights]]).float().to(device) | |
self.theta_weights /= self.theta_weights.mean() | |
self.kp_weights = torch.tensor([kp_weights]).float().to(device) | |
self.epoch = -1 | |
self.step() | |
def step(self): | |
self.epoch += 1 | |
self.skip_camera_loss = self.epoch < self.cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH | |
def forward(self, pred, gt): | |
loss = 0.0 | |
b, f = gt['kp3d'].shape[:2] | |
# <======= Predictions and Groundtruths | |
pred_betas = pred['betas'] | |
pred_pose = pred['pose'].reshape(b, f, -1, 6) | |
pred_kp3d_nn = pred['kp3d_nn'] | |
pred_kp3d_smpl = root_centering(pred['kp3d'].reshape(b, f, -1, 3)) | |
pred_full_kp2d = pred['full_kp2d'] | |
pred_weak_kp2d = pred['weak_kp2d'] | |
pred_contact = pred['contact'] | |
pred_vel_root = pred['vel_root'] | |
pred_pose_root = pred['poses_root_r6d'][:, 1:] | |
pred_vel_root_ref = pred['vel_root_refined'] | |
pred_pose_root_ref = pred['poses_root_r6d_refined'][:, 1:] | |
pred_cam_r = transforms.matrix_to_rotation_6d(pred['R']) | |
gt_betas = gt['betas'] | |
gt_pose = gt['pose'] | |
gt_kp3d = root_centering(gt['kp3d']) | |
gt_full_kp2d = gt['full_kp2d'] | |
gt_weak_kp2d = gt['weak_kp2d'] | |
gt_contact = gt['contact'] | |
gt_vel_root = gt['vel_root'] | |
gt_pose_root = gt['pose_root'][:, 1:] | |
gt_cam_angvel = gt['cam_angvel'] | |
gt_cam_r = transforms.matrix_to_rotation_6d(gt['R'][:, 1:]) | |
bbox = gt['bbox'] | |
# =======> | |
loss_keypoints_full = full_projected_keypoint_loss( | |
pred_full_kp2d, | |
gt_full_kp2d, | |
bbox, | |
self.kp_weights, | |
criterion=self.criterion_noreduce, | |
) | |
loss_keypoints_weak = weak_projected_keypoint_loss( | |
pred_weak_kp2d, | |
gt_weak_kp2d, | |
self.kp_weights, | |
criterion=self.criterion_noreduce | |
) | |
# Compute 3D keypoint loss | |
loss_keypoints_3d_nn = keypoint_3d_loss( | |
pred_kp3d_nn, | |
gt_kp3d[:, :, :self.n_joints], | |
self.kp_weights[:, :self.n_joints], | |
criterion=self.criterion_noreduce, | |
) | |
loss_keypoints_3d_smpl = keypoint_3d_loss( | |
pred_kp3d_smpl, | |
gt_kp3d, | |
self.kp_weights, | |
criterion=self.criterion_noreduce, | |
) | |
loss_cascaded = keypoint_3d_loss( | |
pred_kp3d_nn, | |
torch.cat((pred_kp3d_smpl[:, :, :self.n_joints], gt_kp3d[:, :, :self.n_joints, -1:]), dim=-1), | |
self.kp_weights[:, :self.n_joints] * 0.5, | |
criterion=self.criterion_noreduce, | |
) | |
loss_vertices = vertices_loss( | |
pred['verts_cam'], | |
gt['verts'], | |
gt['has_verts'], | |
criterion=self.criterion_noreduce, | |
) | |
# Compute loss on SMPL parameters | |
smpl_mask = gt['has_smpl'] | |
loss_regr_pose, loss_regr_betas = smpl_losses( | |
pred_pose, | |
pred_betas, | |
gt_pose, | |
gt_betas, | |
self.theta_weights, | |
smpl_mask, | |
criterion=self.criterion_noreduce | |
) | |
# Compute loss on foot contact | |
loss_contact = contact_loss( | |
pred_contact, | |
gt_contact, | |
self.criterion_noreduce | |
) | |
# Compute loss on root velocity and angular velocity | |
loss_vel_root, loss_pose_root = root_loss( | |
pred_vel_root, | |
pred_pose_root, | |
gt_vel_root, | |
gt_pose_root, | |
gt_contact, | |
self.criterion_noreduce | |
) | |
# Root loss after trajectory refinement | |
loss_vel_root_ref, loss_pose_root_ref = root_loss( | |
pred_vel_root_ref, | |
pred_pose_root_ref, | |
gt_vel_root, | |
gt_pose_root, | |
gt_contact, | |
self.criterion_noreduce | |
) | |
# Camera prediction loss | |
loss_camera = camera_loss( | |
pred_cam_r, | |
gt_cam_r, | |
gt_cam_angvel[:, 1:], | |
gt['has_traj'], | |
self.criterion_noreduce, | |
self.skip_camera_loss | |
) | |
# Foot sliding loss | |
loss_sliding = sliding_loss( | |
pred['feet'], | |
gt_contact, | |
) | |
# Foot sliding loss | |
loss_sliding_ref = sliding_loss( | |
pred['feet_refined'], | |
gt_contact, | |
) | |
loss_keypoints = loss_keypoints_full + loss_keypoints_weak | |
loss_keypoints *= self.keypoint_2d_loss_weight | |
loss_keypoints_3d_smpl *= self.keypoint_3d_loss_weight | |
loss_keypoints_3d_nn *= self.keypoint_3d_loss_weight | |
loss_cascaded *= self.cascaded_loss_weight | |
loss_vertices *= self.vertices_loss_weight | |
loss_contact *= self.contact_loss_weight | |
loss_root = loss_vel_root * self.root_vel_loss_weight + loss_pose_root * self.root_pose_loss_weight | |
loss_root_ref = loss_vel_root_ref * self.root_vel_loss_weight + loss_pose_root_ref * self.root_pose_loss_weight | |
loss_regr_pose *= self.pose_loss_weight | |
loss_regr_betas *= self.shape_loss_weight | |
loss_sliding *= self.sliding_loss_weight | |
loss_camera *= self.camera_loss_weight | |
loss_sliding_ref *= self.sliding_loss_weight | |
loss_dict = { | |
'pose': loss_regr_pose * self.loss_weight, | |
'betas': loss_regr_betas * self.loss_weight, | |
'2d': loss_keypoints * self.loss_weight, | |
'3d': loss_keypoints_3d_smpl * self.loss_weight, | |
'3d_nn': loss_keypoints_3d_nn * self.loss_weight, | |
'casc': loss_cascaded * self.loss_weight, | |
'v3d': loss_vertices * self.loss_weight, | |
'contact': loss_contact * self.loss_weight, | |
'root': loss_root * self.loss_weight, | |
'root_ref': loss_root_ref * self.loss_weight, | |
'sliding': loss_sliding * self.loss_weight, | |
'camera': loss_camera * self.loss_weight, | |
'sliding_ref': loss_sliding_ref * self.loss_weight, | |
} | |
loss = sum(loss for loss in loss_dict.values()) | |
return loss, loss_dict | |
def root_loss( | |
pred_vel_root, | |
pred_pose_root, | |
gt_vel_root, | |
gt_pose_root, | |
stationary, | |
criterion | |
): | |
mask_r = (gt_pose_root != 0.0).all(dim=-1).all(dim=-1) | |
mask_v = (gt_vel_root != 0.0).all(dim=-1).all(dim=-1) | |
mask_s = (stationary != -1).any(dim=1).any(dim=1) | |
mask_v = mask_v * mask_s | |
if mask_r.any(): | |
loss_r = criterion(pred_pose_root, gt_pose_root)[mask_r].mean() | |
else: | |
loss_r = torch.FloatTensor(1).fill_(0.).to(gt_pose_root.device)[0] | |
if mask_v.any(): | |
loss_v = 0 | |
T = gt_vel_root.shape[0] | |
ws_list = [1, 3, 9, 27] | |
for ws in ws_list: | |
tmp_v = 0 | |
for m in range(T//ws): | |
cumulative_v = torch.sum(pred_vel_root[:, m:(m+1)*ws] - gt_vel_root[:, m:(m+1)*ws], dim=1) | |
tmp_v += torch.norm(cumulative_v, dim=-1) | |
loss_v += tmp_v | |
loss_v = loss_v[mask_v].mean() | |
else: | |
loss_v = torch.FloatTensor(1).fill_(0.).to(gt_vel_root.device)[0] | |
return loss_v, loss_r | |
def contact_loss( | |
pred_stationary, | |
gt_stationary, | |
criterion, | |
): | |
mask = gt_stationary != -1 | |
if mask.any(): | |
loss = criterion(pred_stationary, gt_stationary)[mask].mean() | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_stationary.device)[0] | |
return loss | |
def full_projected_keypoint_loss( | |
pred_keypoints_2d, | |
gt_keypoints_2d, | |
bbox, | |
weight, | |
criterion, | |
): | |
scale = bbox[..., 2:] * 200. | |
conf = gt_keypoints_2d[..., -1] | |
if (conf > 0).any(): | |
loss = torch.mean( | |
weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) | |
) / scale, dim=1).mean() * conf.mean() | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] | |
return loss | |
def weak_projected_keypoint_loss( | |
pred_keypoints_2d, | |
gt_keypoints_2d, | |
weight, | |
criterion, | |
): | |
conf = gt_keypoints_2d[..., -1] | |
if (conf > 0).any(): | |
loss = torch.mean( | |
weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) | |
), dim=1).mean() * conf.mean() * 5 | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] | |
return loss | |
def keypoint_3d_loss( | |
pred_keypoints_3d, | |
gt_keypoints_3d, | |
weight, | |
criterion, | |
): | |
conf = gt_keypoints_3d[..., -1] | |
if (conf > 0).any(): | |
if weight.shape[-2] > 17: | |
pred_keypoints_3d[..., -14:] = pred_keypoints_3d[..., -14:] - pred_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) | |
gt_keypoints_3d[..., -14:] = gt_keypoints_3d[..., -14:] - gt_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) | |
loss = torch.mean( | |
weight * (conf * torch.norm(pred_keypoints_3d - gt_keypoints_3d[..., :3], dim=-1) | |
), dim=1).mean() * conf.mean() | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_3d.device)[0] | |
return loss | |
def vertices_loss( | |
pred_verts, | |
gt_verts, | |
mask, | |
criterion, | |
): | |
if mask.sum() > 0: | |
# Align | |
pred_verts = pred_verts.view_as(gt_verts) | |
pred_verts = pred_verts - pred_verts.mean(-2, True) | |
gt_verts = gt_verts - gt_verts.mean(-2, True) | |
# loss = criterion(pred_verts, gt_verts).mean() * mask.float().mean() | |
# loss = torch.mean( | |
# `(torch.norm(pred_verts - gt_verts, dim=-1)[mask]` | |
# ), dim=1).mean() * mask.float().mean() | |
loss = torch.mean( | |
(torch.norm(pred_verts - gt_verts, p=1, dim=-1)[mask] | |
), dim=1).mean() * mask.float().mean() | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_verts.device)[0] | |
return loss | |
def smpl_losses( | |
pred_pose, | |
pred_betas, | |
gt_pose, | |
gt_betas, | |
weight, | |
mask, | |
criterion, | |
): | |
if mask.any().item(): | |
loss_regr_pose = torch.mean( | |
weight * torch.square(pred_pose - gt_pose)[mask].mean(-1) | |
) * mask.float().mean() | |
loss_regr_betas = F.mse_loss(pred_betas, gt_betas, reduction='none')[mask].mean() * mask.float().mean() | |
else: | |
loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] | |
loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] | |
return loss_regr_pose, loss_regr_betas | |
def camera_loss( | |
pred_cam_r, | |
gt_cam_r, | |
cam_angvel, | |
mask, | |
criterion, | |
skip | |
): | |
# mask = (gt_cam_r != 0.0).all(dim=-1).all(dim=-1) | |
if mask.any() and not skip: | |
# Camera pose loss in 6D representation | |
loss_r = criterion(pred_cam_r, gt_cam_r)[mask].mean() | |
# Reconstruct camera angular velocity and compute reconstruction loss | |
pred_R = transforms.rotation_6d_to_matrix(pred_cam_r) | |
cam_angvel_from_R = transforms.matrix_to_rotation_6d(pred_R[:, :-1] @ pred_R[:, 1:].transpose(-1, -2)) | |
cam_angvel_from_R = (cam_angvel_from_R - torch.tensor([[[1, 0, 0, 0, 1, 0]]]).to(cam_angvel)) * 30 | |
loss_a = criterion(cam_angvel, cam_angvel_from_R)[mask].mean() | |
loss = loss_r + loss_a | |
else: | |
loss = torch.FloatTensor(1).fill_(0.).to(gt_cam_r.device)[0] | |
return loss | |
def sliding_loss( | |
foot_position, | |
contact_prob, | |
): | |
""" Compute foot skate loss when foot is assumed to be on contact with ground | |
foot_position: 3D foot (heel and toe) position, torch.Tensor (B, F, 4, 3) | |
contact_prob: contact probability of foot (heel and toe), torch.Tensor (B, F, 4) | |
""" | |
contact_mask = (contact_prob > 0.5).detach().float() | |
foot_velocity = foot_position[:, 1:] - foot_position[:, :-1] | |
loss = (torch.norm(foot_velocity, dim=-1) * contact_mask[:, 1:]).mean() | |
return loss | |