# Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py # Adhere to their licence to use these functions from pathlib import Path import torch import numpy as np from matplotlib import pyplot as plt def compute_accel(joints): """ Computes acceleration of 3D joints. Args: joints (Nx25x3). Returns: Accelerations (N-2). """ velocities = joints[1:] - joints[:-1] acceleration = velocities[1:] - velocities[:-1] acceleration_normed = np.linalg.norm(acceleration, axis=2) return np.mean(acceleration_normed, axis=1) def compute_error_accel(joints_gt, joints_pred, vis=None): """ Computes acceleration error: 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} Note that for each frame that is not visible, three entries in the acceleration error should be zero'd out. Args: joints_gt (Nx14x3). joints_pred (Nx14x3). vis (N). Returns: error_accel (N-2). """ # (N-2)x14x3 accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] normed = np.linalg.norm(accel_pred - accel_gt, axis=2) if vis is None: new_vis = np.ones(len(normed), dtype=bool) else: invis = np.logical_not(vis) invis1 = np.roll(invis, -1) invis2 = np.roll(invis, -2) new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] new_vis = np.logical_not(new_invis) return np.mean(normed[new_vis], axis=1) def compute_error_verts(pred_verts, target_verts=None, target_theta=None): """ Computes MPJPE over 6890 surface vertices. Args: verts_gt (Nx6890x3). verts_pred (Nx6890x3). Returns: error_verts (N). """ if target_verts is None: from lib.models.smpl import SMPL_MODEL_DIR from lib.models.smpl import SMPL device = 'cpu' smpl = SMPL( SMPL_MODEL_DIR, batch_size=1, # target_theta.shape[0], ).to(device) betas = torch.from_numpy(target_theta[:,75:]).to(device) pose = torch.from_numpy(target_theta[:,3:75]).to(device) target_verts = [] b_ = torch.split(betas, 5000) p_ = torch.split(pose, 5000) for b,p in zip(b_,p_): output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True) target_verts.append(output.vertices.detach().cpu().numpy()) target_verts = np.concatenate(target_verts, axis=0) assert len(pred_verts) == len(target_verts) error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2)) return np.mean(error_per_vert, axis=1) def compute_similarity_transform(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.T S2 = S2.T transposed = True assert(S2.shape[1] == S1.shape[1]) # 1. Remove mean. mu1 = S1.mean(axis=1, keepdims=True) mu2 = S2.mean(axis=1, keepdims=True) X1 = S1 - mu1 X2 = S2 - mu2 # 2. Compute variance of X1 used for scale. var1 = np.sum(X1**2) # 3. The outer product of X1 and X2. K = X1.dot(X2.T) # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are # singular vectors of K. U, s, Vh = np.linalg.svd(K) V = Vh.T # Construct Z that fixes the orientation of R to get det(R)=1. Z = np.eye(U.shape[0]) Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) # Construct R. R = V.dot(Z.dot(U.T)) # 5. Recover scale. scale = np.trace(R.dot(K)) / var1 # 6. Recover translation. t = mu2 - scale*(R.dot(mu1)) # 7. Error: S1_hat = scale*R.dot(S1) + t if transposed: S1_hat = S1_hat.T return S1_hat def compute_similarity_transform_torch(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.T S2 = S2.T transposed = True assert (S2.shape[1] == S1.shape[1]) # 1. Remove mean. mu1 = S1.mean(axis=1, keepdims=True) mu2 = S2.mean(axis=1, keepdims=True) X1 = S1 - mu1 X2 = S2 - mu2 # print('X1', X1.shape) # 2. Compute variance of X1 used for scale. var1 = torch.sum(X1 ** 2) # print('var', var1.shape) # 3. The outer product of X1 and X2. K = X1.mm(X2.T) # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are # singular vectors of K. U, s, V = torch.svd(K) # V = Vh.T # Construct Z that fixes the orientation of R to get det(R)=1. Z = torch.eye(U.shape[0], device=S1.device) Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) # Construct R. R = V.mm(Z.mm(U.T)) # print('R', X1.shape) # 5. Recover scale. scale = torch.trace(R.mm(K)) / var1 # print(R.shape, mu1.shape) # 6. Recover translation. t = mu2 - scale * (R.mm(mu1)) # print(t.shape) # 7. Error: S1_hat = scale * R.mm(S1) + t if transposed: S1_hat = S1_hat.T return S1_hat def batch_compute_similarity_transform_torch(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.permute(0,2,1) S2 = S2.permute(0,2,1) transposed = True assert(S2.shape[1] == S1.shape[1]) # 1. Remove mean. mu1 = S1.mean(axis=-1, keepdims=True) mu2 = S2.mean(axis=-1, keepdims=True) X1 = S1 - mu1 X2 = S2 - mu2 # 2. Compute variance of X1 used for scale. var1 = torch.sum(X1**2, dim=1).sum(dim=1) # 3. The outer product of X1 and X2. K = X1.bmm(X2.permute(0,2,1)) # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are # singular vectors of K. U, s, V = torch.svd(K) # Construct Z that fixes the orientation of R to get det(R)=1. Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) Z = Z.repeat(U.shape[0],1,1) Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) # Construct R. R = V.bmm(Z.bmm(U.permute(0,2,1))) # 5. Recover scale. scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 # 6. Recover translation. t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) # 7. Error: S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t if transposed: S1_hat = S1_hat.permute(0,2,1) return S1_hat def align_by_pelvis(joints): """ Assumes joints is 14 x 3 in LSP order. Then hips are: [3, 2] Takes mid point of these points, then subtracts it. """ left_id = 2 right_id = 3 pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0 return joints - np.expand_dims(pelvis, axis=0) def compute_errors(gt3ds, preds): """ Gets MPJPE after pelvis alignment + MPJPE after Procrustes. Evaluates on the 14 common joints. Inputs: - gt3ds: N x 14 x 3 - preds: N x 14 x 3 """ errors, errors_pa = [], [] for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)): gt3d = gt3d.reshape(-1, 3) # Root align. gt3d = align_by_pelvis(gt3d) pred3d = align_by_pelvis(pred) joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1)) errors.append(np.mean(joint_error)) # Get PA error. pred3d_sym = compute_similarity_transform(pred3d, gt3d) pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1)) errors_pa.append(np.mean(pa_error)) return errors, errors_pa def batch_align_by_pelvis(data_list, pelvis_idxs): """ Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. Each data is in shape of (frames, num_points, 3) Pelvis is notated as one / two joints indices. Align all data to the corresponding pelvis location. """ pred_j3d, target_j3d, pred_verts, target_verts = data_list pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() # Align to the pelvis pred_j3d = pred_j3d - pred_pelvis target_j3d = target_j3d - target_pelvis pred_verts = pred_verts - pred_pelvis target_verts = target_verts - target_pelvis return (pred_j3d, target_j3d, pred_verts, target_verts) def compute_jpe(S1, S2): return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() # The functions below are borrowed from SLAHMR official implementation. # Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py def global_align_joints(gt_joints, pred_joints): """ :param gt_joints (T, J, 3) :param pred_joints (T, J, 3) """ s_glob, R_glob, t_glob = align_pcl( gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) ) pred_glob = ( s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] ) return pred_glob def first_align_joints(gt_joints, pred_joints): """ align the first two frames :param gt_joints (T, J, 3) :param pred_joints (T, J, 3) """ # (1, 1), (1, 3, 3), (1, 3) s_first, R_first, t_first = align_pcl( gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) ) pred_first = ( s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] ) return pred_first def local_align_joints(gt_joints, pred_joints): """ :param gt_joints (T, J, 3) :param pred_joints (T, J, 3) """ s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints) pred_loc = ( s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints) + t_loc[:, None] ) return pred_loc def align_pcl(Y, X, weight=None, fixed_scale=False): """align similarity transform to align X with Y using umeyama method X' = s * R * X + t is aligned with Y :param Y (*, N, 3) first trajectory :param X (*, N, 3) second trajectory :param weight (*, N, 1) optional weight of valid correspondences :returns s (*, 1), R (*, 3, 3), t (*, 3) """ *dims, N, _ = Y.shape N = torch.ones(*dims, 1, 1) * N if weight is not None: Y = Y * weight X = X * weight N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) # subtract mean my = Y.sum(dim=-2) / N[..., 0] # (*, 3) mx = X.sum(dim=-2) / N[..., 0] y0 = Y - my[..., None, :] # (*, N, 3) x0 = X - mx[..., None, :] if weight is not None: y0 = y0 * weight x0 = x0 * weight # correlation C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 S[neg, 2, 2] = -1 R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) D = torch.diag_embed(D) # (*, 3, 3) if fixed_scale: s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) else: var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) s = ( torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( dim=-1, keepdim=True ) / var[..., 0] ) # (*, 1) t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) return s, R, t def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2): """compute foot sliding error The foot ground contact label is computed by the threshold of 1 cm/frame Args: target_output (SMPL ModelOutput). pred_output (SMPL ModelOutput). masks (N). Returns: error (N frames in contact). """ # Foot vertices idxs foot_idxs = [3216, 3387, 6617, 6787] # Compute contact label foot_loc = target_output.vertices[masks][:, foot_idxs] foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) contact = foot_disp[:] < thr pred_feet_loc = pred_output.vertices[:, foot_idxs] pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) error = pred_disp[contact] return error.cpu().numpy() def compute_jitter(pred_output, fps=30): """compute jitter of the motion Args: pred_output (SMPL ModelOutput). fps (float). Returns: jitter (N-3). """ pred3d = pred_output.joints[:, :24] pred_jitter = torch.norm( (pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3), dim=2, ).mean(dim=-1) return pred_jitter.cpu().numpy() / 10.0 def compute_rte(target_trans, pred_trans): # Compute the global alignment _, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True) pred_trans_hat = ( torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] )[0] # Compute the entire displacement of ground truth trajectory disps, disp = [], 0 for p1, p2 in zip(target_trans, target_trans[1:]): delta = (p2 - p1).norm(2, dim=-1) disp += delta disps.append(disp) # Compute absolute root-translation-error (RTE) rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) # Normalize it to the displacement return (rte / disp).numpy()