Spaces:
Sleeping
Sleeping
# Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py | |
# Adhere to their licence to use these functions | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from matplotlib import pyplot as plt | |
def compute_accel(joints): | |
""" | |
Computes acceleration of 3D joints. | |
Args: | |
joints (Nx25x3). | |
Returns: | |
Accelerations (N-2). | |
""" | |
velocities = joints[1:] - joints[:-1] | |
acceleration = velocities[1:] - velocities[:-1] | |
acceleration_normed = np.linalg.norm(acceleration, axis=2) | |
return np.mean(acceleration_normed, axis=1) | |
def compute_error_accel(joints_gt, joints_pred, vis=None): | |
""" | |
Computes acceleration error: | |
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} | |
Note that for each frame that is not visible, three entries in the | |
acceleration error should be zero'd out. | |
Args: | |
joints_gt (Nx14x3). | |
joints_pred (Nx14x3). | |
vis (N). | |
Returns: | |
error_accel (N-2). | |
""" | |
# (N-2)x14x3 | |
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] | |
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] | |
normed = np.linalg.norm(accel_pred - accel_gt, axis=2) | |
if vis is None: | |
new_vis = np.ones(len(normed), dtype=bool) | |
else: | |
invis = np.logical_not(vis) | |
invis1 = np.roll(invis, -1) | |
invis2 = np.roll(invis, -2) | |
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] | |
new_vis = np.logical_not(new_invis) | |
return np.mean(normed[new_vis], axis=1) | |
def compute_error_verts(pred_verts, target_verts=None, target_theta=None): | |
""" | |
Computes MPJPE over 6890 surface vertices. | |
Args: | |
verts_gt (Nx6890x3). | |
verts_pred (Nx6890x3). | |
Returns: | |
error_verts (N). | |
""" | |
if target_verts is None: | |
from lib.models.smpl import SMPL_MODEL_DIR | |
from lib.models.smpl import SMPL | |
device = 'cpu' | |
smpl = SMPL( | |
SMPL_MODEL_DIR, | |
batch_size=1, # target_theta.shape[0], | |
).to(device) | |
betas = torch.from_numpy(target_theta[:,75:]).to(device) | |
pose = torch.from_numpy(target_theta[:,3:75]).to(device) | |
target_verts = [] | |
b_ = torch.split(betas, 5000) | |
p_ = torch.split(pose, 5000) | |
for b,p in zip(b_,p_): | |
output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True) | |
target_verts.append(output.vertices.detach().cpu().numpy()) | |
target_verts = np.concatenate(target_verts, axis=0) | |
assert len(pred_verts) == len(target_verts) | |
error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2)) | |
return np.mean(error_per_vert, axis=1) | |
def compute_similarity_transform(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = np.sum(X1**2) | |
# 3. The outer product of X1 and X2. | |
K = X1.dot(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, Vh = np.linalg.svd(K) | |
V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = np.eye(U.shape[0]) | |
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) | |
# Construct R. | |
R = V.dot(Z.dot(U.T)) | |
# 5. Recover scale. | |
scale = np.trace(R.dot(K)) / var1 | |
# 6. Recover translation. | |
t = mu2 - scale*(R.dot(mu1)) | |
# 7. Error: | |
S1_hat = scale*R.dot(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert (S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# print('X1', X1.shape) | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1 ** 2) | |
# print('var', var1.shape) | |
# 3. The outer product of X1 and X2. | |
K = X1.mm(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[0], device=S1.device) | |
Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) | |
# Construct R. | |
R = V.mm(Z.mm(U.T)) | |
# print('R', X1.shape) | |
# 5. Recover scale. | |
scale = torch.trace(R.mm(K)) / var1 | |
# print(R.shape, mu1.shape) | |
# 6. Recover translation. | |
t = mu2 - scale * (R.mm(mu1)) | |
# print(t.shape) | |
# 7. Error: | |
S1_hat = scale * R.mm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def batch_compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.permute(0,2,1) | |
S2 = S2.permute(0,2,1) | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=-1, keepdims=True) | |
mu2 = S2.mean(axis=-1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1**2, dim=1).sum(dim=1) | |
# 3. The outer product of X1 and X2. | |
K = X1.bmm(X2.permute(0,2,1)) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) | |
Z = Z.repeat(U.shape[0],1,1) | |
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) | |
# Construct R. | |
R = V.bmm(Z.bmm(U.permute(0,2,1))) | |
# 5. Recover scale. | |
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 | |
# 6. Recover translation. | |
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) | |
# 7. Error: | |
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.permute(0,2,1) | |
return S1_hat | |
def align_by_pelvis(joints): | |
""" | |
Assumes joints is 14 x 3 in LSP order. | |
Then hips are: [3, 2] | |
Takes mid point of these points, then subtracts it. | |
""" | |
left_id = 2 | |
right_id = 3 | |
pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0 | |
return joints - np.expand_dims(pelvis, axis=0) | |
def compute_errors(gt3ds, preds): | |
""" | |
Gets MPJPE after pelvis alignment + MPJPE after Procrustes. | |
Evaluates on the 14 common joints. | |
Inputs: | |
- gt3ds: N x 14 x 3 | |
- preds: N x 14 x 3 | |
""" | |
errors, errors_pa = [], [] | |
for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)): | |
gt3d = gt3d.reshape(-1, 3) | |
# Root align. | |
gt3d = align_by_pelvis(gt3d) | |
pred3d = align_by_pelvis(pred) | |
joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1)) | |
errors.append(np.mean(joint_error)) | |
# Get PA error. | |
pred3d_sym = compute_similarity_transform(pred3d, gt3d) | |
pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1)) | |
errors_pa.append(np.mean(pa_error)) | |
return errors, errors_pa | |
def batch_align_by_pelvis(data_list, pelvis_idxs): | |
""" | |
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. | |
Each data is in shape of (frames, num_points, 3) | |
Pelvis is notated as one / two joints indices. | |
Align all data to the corresponding pelvis location. | |
""" | |
pred_j3d, target_j3d, pred_verts, target_verts = data_list | |
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() | |
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() | |
# Align to the pelvis | |
pred_j3d = pred_j3d - pred_pelvis | |
target_j3d = target_j3d - target_pelvis | |
pred_verts = pred_verts - pred_pelvis | |
target_verts = target_verts - target_pelvis | |
return (pred_j3d, target_j3d, pred_verts, target_verts) | |
def compute_jpe(S1, S2): | |
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() | |
# The functions below are borrowed from SLAHMR official implementation. | |
# Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py | |
def global_align_joints(gt_joints, pred_joints): | |
""" | |
:param gt_joints (T, J, 3) | |
:param pred_joints (T, J, 3) | |
""" | |
s_glob, R_glob, t_glob = align_pcl( | |
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) | |
) | |
pred_glob = ( | |
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] | |
) | |
return pred_glob | |
def first_align_joints(gt_joints, pred_joints): | |
""" | |
align the first two frames | |
:param gt_joints (T, J, 3) | |
:param pred_joints (T, J, 3) | |
""" | |
# (1, 1), (1, 3, 3), (1, 3) | |
s_first, R_first, t_first = align_pcl( | |
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) | |
) | |
pred_first = ( | |
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] | |
) | |
return pred_first | |
def local_align_joints(gt_joints, pred_joints): | |
""" | |
:param gt_joints (T, J, 3) | |
:param pred_joints (T, J, 3) | |
""" | |
s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints) | |
pred_loc = ( | |
s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints) | |
+ t_loc[:, None] | |
) | |
return pred_loc | |
def align_pcl(Y, X, weight=None, fixed_scale=False): | |
"""align similarity transform to align X with Y using umeyama method | |
X' = s * R * X + t is aligned with Y | |
:param Y (*, N, 3) first trajectory | |
:param X (*, N, 3) second trajectory | |
:param weight (*, N, 1) optional weight of valid correspondences | |
:returns s (*, 1), R (*, 3, 3), t (*, 3) | |
""" | |
*dims, N, _ = Y.shape | |
N = torch.ones(*dims, 1, 1) * N | |
if weight is not None: | |
Y = Y * weight | |
X = X * weight | |
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) | |
# subtract mean | |
my = Y.sum(dim=-2) / N[..., 0] # (*, 3) | |
mx = X.sum(dim=-2) / N[..., 0] | |
y0 = Y - my[..., None, :] # (*, N, 3) | |
x0 = X - mx[..., None, :] | |
if weight is not None: | |
y0 = y0 * weight | |
x0 = x0 * weight | |
# correlation | |
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) | |
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) | |
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) | |
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 | |
S[neg, 2, 2] = -1 | |
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) | |
D = torch.diag_embed(D) # (*, 3, 3) | |
if fixed_scale: | |
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) | |
else: | |
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) | |
s = ( | |
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( | |
dim=-1, keepdim=True | |
) | |
/ var[..., 0] | |
) # (*, 1) | |
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) | |
return s, R, t | |
def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2): | |
"""compute foot sliding error | |
The foot ground contact label is computed by the threshold of 1 cm/frame | |
Args: | |
target_output (SMPL ModelOutput). | |
pred_output (SMPL ModelOutput). | |
masks (N). | |
Returns: | |
error (N frames in contact). | |
""" | |
# Foot vertices idxs | |
foot_idxs = [3216, 3387, 6617, 6787] | |
# Compute contact label | |
foot_loc = target_output.vertices[masks][:, foot_idxs] | |
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) | |
contact = foot_disp[:] < thr | |
pred_feet_loc = pred_output.vertices[:, foot_idxs] | |
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) | |
error = pred_disp[contact] | |
return error.cpu().numpy() | |
def compute_jitter(pred_output, fps=30): | |
"""compute jitter of the motion | |
Args: | |
pred_output (SMPL ModelOutput). | |
fps (float). | |
Returns: | |
jitter (N-3). | |
""" | |
pred3d = pred_output.joints[:, :24] | |
pred_jitter = torch.norm( | |
(pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3), | |
dim=2, | |
).mean(dim=-1) | |
return pred_jitter.cpu().numpy() / 10.0 | |
def compute_rte(target_trans, pred_trans): | |
# Compute the global alignment | |
_, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True) | |
pred_trans_hat = ( | |
torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] | |
)[0] | |
# Compute the entire displacement of ground truth trajectory | |
disps, disp = [], 0 | |
for p1, p2 in zip(target_trans, target_trans[1:]): | |
delta = (p2 - p1).norm(2, dim=-1) | |
disp += delta | |
disps.append(disp) | |
# Compute absolute root-translation-error (RTE) | |
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) | |
# Normalize it to the displacement | |
return (rte / disp).numpy() |