Spaces:
Sleeping
Sleeping
import os | |
import time | |
import os.path as osp | |
from glob import glob | |
from collections import defaultdict | |
import torch | |
import pickle | |
import numpy as np | |
from smplx import SMPL | |
from loguru import logger | |
from progress.bar import Bar | |
from configs import constants as _C | |
from configs.config import parse_args | |
from lib.data.dataloader import setup_eval_dataloader | |
from lib.models import build_network, build_body_model | |
from lib.eval.eval_utils import ( | |
compute_jpe, | |
compute_rte, | |
compute_jitter, | |
compute_error_accel, | |
compute_foot_sliding, | |
batch_align_by_pelvis, | |
first_align_joints, | |
global_align_joints, | |
compute_rte, | |
compute_jitter, | |
compute_foot_sliding | |
batch_compute_similarity_transform_torch, | |
) | |
from lib.utils import transforms | |
from lib.utils.utils import prepare_output_dir | |
from lib.utils.utils import prepare_batch | |
from lib.utils.imutils import avg_preds | |
""" | |
This is a tentative script to evaluate WHAM on EMDB dataset. | |
Current implementation requires EMDB dataset downloaded at ./datasets/EMDB/ | |
""" | |
m2mm = 1e3 | |
def main(cfg, args): | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
logger.info(f'GPU name -> {torch.cuda.get_device_name()}') | |
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') | |
# ========= Dataloaders ========= # | |
eval_loader = setup_eval_dataloader(cfg, 'emdb', args.eval_split, cfg.MODEL.BACKBONE) | |
logger.info(f'Dataset loaded') | |
# ========= Load WHAM ========= # | |
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN | |
smpl = build_body_model(cfg.DEVICE, smpl_batch_size) | |
network = build_network(cfg, smpl) | |
network.eval() | |
# Build SMPL models with each gender | |
smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']} | |
# Load vertices -> joints regression matrix to evaluate | |
pelvis_idxs = [1, 2] | |
# WHAM uses Y-down coordinate system, while EMDB dataset uses Y-up one. | |
yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(cfg.DEVICE) | |
# To torch tensor function | |
tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE) | |
accumulator = defaultdict(list) | |
bar = Bar('Inference', fill='#', max=len(eval_loader)) | |
with torch.no_grad(): | |
for i in range(len(eval_loader)): | |
# Original batch | |
batch = eval_loader.dataset.load_data(i, False) | |
x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2') | |
# Align with groundtruth data to the first frame | |
cam2yup = batch['R'][0][:1].to(cfg.DEVICE) | |
cam2ydown = cam2yup @ yup2ydown | |
cam2root = transforms.rotation_6d_to_matrix(inits[1][:, 0, 0]) | |
ydown2root = cam2ydown.mT @ cam2root | |
ydown2root = transforms.matrix_to_rotation_6d(ydown2root) | |
kwargs['init_root'][:, 0] = ydown2root | |
if cfg.FLIP_EVAL: | |
flipped_batch = eval_loader.dataset.load_data(i, True) | |
f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2') | |
# Forward pass with flipped input | |
flipped_pred = network(f_x, f_inits, f_features, **f_kwargs) | |
# Forward pass with normal input | |
pred = network(x, inits, features, **kwargs) | |
if cfg.FLIP_EVAL: | |
# Merge two predictions | |
flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) | |
pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) | |
flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) | |
avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) | |
avg_pose = avg_pose.reshape(-1, 144) | |
avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2 | |
# Refine trajectory with merged prediction | |
network.pred_pose = avg_pose.view_as(network.pred_pose) | |
network.pred_shape = avg_shape.view_as(network.pred_shape) | |
network.pred_contact = avg_contact.view_as(network.pred_contact) | |
output = network.forward_smpl(**kwargs) | |
pred = network.refine_trajectory(output, return_y_up=True, **kwargs) | |
# <======= Prepare groundtruth data | |
subj, seq = batch['vid'][:2], batch['vid'][3:] | |
annot_pth = glob(osp.join(_C.PATHS.EMDB_PTH, subj, seq, '*_data.pkl'))[0] | |
annot = pickle.load(open(annot_pth, 'rb')) | |
masks = annot['good_frames_mask'] | |
gender = annot['gender'] | |
poses_body = annot["smpl"]["poses_body"] | |
poses_root = annot["smpl"]["poses_root"] | |
betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0) | |
trans = annot["smpl"]["trans"] | |
extrinsics = annot["camera"]["extrinsics"] | |
# # Map to camear coordinate | |
poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root))) | |
# Groundtruth global motion | |
target_glob = smpl[gender](body_pose=tt(poses_body), global_orient=tt(poses_root), betas=tt(betas), transl=tt(trans)) | |
target_j3d_glob = target_glob.joints[:, :24][masks] | |
# Groundtruth local motion | |
target_cam = smpl[gender](body_pose=tt(poses_body), global_orient=poses_root_cam, betas=tt(betas)) | |
target_verts_cam = target_cam.vertices[masks] | |
target_j3d_cam = target_cam.joints[:, :24][masks] | |
# =======> | |
# Convert WHAM global orient to Y-up coordinate | |
poses_root = pred['poses_root_world'].squeeze(0) | |
pred_trans = pred['trans_world'].squeeze(0) | |
poses_root = yup2ydown.mT @ poses_root | |
pred_trans = (yup2ydown.mT @ pred_trans.unsqueeze(-1)).squeeze(-1) | |
# <======= Build predicted motion | |
# Predicted global motion | |
pred_glob = smpl['neutral'](body_pose=pred['poses_body'], global_orient=poses_root.unsqueeze(1), betas=pred['betas'].squeeze(0), transl=pred_trans, pose2rot=False) | |
pred_j3d_glob = pred_glob.joints[:, :24] | |
# Predicted local motion | |
pred_cam = smpl['neutral'](body_pose=pred['poses_body'], global_orient=pred['poses_root_cam'], betas=pred['betas'].squeeze(0), pose2rot=False) | |
pred_verts_cam = pred_cam.vertices | |
pred_j3d_cam = pred_cam.joints[:, :24] | |
# =======> | |
# <======= Evaluation on the local motion | |
pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam = batch_align_by_pelvis( | |
[pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam], pelvis_idxs | |
) | |
S1_hat = batch_compute_similarity_transform_torch(pred_j3d_cam, target_j3d_cam) | |
pa_mpjpe = torch.sqrt(((S1_hat - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm | |
mpjpe = torch.sqrt(((pred_j3d_cam - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm | |
pve = torch.sqrt(((pred_verts_cam - target_verts_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm | |
accel = compute_error_accel(joints_pred=pred_j3d_cam.cpu(), joints_gt=target_j3d_cam.cpu())[1:-1] | |
accel = accel * (30 ** 2) # per frame^s to per s^2 | |
summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}' | |
bar.suffix = summary_string | |
bar.next() | |
# =======> | |
# <======= Evaluation on the global motion | |
chunk_length = 100 | |
w_mpjpe, wa_mpjpe = [], [] | |
for start in range(0, masks.sum(), chunk_length): | |
end = min(masks.sum(), start + chunk_length) | |
target_j3d = target_j3d_glob[start:end].clone().cpu() | |
pred_j3d = pred_j3d_glob[start:end].clone().cpu() | |
w_j3d = first_align_joints(target_j3d, pred_j3d) | |
wa_j3d = global_align_joints(target_j3d, pred_j3d) | |
w_jpe = compute_jpe(target_j3d, w_j3d) | |
wa_jpe = compute_jpe(target_j3d, wa_j3d) | |
w_mpjpe.append(w_jpe) | |
wa_mpjpe.append(wa_jpe) | |
w_mpjpe = np.concatenate(w_mpjpe) * m2mm | |
wa_mpjpe = np.concatenate(wa_mpjpe) * m2mm | |
# Additional metrics | |
rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2 | |
jitter = compute_jitter(pred_glob, fps=30) | |
foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm | |
# =======> | |
# Additional metrics | |
rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2 | |
jitter = compute_jitter(pred_glob, fps=30) | |
foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm | |
# <======= Accumulate the results over entire sequences | |
accumulator['pa_mpjpe'].append(pa_mpjpe) | |
accumulator['mpjpe'].append(mpjpe) | |
accumulator['pve'].append(pve) | |
accumulator['accel'].append(accel) | |
accumulator['wa_mpjpe'].append(wa_mpjpe) | |
accumulator['w_mpjpe'].append(w_mpjpe) | |
accumulator['RTE'].append(rte) | |
accumulator['jitter'].append(jitter) | |
accumulator['FS'].append(foot_sliding) | |
# =======> | |
for k, v in accumulator.items(): | |
accumulator[k] = np.concatenate(v).mean() | |
print('') | |
log_str = f'Evaluation on EMDB {args.eval_split}, ' | |
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()]) | |
logger.info(log_str) | |
if __name__ == '__main__': | |
cfg, cfg_file, args = parse_args(test=True) | |
cfg = prepare_output_dir(cfg, cfg_file) | |
main(cfg, args) |