FormFighterAIStack / lib /eval /eval_utils.py
Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
14.7 kB
# Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py
# Adhere to their licence to use these functions
from pathlib import Path
import torch
import numpy as np
from matplotlib import pyplot as plt
def compute_accel(joints):
"""
Computes acceleration of 3D joints.
Args:
joints (Nx25x3).
Returns:
Accelerations (N-2).
"""
velocities = joints[1:] - joints[:-1]
acceleration = velocities[1:] - velocities[:-1]
acceleration_normed = np.linalg.norm(acceleration, axis=2)
return np.mean(acceleration_normed, axis=1)
def compute_error_accel(joints_gt, joints_pred, vis=None):
"""
Computes acceleration error:
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
Note that for each frame that is not visible, three entries in the
acceleration error should be zero'd out.
Args:
joints_gt (Nx14x3).
joints_pred (Nx14x3).
vis (N).
Returns:
error_accel (N-2).
"""
# (N-2)x14x3
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]
normed = np.linalg.norm(accel_pred - accel_gt, axis=2)
if vis is None:
new_vis = np.ones(len(normed), dtype=bool)
else:
invis = np.logical_not(vis)
invis1 = np.roll(invis, -1)
invis2 = np.roll(invis, -2)
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
new_vis = np.logical_not(new_invis)
return np.mean(normed[new_vis], axis=1)
def compute_error_verts(pred_verts, target_verts=None, target_theta=None):
"""
Computes MPJPE over 6890 surface vertices.
Args:
verts_gt (Nx6890x3).
verts_pred (Nx6890x3).
Returns:
error_verts (N).
"""
if target_verts is None:
from lib.models.smpl import SMPL_MODEL_DIR
from lib.models.smpl import SMPL
device = 'cpu'
smpl = SMPL(
SMPL_MODEL_DIR,
batch_size=1, # target_theta.shape[0],
).to(device)
betas = torch.from_numpy(target_theta[:,75:]).to(device)
pose = torch.from_numpy(target_theta[:,3:75]).to(device)
target_verts = []
b_ = torch.split(betas, 5000)
p_ = torch.split(pose, 5000)
for b,p in zip(b_,p_):
output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True)
target_verts.append(output.vertices.detach().cpu().numpy())
target_verts = np.concatenate(target_verts, axis=0)
assert len(pred_verts) == len(target_verts)
error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2))
return np.mean(error_per_vert, axis=1)
def compute_similarity_transform(S1, S2):
'''
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
'''
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.T
S2 = S2.T
transposed = True
assert(S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = np.sum(X1**2)
# 3. The outer product of X1 and X2.
K = X1.dot(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, Vh = np.linalg.svd(K)
V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = np.eye(U.shape[0])
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
# Construct R.
R = V.dot(Z.dot(U.T))
# 5. Recover scale.
scale = np.trace(R.dot(K)) / var1
# 6. Recover translation.
t = mu2 - scale*(R.dot(mu1))
# 7. Error:
S1_hat = scale*R.dot(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def compute_similarity_transform_torch(S1, S2):
'''
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
'''
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.T
S2 = S2.T
transposed = True
assert (S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# print('X1', X1.shape)
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1 ** 2)
# print('var', var1.shape)
# 3. The outer product of X1 and X2.
K = X1.mm(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[0], device=S1.device)
Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
# Construct R.
R = V.mm(Z.mm(U.T))
# print('R', X1.shape)
# 5. Recover scale.
scale = torch.trace(R.mm(K)) / var1
# print(R.shape, mu1.shape)
# 6. Recover translation.
t = mu2 - scale * (R.mm(mu1))
# print(t.shape)
# 7. Error:
S1_hat = scale * R.mm(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def batch_compute_similarity_transform_torch(S1, S2):
'''
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
'''
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.permute(0,2,1)
S2 = S2.permute(0,2,1)
transposed = True
assert(S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0,2,1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0],1,1)
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0,2,1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
# 7. Error:
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
if transposed:
S1_hat = S1_hat.permute(0,2,1)
return S1_hat
def align_by_pelvis(joints):
"""
Assumes joints is 14 x 3 in LSP order.
Then hips are: [3, 2]
Takes mid point of these points, then subtracts it.
"""
left_id = 2
right_id = 3
pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0
return joints - np.expand_dims(pelvis, axis=0)
def compute_errors(gt3ds, preds):
"""
Gets MPJPE after pelvis alignment + MPJPE after Procrustes.
Evaluates on the 14 common joints.
Inputs:
- gt3ds: N x 14 x 3
- preds: N x 14 x 3
"""
errors, errors_pa = [], []
for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)):
gt3d = gt3d.reshape(-1, 3)
# Root align.
gt3d = align_by_pelvis(gt3d)
pred3d = align_by_pelvis(pred)
joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1))
errors.append(np.mean(joint_error))
# Get PA error.
pred3d_sym = compute_similarity_transform(pred3d, gt3d)
pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1))
errors_pa.append(np.mean(pa_error))
return errors, errors_pa
def batch_align_by_pelvis(data_list, pelvis_idxs):
"""
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts].
Each data is in shape of (frames, num_points, 3)
Pelvis is notated as one / two joints indices.
Align all data to the corresponding pelvis location.
"""
pred_j3d, target_j3d, pred_verts, target_verts = data_list
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
# Align to the pelvis
pred_j3d = pred_j3d - pred_pelvis
target_j3d = target_j3d - target_pelvis
pred_verts = pred_verts - pred_pelvis
target_verts = target_verts - target_pelvis
return (pred_j3d, target_j3d, pred_verts, target_verts)
def compute_jpe(S1, S2):
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy()
# The functions below are borrowed from SLAHMR official implementation.
# Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py
def global_align_joints(gt_joints, pred_joints):
"""
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
s_glob, R_glob, t_glob = align_pcl(
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
)
pred_glob = (
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
)
return pred_glob
def first_align_joints(gt_joints, pred_joints):
"""
align the first two frames
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
# (1, 1), (1, 3, 3), (1, 3)
s_first, R_first, t_first = align_pcl(
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
)
pred_first = (
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
)
return pred_first
def local_align_joints(gt_joints, pred_joints):
"""
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints)
pred_loc = (
s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints)
+ t_loc[:, None]
)
return pred_loc
def align_pcl(Y, X, weight=None, fixed_scale=False):
"""align similarity transform to align X with Y using umeyama method
X' = s * R * X + t is aligned with Y
:param Y (*, N, 3) first trajectory
:param X (*, N, 3) second trajectory
:param weight (*, N, 1) optional weight of valid correspondences
:returns s (*, 1), R (*, 3, 3), t (*, 3)
"""
*dims, N, _ = Y.shape
N = torch.ones(*dims, 1, 1) * N
if weight is not None:
Y = Y * weight
X = X * weight
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
# subtract mean
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
mx = X.sum(dim=-2) / N[..., 0]
y0 = Y - my[..., None, :] # (*, N, 3)
x0 = X - mx[..., None, :]
if weight is not None:
y0 = y0 * weight
x0 = x0 * weight
# correlation
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
S[neg, 2, 2] = -1
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
D = torch.diag_embed(D) # (*, 3, 3)
if fixed_scale:
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
else:
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
s = (
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
dim=-1, keepdim=True
)
/ var[..., 0]
) # (*, 1)
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
return s, R, t
def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2):
"""compute foot sliding error
The foot ground contact label is computed by the threshold of 1 cm/frame
Args:
target_output (SMPL ModelOutput).
pred_output (SMPL ModelOutput).
masks (N).
Returns:
error (N frames in contact).
"""
# Foot vertices idxs
foot_idxs = [3216, 3387, 6617, 6787]
# Compute contact label
foot_loc = target_output.vertices[masks][:, foot_idxs]
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1)
contact = foot_disp[:] < thr
pred_feet_loc = pred_output.vertices[:, foot_idxs]
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1)
error = pred_disp[contact]
return error.cpu().numpy()
def compute_jitter(pred_output, fps=30):
"""compute jitter of the motion
Args:
pred_output (SMPL ModelOutput).
fps (float).
Returns:
jitter (N-3).
"""
pred3d = pred_output.joints[:, :24]
pred_jitter = torch.norm(
(pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3),
dim=2,
).mean(dim=-1)
return pred_jitter.cpu().numpy() / 10.0
def compute_rte(target_trans, pred_trans):
# Compute the global alignment
_, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True)
pred_trans_hat = (
torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :]
)[0]
# Compute the entire displacement of ground truth trajectory
disps, disp = [], 0
for p1, p2 in zip(target_trans, target_trans[1:]):
delta = (p2 - p1).norm(2, dim=-1)
disp += delta
disps.append(disp)
# Compute absolute root-translation-error (RTE)
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1)
# Normalize it to the displacement
return (rte / disp).numpy()