Spaces:
Sleeping
Sleeping
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- examples/test16.mov +3 -0
- examples/test17.mov +3 -0
- examples/test18.mov +3 -0
- examples/test19.mov +3 -0
- fetch_demo_data.sh +50 -0
- lib/core/loss.py +438 -0
- lib/core/trainer.py +341 -0
- lib/data/__init__.py +0 -0
- lib/data/__pycache__/__init__.cpython-39.pyc +0 -0
- lib/data/__pycache__/_dataset.cpython-39.pyc +0 -0
- lib/data/_dataset.py +77 -0
- lib/data/dataloader.py +46 -0
- lib/data/datasets/__init__.py +3 -0
- lib/data/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/amass.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/bedlam.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc +0 -0
- lib/data/datasets/__pycache__/videos.cpython-39.pyc +0 -0
- lib/data/datasets/amass.py +173 -0
- lib/data/datasets/bedlam.py +165 -0
- lib/data/datasets/dataset2d.py +140 -0
- lib/data/datasets/dataset3d.py +172 -0
- lib/data/datasets/dataset_custom.py +115 -0
- lib/data/datasets/dataset_eval.py +113 -0
- lib/data/datasets/mixed_dataset.py +61 -0
- lib/data/datasets/videos.py +105 -0
- lib/data/utils/__pycache__/augmentor.cpython-39.pyc +0 -0
- lib/data/utils/__pycache__/normalizer.cpython-39.pyc +0 -0
- lib/data/utils/augmentor.py +292 -0
- lib/data/utils/normalizer.py +105 -0
- lib/data_utils/amass_utils.py +107 -0
- lib/data_utils/emdb_eval_utils.py +189 -0
- lib/data_utils/rich_eval_utils.py +69 -0
- lib/data_utils/threedpw_eval_utils.py +185 -0
- lib/data_utils/threedpw_train_utils.py +146 -0
- lib/eval/eval_utils.py +482 -0
- lib/eval/evaluate_3dpw.py +181 -0
- lib/eval/evaluate_emdb.py +228 -0
- lib/eval/evaluate_rich.py +156 -0
- lib/models/__init__.py +40 -0
- lib/models/__pycache__/__init__.cpython-39.pyc +0 -0
- lib/models/__pycache__/smpl.cpython-39.pyc +0 -0
- lib/models/__pycache__/wham.cpython-39.pyc +0 -0
- lib/models/layers/__init__.py +2 -0
- lib/models/layers/__pycache__/__init__.cpython-39.pyc +0 -0
.gitattributes
CHANGED
@@ -37,3 +37,7 @@ examples/drone_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
37 |
examples/IMG_9730.mov filter=lfs diff=lfs merge=lfs -text
|
38 |
examples/IMG_9731.mov filter=lfs diff=lfs merge=lfs -text
|
39 |
examples/IMG_9732.mov filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
37 |
examples/IMG_9730.mov filter=lfs diff=lfs merge=lfs -text
|
38 |
examples/IMG_9731.mov filter=lfs diff=lfs merge=lfs -text
|
39 |
examples/IMG_9732.mov filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/test16.mov filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/test17.mov filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/test18.mov filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/test19.mov filter=lfs diff=lfs merge=lfs -text
|
examples/test16.mov
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f068400bf962e732e5517af45397694f84fae0a6592085b9dd3781fdbacaa550
|
3 |
+
size 1567779
|
examples/test17.mov
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce06d8885332fd0b770273010dbd4da20a0867a386dc55925f85198651651253
|
3 |
+
size 2299497
|
examples/test18.mov
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66fc6eb20e1c8525070c8004bed621e0acc2712accace1dbf1eb72fced62bb14
|
3 |
+
size 2033756
|
examples/test19.mov
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:878219571dbf0e8ff56f4ba4bf325f90f46a730b57a35a2df91f4f509af616d8
|
3 |
+
size 1940593
|
fetch_demo_data.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
|
3 |
+
|
4 |
+
# SMPL Neutral model
|
5 |
+
echo -e "\nYou need to register at https://smplify.is.tue.mpg.de"
|
6 |
+
read -p "Username (SMPLify):" username
|
7 |
+
read -p "Password (SMPLify):" password
|
8 |
+
username=$(urle $username)
|
9 |
+
password=$(urle $password)
|
10 |
+
|
11 |
+
mkdir -p dataset/body_models/smpl
|
12 |
+
wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&resume=1&sfile=mpips_smplify_public_v2.zip' -O './dataset/body_models/smplify.zip' --no-check-certificate --continue
|
13 |
+
unzip dataset/body_models/smplify.zip -d dataset/body_models/smplify
|
14 |
+
mv dataset/body_models/smplify/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_NEUTRAL.pkl
|
15 |
+
rm -rf dataset/body_models/smplify
|
16 |
+
rm -rf dataset/body_models/smplify.zip
|
17 |
+
|
18 |
+
# SMPL Male and Female model
|
19 |
+
echo -e "\nYou need to register at https://smpl.is.tue.mpg.de"
|
20 |
+
read -p "Username (SMPL):" username
|
21 |
+
read -p "Password (SMPL):" password
|
22 |
+
username=$(urle $username)
|
23 |
+
password=$(urle $password)
|
24 |
+
|
25 |
+
wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip' -O './dataset/body_models/smpl.zip' --no-check-certificate --continue
|
26 |
+
unzip dataset/body_models/smpl.zip -d dataset/body_models/smpl
|
27 |
+
mv dataset/body_models/smpl/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_FEMALE.pkl
|
28 |
+
mv dataset/body_models/smpl/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_MALE.pkl
|
29 |
+
rm -rf dataset/body_models/smpl/smpl
|
30 |
+
rm -rf dataset/body_models/smpl.zip
|
31 |
+
|
32 |
+
# Auxiliary SMPL-related data
|
33 |
+
wget "https://drive.google.com/uc?id=1pbmzRbWGgae6noDIyQOnohzaVnX_csUZ&export=download&confirm=t" -O 'dataset/body_models.tar.gz'
|
34 |
+
tar -xvf dataset/body_models.tar.gz -C dataset/
|
35 |
+
rm -rf dataset/body_models.tar.gz
|
36 |
+
|
37 |
+
# Checkpoints
|
38 |
+
mkdir checkpoints
|
39 |
+
gdown "https://drive.google.com/uc?id=1i7kt9RlCCCNEW2aYaDWVr-G778JkLNcB&export=download&confirm=t" -O 'checkpoints/wham_vit_w_3dpw.pth.tar'
|
40 |
+
gdown "https://drive.google.com/uc?id=19qkI-a6xuwob9_RFNSPWf1yWErwVVlks&export=download&confirm=t" -O 'checkpoints/wham_vit_bedlam_w_3dpw.pth.tar'
|
41 |
+
gdown "https://drive.google.com/uc?id=1J6l8teyZrL0zFzHhzkC7efRhU0ZJ5G9Y&export=download&confirm=t" -O 'checkpoints/hmr2a.ckpt'
|
42 |
+
gdown "https://drive.google.com/uc?id=1kXTV4EYb-BI3H7J-bkR3Bc4gT9zfnHGT&export=download&confirm=t" -O 'checkpoints/dpvo.pth'
|
43 |
+
gdown "https://drive.google.com/uc?id=1zJ0KP23tXD42D47cw1Gs7zE2BA_V_ERo&export=download&confirm=t" -O 'checkpoints/yolov8x.pt'
|
44 |
+
gdown "https://drive.google.com/uc?id=1xyF7F3I7lWtdq82xmEPVQ5zl4HaasBso&export=download&confirm=t" -O 'checkpoints/vitpose-h-multi-coco.pth'
|
45 |
+
|
46 |
+
# Demo videos
|
47 |
+
gdown "https://drive.google.com/uc?id=1KjfODCcOUm_xIMLLR54IcjJtf816Dkc7&export=download&confirm=t" -O 'examples.tar.gz'
|
48 |
+
tar -xvf examples.tar.gz
|
49 |
+
rm -rf examples.tar.gz
|
50 |
+
|
lib/core/loss.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
from lib.utils import transforms
|
11 |
+
from lib.utils.kp_utils import root_centering
|
12 |
+
|
13 |
+
class WHAMLoss(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
cfg=None,
|
17 |
+
device=None,
|
18 |
+
):
|
19 |
+
super(WHAMLoss, self).__init__()
|
20 |
+
|
21 |
+
self.cfg = cfg
|
22 |
+
self.n_joints = _C.KEYPOINTS.NUM_JOINTS
|
23 |
+
self.criterion = nn.MSELoss()
|
24 |
+
self.criterion_noreduce = nn.MSELoss(reduction='none')
|
25 |
+
|
26 |
+
self.pose_loss_weight = cfg.LOSS.POSE_LOSS_WEIGHT
|
27 |
+
self.shape_loss_weight = cfg.LOSS.SHAPE_LOSS_WEIGHT
|
28 |
+
self.keypoint_2d_loss_weight = cfg.LOSS.JOINT2D_LOSS_WEIGHT
|
29 |
+
self.keypoint_3d_loss_weight = cfg.LOSS.JOINT3D_LOSS_WEIGHT
|
30 |
+
self.cascaded_loss_weight = cfg.LOSS.CASCADED_LOSS_WEIGHT
|
31 |
+
self.vertices_loss_weight = cfg.LOSS.VERTS3D_LOSS_WEIGHT
|
32 |
+
self.contact_loss_weight = cfg.LOSS.CONTACT_LOSS_WEIGHT
|
33 |
+
self.root_vel_loss_weight = cfg.LOSS.ROOT_VEL_LOSS_WEIGHT
|
34 |
+
self.root_pose_loss_weight = cfg.LOSS.ROOT_POSE_LOSS_WEIGHT
|
35 |
+
self.sliding_loss_weight = cfg.LOSS.SLIDING_LOSS_WEIGHT
|
36 |
+
self.camera_loss_weight = cfg.LOSS.CAMERA_LOSS_WEIGHT
|
37 |
+
self.loss_weight = cfg.LOSS.LOSS_WEIGHT
|
38 |
+
|
39 |
+
kp_weights = [
|
40 |
+
0.5, 0.5, 0.5, 0.5, 0.5, # Face
|
41 |
+
1.5, 1.5, 4, 4, 4, 4, # Arms
|
42 |
+
1.5, 1.5, 4, 4, 4, 4, # Legs
|
43 |
+
4, 4, 1.5, 1.5, 4, 4, # Legs
|
44 |
+
4, 4, 1.5, 1.5, 4, 4, # Arms
|
45 |
+
0.5, 0.5 # Head
|
46 |
+
]
|
47 |
+
|
48 |
+
theta_weights = [
|
49 |
+
0.1, 1.0, 1.0, 1.0, 1.0, # pelvis, lhip, rhip, spine1, lknee
|
50 |
+
1.0, 1.0, 1.0, 1.0, 1.0, # rknn, spine2, lankle, rankle, spin3
|
51 |
+
0.1, 0.1, # Foot
|
52 |
+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # neck, lisldr, risldr, head, losldr, rosldr,
|
53 |
+
1.0, 1.0, 1.0, 1.0, # lelbow, relbow, lwrist, rwrist
|
54 |
+
0.1, 0.1, # Hand
|
55 |
+
]
|
56 |
+
self.theta_weights = torch.tensor([[theta_weights]]).float().to(device)
|
57 |
+
self.theta_weights /= self.theta_weights.mean()
|
58 |
+
self.kp_weights = torch.tensor([kp_weights]).float().to(device)
|
59 |
+
|
60 |
+
self.epoch = -1
|
61 |
+
self.step()
|
62 |
+
|
63 |
+
def step(self):
|
64 |
+
self.epoch += 1
|
65 |
+
self.skip_camera_loss = self.epoch < self.cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH
|
66 |
+
|
67 |
+
def forward(self, pred, gt):
|
68 |
+
|
69 |
+
loss = 0.0
|
70 |
+
b, f = gt['kp3d'].shape[:2]
|
71 |
+
|
72 |
+
# <======= Predictions and Groundtruths
|
73 |
+
pred_betas = pred['betas']
|
74 |
+
pred_pose = pred['pose'].reshape(b, f, -1, 6)
|
75 |
+
pred_kp3d_nn = pred['kp3d_nn']
|
76 |
+
pred_kp3d_smpl = root_centering(pred['kp3d'].reshape(b, f, -1, 3))
|
77 |
+
pred_full_kp2d = pred['full_kp2d']
|
78 |
+
pred_weak_kp2d = pred['weak_kp2d']
|
79 |
+
pred_contact = pred['contact']
|
80 |
+
pred_vel_root = pred['vel_root']
|
81 |
+
pred_pose_root = pred['poses_root_r6d'][:, 1:]
|
82 |
+
pred_vel_root_ref = pred['vel_root_refined']
|
83 |
+
pred_pose_root_ref = pred['poses_root_r6d_refined'][:, 1:]
|
84 |
+
pred_cam_r = transforms.matrix_to_rotation_6d(pred['R'])
|
85 |
+
|
86 |
+
gt_betas = gt['betas']
|
87 |
+
gt_pose = gt['pose']
|
88 |
+
gt_kp3d = root_centering(gt['kp3d'])
|
89 |
+
gt_full_kp2d = gt['full_kp2d']
|
90 |
+
gt_weak_kp2d = gt['weak_kp2d']
|
91 |
+
gt_contact = gt['contact']
|
92 |
+
gt_vel_root = gt['vel_root']
|
93 |
+
gt_pose_root = gt['pose_root'][:, 1:]
|
94 |
+
gt_cam_angvel = gt['cam_angvel']
|
95 |
+
gt_cam_r = transforms.matrix_to_rotation_6d(gt['R'][:, 1:])
|
96 |
+
bbox = gt['bbox']
|
97 |
+
# =======>
|
98 |
+
|
99 |
+
loss_keypoints_full = full_projected_keypoint_loss(
|
100 |
+
pred_full_kp2d,
|
101 |
+
gt_full_kp2d,
|
102 |
+
bbox,
|
103 |
+
self.kp_weights,
|
104 |
+
criterion=self.criterion_noreduce,
|
105 |
+
)
|
106 |
+
|
107 |
+
loss_keypoints_weak = weak_projected_keypoint_loss(
|
108 |
+
pred_weak_kp2d,
|
109 |
+
gt_weak_kp2d,
|
110 |
+
self.kp_weights,
|
111 |
+
criterion=self.criterion_noreduce
|
112 |
+
)
|
113 |
+
|
114 |
+
# Compute 3D keypoint loss
|
115 |
+
loss_keypoints_3d_nn = keypoint_3d_loss(
|
116 |
+
pred_kp3d_nn,
|
117 |
+
gt_kp3d[:, :, :self.n_joints],
|
118 |
+
self.kp_weights[:, :self.n_joints],
|
119 |
+
criterion=self.criterion_noreduce,
|
120 |
+
)
|
121 |
+
|
122 |
+
loss_keypoints_3d_smpl = keypoint_3d_loss(
|
123 |
+
pred_kp3d_smpl,
|
124 |
+
gt_kp3d,
|
125 |
+
self.kp_weights,
|
126 |
+
criterion=self.criterion_noreduce,
|
127 |
+
)
|
128 |
+
|
129 |
+
loss_cascaded = keypoint_3d_loss(
|
130 |
+
pred_kp3d_nn,
|
131 |
+
torch.cat((pred_kp3d_smpl[:, :, :self.n_joints], gt_kp3d[:, :, :self.n_joints, -1:]), dim=-1),
|
132 |
+
self.kp_weights[:, :self.n_joints] * 0.5,
|
133 |
+
criterion=self.criterion_noreduce,
|
134 |
+
)
|
135 |
+
|
136 |
+
loss_vertices = vertices_loss(
|
137 |
+
pred['verts_cam'],
|
138 |
+
gt['verts'],
|
139 |
+
gt['has_verts'],
|
140 |
+
criterion=self.criterion_noreduce,
|
141 |
+
)
|
142 |
+
|
143 |
+
# Compute loss on SMPL parameters
|
144 |
+
smpl_mask = gt['has_smpl']
|
145 |
+
loss_regr_pose, loss_regr_betas = smpl_losses(
|
146 |
+
pred_pose,
|
147 |
+
pred_betas,
|
148 |
+
gt_pose,
|
149 |
+
gt_betas,
|
150 |
+
self.theta_weights,
|
151 |
+
smpl_mask,
|
152 |
+
criterion=self.criterion_noreduce
|
153 |
+
)
|
154 |
+
|
155 |
+
# Compute loss on foot contact
|
156 |
+
loss_contact = contact_loss(
|
157 |
+
pred_contact,
|
158 |
+
gt_contact,
|
159 |
+
self.criterion_noreduce
|
160 |
+
)
|
161 |
+
|
162 |
+
# Compute loss on root velocity and angular velocity
|
163 |
+
loss_vel_root, loss_pose_root = root_loss(
|
164 |
+
pred_vel_root,
|
165 |
+
pred_pose_root,
|
166 |
+
gt_vel_root,
|
167 |
+
gt_pose_root,
|
168 |
+
gt_contact,
|
169 |
+
self.criterion_noreduce
|
170 |
+
)
|
171 |
+
|
172 |
+
# Root loss after trajectory refinement
|
173 |
+
loss_vel_root_ref, loss_pose_root_ref = root_loss(
|
174 |
+
pred_vel_root_ref,
|
175 |
+
pred_pose_root_ref,
|
176 |
+
gt_vel_root,
|
177 |
+
gt_pose_root,
|
178 |
+
gt_contact,
|
179 |
+
self.criterion_noreduce
|
180 |
+
)
|
181 |
+
|
182 |
+
# Camera prediction loss
|
183 |
+
loss_camera = camera_loss(
|
184 |
+
pred_cam_r,
|
185 |
+
gt_cam_r,
|
186 |
+
gt_cam_angvel[:, 1:],
|
187 |
+
gt['has_traj'],
|
188 |
+
self.criterion_noreduce,
|
189 |
+
self.skip_camera_loss
|
190 |
+
)
|
191 |
+
|
192 |
+
# Foot sliding loss
|
193 |
+
loss_sliding = sliding_loss(
|
194 |
+
pred['feet'],
|
195 |
+
gt_contact,
|
196 |
+
)
|
197 |
+
|
198 |
+
# Foot sliding loss
|
199 |
+
loss_sliding_ref = sliding_loss(
|
200 |
+
pred['feet_refined'],
|
201 |
+
gt_contact,
|
202 |
+
)
|
203 |
+
|
204 |
+
loss_keypoints = loss_keypoints_full + loss_keypoints_weak
|
205 |
+
loss_keypoints *= self.keypoint_2d_loss_weight
|
206 |
+
loss_keypoints_3d_smpl *= self.keypoint_3d_loss_weight
|
207 |
+
loss_keypoints_3d_nn *= self.keypoint_3d_loss_weight
|
208 |
+
loss_cascaded *= self.cascaded_loss_weight
|
209 |
+
loss_vertices *= self.vertices_loss_weight
|
210 |
+
loss_contact *= self.contact_loss_weight
|
211 |
+
loss_root = loss_vel_root * self.root_vel_loss_weight + loss_pose_root * self.root_pose_loss_weight
|
212 |
+
loss_root_ref = loss_vel_root_ref * self.root_vel_loss_weight + loss_pose_root_ref * self.root_pose_loss_weight
|
213 |
+
|
214 |
+
loss_regr_pose *= self.pose_loss_weight
|
215 |
+
loss_regr_betas *= self.shape_loss_weight
|
216 |
+
|
217 |
+
loss_sliding *= self.sliding_loss_weight
|
218 |
+
loss_camera *= self.camera_loss_weight
|
219 |
+
loss_sliding_ref *= self.sliding_loss_weight
|
220 |
+
|
221 |
+
loss_dict = {
|
222 |
+
'pose': loss_regr_pose * self.loss_weight,
|
223 |
+
'betas': loss_regr_betas * self.loss_weight,
|
224 |
+
'2d': loss_keypoints * self.loss_weight,
|
225 |
+
'3d': loss_keypoints_3d_smpl * self.loss_weight,
|
226 |
+
'3d_nn': loss_keypoints_3d_nn * self.loss_weight,
|
227 |
+
'casc': loss_cascaded * self.loss_weight,
|
228 |
+
'v3d': loss_vertices * self.loss_weight,
|
229 |
+
'contact': loss_contact * self.loss_weight,
|
230 |
+
'root': loss_root * self.loss_weight,
|
231 |
+
'root_ref': loss_root_ref * self.loss_weight,
|
232 |
+
'sliding': loss_sliding * self.loss_weight,
|
233 |
+
'camera': loss_camera * self.loss_weight,
|
234 |
+
'sliding_ref': loss_sliding_ref * self.loss_weight,
|
235 |
+
}
|
236 |
+
|
237 |
+
loss = sum(loss for loss in loss_dict.values())
|
238 |
+
|
239 |
+
return loss, loss_dict
|
240 |
+
|
241 |
+
|
242 |
+
def root_loss(
|
243 |
+
pred_vel_root,
|
244 |
+
pred_pose_root,
|
245 |
+
gt_vel_root,
|
246 |
+
gt_pose_root,
|
247 |
+
stationary,
|
248 |
+
criterion
|
249 |
+
):
|
250 |
+
|
251 |
+
mask_r = (gt_pose_root != 0.0).all(dim=-1).all(dim=-1)
|
252 |
+
mask_v = (gt_vel_root != 0.0).all(dim=-1).all(dim=-1)
|
253 |
+
mask_s = (stationary != -1).any(dim=1).any(dim=1)
|
254 |
+
mask_v = mask_v * mask_s
|
255 |
+
|
256 |
+
if mask_r.any():
|
257 |
+
loss_r = criterion(pred_pose_root, gt_pose_root)[mask_r].mean()
|
258 |
+
else:
|
259 |
+
loss_r = torch.FloatTensor(1).fill_(0.).to(gt_pose_root.device)[0]
|
260 |
+
|
261 |
+
if mask_v.any():
|
262 |
+
loss_v = 0
|
263 |
+
T = gt_vel_root.shape[0]
|
264 |
+
ws_list = [1, 3, 9, 27]
|
265 |
+
for ws in ws_list:
|
266 |
+
tmp_v = 0
|
267 |
+
for m in range(T//ws):
|
268 |
+
cumulative_v = torch.sum(pred_vel_root[:, m:(m+1)*ws] - gt_vel_root[:, m:(m+1)*ws], dim=1)
|
269 |
+
tmp_v += torch.norm(cumulative_v, dim=-1)
|
270 |
+
loss_v += tmp_v
|
271 |
+
loss_v = loss_v[mask_v].mean()
|
272 |
+
else:
|
273 |
+
loss_v = torch.FloatTensor(1).fill_(0.).to(gt_vel_root.device)[0]
|
274 |
+
|
275 |
+
return loss_v, loss_r
|
276 |
+
|
277 |
+
|
278 |
+
def contact_loss(
|
279 |
+
pred_stationary,
|
280 |
+
gt_stationary,
|
281 |
+
criterion,
|
282 |
+
):
|
283 |
+
|
284 |
+
mask = gt_stationary != -1
|
285 |
+
if mask.any():
|
286 |
+
loss = criterion(pred_stationary, gt_stationary)[mask].mean()
|
287 |
+
else:
|
288 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_stationary.device)[0]
|
289 |
+
return loss
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
def full_projected_keypoint_loss(
|
294 |
+
pred_keypoints_2d,
|
295 |
+
gt_keypoints_2d,
|
296 |
+
bbox,
|
297 |
+
weight,
|
298 |
+
criterion,
|
299 |
+
):
|
300 |
+
|
301 |
+
scale = bbox[..., 2:] * 200.
|
302 |
+
conf = gt_keypoints_2d[..., -1]
|
303 |
+
|
304 |
+
if (conf > 0).any():
|
305 |
+
loss = torch.mean(
|
306 |
+
weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1)
|
307 |
+
) / scale, dim=1).mean() * conf.mean()
|
308 |
+
else:
|
309 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0]
|
310 |
+
return loss
|
311 |
+
|
312 |
+
|
313 |
+
def weak_projected_keypoint_loss(
|
314 |
+
pred_keypoints_2d,
|
315 |
+
gt_keypoints_2d,
|
316 |
+
weight,
|
317 |
+
criterion,
|
318 |
+
):
|
319 |
+
|
320 |
+
conf = gt_keypoints_2d[..., -1]
|
321 |
+
if (conf > 0).any():
|
322 |
+
loss = torch.mean(
|
323 |
+
weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1)
|
324 |
+
), dim=1).mean() * conf.mean() * 5
|
325 |
+
else:
|
326 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0]
|
327 |
+
return loss
|
328 |
+
|
329 |
+
|
330 |
+
def keypoint_3d_loss(
|
331 |
+
pred_keypoints_3d,
|
332 |
+
gt_keypoints_3d,
|
333 |
+
weight,
|
334 |
+
criterion,
|
335 |
+
):
|
336 |
+
|
337 |
+
conf = gt_keypoints_3d[..., -1]
|
338 |
+
if (conf > 0).any():
|
339 |
+
if weight.shape[-2] > 17:
|
340 |
+
pred_keypoints_3d[..., -14:] = pred_keypoints_3d[..., -14:] - pred_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True)
|
341 |
+
gt_keypoints_3d[..., -14:] = gt_keypoints_3d[..., -14:] - gt_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True)
|
342 |
+
|
343 |
+
loss = torch.mean(
|
344 |
+
weight * (conf * torch.norm(pred_keypoints_3d - gt_keypoints_3d[..., :3], dim=-1)
|
345 |
+
), dim=1).mean() * conf.mean()
|
346 |
+
else:
|
347 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_3d.device)[0]
|
348 |
+
return loss
|
349 |
+
|
350 |
+
|
351 |
+
def vertices_loss(
|
352 |
+
pred_verts,
|
353 |
+
gt_verts,
|
354 |
+
mask,
|
355 |
+
criterion,
|
356 |
+
):
|
357 |
+
|
358 |
+
if mask.sum() > 0:
|
359 |
+
# Align
|
360 |
+
pred_verts = pred_verts.view_as(gt_verts)
|
361 |
+
pred_verts = pred_verts - pred_verts.mean(-2, True)
|
362 |
+
gt_verts = gt_verts - gt_verts.mean(-2, True)
|
363 |
+
|
364 |
+
# loss = criterion(pred_verts, gt_verts).mean() * mask.float().mean()
|
365 |
+
# loss = torch.mean(
|
366 |
+
# `(torch.norm(pred_verts - gt_verts, dim=-1)[mask]`
|
367 |
+
# ), dim=1).mean() * mask.float().mean()
|
368 |
+
loss = torch.mean(
|
369 |
+
(torch.norm(pred_verts - gt_verts, p=1, dim=-1)[mask]
|
370 |
+
), dim=1).mean() * mask.float().mean()
|
371 |
+
else:
|
372 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_verts.device)[0]
|
373 |
+
return loss
|
374 |
+
|
375 |
+
|
376 |
+
def smpl_losses(
|
377 |
+
pred_pose,
|
378 |
+
pred_betas,
|
379 |
+
gt_pose,
|
380 |
+
gt_betas,
|
381 |
+
weight,
|
382 |
+
mask,
|
383 |
+
criterion,
|
384 |
+
):
|
385 |
+
|
386 |
+
if mask.any().item():
|
387 |
+
loss_regr_pose = torch.mean(
|
388 |
+
weight * torch.square(pred_pose - gt_pose)[mask].mean(-1)
|
389 |
+
) * mask.float().mean()
|
390 |
+
loss_regr_betas = F.mse_loss(pred_betas, gt_betas, reduction='none')[mask].mean() * mask.float().mean()
|
391 |
+
else:
|
392 |
+
loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0]
|
393 |
+
loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0]
|
394 |
+
|
395 |
+
return loss_regr_pose, loss_regr_betas
|
396 |
+
|
397 |
+
|
398 |
+
def camera_loss(
|
399 |
+
pred_cam_r,
|
400 |
+
gt_cam_r,
|
401 |
+
cam_angvel,
|
402 |
+
mask,
|
403 |
+
criterion,
|
404 |
+
skip
|
405 |
+
):
|
406 |
+
# mask = (gt_cam_r != 0.0).all(dim=-1).all(dim=-1)
|
407 |
+
|
408 |
+
if mask.any() and not skip:
|
409 |
+
# Camera pose loss in 6D representation
|
410 |
+
loss_r = criterion(pred_cam_r, gt_cam_r)[mask].mean()
|
411 |
+
|
412 |
+
# Reconstruct camera angular velocity and compute reconstruction loss
|
413 |
+
pred_R = transforms.rotation_6d_to_matrix(pred_cam_r)
|
414 |
+
cam_angvel_from_R = transforms.matrix_to_rotation_6d(pred_R[:, :-1] @ pred_R[:, 1:].transpose(-1, -2))
|
415 |
+
cam_angvel_from_R = (cam_angvel_from_R - torch.tensor([[[1, 0, 0, 0, 1, 0]]]).to(cam_angvel)) * 30
|
416 |
+
loss_a = criterion(cam_angvel, cam_angvel_from_R)[mask].mean()
|
417 |
+
|
418 |
+
loss = loss_r + loss_a
|
419 |
+
else:
|
420 |
+
loss = torch.FloatTensor(1).fill_(0.).to(gt_cam_r.device)[0]
|
421 |
+
|
422 |
+
return loss
|
423 |
+
|
424 |
+
|
425 |
+
def sliding_loss(
|
426 |
+
foot_position,
|
427 |
+
contact_prob,
|
428 |
+
):
|
429 |
+
""" Compute foot skate loss when foot is assumed to be on contact with ground
|
430 |
+
|
431 |
+
foot_position: 3D foot (heel and toe) position, torch.Tensor (B, F, 4, 3)
|
432 |
+
contact_prob: contact probability of foot (heel and toe), torch.Tensor (B, F, 4)
|
433 |
+
"""
|
434 |
+
|
435 |
+
contact_mask = (contact_prob > 0.5).detach().float()
|
436 |
+
foot_velocity = foot_position[:, 1:] - foot_position[:, :-1]
|
437 |
+
loss = (torch.norm(foot_velocity, dim=-1) * contact_mask[:, 1:]).mean()
|
438 |
+
return loss
|
lib/core/trainer.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import time
|
18 |
+
import torch
|
19 |
+
import shutil
|
20 |
+
import logging
|
21 |
+
import numpy as np
|
22 |
+
import os.path as osp
|
23 |
+
from progress.bar import Bar
|
24 |
+
|
25 |
+
from configs import constants as _C
|
26 |
+
from lib.utils import transforms
|
27 |
+
from lib.utils.utils import AverageMeter, prepare_batch
|
28 |
+
from lib.eval.eval_utils import (
|
29 |
+
compute_accel,
|
30 |
+
compute_error_accel,
|
31 |
+
batch_align_by_pelvis,
|
32 |
+
batch_compute_similarity_transform_torch,
|
33 |
+
)
|
34 |
+
from lib.models import build_body_model
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
class Trainer():
|
39 |
+
def __init__(self,
|
40 |
+
data_loaders,
|
41 |
+
network,
|
42 |
+
optimizer,
|
43 |
+
criterion=None,
|
44 |
+
train_stage='syn',
|
45 |
+
start_epoch=0,
|
46 |
+
checkpoint=None,
|
47 |
+
end_epoch=999,
|
48 |
+
lr_scheduler=None,
|
49 |
+
device=None,
|
50 |
+
writer=None,
|
51 |
+
debug=False,
|
52 |
+
resume=False,
|
53 |
+
logdir='output',
|
54 |
+
performance_type='min',
|
55 |
+
summary_iter=1,
|
56 |
+
):
|
57 |
+
|
58 |
+
self.train_loader, self.valid_loader = data_loaders
|
59 |
+
|
60 |
+
# Model and optimizer
|
61 |
+
self.network = network
|
62 |
+
self.optimizer = optimizer
|
63 |
+
|
64 |
+
# Training parameters
|
65 |
+
self.train_stage = train_stage
|
66 |
+
self.start_epoch = start_epoch
|
67 |
+
self.end_epoch = end_epoch
|
68 |
+
self.criterion = criterion
|
69 |
+
self.lr_scheduler = lr_scheduler
|
70 |
+
self.device = device
|
71 |
+
self.writer = writer
|
72 |
+
self.debug = debug
|
73 |
+
self.resume = resume
|
74 |
+
self.logdir = logdir
|
75 |
+
self.summary_iter = summary_iter
|
76 |
+
|
77 |
+
self.performance_type = performance_type
|
78 |
+
self.train_global_step = 0
|
79 |
+
self.valid_global_step = 0
|
80 |
+
self.epoch = 0
|
81 |
+
self.best_performance = float('inf') if performance_type == 'min' else -float('inf')
|
82 |
+
self.summary_loss_keys = ['pose']
|
83 |
+
|
84 |
+
self.evaluation_accumulators = dict.fromkeys(
|
85 |
+
['pred_j3d', 'target_j3d', 'pve'])# 'pred_verts', 'target_verts'])
|
86 |
+
|
87 |
+
self.J_regressor_eval = torch.from_numpy(
|
88 |
+
np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
|
89 |
+
)[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(device)
|
90 |
+
|
91 |
+
if self.writer is None:
|
92 |
+
from torch.utils.tensorboard import SummaryWriter
|
93 |
+
self.writer = SummaryWriter(log_dir=self.logdir)
|
94 |
+
|
95 |
+
if self.device is None:
|
96 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
97 |
+
|
98 |
+
if checkpoint is not None:
|
99 |
+
self.load_pretrained(checkpoint)
|
100 |
+
|
101 |
+
def train(self, ):
|
102 |
+
# Single epoch training routine
|
103 |
+
|
104 |
+
losses = AverageMeter()
|
105 |
+
kp_2d_loss = AverageMeter()
|
106 |
+
kp_3d_loss = AverageMeter()
|
107 |
+
|
108 |
+
timer = {
|
109 |
+
'data': 0,
|
110 |
+
'forward': 0,
|
111 |
+
'loss': 0,
|
112 |
+
'backward': 0,
|
113 |
+
'batch': 0,
|
114 |
+
}
|
115 |
+
self.network.train()
|
116 |
+
start = time.time()
|
117 |
+
summary_string = ''
|
118 |
+
|
119 |
+
bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=len(self.train_loader))
|
120 |
+
for i, batch in enumerate(self.train_loader):
|
121 |
+
|
122 |
+
# <======= Feedforward
|
123 |
+
x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
|
124 |
+
timer['data'] = time.time() - start
|
125 |
+
start = time.time()
|
126 |
+
pred = self.network(x, inits, features, **kwargs)
|
127 |
+
timer['forward'] = time.time() - start
|
128 |
+
start = time.time()
|
129 |
+
# =======>
|
130 |
+
|
131 |
+
# <======= Backprop
|
132 |
+
loss, loss_dict = self.criterion(pred, gt)
|
133 |
+
timer['loss'] = time.time() - start
|
134 |
+
start = time.time()
|
135 |
+
|
136 |
+
# Clip gradients
|
137 |
+
self.optimizer.zero_grad()
|
138 |
+
loss.backward()
|
139 |
+
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
|
140 |
+
self.optimizer.step()
|
141 |
+
# =======>
|
142 |
+
|
143 |
+
# <======= Log training info
|
144 |
+
total_loss = loss
|
145 |
+
losses.update(total_loss.item(), x.size(0))
|
146 |
+
kp_2d_loss.update(loss_dict['2d'].item(), x.size(0))
|
147 |
+
kp_3d_loss.update(loss_dict['3d'].item(), x.size(0))
|
148 |
+
|
149 |
+
timer['backward'] = time.time() - start
|
150 |
+
timer['batch'] = timer['data'] + timer['forward'] + timer['loss'] + timer['backward']
|
151 |
+
start = time.time()
|
152 |
+
|
153 |
+
summary_string = f'({i + 1}/{len(self.train_loader)}) | Total: {bar.elapsed_td} ' \
|
154 |
+
f'| loss: {losses.avg:.2f} | 2d: {kp_2d_loss.avg:.2f} ' \
|
155 |
+
f'| 3d: {kp_3d_loss.avg:.2f} '
|
156 |
+
|
157 |
+
for k, v in loss_dict.items():
|
158 |
+
if k in self.summary_loss_keys:
|
159 |
+
summary_string += f' | {k}: {v:.2f}'
|
160 |
+
if (i + 1) % self.summary_iter == 0:
|
161 |
+
self.writer.add_scalar('train_loss/'+k, v, global_step=self.train_global_step)
|
162 |
+
|
163 |
+
if (i + 1) % self.summary_iter == 0:
|
164 |
+
self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step)
|
165 |
+
|
166 |
+
self.train_global_step += 1
|
167 |
+
bar.suffix = summary_string
|
168 |
+
bar.next(1)
|
169 |
+
|
170 |
+
if torch.isnan(total_loss):
|
171 |
+
exit('Nan value in loss, exiting!...')
|
172 |
+
# =======>
|
173 |
+
|
174 |
+
logger.info(summary_string)
|
175 |
+
bar.finish()
|
176 |
+
|
177 |
+
def validate(self, ):
|
178 |
+
self.network.eval()
|
179 |
+
|
180 |
+
start = time.time()
|
181 |
+
summary_string = ''
|
182 |
+
bar = Bar('Validation', fill='#', max=len(self.valid_loader))
|
183 |
+
|
184 |
+
if self.evaluation_accumulators is not None:
|
185 |
+
for k,v in self.evaluation_accumulators.items():
|
186 |
+
self.evaluation_accumulators[k] = []
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
for i, batch in enumerate(self.valid_loader):
|
190 |
+
x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
|
191 |
+
|
192 |
+
# <======= Feedforward
|
193 |
+
pred = self.network(x, inits, features, **kwargs)
|
194 |
+
|
195 |
+
# 3DPW dataset has groundtruth vertices
|
196 |
+
# NOTE: Following SPIN, we compute PVE against ground truth from Gendered SMPL mesh
|
197 |
+
smpl = build_body_model(self.device, batch_size=len(pred['verts_cam']), gender=batch['gender'][0])
|
198 |
+
gt_output = smpl.get_output(
|
199 |
+
body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]),
|
200 |
+
global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]),
|
201 |
+
betas=gt['betas'][0],
|
202 |
+
pose2rot=False
|
203 |
+
)
|
204 |
+
|
205 |
+
pred_j3d = torch.matmul(self.J_regressor_eval, pred['verts_cam']).cpu()
|
206 |
+
target_j3d = torch.matmul(self.J_regressor_eval, gt_output.vertices).cpu()
|
207 |
+
pred_verts = pred['verts_cam'].cpu()
|
208 |
+
target_verts = gt_output.vertices.cpu()
|
209 |
+
|
210 |
+
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
|
211 |
+
[pred_j3d, target_j3d, pred_verts, target_verts], [2, 3]
|
212 |
+
)
|
213 |
+
|
214 |
+
self.evaluation_accumulators['pred_j3d'].append(pred_j3d.numpy())
|
215 |
+
self.evaluation_accumulators['target_j3d'].append(target_j3d.numpy())
|
216 |
+
pve = np.sqrt(np.sum((target_verts.numpy() - pred_verts.numpy()) ** 2, axis=-1)).mean(-1) * 1e3
|
217 |
+
self.evaluation_accumulators['pve'].append(pve[:, None])
|
218 |
+
# =======>
|
219 |
+
|
220 |
+
batch_time = time.time() - start
|
221 |
+
|
222 |
+
summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \
|
223 |
+
f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}'
|
224 |
+
|
225 |
+
self.valid_global_step += 1
|
226 |
+
bar.suffix = summary_string
|
227 |
+
bar.next()
|
228 |
+
|
229 |
+
logger.info(summary_string)
|
230 |
+
|
231 |
+
bar.finish()
|
232 |
+
|
233 |
+
def evaluate(self, ):
|
234 |
+
for k, v in self.evaluation_accumulators.items():
|
235 |
+
self.evaluation_accumulators[k] = np.vstack(v)
|
236 |
+
|
237 |
+
pred_j3ds = self.evaluation_accumulators['pred_j3d']
|
238 |
+
target_j3ds = self.evaluation_accumulators['target_j3d']
|
239 |
+
|
240 |
+
pred_j3ds = torch.from_numpy(pred_j3ds).float()
|
241 |
+
target_j3ds = torch.from_numpy(target_j3ds).float()
|
242 |
+
|
243 |
+
print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...')
|
244 |
+
errors = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
245 |
+
S1_hat = batch_compute_similarity_transform_torch(pred_j3ds, target_j3ds)
|
246 |
+
errors_pa = torch.sqrt(((S1_hat - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
247 |
+
|
248 |
+
m2mm = 1000
|
249 |
+
accel = np.mean(compute_accel(pred_j3ds)) * m2mm
|
250 |
+
accel_err = np.mean(compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm
|
251 |
+
mpjpe = np.mean(errors) * m2mm
|
252 |
+
pa_mpjpe = np.mean(errors_pa) * m2mm
|
253 |
+
|
254 |
+
eval_dict = {
|
255 |
+
'mpjpe': mpjpe,
|
256 |
+
'pa-mpjpe': pa_mpjpe,
|
257 |
+
'accel': accel,
|
258 |
+
'accel_err': accel_err
|
259 |
+
}
|
260 |
+
|
261 |
+
if 'pred_verts' in self.evaluation_accumulators.keys():
|
262 |
+
eval_dict.update({'pve': self.evaluation_accumulators['pve'].mean()})
|
263 |
+
|
264 |
+
log_str = f'Epoch {self.epoch}, '
|
265 |
+
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in eval_dict.items()])
|
266 |
+
logger.info(log_str)
|
267 |
+
|
268 |
+
for k,v in eval_dict.items():
|
269 |
+
self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch)
|
270 |
+
|
271 |
+
# return (mpjpe + pa_mpjpe) / 2.
|
272 |
+
return pa_mpjpe
|
273 |
+
|
274 |
+
def save_model(self, performance, epoch):
|
275 |
+
save_dict = {
|
276 |
+
'epoch': epoch,
|
277 |
+
'model': self.network.state_dict(),
|
278 |
+
'performance': performance,
|
279 |
+
'optimizer': self.optimizer.state_dict(),
|
280 |
+
}
|
281 |
+
|
282 |
+
filename = osp.join(self.logdir, 'checkpoint.pth.tar')
|
283 |
+
torch.save(save_dict, filename)
|
284 |
+
|
285 |
+
if self.performance_type == 'min':
|
286 |
+
is_best = performance < self.best_performance
|
287 |
+
else:
|
288 |
+
is_best = performance > self.best_performance
|
289 |
+
|
290 |
+
if is_best:
|
291 |
+
logger.info('Best performance achived, saving it!')
|
292 |
+
self.best_performance = performance
|
293 |
+
shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar'))
|
294 |
+
|
295 |
+
with open(osp.join(self.logdir, 'best.txt'), 'w') as f:
|
296 |
+
f.write(str(float(performance)))
|
297 |
+
|
298 |
+
def fit(self):
|
299 |
+
for epoch in range(self.start_epoch, self.end_epoch):
|
300 |
+
self.epoch = epoch
|
301 |
+
self.train()
|
302 |
+
self.validate()
|
303 |
+
performance = self.evaluate()
|
304 |
+
|
305 |
+
self.criterion.step()
|
306 |
+
if self.lr_scheduler is not None:
|
307 |
+
self.lr_scheduler.step()
|
308 |
+
|
309 |
+
# log the learning rate
|
310 |
+
for param_group in self.optimizer.param_groups[:2]:
|
311 |
+
print(f'Learning rate {param_group["lr"]}')
|
312 |
+
self.writer.add_scalar('lr', param_group['lr'], global_step=self.epoch)
|
313 |
+
|
314 |
+
logger.info(f'Epoch {epoch+1} performance: {performance:.4f}')
|
315 |
+
|
316 |
+
self.save_model(performance, epoch)
|
317 |
+
self.train_loader.dataset.prepare_video_batch()
|
318 |
+
|
319 |
+
self.writer.close()
|
320 |
+
|
321 |
+
def load_pretrained(self, model_path):
|
322 |
+
if osp.isfile(model_path):
|
323 |
+
checkpoint = torch.load(model_path)
|
324 |
+
|
325 |
+
# network
|
326 |
+
ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
|
327 |
+
ignore_keys2 = [k for k in checkpoint['model'].keys() if 'integrator' in k]
|
328 |
+
ignore_keys.extend(ignore_keys2)
|
329 |
+
model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
|
330 |
+
model_state_dict = {k: v for k, v in model_state_dict.items() if k in self.network.state_dict().keys()}
|
331 |
+
self.network.load_state_dict(model_state_dict, strict=False)
|
332 |
+
|
333 |
+
if self.resume:
|
334 |
+
self.start_epoch = checkpoint['epoch']
|
335 |
+
self.best_performance = checkpoint['performance']
|
336 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
337 |
+
|
338 |
+
logger.info(f"=> loaded checkpoint '{model_path}' "
|
339 |
+
f"(epoch {self.start_epoch}, performance {self.best_performance})")
|
340 |
+
else:
|
341 |
+
logger.info(f"=> no checkpoint found at '{model_path}'")
|
lib/data/__init__.py
ADDED
File without changes
|
lib/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (184 Bytes). View file
|
|
lib/data/__pycache__/_dataset.cpython-39.pyc
ADDED
Binary file (3.18 kB). View file
|
|
lib/data/_dataset.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from skimage.util.shape import view_as_windows
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
from .utils.normalizer import Normalizer
|
11 |
+
from ..utils.imutils import transform
|
12 |
+
|
13 |
+
class BaseDataset(torch.utils.data.Dataset):
|
14 |
+
def __init__(self, cfg, training=True):
|
15 |
+
super(BaseDataset, self).__init__()
|
16 |
+
self.epoch = 0
|
17 |
+
self.training = training
|
18 |
+
self.n_joints = _C.KEYPOINTS.NUM_JOINTS
|
19 |
+
self.n_frames = cfg.DATASET.SEQLEN + 1
|
20 |
+
self.keypoints_normalizer = Normalizer(cfg)
|
21 |
+
|
22 |
+
def prepare_video_batch(self):
|
23 |
+
r = self.epoch % 4
|
24 |
+
|
25 |
+
self.video_indices = []
|
26 |
+
vid_name = self.labels['vid']
|
27 |
+
if isinstance(vid_name, torch.Tensor): vid_name = vid_name.numpy()
|
28 |
+
video_names_unique, group = np.unique(
|
29 |
+
vid_name, return_index=True)
|
30 |
+
perm = np.argsort(group)
|
31 |
+
group_perm = group[perm]
|
32 |
+
indices = np.split(
|
33 |
+
np.arange(0, self.labels['vid'].shape[0]), group_perm[1:]
|
34 |
+
)
|
35 |
+
for idx in range(len(video_names_unique)):
|
36 |
+
indexes = indices[idx]
|
37 |
+
if indexes.shape[0] < self.n_frames: continue
|
38 |
+
chunks = view_as_windows(
|
39 |
+
indexes, (self.n_frames), step=self.n_frames // 4
|
40 |
+
)
|
41 |
+
start_finish = chunks[r::4, (0, -1)].tolist()
|
42 |
+
self.video_indices += start_finish
|
43 |
+
|
44 |
+
self.epoch += 1
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
if self.training:
|
48 |
+
return len(self.video_indices)
|
49 |
+
else:
|
50 |
+
return len(self.labels['kp2d'])
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
return self.get_single_sequence(index)
|
54 |
+
|
55 |
+
def get_single_sequence(self, index):
|
56 |
+
NotImplementedError('get_single_sequence is not implemented')
|
57 |
+
|
58 |
+
def get_naive_intrinsics(self, res):
|
59 |
+
# Assume 45 degree FOV
|
60 |
+
img_w, img_h = res
|
61 |
+
self.focal_length = (img_w * img_w + img_h * img_h) ** 0.5
|
62 |
+
self.cam_intrinsics = torch.eye(3).repeat(1, 1, 1).float()
|
63 |
+
self.cam_intrinsics[:, 0, 0] = self.focal_length
|
64 |
+
self.cam_intrinsics[:, 1, 1] = self.focal_length
|
65 |
+
self.cam_intrinsics[:, 0, 2] = img_w/2.
|
66 |
+
self.cam_intrinsics[:, 1, 2] = img_h/2.
|
67 |
+
|
68 |
+
def j2d_processing(self, kp, bbox):
|
69 |
+
center = bbox[..., :2]
|
70 |
+
scale = bbox[..., -1:]
|
71 |
+
nparts = kp.shape[0]
|
72 |
+
for i in range(nparts):
|
73 |
+
kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale,
|
74 |
+
[224, 224])
|
75 |
+
kp[:, :2] = 2. * kp[:, :2] / 224 - 1.
|
76 |
+
kp = kp.astype('float32')
|
77 |
+
return kp
|
lib/data/dataloader.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .datasets import EvalDataset, DataFactory
|
8 |
+
from ..utils.data_utils import make_collate_fn
|
9 |
+
|
10 |
+
|
11 |
+
def setup_eval_dataloader(cfg, data, split='test', backbone=None):
|
12 |
+
if backbone is None:
|
13 |
+
backbone = cfg.MODEL.BACKBONE
|
14 |
+
|
15 |
+
dataset = EvalDataset(cfg, data, split, backbone)
|
16 |
+
dloader = torch.utils.data.DataLoader(
|
17 |
+
dataset,
|
18 |
+
batch_size=1,
|
19 |
+
num_workers=0,
|
20 |
+
shuffle=False,
|
21 |
+
pin_memory=True,
|
22 |
+
collate_fn=make_collate_fn()
|
23 |
+
)
|
24 |
+
return dloader
|
25 |
+
|
26 |
+
|
27 |
+
def setup_train_dataloader(cfg, ):
|
28 |
+
n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS
|
29 |
+
|
30 |
+
train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE)
|
31 |
+
dloader = torch.utils.data.DataLoader(
|
32 |
+
train_dataset,
|
33 |
+
batch_size=cfg.TRAIN.BATCH_SIZE,
|
34 |
+
num_workers=n_workers,
|
35 |
+
shuffle=True,
|
36 |
+
pin_memory=True,
|
37 |
+
collate_fn=make_collate_fn()
|
38 |
+
)
|
39 |
+
return dloader
|
40 |
+
|
41 |
+
|
42 |
+
def setup_dloaders(cfg, dset='3dpw', split='val'):
|
43 |
+
test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE)
|
44 |
+
train_dloader = setup_train_dataloader(cfg)
|
45 |
+
|
46 |
+
return train_dloader, test_dloader
|
lib/data/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset_eval import EvalDataset
|
2 |
+
from .dataset_custom import CustomDataset
|
3 |
+
from .mixed_dataset import DataFactory
|
lib/data/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (342 Bytes). View file
|
|
lib/data/datasets/__pycache__/amass.cpython-39.pyc
ADDED
Binary file (5.66 kB). View file
|
|
lib/data/datasets/__pycache__/bedlam.cpython-39.pyc
ADDED
Binary file (5.39 kB). View file
|
|
lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc
ADDED
Binary file (4.14 kB). View file
|
|
lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc
ADDED
Binary file (5.04 kB). View file
|
|
lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc
ADDED
Binary file (3.57 kB). View file
|
|
lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc
ADDED
Binary file (3.93 kB). View file
|
|
lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc
ADDED
Binary file (2.99 kB). View file
|
|
lib/data/datasets/__pycache__/videos.cpython-39.pyc
ADDED
Binary file (4.13 kB). View file
|
|
lib/data/datasets/amass.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import joblib
|
7 |
+
from lib.utils import transforms
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
|
11 |
+
from ..utils.augmentor import *
|
12 |
+
from .._dataset import BaseDataset
|
13 |
+
from ...models import build_body_model
|
14 |
+
from ...utils import data_utils as d_utils
|
15 |
+
from ...utils.kp_utils import root_centering
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def compute_contact_label(feet, thr=1e-2, alpha=5):
|
20 |
+
vel = torch.zeros_like(feet[..., 0])
|
21 |
+
label = torch.zeros_like(feet[..., 0])
|
22 |
+
|
23 |
+
vel[1:-1] = (feet[2:] - feet[:-2]).norm(dim=-1) / 2.0
|
24 |
+
vel[0] = vel[1].clone()
|
25 |
+
vel[-1] = vel[-2].clone()
|
26 |
+
|
27 |
+
label = 1 / (1 + torch.exp(alpha * (thr ** -1) * (vel - thr)))
|
28 |
+
return label
|
29 |
+
|
30 |
+
|
31 |
+
class AMASSDataset(BaseDataset):
|
32 |
+
def __init__(self, cfg):
|
33 |
+
label_pth = _C.PATHS.AMASS_LABEL
|
34 |
+
super(AMASSDataset, self).__init__(cfg, training=True)
|
35 |
+
|
36 |
+
self.supervise_pose = cfg.TRAIN.STAGE == 'stage1'
|
37 |
+
self.labels = joblib.load(label_pth)
|
38 |
+
self.SequenceAugmentor = SequenceAugmentor(cfg.DATASET.SEQLEN + 1)
|
39 |
+
|
40 |
+
# Load augmentators
|
41 |
+
self.VideoAugmentor = VideoAugmentor(cfg)
|
42 |
+
self.SMPLAugmentor = SMPLAugmentor(cfg)
|
43 |
+
self.d_img_feature = _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]
|
44 |
+
|
45 |
+
self.n_frames = int(cfg.DATASET.SEQLEN * self.SequenceAugmentor.l_factor) + 1
|
46 |
+
self.smpl = build_body_model('cpu', self.n_frames)
|
47 |
+
self.prepare_video_batch()
|
48 |
+
|
49 |
+
# Naive assumption of image intrinsics
|
50 |
+
self.img_w, self.img_h = 1000, 1000
|
51 |
+
self.get_naive_intrinsics((self.img_w, self.img_h))
|
52 |
+
|
53 |
+
self.CameraAugmentor = CameraAugmentor(cfg.DATASET.SEQLEN + 1, self.img_w, self.img_h, self.focal_length)
|
54 |
+
|
55 |
+
|
56 |
+
@property
|
57 |
+
def __name__(self, ):
|
58 |
+
return 'AMASS'
|
59 |
+
|
60 |
+
def get_input(self, target):
|
61 |
+
gt_kp3d = target['kp3d']
|
62 |
+
inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
|
63 |
+
kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
|
64 |
+
mask = self.VideoAugmentor.get_mask()
|
65 |
+
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
|
66 |
+
|
67 |
+
target['bbox'] = bbox[1:]
|
68 |
+
target['kp2d'] = kp2d
|
69 |
+
target['mask'] = mask[1:]
|
70 |
+
target['features'] = torch.zeros((self.SMPLAugmentor.n_frames, self.d_img_feature)).float()
|
71 |
+
return target
|
72 |
+
|
73 |
+
def get_groundtruth(self, target):
|
74 |
+
# GT 1. Joints
|
75 |
+
gt_kp3d = target['kp3d']
|
76 |
+
gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
|
77 |
+
target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1]) * float(self.supervise_pose)), dim=-1)
|
78 |
+
target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1]) * float(self.supervise_pose)), dim=-1)[1:]
|
79 |
+
target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
|
80 |
+
target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
|
81 |
+
target['verts'] = torch.zeros((self.SMPLAugmentor.n_frames, 6890, 3)).float()
|
82 |
+
|
83 |
+
# GT 2. Root pose
|
84 |
+
vel_world = (target['transl'][1:] - target['transl'][:-1])
|
85 |
+
pose_root = target['pose_root'].clone()
|
86 |
+
vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
|
87 |
+
target['vel_root'] = vel_root.clone()
|
88 |
+
target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
|
89 |
+
target['init_root'] = target['pose_root'][:1].clone()
|
90 |
+
|
91 |
+
# GT 3. Foot contact
|
92 |
+
contact = compute_contact_label(target['feet'])
|
93 |
+
if 'tread' in target['vid']:
|
94 |
+
target['contact'] = torch.ones_like(contact) * (-1)
|
95 |
+
else:
|
96 |
+
target['contact'] = contact
|
97 |
+
|
98 |
+
return target
|
99 |
+
|
100 |
+
def forward_smpl(self, target):
|
101 |
+
output = self.smpl.get_output(
|
102 |
+
body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
|
103 |
+
global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
|
104 |
+
betas=target['betas'],
|
105 |
+
pose2rot=False)
|
106 |
+
|
107 |
+
target['transl'] = target['transl'] - output.offset
|
108 |
+
target['transl'] = target['transl'] - target['transl'][0]
|
109 |
+
target['kp3d'] = output.joints
|
110 |
+
target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
|
111 |
+
|
112 |
+
return target
|
113 |
+
|
114 |
+
def augment_data(self, target):
|
115 |
+
# Augmentation 1. SMPL params augmentation
|
116 |
+
target = self.SMPLAugmentor(target)
|
117 |
+
|
118 |
+
# Augmentation 2. Sequence speed augmentation
|
119 |
+
target = self.SequenceAugmentor(target)
|
120 |
+
|
121 |
+
# Get world-coordinate SMPL
|
122 |
+
target = self.forward_smpl(target)
|
123 |
+
|
124 |
+
# Augmentation 3. Virtual camera generation
|
125 |
+
target = self.CameraAugmentor(target)
|
126 |
+
|
127 |
+
return target
|
128 |
+
|
129 |
+
def load_amass(self, index, target):
|
130 |
+
start_index, end_index = self.video_indices[index]
|
131 |
+
|
132 |
+
# Load AMASS labels
|
133 |
+
pose = torch.from_numpy(self.labels['pose'][start_index:end_index+1].copy())
|
134 |
+
pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
|
135 |
+
transl = torch.from_numpy(self.labels['transl'][start_index:end_index+1].copy())
|
136 |
+
betas = torch.from_numpy(self.labels['betas'][start_index:end_index+1].copy())
|
137 |
+
|
138 |
+
# Stack GT
|
139 |
+
target.update({'vid': self.labels['vid'][start_index],
|
140 |
+
'pose': pose,
|
141 |
+
'transl': transl,
|
142 |
+
'betas': betas})
|
143 |
+
|
144 |
+
return target
|
145 |
+
|
146 |
+
def get_single_sequence(self, index):
|
147 |
+
target = {'res': torch.tensor([self.img_w, self.img_h]).float(),
|
148 |
+
'cam_intrinsics': self.cam_intrinsics.clone(),
|
149 |
+
'has_full_screen': torch.tensor(True),
|
150 |
+
'has_smpl': torch.tensor(self.supervise_pose),
|
151 |
+
'has_traj': torch.tensor(True),
|
152 |
+
'has_verts': torch.tensor(False),}
|
153 |
+
|
154 |
+
target = self.load_amass(index, target)
|
155 |
+
target = self.augment_data(target)
|
156 |
+
target = self.get_groundtruth(target)
|
157 |
+
target = self.get_input(target)
|
158 |
+
|
159 |
+
target = d_utils.prepare_keypoints_data(target)
|
160 |
+
target = d_utils.prepare_smpl_data(target)
|
161 |
+
|
162 |
+
return target
|
163 |
+
|
164 |
+
|
165 |
+
def perspective_projection(points, cam_intrinsics, rotation=None, translation=None):
|
166 |
+
K = cam_intrinsics
|
167 |
+
if rotation is not None:
|
168 |
+
points = torch.matmul(rotation, points.transpose(1, 2)).transpose(1, 2)
|
169 |
+
if translation is not None:
|
170 |
+
points = points + translation.unsqueeze(1)
|
171 |
+
projected_points = points / points[:, :, -1].unsqueeze(-1)
|
172 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
|
173 |
+
return projected_points[:, :, :-1]
|
lib/data/datasets/bedlam.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import joblib
|
7 |
+
from lib.utils import transforms
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
|
11 |
+
from .amass import compute_contact_label, perspective_projection
|
12 |
+
from ..utils.augmentor import *
|
13 |
+
from .._dataset import BaseDataset
|
14 |
+
from ...models import build_body_model
|
15 |
+
from ...utils import data_utils as d_utils
|
16 |
+
from ...utils.kp_utils import root_centering
|
17 |
+
|
18 |
+
class BEDLAMDataset(BaseDataset):
|
19 |
+
def __init__(self, cfg):
|
20 |
+
label_pth = _C.PATHS.BEDLAM_LABEL.replace('backbone', cfg.MODEL.BACKBONE)
|
21 |
+
super(BEDLAMDataset, self).__init__(cfg, training=True)
|
22 |
+
|
23 |
+
self.labels = joblib.load(label_pth)
|
24 |
+
|
25 |
+
self.VideoAugmentor = VideoAugmentor(cfg)
|
26 |
+
self.SMPLAugmentor = SMPLAugmentor(cfg, False)
|
27 |
+
|
28 |
+
self.smpl = build_body_model('cpu', self.n_frames)
|
29 |
+
self.prepare_video_batch()
|
30 |
+
|
31 |
+
@property
|
32 |
+
def __name__(self, ):
|
33 |
+
return 'BEDLAM'
|
34 |
+
|
35 |
+
def get_inputs(self, index, target, vis_thr=0.6):
|
36 |
+
start_index, end_index = self.video_indices[index]
|
37 |
+
|
38 |
+
bbox = self.labels['bbox'][start_index:end_index+1].clone()
|
39 |
+
bbox[:, 2] = bbox[:, 2] / 200
|
40 |
+
|
41 |
+
gt_kp3d = target['kp3d']
|
42 |
+
inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
|
43 |
+
# kp2d = perspective_projection(inpt_kp3d, target['K'])
|
44 |
+
kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
|
45 |
+
mask = self.VideoAugmentor.get_mask()
|
46 |
+
# kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
|
47 |
+
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
|
48 |
+
|
49 |
+
target['bbox'] = bbox[1:]
|
50 |
+
target['kp2d'] = kp2d
|
51 |
+
target['mask'] = mask[1:]
|
52 |
+
|
53 |
+
# Image features
|
54 |
+
target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
|
55 |
+
|
56 |
+
return target
|
57 |
+
|
58 |
+
def get_groundtruth(self, index, target):
|
59 |
+
start_index, end_index = self.video_indices[index]
|
60 |
+
|
61 |
+
# GT 1. Joints
|
62 |
+
gt_kp3d = target['kp3d']
|
63 |
+
# gt_kp2d = perspective_projection(gt_kp3d, target['K'])
|
64 |
+
gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
|
65 |
+
target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1])), dim=-1)
|
66 |
+
# target['full_kp2d'] = torch.cat((gt_kp2d, torch.zeros_like(gt_kp2d[..., :1])), dim=-1)[1:]
|
67 |
+
target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1])), dim=-1)[1:]
|
68 |
+
target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
|
69 |
+
target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
|
70 |
+
|
71 |
+
# GT 2. Root pose
|
72 |
+
w_transl = self.labels['w_trans'][start_index:end_index+1]
|
73 |
+
pose_root = transforms.axis_angle_to_matrix(self.labels['root'][start_index:end_index+1])
|
74 |
+
vel_world = (w_transl[1:] - w_transl[:-1])
|
75 |
+
vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
|
76 |
+
target['vel_root'] = vel_root.clone()
|
77 |
+
target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
|
78 |
+
target['init_root'] = target['pose_root'][:1].clone()
|
79 |
+
|
80 |
+
return target
|
81 |
+
|
82 |
+
def forward_smpl(self, target):
|
83 |
+
output = self.smpl.get_output(
|
84 |
+
body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
|
85 |
+
global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
|
86 |
+
betas=target['betas'],
|
87 |
+
transl=target['transl'],
|
88 |
+
pose2rot=False)
|
89 |
+
|
90 |
+
target['kp3d'] = output.joints + output.offset.unsqueeze(1)
|
91 |
+
target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
|
92 |
+
target['verts'] = output.vertices[1:, ].clone()
|
93 |
+
|
94 |
+
return target
|
95 |
+
|
96 |
+
def augment_data(self, target):
|
97 |
+
# Augmentation 1. SMPL params augmentation
|
98 |
+
target = self.SMPLAugmentor(target)
|
99 |
+
|
100 |
+
# Get world-coordinate SMPL
|
101 |
+
target = self.forward_smpl(target)
|
102 |
+
|
103 |
+
return target
|
104 |
+
|
105 |
+
def load_camera(self, index, target):
|
106 |
+
start_index, end_index = self.video_indices[index]
|
107 |
+
|
108 |
+
# Get camera info
|
109 |
+
extrinsics = self.labels['extrinsics'][start_index:end_index+1].clone()
|
110 |
+
R = extrinsics[:, :3, :3]
|
111 |
+
T = extrinsics[:, :3, -1]
|
112 |
+
K = self.labels['intrinsics'][start_index:end_index+1].clone()
|
113 |
+
width, height = K[0, 0, 2] * 2, K[0, 1, 2] * 2
|
114 |
+
target['R'] = R
|
115 |
+
target['res'] = torch.tensor([width, height]).float()
|
116 |
+
|
117 |
+
# Compute angular velocity
|
118 |
+
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
|
119 |
+
cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
|
120 |
+
target['cam_angvel'] = cam_angvel * 3e1 # BEDLAM is 30-fps
|
121 |
+
|
122 |
+
target['K'] = K # Use GT camera intrinsics for projecting keypoints
|
123 |
+
self.get_naive_intrinsics(target['res'])
|
124 |
+
target['cam_intrinsics'] = self.cam_intrinsics
|
125 |
+
|
126 |
+
return target
|
127 |
+
|
128 |
+
def load_params(self, index, target):
|
129 |
+
start_index, end_index = self.video_indices[index]
|
130 |
+
|
131 |
+
# Load AMASS labels
|
132 |
+
pose = self.labels['pose'][start_index:end_index+1].clone()
|
133 |
+
pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
|
134 |
+
transl = self.labels['c_trans'][start_index:end_index+1].clone()
|
135 |
+
betas = self.labels['betas'][start_index:end_index+1, :10].clone()
|
136 |
+
|
137 |
+
# Stack GT
|
138 |
+
target.update({'vid': self.labels['vid'][start_index].clone(),
|
139 |
+
'pose': pose,
|
140 |
+
'transl': transl,
|
141 |
+
'betas': betas})
|
142 |
+
|
143 |
+
return target
|
144 |
+
|
145 |
+
|
146 |
+
def get_single_sequence(self, index):
|
147 |
+
target = {'has_full_screen': torch.tensor(True),
|
148 |
+
'has_smpl': torch.tensor(True),
|
149 |
+
'has_traj': torch.tensor(False),
|
150 |
+
'has_verts': torch.tensor(True),
|
151 |
+
|
152 |
+
# Null contact label
|
153 |
+
'contact': torch.ones((self.n_frames - 1, 4)) * (-1),
|
154 |
+
}
|
155 |
+
|
156 |
+
target = self.load_params(index, target)
|
157 |
+
target = self.load_camera(index, target)
|
158 |
+
target = self.augment_data(target)
|
159 |
+
target = self.get_groundtruth(index, target)
|
160 |
+
target = self.get_inputs(index, target)
|
161 |
+
|
162 |
+
target = d_utils.prepare_keypoints_data(target)
|
163 |
+
target = d_utils.prepare_smpl_data(target)
|
164 |
+
|
165 |
+
return target
|
lib/data/datasets/dataset2d.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import joblib
|
7 |
+
|
8 |
+
from .._dataset import BaseDataset
|
9 |
+
from ..utils.augmentor import *
|
10 |
+
from ...utils import data_utils as d_utils
|
11 |
+
from ...utils import transforms
|
12 |
+
from ...models import build_body_model
|
13 |
+
from ...utils.kp_utils import convert_kps, root_centering
|
14 |
+
|
15 |
+
|
16 |
+
class Dataset2D(BaseDataset):
|
17 |
+
def __init__(self, cfg, fname, training):
|
18 |
+
super(Dataset2D, self).__init__(cfg, training)
|
19 |
+
|
20 |
+
self.epoch = 0
|
21 |
+
self.n_frames = cfg.DATASET.SEQLEN + 1
|
22 |
+
self.labels = joblib.load(fname)
|
23 |
+
|
24 |
+
if self.training:
|
25 |
+
self.prepare_video_batch()
|
26 |
+
|
27 |
+
self.smpl = build_body_model('cpu', self.n_frames)
|
28 |
+
self.SMPLAugmentor = SMPLAugmentor(cfg, False)
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
return self.get_single_sequence(index)
|
32 |
+
|
33 |
+
def get_inputs(self, index, target, vis_thr=0.6):
|
34 |
+
start_index, end_index = self.video_indices[index]
|
35 |
+
|
36 |
+
# 2D keypoints detection
|
37 |
+
kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone()
|
38 |
+
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], target['cam_intrinsics'], 224, 224, target['bbox'])
|
39 |
+
target['bbox'] = bbox[1:]
|
40 |
+
target['kp2d'] = kp2d
|
41 |
+
|
42 |
+
# Detection mask
|
43 |
+
target['mask'] = ~self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone().bool()
|
44 |
+
|
45 |
+
# Image features
|
46 |
+
target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
|
47 |
+
|
48 |
+
return target
|
49 |
+
|
50 |
+
def get_labels(self, index, target):
|
51 |
+
start_index, end_index = self.video_indices[index]
|
52 |
+
|
53 |
+
# SMPL parameters
|
54 |
+
# NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input.
|
55 |
+
# We do not supervise the network on SMPL parameters.
|
56 |
+
target['pose'] = transforms.axis_angle_to_matrix(
|
57 |
+
self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3))
|
58 |
+
target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t
|
59 |
+
|
60 |
+
# Apply SMPL augmentor (y-axis rotation and initial frame noise)
|
61 |
+
target = self.SMPLAugmentor(target)
|
62 |
+
|
63 |
+
# 2D keypoints
|
64 |
+
kp2d = self.labels['kp2d'][start_index:end_index+1].clone().float()[..., :2]
|
65 |
+
gt_kp2d = torch.zeros((self.n_frames - 1, 31, 2))
|
66 |
+
gt_kp2d[:, :17] = kp2d[1:].clone()
|
67 |
+
|
68 |
+
# Set 0 confidence to the masked keypoints
|
69 |
+
mask = torch.zeros((self.n_frames - 1, 31))
|
70 |
+
mask[:, :17] = self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone()
|
71 |
+
mask = torch.logical_and(gt_kp2d.mean(-1) != 0, mask)
|
72 |
+
gt_kp2d = torch.cat((gt_kp2d, mask.float().unsqueeze(-1)), dim=-1)
|
73 |
+
|
74 |
+
_gt_kp2d = gt_kp2d.clone()
|
75 |
+
for idx in range(len(_gt_kp2d)):
|
76 |
+
_gt_kp2d[idx][..., :2] = torch.from_numpy(
|
77 |
+
self.j2d_processing(gt_kp2d[idx][..., :2].numpy().copy(),
|
78 |
+
target['bbox'][idx].numpy().copy()))
|
79 |
+
|
80 |
+
target['weak_kp2d'] = _gt_kp2d.clone()
|
81 |
+
target['full_kp2d'] = torch.zeros_like(gt_kp2d)
|
82 |
+
target['kp3d'] = torch.zeros((kp2d.shape[0], 31, 4))
|
83 |
+
|
84 |
+
# No SMPL vertices available
|
85 |
+
target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float()
|
86 |
+
return target
|
87 |
+
|
88 |
+
def get_init_frame(self, target):
|
89 |
+
# Prepare initial frame
|
90 |
+
output = self.smpl.get_output(
|
91 |
+
body_pose=target['init_pose'][:, 1:],
|
92 |
+
global_orient=target['init_pose'][:, :1],
|
93 |
+
betas=target['betas'][:1],
|
94 |
+
pose2rot=False
|
95 |
+
)
|
96 |
+
target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1)
|
97 |
+
|
98 |
+
return target
|
99 |
+
|
100 |
+
def get_single_sequence(self, index):
|
101 |
+
# Camera parameters
|
102 |
+
res = (224.0, 224.0)
|
103 |
+
bbox = torch.tensor([112.0, 112.0, 1.12])
|
104 |
+
res = torch.tensor(res)
|
105 |
+
self.get_naive_intrinsics(res)
|
106 |
+
bbox = bbox.repeat(self.n_frames, 1)
|
107 |
+
|
108 |
+
# Universal target
|
109 |
+
target = {'has_full_screen': torch.tensor(False),
|
110 |
+
'has_smpl': torch.tensor(self.has_smpl),
|
111 |
+
'has_traj': torch.tensor(self.has_traj),
|
112 |
+
'has_verts': torch.tensor(False),
|
113 |
+
'transl': torch.zeros((self.n_frames, 3)),
|
114 |
+
|
115 |
+
# Camera parameters and bbox
|
116 |
+
'res': res,
|
117 |
+
'cam_intrinsics': self.cam_intrinsics,
|
118 |
+
'bbox': bbox,
|
119 |
+
|
120 |
+
# Null camera motion
|
121 |
+
'R': torch.eye(3).repeat(self.n_frames, 1, 1),
|
122 |
+
'cam_angvel': torch.zeros((self.n_frames - 1, 6)),
|
123 |
+
|
124 |
+
# Null root orientation and velocity
|
125 |
+
'pose_root': torch.zeros((self.n_frames, 6)),
|
126 |
+
'vel_root': torch.zeros((self.n_frames - 1, 3)),
|
127 |
+
'init_root': torch.zeros((1, 6)),
|
128 |
+
|
129 |
+
# Null contact label
|
130 |
+
'contact': torch.ones((self.n_frames - 1, 4)) * (-1)
|
131 |
+
}
|
132 |
+
|
133 |
+
self.get_inputs(index, target)
|
134 |
+
self.get_labels(index, target)
|
135 |
+
self.get_init_frame(target)
|
136 |
+
|
137 |
+
target = d_utils.prepare_keypoints_data(target)
|
138 |
+
target = d_utils.prepare_smpl_data(target)
|
139 |
+
|
140 |
+
return target
|
lib/data/datasets/dataset3d.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import joblib
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from .._dataset import BaseDataset
|
10 |
+
from ..utils.augmentor import *
|
11 |
+
from ...utils import data_utils as d_utils
|
12 |
+
from ...utils import transforms
|
13 |
+
from ...models import build_body_model
|
14 |
+
from ...utils.kp_utils import convert_kps, root_centering
|
15 |
+
|
16 |
+
|
17 |
+
class Dataset3D(BaseDataset):
|
18 |
+
def __init__(self, cfg, fname, training):
|
19 |
+
super(Dataset3D, self).__init__(cfg, training)
|
20 |
+
|
21 |
+
self.epoch = 0
|
22 |
+
self.labels = joblib.load(fname)
|
23 |
+
self.n_frames = cfg.DATASET.SEQLEN + 1
|
24 |
+
|
25 |
+
if self.training:
|
26 |
+
self.prepare_video_batch()
|
27 |
+
|
28 |
+
self.smpl = build_body_model('cpu', self.n_frames)
|
29 |
+
self.SMPLAugmentor = SMPLAugmentor(cfg, False)
|
30 |
+
self.VideoAugmentor = VideoAugmentor(cfg)
|
31 |
+
|
32 |
+
def __getitem__(self, index):
|
33 |
+
return self.get_single_sequence(index)
|
34 |
+
|
35 |
+
def get_inputs(self, index, target, vis_thr=0.6):
|
36 |
+
start_index, end_index = self.video_indices[index]
|
37 |
+
|
38 |
+
# 2D keypoints detection
|
39 |
+
kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone()
|
40 |
+
bbox = self.labels['bbox'][start_index:end_index+1][..., [0, 1, -1]].clone()
|
41 |
+
bbox[:, 2] = bbox[:, 2] / 200
|
42 |
+
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
|
43 |
+
|
44 |
+
target['bbox'] = bbox[1:]
|
45 |
+
target['kp2d'] = kp2d
|
46 |
+
target['mask'] = self.labels['kp2d'][start_index+1:end_index+1][..., -1] < vis_thr
|
47 |
+
|
48 |
+
# Image features
|
49 |
+
target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
|
50 |
+
|
51 |
+
return target
|
52 |
+
|
53 |
+
def get_labels(self, index, target):
|
54 |
+
start_index, end_index = self.video_indices[index]
|
55 |
+
|
56 |
+
# SMPL parameters
|
57 |
+
# NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input.
|
58 |
+
# We do not supervise the network on SMPL parameters.
|
59 |
+
target['pose'] = transforms.axis_angle_to_matrix(
|
60 |
+
self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3))
|
61 |
+
target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t
|
62 |
+
|
63 |
+
# Apply SMPL augmentor (y-axis rotation and initial frame noise)
|
64 |
+
target = self.SMPLAugmentor(target)
|
65 |
+
|
66 |
+
# 3D and 2D keypoints
|
67 |
+
if self.__name__ == 'ThreeDPW': # 3DPW has SMPL labels
|
68 |
+
gt_kp3d = self.labels['joints3D'][start_index:end_index+1].clone()
|
69 |
+
gt_kp2d = self.labels['joints2D'][start_index+1:end_index+1, ..., :2].clone()
|
70 |
+
gt_kp3d = root_centering(gt_kp3d.clone())
|
71 |
+
|
72 |
+
else: # Human36m and MPII do not have SMPL labels
|
73 |
+
gt_kp3d = torch.zeros((self.n_frames, self.n_joints + 14, 3))
|
74 |
+
gt_kp3d[:, self.n_joints:] = convert_kps(self.labels['joints3D'][start_index:end_index+1], 'spin', 'common')
|
75 |
+
gt_kp2d = torch.zeros((self.n_frames - 1, self.n_joints + 14, 2))
|
76 |
+
gt_kp2d[:, self.n_joints:] = convert_kps(self.labels['joints2D'][start_index+1:end_index+1, ..., :2], 'spin', 'common')
|
77 |
+
|
78 |
+
conf = self.mask.repeat(self.n_frames, 1).unsqueeze(-1)
|
79 |
+
gt_kp2d = torch.cat((gt_kp2d, conf[1:]), dim=-1)
|
80 |
+
gt_kp3d = torch.cat((gt_kp3d, conf), dim=-1)
|
81 |
+
target['kp3d'] = gt_kp3d
|
82 |
+
target['full_kp2d'] = gt_kp2d
|
83 |
+
target['weak_kp2d'] = torch.zeros_like(gt_kp2d)
|
84 |
+
|
85 |
+
if self.__name__ != 'ThreeDPW': # 3DPW does not contain world-coordinate motion
|
86 |
+
# Foot ground contact labels for Human36M and MPII3D
|
87 |
+
target['contact'] = self.labels['stationaries'][start_index+1:end_index+1].clone()
|
88 |
+
else:
|
89 |
+
# No foot ground contact label available for 3DPW
|
90 |
+
target['contact'] = torch.ones((self.n_frames - 1, 4)) * (-1)
|
91 |
+
|
92 |
+
if self.has_verts:
|
93 |
+
# SMPL vertices available for 3DPW
|
94 |
+
with torch.no_grad():
|
95 |
+
start_index, end_index = self.video_indices[index]
|
96 |
+
gender = self.labels['gender'][start_index].item()
|
97 |
+
output = self.smpl_gender[gender](
|
98 |
+
body_pose=target['pose'][1:, 1:],
|
99 |
+
global_orient=target['pose'][1:, :1],
|
100 |
+
betas=target['betas'][1:],
|
101 |
+
pose2rot=False,
|
102 |
+
)
|
103 |
+
target['verts'] = output.vertices.clone()
|
104 |
+
else:
|
105 |
+
# No SMPL vertices available
|
106 |
+
target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float()
|
107 |
+
|
108 |
+
return target
|
109 |
+
|
110 |
+
def get_init_frame(self, target):
|
111 |
+
# Prepare initial frame
|
112 |
+
output = self.smpl.get_output(
|
113 |
+
body_pose=target['init_pose'][:, 1:],
|
114 |
+
global_orient=target['init_pose'][:, :1],
|
115 |
+
betas=target['betas'][:1],
|
116 |
+
pose2rot=False
|
117 |
+
)
|
118 |
+
target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1)
|
119 |
+
|
120 |
+
return target
|
121 |
+
|
122 |
+
def get_camera_info(self, index, target):
|
123 |
+
start_index, end_index = self.video_indices[index]
|
124 |
+
|
125 |
+
# Intrinsics
|
126 |
+
target['res'] = self.labels['res'][start_index:end_index+1][0].clone()
|
127 |
+
self.get_naive_intrinsics(target['res'])
|
128 |
+
target['cam_intrinsics'] = self.cam_intrinsics.clone()
|
129 |
+
|
130 |
+
# Extrinsics pose
|
131 |
+
R = self.labels['cam_poses'][start_index:end_index+1, :3, :3].clone().float()
|
132 |
+
yaw = transforms.axis_angle_to_matrix(torch.tensor([[0, 2 * np.pi * np.random.uniform(), 0]])).float()
|
133 |
+
if self.__name__ == 'Human36M':
|
134 |
+
# Map Z-up to Y-down coordinate
|
135 |
+
zup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[-np.pi/2, 0, 0]])).float()
|
136 |
+
zup2ydown = torch.matmul(yaw, zup2ydown)
|
137 |
+
R = torch.matmul(R, zup2ydown)
|
138 |
+
elif self.__name__ == 'MPII3D':
|
139 |
+
# Map Y-up to Y-down coordinate
|
140 |
+
yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float()
|
141 |
+
yup2ydown = torch.matmul(yaw, yup2ydown)
|
142 |
+
R = torch.matmul(R, yup2ydown)
|
143 |
+
|
144 |
+
return target
|
145 |
+
|
146 |
+
def get_single_sequence(self, index):
|
147 |
+
# Universal target
|
148 |
+
target = {'has_full_screen': torch.tensor(True),
|
149 |
+
'has_smpl': torch.tensor(self.has_smpl),
|
150 |
+
'has_traj': torch.tensor(self.has_traj),
|
151 |
+
'has_verts': torch.tensor(self.has_verts),
|
152 |
+
'transl': torch.zeros((self.n_frames, 3)),
|
153 |
+
|
154 |
+
# Null camera motion
|
155 |
+
'R': torch.eye(3).repeat(self.n_frames, 1, 1),
|
156 |
+
'cam_angvel': torch.zeros((self.n_frames - 1, 6)),
|
157 |
+
|
158 |
+
# Null root orientation and velocity
|
159 |
+
'pose_root': torch.zeros((self.n_frames, 6)),
|
160 |
+
'vel_root': torch.zeros((self.n_frames - 1, 3)),
|
161 |
+
'init_root': torch.zeros((1, 6)),
|
162 |
+
}
|
163 |
+
|
164 |
+
self.get_camera_info(index, target)
|
165 |
+
self.get_inputs(index, target)
|
166 |
+
self.get_labels(index, target)
|
167 |
+
self.get_init_frame(target)
|
168 |
+
|
169 |
+
target = d_utils.prepare_keypoints_data(target)
|
170 |
+
target = d_utils.prepare_smpl_data(target)
|
171 |
+
|
172 |
+
return target
|
lib/data/datasets/dataset_custom.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ..utils.normalizer import Normalizer
|
8 |
+
from ...models import build_body_model
|
9 |
+
from ...utils import transforms
|
10 |
+
from ...utils.kp_utils import root_centering
|
11 |
+
from ...utils.imutils import compute_cam_intrinsics
|
12 |
+
|
13 |
+
KEYPOINTS_THR = 0.3
|
14 |
+
|
15 |
+
def convert_dpvo_to_cam_angvel(traj, fps):
|
16 |
+
"""Function to convert DPVO trajectory output to camera angular velocity"""
|
17 |
+
|
18 |
+
# 0 ~ 3: translation, 3 ~ 7: Quaternion
|
19 |
+
quat = traj[:, 3:]
|
20 |
+
|
21 |
+
# Convert (x,y,z,q) to (q,x,y,z)
|
22 |
+
quat = quat[:, [3, 0, 1, 2]]
|
23 |
+
|
24 |
+
# Quat is camera to world transformation. Convert it to world to camera
|
25 |
+
world2cam = transforms.quaternion_to_matrix(torch.from_numpy(quat)).float()
|
26 |
+
R = world2cam.mT
|
27 |
+
|
28 |
+
# Compute the rotational changes over time.
|
29 |
+
cam_angvel = transforms.matrix_to_axis_angle(R[:-1] @ R[1:].transpose(-1, -2))
|
30 |
+
|
31 |
+
# Convert matrix to 6D representation
|
32 |
+
cam_angvel = transforms.matrix_to_rotation_6d(transforms.axis_angle_to_matrix(cam_angvel))
|
33 |
+
|
34 |
+
# Normalize 6D angular velocity
|
35 |
+
cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
|
36 |
+
cam_angvel = cam_angvel * fps
|
37 |
+
cam_angvel = torch.cat((cam_angvel, cam_angvel[:1]), dim=0)
|
38 |
+
return cam_angvel
|
39 |
+
|
40 |
+
|
41 |
+
class CustomDataset(torch.utils.data.Dataset):
|
42 |
+
def __init__(self, cfg, tracking_results, slam_results, width, height, fps):
|
43 |
+
|
44 |
+
self.tracking_results = tracking_results
|
45 |
+
self.slam_results = slam_results
|
46 |
+
self.width = width
|
47 |
+
self.height = height
|
48 |
+
self.fps = fps
|
49 |
+
self.res = torch.tensor([width, height]).float()
|
50 |
+
self.intrinsics = compute_cam_intrinsics(self.res)
|
51 |
+
|
52 |
+
self.device = cfg.DEVICE.lower()
|
53 |
+
|
54 |
+
self.smpl = build_body_model('cpu')
|
55 |
+
self.keypoints_normalizer = Normalizer(cfg)
|
56 |
+
|
57 |
+
self._to = lambda x: x.unsqueeze(0).to(self.device)
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.tracking_results.keys())
|
61 |
+
|
62 |
+
def load_data(self, index, flip=False):
|
63 |
+
if flip:
|
64 |
+
self.prefix = 'flipped_'
|
65 |
+
else:
|
66 |
+
self.prefix = ''
|
67 |
+
|
68 |
+
return self.__getitem__(index)
|
69 |
+
|
70 |
+
def __getitem__(self, _index):
|
71 |
+
if _index >= len(self): return
|
72 |
+
|
73 |
+
index = sorted(list(self.tracking_results.keys()))[_index]
|
74 |
+
|
75 |
+
# Process 2D keypoints
|
76 |
+
kp2d = torch.from_numpy(self.tracking_results[index][self.prefix + 'keypoints']).float()
|
77 |
+
mask = kp2d[..., -1] < KEYPOINTS_THR
|
78 |
+
bbox = torch.from_numpy(self.tracking_results[index][self.prefix + 'bbox']).float()
|
79 |
+
|
80 |
+
norm_kp2d, _ = self.keypoints_normalizer(
|
81 |
+
kp2d[..., :-1].clone(), self.res, self.intrinsics, 224, 224, bbox
|
82 |
+
)
|
83 |
+
|
84 |
+
# Process image features
|
85 |
+
features = self.tracking_results[index][self.prefix + 'features']
|
86 |
+
|
87 |
+
# Process initial pose
|
88 |
+
init_output = self.smpl.get_output(
|
89 |
+
global_orient=self.tracking_results[index][self.prefix + 'init_global_orient'],
|
90 |
+
body_pose=self.tracking_results[index][self.prefix + 'init_body_pose'],
|
91 |
+
betas=self.tracking_results[index][self.prefix + 'init_betas'],
|
92 |
+
pose2rot=False,
|
93 |
+
return_full_pose=True
|
94 |
+
)
|
95 |
+
init_kp3d = root_centering(init_output.joints[:, :17], 'coco')
|
96 |
+
init_kp = torch.cat((init_kp3d.reshape(1, -1), norm_kp2d[0].clone().reshape(1, -1)), dim=-1)
|
97 |
+
init_smpl = transforms.matrix_to_rotation_6d(init_output.full_pose)
|
98 |
+
init_root = transforms.matrix_to_rotation_6d(init_output.global_orient)
|
99 |
+
|
100 |
+
# Process SLAM results
|
101 |
+
cam_angvel = convert_dpvo_to_cam_angvel(self.slam_results, self.fps)
|
102 |
+
|
103 |
+
return (
|
104 |
+
index, # subject id
|
105 |
+
self._to(norm_kp2d), # 2d keypoints
|
106 |
+
(self._to(init_kp), self._to(init_smpl)), # initial pose
|
107 |
+
self._to(features), # image features
|
108 |
+
self._to(mask), # keypoints mask
|
109 |
+
init_root.to(self.device), # initial root orientation
|
110 |
+
self._to(cam_angvel), # camera angular velocity
|
111 |
+
self.tracking_results[index]['frame_id'], # frame indices
|
112 |
+
{'cam_intrinsics': self._to(self.intrinsics), # other keyword arguments
|
113 |
+
'bbox': self._to(bbox),
|
114 |
+
'res': self._to(self.res)},
|
115 |
+
)
|
lib/data/datasets/dataset_eval.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import joblib
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
from .._dataset import BaseDataset
|
11 |
+
from ...utils import transforms
|
12 |
+
from ...utils import data_utils as d_utils
|
13 |
+
from ...utils.kp_utils import root_centering
|
14 |
+
|
15 |
+
FPS = 30
|
16 |
+
class EvalDataset(BaseDataset):
|
17 |
+
def __init__(self, cfg, data, split, backbone):
|
18 |
+
super(EvalDataset, self).__init__(cfg, False)
|
19 |
+
|
20 |
+
self.prefix = ''
|
21 |
+
self.data = data
|
22 |
+
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth')
|
23 |
+
self.labels = joblib.load(parsed_data_path)
|
24 |
+
|
25 |
+
def load_data(self, index, flip=False):
|
26 |
+
if flip:
|
27 |
+
self.prefix = 'flipped_'
|
28 |
+
else:
|
29 |
+
self.prefix = ''
|
30 |
+
|
31 |
+
target = self.__getitem__(index)
|
32 |
+
for key, val in target.items():
|
33 |
+
if isinstance(val, torch.Tensor):
|
34 |
+
target[key] = val.unsqueeze(0)
|
35 |
+
return target
|
36 |
+
|
37 |
+
def __getitem__(self, index):
|
38 |
+
target = {}
|
39 |
+
target = self.get_data(index)
|
40 |
+
target = d_utils.prepare_keypoints_data(target)
|
41 |
+
target = d_utils.prepare_smpl_data(target)
|
42 |
+
|
43 |
+
return target
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.labels['kp2d'])
|
47 |
+
|
48 |
+
def prepare_labels(self, index, target):
|
49 |
+
# Ground truth SMPL parameters
|
50 |
+
target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3))
|
51 |
+
target['betas'] = self.labels['betas'][index]
|
52 |
+
target['gender'] = self.labels['gender'][index]
|
53 |
+
|
54 |
+
# Sequence information
|
55 |
+
target['res'] = self.labels['res'][index][0]
|
56 |
+
target['vid'] = self.labels['vid'][index]
|
57 |
+
target['frame_id'] = self.labels['frame_id'][index][1:]
|
58 |
+
|
59 |
+
# Camera information
|
60 |
+
self.get_naive_intrinsics(target['res'])
|
61 |
+
target['cam_intrinsics'] = self.cam_intrinsics
|
62 |
+
R = self.labels['cam_poses'][index][:, :3, :3].clone()
|
63 |
+
if 'emdb' in self.data.lower():
|
64 |
+
# Use groundtruth camera angular velocity.
|
65 |
+
# Can be updated with SLAM results if you have it.
|
66 |
+
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
|
67 |
+
cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS
|
68 |
+
target['R'] = R
|
69 |
+
else:
|
70 |
+
cam_angvel = torch.zeros((len(target['pose']) - 1, 6))
|
71 |
+
target['cam_angvel'] = cam_angvel
|
72 |
+
return target
|
73 |
+
|
74 |
+
def prepare_inputs(self, index, target):
|
75 |
+
for key in ['features', 'bbox']:
|
76 |
+
data = self.labels[self.prefix + key][index][1:]
|
77 |
+
target[key] = data
|
78 |
+
|
79 |
+
bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float()
|
80 |
+
bbox[:, 2] = bbox[:, 2] / 200
|
81 |
+
|
82 |
+
# Normalize keypoints
|
83 |
+
kp2d, bbox = self.keypoints_normalizer(
|
84 |
+
self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(),
|
85 |
+
target['res'], target['cam_intrinsics'], 224, 224, bbox)
|
86 |
+
target['kp2d'] = kp2d
|
87 |
+
target['bbox'] = bbox[1:]
|
88 |
+
|
89 |
+
# Masking out low confident keypoints
|
90 |
+
mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3
|
91 |
+
target['input_kp2d'] = self.labels['kp2d'][index][1:]
|
92 |
+
target['input_kp2d'][mask[1:]] *= 0
|
93 |
+
target['mask'] = mask[1:]
|
94 |
+
|
95 |
+
return target
|
96 |
+
|
97 |
+
def prepare_initialization(self, index, target):
|
98 |
+
# Initial frame per-frame estimation
|
99 |
+
target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1)
|
100 |
+
target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu()
|
101 |
+
pose_root = target['pose'][:, 0].clone()
|
102 |
+
target['init_root'] = transforms.matrix_to_rotation_6d(pose_root)
|
103 |
+
|
104 |
+
return target
|
105 |
+
|
106 |
+
def get_data(self, index):
|
107 |
+
target = {}
|
108 |
+
|
109 |
+
target = self.prepare_labels(index, target)
|
110 |
+
target = self.prepare_inputs(index, target)
|
111 |
+
target = self.prepare_initialization(index, target)
|
112 |
+
|
113 |
+
return target
|
lib/data/datasets/mixed_dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .amass import AMASSDataset
|
9 |
+
from .videos import Human36M, ThreeDPW, MPII3D, InstaVariety
|
10 |
+
from .bedlam import BEDLAMDataset
|
11 |
+
from lib.utils.data_utils import make_collate_fn
|
12 |
+
|
13 |
+
|
14 |
+
class DataFactory(torch.utils.data.Dataset):
|
15 |
+
def __init__(self, cfg, train_stage='syn'):
|
16 |
+
super(DataFactory, self).__init__()
|
17 |
+
|
18 |
+
if train_stage == 'stage1':
|
19 |
+
self.datasets = [AMASSDataset(cfg)]
|
20 |
+
self.dataset_names = ['AMASS']
|
21 |
+
elif train_stage == 'stage2':
|
22 |
+
self.datasets = [
|
23 |
+
AMASSDataset(cfg), ThreeDPW(cfg),
|
24 |
+
Human36M(cfg), MPII3D(cfg), InstaVariety(cfg)
|
25 |
+
]
|
26 |
+
self.dataset_names = ['AMASS', '3DPW', 'Human36M', 'MPII3D', 'Insta']
|
27 |
+
|
28 |
+
if len(cfg.DATASET.RATIO) == 6: # Use BEDLAM
|
29 |
+
self.datasets.append(BEDLAMDataset(cfg))
|
30 |
+
self.dataset_names.append('BEDLAM')
|
31 |
+
|
32 |
+
self._set_partition(cfg.DATASET.RATIO)
|
33 |
+
self.lengths = [len(ds) for ds in self.datasets]
|
34 |
+
|
35 |
+
@property
|
36 |
+
def __name__(self, ):
|
37 |
+
return 'MixedData'
|
38 |
+
|
39 |
+
def prepare_video_batch(self):
|
40 |
+
[ds.prepare_video_batch() for ds in self.datasets]
|
41 |
+
self.lengths = [len(ds) for ds in self.datasets]
|
42 |
+
|
43 |
+
def _set_partition(self, partition):
|
44 |
+
self.partition = partition
|
45 |
+
self.ratio = partition
|
46 |
+
self.partition = np.array(self.partition).cumsum()
|
47 |
+
self.partition /= self.partition[-1]
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return int(np.array([l for l, r in zip(self.lengths, self.ratio) if r > 0]).mean())
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
# Get the dataset to sample from
|
54 |
+
p = np.random.rand()
|
55 |
+
for i in range(len(self.datasets)):
|
56 |
+
if p <= self.partition[i]:
|
57 |
+
if len(self.datasets) == 1:
|
58 |
+
return self.datasets[i][index % self.lengths[i]]
|
59 |
+
else:
|
60 |
+
d_index = np.random.randint(0, self.lengths[i])
|
61 |
+
return self.datasets[i][d_index]
|
lib/data/datasets/videos.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from configs import constants as _C
|
9 |
+
from .dataset3d import Dataset3D
|
10 |
+
from .dataset2d import Dataset2D
|
11 |
+
from ...utils.kp_utils import convert_kps
|
12 |
+
from smplx import SMPL
|
13 |
+
|
14 |
+
|
15 |
+
class Human36M(Dataset3D):
|
16 |
+
def __init__(self, cfg, dset='train'):
|
17 |
+
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth')
|
18 |
+
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
|
19 |
+
super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train')
|
20 |
+
|
21 |
+
self.has_3d = True
|
22 |
+
self.has_traj = True
|
23 |
+
self.has_smpl = False
|
24 |
+
self.has_verts = False
|
25 |
+
|
26 |
+
# Among 31 joints format, 14 common joints are avaialable
|
27 |
+
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
|
28 |
+
self.mask[-14:] = 1
|
29 |
+
|
30 |
+
@property
|
31 |
+
def __name__(self, ):
|
32 |
+
return 'Human36M'
|
33 |
+
|
34 |
+
def compute_3d_keypoints(self, index):
|
35 |
+
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
|
36 |
+
)[:, _C.KEYPOINTS.H36M_TO_J14].float()
|
37 |
+
|
38 |
+
class MPII3D(Dataset3D):
|
39 |
+
def __init__(self, cfg, dset='train'):
|
40 |
+
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth')
|
41 |
+
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
|
42 |
+
super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train')
|
43 |
+
|
44 |
+
self.has_3d = True
|
45 |
+
self.has_traj = True
|
46 |
+
self.has_smpl = False
|
47 |
+
self.has_verts = False
|
48 |
+
|
49 |
+
# Among 31 joints format, 14 common joints are avaialable
|
50 |
+
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
|
51 |
+
self.mask[-14:] = 1
|
52 |
+
|
53 |
+
@property
|
54 |
+
def __name__(self, ):
|
55 |
+
return 'MPII3D'
|
56 |
+
|
57 |
+
def compute_3d_keypoints(self, index):
|
58 |
+
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
|
59 |
+
)[:, _C.KEYPOINTS.H36M_TO_J17].float()
|
60 |
+
|
61 |
+
class ThreeDPW(Dataset3D):
|
62 |
+
def __init__(self, cfg, dset='train'):
|
63 |
+
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth')
|
64 |
+
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
|
65 |
+
super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train')
|
66 |
+
|
67 |
+
self.has_3d = True
|
68 |
+
self.has_traj = False
|
69 |
+
self.has_smpl = True
|
70 |
+
self.has_verts = True # In testing
|
71 |
+
|
72 |
+
# Among 31 joints format, 14 common joints are avaialable
|
73 |
+
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
|
74 |
+
self.mask[:-14] = 1
|
75 |
+
|
76 |
+
self.smpl_gender = {
|
77 |
+
0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10),
|
78 |
+
1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10)
|
79 |
+
}
|
80 |
+
|
81 |
+
@property
|
82 |
+
def __name__(self, ):
|
83 |
+
return 'ThreeDPW'
|
84 |
+
|
85 |
+
def compute_3d_keypoints(self, index):
|
86 |
+
return self.labels['joints3D'][index]
|
87 |
+
|
88 |
+
|
89 |
+
class InstaVariety(Dataset2D):
|
90 |
+
def __init__(self, cfg, dset='train'):
|
91 |
+
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth')
|
92 |
+
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
|
93 |
+
super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train')
|
94 |
+
|
95 |
+
self.has_3d = False
|
96 |
+
self.has_traj = False
|
97 |
+
self.has_smpl = False
|
98 |
+
|
99 |
+
# Among 31 joints format, 17 coco joints are avaialable
|
100 |
+
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
|
101 |
+
self.mask[:17] = 1
|
102 |
+
|
103 |
+
@property
|
104 |
+
def __name__(self, ):
|
105 |
+
return 'InstaVariety'
|
lib/data/utils/__pycache__/augmentor.cpython-39.pyc
ADDED
Binary file (10.1 kB). View file
|
|
lib/data/utils/__pycache__/normalizer.cpython-39.pyc
ADDED
Binary file (4.13 kB). View file
|
|
lib/data/utils/augmentor.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
from configs import constants as _C
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from ...utils import transforms
|
12 |
+
|
13 |
+
__all__ = ['VideoAugmentor', 'SMPLAugmentor', 'SequenceAugmentor', 'CameraAugmentor']
|
14 |
+
|
15 |
+
|
16 |
+
num_joints = _C.KEYPOINTS.NUM_JOINTS
|
17 |
+
class VideoAugmentor():
|
18 |
+
def __init__(self, cfg, train=True):
|
19 |
+
self.train = train
|
20 |
+
self.l = cfg.DATASET.SEQLEN + 1
|
21 |
+
self.aug_dict = torch.load(_C.KEYPOINTS.COCO_AUG_DICT)
|
22 |
+
|
23 |
+
def get_jitter(self, ):
|
24 |
+
"""Guassian jitter modeling."""
|
25 |
+
jittering_noise = torch.normal(
|
26 |
+
mean=torch.zeros((self.l, num_joints, 3)),
|
27 |
+
std=self.aug_dict['jittering'].reshape(1, num_joints, 1).expand(self.l, -1, 3)
|
28 |
+
) * _C.KEYPOINTS.S_JITTERING
|
29 |
+
return jittering_noise
|
30 |
+
|
31 |
+
def get_lfhp(self, ):
|
32 |
+
"""Low-frequency high-peak noise modeling."""
|
33 |
+
def get_peak_noise_mask():
|
34 |
+
peak_noise_mask = torch.rand(self.l, num_joints).float() * self.aug_dict['pmask'].squeeze(0)
|
35 |
+
peak_noise_mask = peak_noise_mask < _C.KEYPOINTS.S_PEAK_MASK
|
36 |
+
return peak_noise_mask
|
37 |
+
|
38 |
+
peak_noise_mask = get_peak_noise_mask()
|
39 |
+
peak_noise = peak_noise_mask.float().unsqueeze(-1).repeat(1, 1, 3)
|
40 |
+
peak_noise = peak_noise * torch.randn(3) * self.aug_dict['peak'].reshape(1, -1, 1) * _C.KEYPOINTS.S_PEAK
|
41 |
+
return peak_noise
|
42 |
+
|
43 |
+
def get_bias(self, ):
|
44 |
+
"""Bias noise modeling."""
|
45 |
+
bias_noise = torch.normal(
|
46 |
+
mean=torch.zeros((num_joints, 3)), std=self.aug_dict['bias'].reshape(num_joints, 1)
|
47 |
+
).unsqueeze(0) * _C.KEYPOINTS.S_BIAS
|
48 |
+
return bias_noise
|
49 |
+
|
50 |
+
def get_mask(self, scale=None):
|
51 |
+
"""Mask modeling."""
|
52 |
+
|
53 |
+
if scale is None:
|
54 |
+
scale = _C.KEYPOINTS.S_MASK
|
55 |
+
# Per-frame and joint
|
56 |
+
mask = torch.rand(self.l, num_joints) < scale
|
57 |
+
visible = (~mask).clone()
|
58 |
+
for child in range(num_joints):
|
59 |
+
parent = _C.KEYPOINTS.TREE[child]
|
60 |
+
if parent == -1: continue
|
61 |
+
if isinstance(parent, list):
|
62 |
+
visible[:, child] *= (visible[:, parent[0]] * visible[:, parent[1]])
|
63 |
+
else:
|
64 |
+
visible[:, child] *= visible[:, parent]
|
65 |
+
mask = (~visible).clone()
|
66 |
+
|
67 |
+
return mask
|
68 |
+
|
69 |
+
def __call__(self, keypoints):
|
70 |
+
keypoints += self.get_bias() + self.get_jitter() + self.get_lfhp()
|
71 |
+
return keypoints
|
72 |
+
|
73 |
+
|
74 |
+
class SMPLAugmentor():
|
75 |
+
noise_scale = 1e-2
|
76 |
+
|
77 |
+
def __init__(self, cfg, augment=True):
|
78 |
+
self.n_frames = cfg.DATASET.SEQLEN
|
79 |
+
self.augment = augment
|
80 |
+
|
81 |
+
def __call__(self, target):
|
82 |
+
if not self.augment:
|
83 |
+
# Only add initial frame augmentation
|
84 |
+
if not 'init_pose' in target:
|
85 |
+
target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation()
|
86 |
+
return target
|
87 |
+
|
88 |
+
n_frames = target['pose'].shape[0]
|
89 |
+
|
90 |
+
# Global rotation
|
91 |
+
rmat = self.get_global_augmentation()
|
92 |
+
target['pose'][:, 0] = rmat @ target['pose'][:, 0]
|
93 |
+
target['transl'] = (rmat.squeeze() @ target['transl'].T).T
|
94 |
+
|
95 |
+
# Shape
|
96 |
+
shape_noise = self.get_shape_augmentation(n_frames)
|
97 |
+
target['betas'] = target['betas'] + shape_noise
|
98 |
+
|
99 |
+
# Initial frames mis-prediction
|
100 |
+
target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation()
|
101 |
+
|
102 |
+
return target
|
103 |
+
|
104 |
+
def get_global_augmentation(self, ):
|
105 |
+
"""Global coordinate augmentation. Random rotation around y-axis"""
|
106 |
+
|
107 |
+
angle_y = torch.rand(1) * 2 * np.pi * float(self.augment)
|
108 |
+
aa = torch.tensor([0.0, angle_y, 0.0]).float().unsqueeze(0)
|
109 |
+
rmat = transforms.axis_angle_to_matrix(aa)
|
110 |
+
|
111 |
+
return rmat
|
112 |
+
|
113 |
+
def get_shape_augmentation(self, n_frames):
|
114 |
+
"""Shape noise modeling."""
|
115 |
+
|
116 |
+
shape_noise = torch.normal(
|
117 |
+
mean=torch.zeros((1, 10)),
|
118 |
+
std=torch.ones((1, 10)) * 0.1 * float(self.augment)).expand(n_frames, 10)
|
119 |
+
|
120 |
+
return shape_noise
|
121 |
+
|
122 |
+
def get_initial_pose_augmentation(self, ):
|
123 |
+
"""Initial frame pose noise modeling. Random rotation around all joints."""
|
124 |
+
|
125 |
+
euler = torch.normal(
|
126 |
+
mean=torch.zeros((24, 3)),
|
127 |
+
std=torch.ones((24, 3))
|
128 |
+
) * self.noise_scale #* float(self.augment)
|
129 |
+
rmat = transforms.axis_angle_to_matrix(euler)
|
130 |
+
|
131 |
+
return rmat.unsqueeze(0)
|
132 |
+
|
133 |
+
|
134 |
+
class SequenceAugmentor:
|
135 |
+
"""Augment the play speed of the motion sequence"""
|
136 |
+
l_factor = 1.5
|
137 |
+
def __init__(self, l_default):
|
138 |
+
self.l_default = l_default
|
139 |
+
|
140 |
+
def __call__(self, target):
|
141 |
+
l = torch.randint(low=int(self.l_default / self.l_factor), high=int(self.l_default * self.l_factor), size=(1, ))
|
142 |
+
|
143 |
+
pose = transforms.matrix_to_rotation_6d(target['pose'])
|
144 |
+
resampled_pose = F.interpolate(
|
145 |
+
pose[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True
|
146 |
+
).permute(2, 0, 1)
|
147 |
+
resampled_pose = transforms.rotation_6d_to_matrix(resampled_pose)
|
148 |
+
|
149 |
+
transl = target['transl'].unsqueeze(1)
|
150 |
+
resampled_transl = F.interpolate(
|
151 |
+
transl[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True
|
152 |
+
).squeeze(0).T
|
153 |
+
|
154 |
+
target['pose'] = resampled_pose
|
155 |
+
target['transl'] = resampled_transl
|
156 |
+
target['betas'] = target['betas'][:self.l_default]
|
157 |
+
|
158 |
+
return target
|
159 |
+
|
160 |
+
|
161 |
+
class CameraAugmentor:
|
162 |
+
rx_factor = np.pi/8
|
163 |
+
ry_factor = np.pi/4
|
164 |
+
rz_factor = np.pi/8
|
165 |
+
|
166 |
+
pitch_std = np.pi/8
|
167 |
+
pitch_mean = np.pi/36
|
168 |
+
roll_std = np.pi/24
|
169 |
+
t_factor = 1
|
170 |
+
|
171 |
+
tz_scale = 10
|
172 |
+
tz_min = 2
|
173 |
+
|
174 |
+
motion_prob = 0.75
|
175 |
+
interp_noise = 0.2
|
176 |
+
|
177 |
+
def __init__(self, l, w, h, f):
|
178 |
+
self.l = l
|
179 |
+
self.w = w
|
180 |
+
self.h = h
|
181 |
+
self.f = f
|
182 |
+
self.fov_tol = 1.2 * (0.5 ** 0.5)
|
183 |
+
|
184 |
+
def __call__(self, target):
|
185 |
+
|
186 |
+
R, T = self.create_camera(target)
|
187 |
+
|
188 |
+
if np.random.rand() < self.motion_prob:
|
189 |
+
R = self.create_rotation_move(R)
|
190 |
+
T = self.create_translation_move(T)
|
191 |
+
|
192 |
+
return self.apply(target, R, T)
|
193 |
+
|
194 |
+
def create_camera(self, target):
|
195 |
+
"""Create the initial frame camera pose"""
|
196 |
+
yaw = np.random.rand() * 2 * np.pi
|
197 |
+
pitch = np.random.normal(scale=self.pitch_std) + self.pitch_mean
|
198 |
+
roll = np.random.normal(scale=self.roll_std)
|
199 |
+
|
200 |
+
yaw_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, yaw, 0]]).float())
|
201 |
+
pitch_rm = transforms.axis_angle_to_matrix(torch.tensor([[pitch, 0, 0]]).float())
|
202 |
+
roll_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, 0, roll]]).float())
|
203 |
+
R = (roll_rm @ pitch_rm @ yaw_rm)
|
204 |
+
|
205 |
+
# Place people in the scene
|
206 |
+
tz = np.random.rand() * self.tz_scale + self.tz_min
|
207 |
+
max_d = self.w * tz / self.f / 2
|
208 |
+
tx = np.random.normal(scale=0.25) * max_d
|
209 |
+
ty = np.random.normal(scale=0.25) * max_d
|
210 |
+
dist = torch.tensor([tx, ty, tz]).float()
|
211 |
+
T = dist - torch.matmul(R, target['transl'][0])
|
212 |
+
|
213 |
+
return R.repeat(self.l, 1, 1), T.repeat(self.l, 1)
|
214 |
+
|
215 |
+
def create_rotation_move(self, R):
|
216 |
+
"""Create rotational move for the camera"""
|
217 |
+
|
218 |
+
# Create final camera pose
|
219 |
+
rx = np.random.normal(scale=self.rx_factor)
|
220 |
+
ry = np.random.normal(scale=self.ry_factor)
|
221 |
+
rz = np.random.normal(scale=self.rz_factor)
|
222 |
+
Rf = R[0] @ transforms.axis_angle_to_matrix(torch.tensor([rx, ry, rz]).float())
|
223 |
+
|
224 |
+
# Inbetweening two poses
|
225 |
+
Rs = torch.stack((R[0], Rf))
|
226 |
+
rs = transforms.matrix_to_rotation_6d(Rs).numpy()
|
227 |
+
rs_move = self.noisy_interpolation(rs)
|
228 |
+
R_move = transforms.rotation_6d_to_matrix(torch.from_numpy(rs_move).float())
|
229 |
+
return R_move
|
230 |
+
|
231 |
+
def create_translation_move(self, T):
|
232 |
+
"""Create translational move for the camera"""
|
233 |
+
|
234 |
+
# Create final camera position
|
235 |
+
tx = np.random.normal(scale=self.t_factor)
|
236 |
+
ty = np.random.normal(scale=self.t_factor)
|
237 |
+
tz = np.random.normal(scale=self.t_factor)
|
238 |
+
Ts = np.array([[0, 0, 0], [tx, ty, tz]])
|
239 |
+
|
240 |
+
T_move = self.noisy_interpolation(Ts)
|
241 |
+
T_move = torch.from_numpy(T_move).float()
|
242 |
+
return T_move + T
|
243 |
+
|
244 |
+
def noisy_interpolation(self, data):
|
245 |
+
"""Non-linear interpolation with noise"""
|
246 |
+
|
247 |
+
dim = data.shape[-1]
|
248 |
+
output = np.zeros((self.l, dim))
|
249 |
+
|
250 |
+
linspace = np.stack([np.linspace(0, 1, self.l) for _ in range(dim)])
|
251 |
+
noise = (linspace[0, 1] - linspace[0, 0]) * self.interp_noise
|
252 |
+
space_noise = np.stack([np.random.uniform(-noise, noise, self.l - 2) for _ in range(dim)])
|
253 |
+
|
254 |
+
linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise
|
255 |
+
for i in range(dim):
|
256 |
+
output[:, i] = np.interp(linspace[i], np.array([0., 1.,]), data[:, i])
|
257 |
+
return output
|
258 |
+
|
259 |
+
def apply(self, target, R, T):
|
260 |
+
target['R'] = R
|
261 |
+
target['T'] = T
|
262 |
+
|
263 |
+
# Recompute the translation
|
264 |
+
transl_cam = torch.matmul(R, target['transl'].unsqueeze(-1)).squeeze(-1)
|
265 |
+
transl_cam = transl_cam + T
|
266 |
+
if transl_cam[..., 2].min() < 0.5: # If the person is too close to the camera
|
267 |
+
transl_cam[..., 2] = transl_cam[..., 2] + (1.0 - transl_cam[..., 2].min())
|
268 |
+
|
269 |
+
# If the subject is away from the field of view, put the camera behind
|
270 |
+
fov = torch.div(transl_cam[..., :2], transl_cam[..., 2:]).abs()
|
271 |
+
if fov.max() > self.fov_tol:
|
272 |
+
t_max = transl_cam[fov.max(1)[0].max(0)[1].item()]
|
273 |
+
z_trg = t_max[:2].abs().max(0)[0] / self.fov_tol
|
274 |
+
pad = z_trg - t_max[2]
|
275 |
+
transl_cam[..., 2] = transl_cam[..., 2] + pad
|
276 |
+
|
277 |
+
target['transl_cam'] = transl_cam
|
278 |
+
|
279 |
+
# Transform world coordinate to camera coordinate
|
280 |
+
target['pose_root'] = target['pose'][:, 0].clone()
|
281 |
+
target['pose'][:, 0] = R @ target['pose'][:, 0] # pose
|
282 |
+
target['init_pose'][:, 0] = R[:1] @ target['init_pose'][:, 0] # init pose
|
283 |
+
|
284 |
+
# Compute angular velocity
|
285 |
+
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
|
286 |
+
cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
|
287 |
+
target['cam_angvel'] = cam_angvel * 3e1 # assume 30-fps
|
288 |
+
|
289 |
+
if 'kp3d' in target:
|
290 |
+
target['kp3d'] = torch.matmul(R, target['kp3d'].transpose(1, 2)).transpose(1, 2) + target['transl_cam'].unsqueeze(1)
|
291 |
+
|
292 |
+
return target
|
lib/data/utils/normalizer.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from ...utils.imutils import transform_keypoints
|
5 |
+
|
6 |
+
class Normalizer:
|
7 |
+
def __init__(self, cfg):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def __call__(self, kp_2d, res, cam_intrinsics, patch_width=224, patch_height=224, bbox=None, mask=None):
|
11 |
+
if bbox is None:
|
12 |
+
bbox = compute_bbox_from_keypoints(kp_2d, do_augment=True, mask=mask)
|
13 |
+
|
14 |
+
out_kp_2d = self.bbox_normalization(kp_2d, bbox, res, patch_width, patch_height)
|
15 |
+
return out_kp_2d, bbox
|
16 |
+
|
17 |
+
def bbox_normalization(self, kp_2d, bbox, res, patch_width, patch_height):
|
18 |
+
to_torch = False
|
19 |
+
if isinstance(kp_2d, torch.Tensor):
|
20 |
+
to_torch = True
|
21 |
+
kp_2d = kp_2d.numpy()
|
22 |
+
bbox = bbox.numpy()
|
23 |
+
|
24 |
+
out_kp_2d = np.zeros_like(kp_2d)
|
25 |
+
for idx in range(len(out_kp_2d)):
|
26 |
+
out_kp_2d[idx] = transform_keypoints(kp_2d[idx], bbox[idx][:3], patch_width, patch_height)[0]
|
27 |
+
out_kp_2d[idx] = normalize_keypoints_to_patch(out_kp_2d[idx], patch_width)
|
28 |
+
|
29 |
+
if to_torch:
|
30 |
+
out_kp_2d = torch.from_numpy(out_kp_2d)
|
31 |
+
bbox = torch.from_numpy(bbox)
|
32 |
+
|
33 |
+
centers = normalize_keypoints_to_image(bbox[:, :2].unsqueeze(1), res).squeeze(1)
|
34 |
+
scale = bbox[:, 2:] * 200 / res.max()
|
35 |
+
location = torch.cat((centers, scale), dim=-1)
|
36 |
+
|
37 |
+
out_kp_2d = out_kp_2d.reshape(out_kp_2d.shape[0], -1)
|
38 |
+
out_kp_2d = torch.cat((out_kp_2d, location), dim=-1)
|
39 |
+
return out_kp_2d
|
40 |
+
|
41 |
+
|
42 |
+
def normalize_keypoints_to_patch(kp_2d, crop_size=224, inv=False):
|
43 |
+
# Normalize keypoints between -1, 1
|
44 |
+
if not inv:
|
45 |
+
ratio = 1.0 / crop_size
|
46 |
+
kp_2d = 2.0 * kp_2d * ratio - 1.0
|
47 |
+
else:
|
48 |
+
ratio = 1.0 / crop_size
|
49 |
+
kp_2d = (kp_2d + 1.0)/(2*ratio)
|
50 |
+
|
51 |
+
return kp_2d
|
52 |
+
|
53 |
+
|
54 |
+
def normalize_keypoints_to_image(x, res):
|
55 |
+
res = res.to(x.device)
|
56 |
+
scale = res.max(-1)[0].reshape(-1)
|
57 |
+
mean = torch.stack([res[..., 0] / scale, res[..., 1] / scale], dim=-1).to(x.device)
|
58 |
+
x = (2 * x / scale.reshape(*[1 for i in range(len(x.shape[1:]))]) - \
|
59 |
+
mean.reshape(*[1 for i in range(len(x.shape[1:-1]))], -1))
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def compute_bbox_from_keypoints(X, do_augment=False, mask=None):
|
64 |
+
def smooth_bbox(bb):
|
65 |
+
# Smooth bounding box detection
|
66 |
+
import scipy.signal as signal
|
67 |
+
smoothed = np.array([signal.medfilt(param, int(30 / 2)) for param in bb])
|
68 |
+
return smoothed
|
69 |
+
|
70 |
+
def do_augmentation(scale_factor=0.2, trans_factor=0.05):
|
71 |
+
_scaleFactor = np.random.uniform(1.0 - scale_factor, 1.2 + scale_factor)
|
72 |
+
_trans_x = np.random.uniform(-trans_factor, trans_factor)
|
73 |
+
_trans_y = np.random.uniform(-trans_factor, trans_factor)
|
74 |
+
|
75 |
+
return _scaleFactor, _trans_x, _trans_y
|
76 |
+
|
77 |
+
if do_augment:
|
78 |
+
scaleFactor, trans_x, trans_y = do_augmentation()
|
79 |
+
else:
|
80 |
+
scaleFactor, trans_x, trans_y = 1.2, 0.0, 0.0
|
81 |
+
|
82 |
+
if mask is None:
|
83 |
+
bbox = [X[:, :, 0].min(-1)[0], X[:, :, 1].min(-1)[0],
|
84 |
+
X[:, :, 0].max(-1)[0], X[:, :, 1].max(-1)[0]]
|
85 |
+
else:
|
86 |
+
bbox = []
|
87 |
+
for x, _mask in zip(X, mask):
|
88 |
+
if _mask.sum() > 10:
|
89 |
+
_mask[:] = False
|
90 |
+
_bbox = [x[~_mask, 0].min(-1)[0], x[~_mask, 1].min(-1)[0],
|
91 |
+
x[~_mask, 0].max(-1)[0], x[~_mask, 1].max(-1)[0]]
|
92 |
+
bbox.append(_bbox)
|
93 |
+
bbox = torch.tensor(bbox).T
|
94 |
+
|
95 |
+
cx, cy = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2]
|
96 |
+
bbox_w = bbox[2] - bbox[0]
|
97 |
+
bbox_h = bbox[3] - bbox[1]
|
98 |
+
bbox_size = torch.stack((bbox_w, bbox_h)).max(0)[0]
|
99 |
+
scale = bbox_size * scaleFactor
|
100 |
+
bbox = torch.stack((cx + trans_x * scale, cy + trans_y * scale, scale / 200))
|
101 |
+
|
102 |
+
if do_augment:
|
103 |
+
bbox = torch.from_numpy(smooth_bbox(bbox.numpy()))
|
104 |
+
|
105 |
+
return bbox.T
|
lib/data_utils/amass_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import joblib
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
from smplx import SMPL
|
14 |
+
|
15 |
+
from configs import constants as _C
|
16 |
+
from lib.utils.data_utils import map_dmpl_to_smpl, transform_global_coordinate
|
17 |
+
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def process_amass():
|
21 |
+
target_fps = 30
|
22 |
+
|
23 |
+
_, seqs, _ = next(os.walk(_C.PATHS.AMASS_PTH))
|
24 |
+
|
25 |
+
zup2ydown = torch.Tensor(
|
26 |
+
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
27 |
+
).unsqueeze(0).float()
|
28 |
+
|
29 |
+
smpl_dict = {'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'),
|
30 |
+
'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'),
|
31 |
+
'neutral': SMPL(model_path=_C.BMODEL.FLDR)}
|
32 |
+
processed_data = defaultdict(list)
|
33 |
+
|
34 |
+
for seq in (seq_bar := tqdm(sorted(seqs), leave=True)):
|
35 |
+
seq_bar.set_description(f'Dataset: {seq}')
|
36 |
+
seq_fldr = osp.join(_C.PATHS.AMASS_PTH, seq)
|
37 |
+
_, subjs, _ = next(os.walk(seq_fldr))
|
38 |
+
|
39 |
+
for subj in (subj_bar := tqdm(sorted(subjs), leave=False)):
|
40 |
+
subj_bar.set_description(f'Subject: {subj}')
|
41 |
+
subj_fldr = osp.join(seq_fldr, subj)
|
42 |
+
acts = [x for x in os.listdir(subj_fldr) if x.endswith('.npz')]
|
43 |
+
|
44 |
+
for act in (act_bar := tqdm(sorted(acts), leave=False)):
|
45 |
+
act_bar.set_description(f'Action: {act}')
|
46 |
+
|
47 |
+
# Load data
|
48 |
+
fname = osp.join(subj_fldr, act)
|
49 |
+
if fname.endswith('shape.npz') or fname.endswith('stagei.npz'):
|
50 |
+
# Skip shape and stagei files
|
51 |
+
continue
|
52 |
+
data = dict(np.load(fname, allow_pickle=True))
|
53 |
+
|
54 |
+
# Resample data to target_fps
|
55 |
+
key = [k for k in data.keys() if 'mocap_frame' in k][0]
|
56 |
+
mocap_framerate = data[key]
|
57 |
+
retain_freq = int(mocap_framerate / target_fps + 0.5)
|
58 |
+
num_frames = len(data['poses'][::retain_freq])
|
59 |
+
|
60 |
+
# Skip if the sequence is too short
|
61 |
+
if num_frames < 25: continue
|
62 |
+
|
63 |
+
# Get SMPL groundtruth from MoSh fitting
|
64 |
+
pose = map_dmpl_to_smpl(torch.from_numpy(data['poses'][::retain_freq]).float())
|
65 |
+
transl = torch.from_numpy(data['trans'][::retain_freq]).float()
|
66 |
+
betas = torch.from_numpy(
|
67 |
+
np.repeat(data['betas'][:10][np.newaxis], pose.shape[0], axis=0)).float()
|
68 |
+
|
69 |
+
# Convert Z-up coordinate to Y-down
|
70 |
+
pose, transl = transform_global_coordinate(pose, zup2ydown, transl)
|
71 |
+
pose = pose.reshape(-1, 72)
|
72 |
+
|
73 |
+
# Create SMPL mesh
|
74 |
+
gender = str(data['gender'])
|
75 |
+
if not gender in ['male', 'female', 'neutral']:
|
76 |
+
if 'female' in gender: gender = 'female'
|
77 |
+
elif 'neutral' in gender: gender = 'neutral'
|
78 |
+
elif 'male' in gender: gender = 'male'
|
79 |
+
|
80 |
+
output = smpl_dict[gender](body_pose=pose[:, 3:],
|
81 |
+
global_orient=pose[:, :3],
|
82 |
+
betas=betas,
|
83 |
+
transl=transl)
|
84 |
+
vertices = output.vertices
|
85 |
+
|
86 |
+
# Assume motion starts with 0-height
|
87 |
+
init_height = vertices[0].max(0)[0][1]
|
88 |
+
transl[:, 1] = transl[:, 1] + init_height
|
89 |
+
vertices[:, :, 1] = vertices[:, :, 1] - init_height
|
90 |
+
|
91 |
+
# Append data
|
92 |
+
processed_data['pose'].append(pose.numpy())
|
93 |
+
processed_data['betas'].append(betas.numpy())
|
94 |
+
processed_data['transl'].append(transl.numpy())
|
95 |
+
processed_data['vid'].append(np.array([f'{seq}_{subj}_{act}'] * pose.shape[0]))
|
96 |
+
|
97 |
+
for key, val in processed_data.items():
|
98 |
+
processed_data[key] = np.concatenate(val)
|
99 |
+
|
100 |
+
joblib.dump(processed_data, _C.PATHS.AMASS_LABEL)
|
101 |
+
print('\nDone!')
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
out_path = '/'.join(_C.PATHS.AMASS_LABEL.split('/')[:-1])
|
105 |
+
os.makedirs(out_path, exist_ok=True)
|
106 |
+
|
107 |
+
process_amass()
|
lib/data_utils/emdb_eval_utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
from glob import glob
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import torch
|
12 |
+
import pickle
|
13 |
+
import joblib
|
14 |
+
import argparse
|
15 |
+
import numpy as np
|
16 |
+
from loguru import logger
|
17 |
+
from progress.bar import Bar
|
18 |
+
|
19 |
+
from configs import constants as _C
|
20 |
+
from lib.models.smpl import SMPL
|
21 |
+
from lib.models.preproc.extractor import FeatureExtractor
|
22 |
+
from lib.models.preproc.backbone.utils import process_image
|
23 |
+
from lib.utils import transforms
|
24 |
+
from lib.utils.imutils import (
|
25 |
+
flip_kp, flip_bbox
|
26 |
+
)
|
27 |
+
|
28 |
+
dataset = defaultdict(list)
|
29 |
+
detection_results_dir = 'dataset/detection_results/EMDB'
|
30 |
+
|
31 |
+
def is_dset(emdb_pkl_file, dset):
|
32 |
+
target_dset = 'emdb' + dset
|
33 |
+
with open(emdb_pkl_file, "rb") as f:
|
34 |
+
data = pickle.load(f)
|
35 |
+
return data[target_dset]
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def preprocess(dset, batch_size):
|
39 |
+
|
40 |
+
tt = lambda x: torch.from_numpy(x).float()
|
41 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
42 |
+
save_pth = osp.join(_C.PATHS.PARSED_DATA, f'emdb_{dset}_vit.pth') # Use ViT feature extractor
|
43 |
+
extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
|
44 |
+
|
45 |
+
all_emdb_pkl_files = sorted(glob(os.path.join(_C.PATHS.EMDB_PTH, "*/*/*_data.pkl")))
|
46 |
+
emdb_sequence_roots = []
|
47 |
+
both = []
|
48 |
+
for emdb_pkl_file in all_emdb_pkl_files:
|
49 |
+
if is_dset(emdb_pkl_file, dset):
|
50 |
+
emdb_sequence_roots.append(os.path.dirname(emdb_pkl_file))
|
51 |
+
|
52 |
+
smpl = {
|
53 |
+
'neutral': SMPL(model_path=_C.BMODEL.FLDR),
|
54 |
+
'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'),
|
55 |
+
'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'),
|
56 |
+
}
|
57 |
+
|
58 |
+
for sequence in emdb_sequence_roots:
|
59 |
+
subj, seq = sequence.split('/')[-2:]
|
60 |
+
annot_pth = glob(osp.join(sequence, '*_data.pkl'))[0]
|
61 |
+
annot = pickle.load(open(annot_pth, 'rb'))
|
62 |
+
|
63 |
+
# Get ground truth data
|
64 |
+
gender = annot['gender']
|
65 |
+
masks = annot['good_frames_mask']
|
66 |
+
poses_body = annot["smpl"]["poses_body"]
|
67 |
+
poses_root = annot["smpl"]["poses_root"]
|
68 |
+
betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0)
|
69 |
+
extrinsics = annot["camera"]["extrinsics"]
|
70 |
+
width, height = annot['camera']['width'], annot['camera']['height']
|
71 |
+
xyxys = annot['bboxes']['bboxes']
|
72 |
+
|
73 |
+
# Map to camear coordinate
|
74 |
+
poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root)))
|
75 |
+
poses = np.concatenate([poses_root_cam, poses_body], axis=-1)
|
76 |
+
|
77 |
+
pred_kp2d = np.load(osp.join(detection_results_dir, f'{subj}_{seq}.npy'))
|
78 |
+
|
79 |
+
# ======== Extract features ======== #
|
80 |
+
imname_list = sorted(glob(osp.join(sequence, 'images/*')))
|
81 |
+
bboxes, frame_ids, patch_list, features, flipped_features = [], [], [], [], []
|
82 |
+
bar = Bar(f'Load images', fill='#', max=len(imname_list))
|
83 |
+
for idx, (imname, xyxy, mask) in enumerate(zip(imname_list, xyxys, masks)):
|
84 |
+
if not mask: continue
|
85 |
+
|
86 |
+
# ========= Load image ========= #
|
87 |
+
img_rgb = cv2.cvtColor(cv2.imread(imname), cv2.COLOR_BGR2RGB)
|
88 |
+
|
89 |
+
# ========= Load bbox ========= #
|
90 |
+
x1, y1, x2, y2 = xyxy
|
91 |
+
bbox = np.array([(x1 + x2)/2., (y1 + y2)/2., max(x2 - x1, y2 - y1) / 1.1])
|
92 |
+
|
93 |
+
# ========= Process image ========= #
|
94 |
+
norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
|
95 |
+
|
96 |
+
patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
|
97 |
+
bboxes.append(bbox)
|
98 |
+
frame_ids.append(idx)
|
99 |
+
bar.next()
|
100 |
+
|
101 |
+
patch_list = torch.split(torch.cat(patch_list), batch_size)
|
102 |
+
bboxes = torch.from_numpy(np.stack(bboxes)).float()
|
103 |
+
for i, patch in enumerate(patch_list):
|
104 |
+
bbox = bboxes[i*batch_size:min((i+1)*batch_size, len(frame_ids))].float().cuda()
|
105 |
+
bbox_center = bbox[:, :2]
|
106 |
+
bbox_scale = bbox[:, 2] / 200
|
107 |
+
|
108 |
+
feature = extractor.model(patch.cuda(), encode=True)
|
109 |
+
features.append(feature.cpu())
|
110 |
+
|
111 |
+
flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True)
|
112 |
+
flipped_features.append(flipped_feature.cpu())
|
113 |
+
|
114 |
+
if i == 0:
|
115 |
+
init_patch = patch[[0]].clone()
|
116 |
+
|
117 |
+
features = torch.cat(features)
|
118 |
+
flipped_features = torch.cat(flipped_features)
|
119 |
+
res_h, res_w = img_rgb.shape[:2]
|
120 |
+
|
121 |
+
# ======== Append data ======== #
|
122 |
+
dataset['gender'].append(gender)
|
123 |
+
dataset['bbox'].append(bboxes)
|
124 |
+
dataset['res'].append(torch.tensor([[width, height]]).repeat(len(frame_ids), 1).float())
|
125 |
+
dataset['vid'].append(f'{subj}_{seq}')
|
126 |
+
dataset['pose'].append(tt(poses)[frame_ids])
|
127 |
+
dataset['betas'].append(tt(betas)[frame_ids])
|
128 |
+
dataset['kp2d'].append(tt(pred_kp2d)[frame_ids])
|
129 |
+
dataset['frame_id'].append(torch.from_numpy(np.array(frame_ids)))
|
130 |
+
dataset['cam_poses'].append(tt(extrinsics)[frame_ids])
|
131 |
+
dataset['features'].append(features)
|
132 |
+
dataset['flipped_features'].append(flipped_features)
|
133 |
+
|
134 |
+
# Flipped data
|
135 |
+
dataset['flipped_bbox'].append(
|
136 |
+
torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float()
|
137 |
+
)
|
138 |
+
dataset['flipped_kp2d'].append(
|
139 |
+
torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float()
|
140 |
+
)
|
141 |
+
# ======== Append data ======== #
|
142 |
+
|
143 |
+
# Pad 1 frame
|
144 |
+
for key, val in dataset.items():
|
145 |
+
if isinstance(val[-1], torch.Tensor):
|
146 |
+
dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0)
|
147 |
+
|
148 |
+
# Initial predictions
|
149 |
+
bbox = bboxes[:1].clone().cuda()
|
150 |
+
bbox_center = bbox[:, :2].clone()
|
151 |
+
bbox_scale = bbox[:, 2].clone() / 200
|
152 |
+
kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(),
|
153 |
+
'img_h': torch.tensor(res_h).repeat(1).float().cuda(),
|
154 |
+
'bbox_center': bbox_center, 'bbox_scale': bbox_scale}
|
155 |
+
|
156 |
+
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs)
|
157 |
+
pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(),
|
158 |
+
body_pose=pred_pose.cpu(),
|
159 |
+
betas=pred_shape.cpu(),
|
160 |
+
pose2rot=False)
|
161 |
+
init_kp3d = pred_output.joints
|
162 |
+
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
|
163 |
+
|
164 |
+
dataset['init_kp3d'].append(init_kp3d)
|
165 |
+
dataset['init_pose'].append(init_pose.cpu())
|
166 |
+
|
167 |
+
# Flipped initial predictions
|
168 |
+
bbox_center[:, 0] = res_w - bbox_center[:, 0]
|
169 |
+
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs)
|
170 |
+
pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(),
|
171 |
+
body_pose=pred_pose.cpu(),
|
172 |
+
betas=pred_shape.cpu(),
|
173 |
+
pose2rot=False)
|
174 |
+
init_kp3d = pred_output.joints
|
175 |
+
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
|
176 |
+
|
177 |
+
dataset['flipped_init_kp3d'].append(init_kp3d)
|
178 |
+
dataset['flipped_init_pose'].append(init_pose.cpu())
|
179 |
+
|
180 |
+
joblib.dump(dataset, save_pth)
|
181 |
+
logger.info(f'==> Done !')
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
parser = argparse.ArgumentParser()
|
185 |
+
parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split')
|
186 |
+
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
|
187 |
+
args = parser.parse_args()
|
188 |
+
|
189 |
+
preprocess(args.split, args.batch_size)
|
lib/data_utils/rich_eval_utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
from glob import glob
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import torch
|
12 |
+
import pickle
|
13 |
+
import joblib
|
14 |
+
import argparse
|
15 |
+
import numpy as np
|
16 |
+
from loguru import logger
|
17 |
+
from progress.bar import Bar
|
18 |
+
|
19 |
+
from configs import constants as _C
|
20 |
+
from lib.models.smpl import SMPL
|
21 |
+
from lib.models.preproc.extractor import FeatureExtractor
|
22 |
+
from lib.models.preproc.backbone.utils import process_image
|
23 |
+
from lib.utils import transforms
|
24 |
+
from lib.utils.imutils import (
|
25 |
+
flip_kp, flip_bbox
|
26 |
+
)
|
27 |
+
|
28 |
+
dataset = defaultdict(list)
|
29 |
+
detection_results_dir = 'dataset/detection_results/RICH'
|
30 |
+
|
31 |
+
def extract_cam_param_xml(xml_path='', dtype=torch.float32):
|
32 |
+
|
33 |
+
import xml.etree.ElementTree as ET
|
34 |
+
tree = ET.parse(xml_path)
|
35 |
+
|
36 |
+
extrinsics_mat = [float(s) for s in tree.find('./CameraMatrix/data').text.split()]
|
37 |
+
intrinsics_mat = [float(s) for s in tree.find('./Intrinsics/data').text.split()]
|
38 |
+
# distortion_vec = [float(s) for s in tree.find('./Distortion/data').text.split()]
|
39 |
+
|
40 |
+
focal_length_x = intrinsics_mat[0]
|
41 |
+
focal_length_y = intrinsics_mat[4]
|
42 |
+
center = torch.tensor([[intrinsics_mat[2], intrinsics_mat[5]]], dtype=dtype)
|
43 |
+
|
44 |
+
rotation = torch.tensor([[extrinsics_mat[0], extrinsics_mat[1], extrinsics_mat[2]],
|
45 |
+
[extrinsics_mat[4], extrinsics_mat[5], extrinsics_mat[6]],
|
46 |
+
[extrinsics_mat[8], extrinsics_mat[9], extrinsics_mat[10]]], dtype=dtype)
|
47 |
+
|
48 |
+
translation = torch.tensor([[extrinsics_mat[3], extrinsics_mat[7], extrinsics_mat[11]]], dtype=dtype)
|
49 |
+
|
50 |
+
# t = -Rc --> c = -R^Tt
|
51 |
+
cam_center = [ -extrinsics_mat[0]*extrinsics_mat[3] - extrinsics_mat[4]*extrinsics_mat[7] - extrinsics_mat[8]*extrinsics_mat[11],
|
52 |
+
-extrinsics_mat[1]*extrinsics_mat[3] - extrinsics_mat[5]*extrinsics_mat[7] - extrinsics_mat[9]*extrinsics_mat[11],
|
53 |
+
-extrinsics_mat[2]*extrinsics_mat[3] - extrinsics_mat[6]*extrinsics_mat[7] - extrinsics_mat[10]*extrinsics_mat[11]]
|
54 |
+
|
55 |
+
cam_center = torch.tensor([cam_center], dtype=dtype)
|
56 |
+
|
57 |
+
return focal_length_x, focal_length_y, center, rotation, translation, cam_center
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def preprocess(dset, batch_size):
|
61 |
+
import pdb; pdb.set_trace()
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split')
|
66 |
+
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
preprocess(args.split, args.batch_size)
|
lib/data_utils/threedpw_eval_utils.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
from glob import glob
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import pickle
|
12 |
+
import joblib
|
13 |
+
import argparse
|
14 |
+
import numpy as np
|
15 |
+
from loguru import logger
|
16 |
+
from progress.bar import Bar
|
17 |
+
|
18 |
+
from configs import constants as _C
|
19 |
+
from lib.models.smpl import SMPL
|
20 |
+
from lib.models.preproc.extractor import FeatureExtractor
|
21 |
+
from lib.models.preproc.backbone.utils import process_image
|
22 |
+
from lib.utils import transforms
|
23 |
+
from lib.utils.imutils import (
|
24 |
+
flip_kp, flip_bbox
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
dataset = defaultdict(list)
|
29 |
+
detection_results_dir = 'dataset/detection_results/3DPW'
|
30 |
+
tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_dset_db.pt'
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def preprocess(dset, batch_size):
|
34 |
+
|
35 |
+
if dset == 'val': _dset = 'validation'
|
36 |
+
else: _dset = dset
|
37 |
+
|
38 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
39 |
+
save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_{dset}_vit.pth') # Use ViT feature extractor
|
40 |
+
extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
|
41 |
+
|
42 |
+
tcmr_data = joblib.load(tcmr_annot_pth.replace('dset', dset))
|
43 |
+
smpl_neutral = SMPL(model_path=_C.BMODEL.FLDR)
|
44 |
+
|
45 |
+
annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True)
|
46 |
+
idxs = idxs.tolist()
|
47 |
+
annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)]
|
48 |
+
annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', _dset, annot_file[:-2] + '.pkl') for annot_file in annot_file_list]
|
49 |
+
annot_file_list = list(dict.fromkeys(annot_file_list))
|
50 |
+
|
51 |
+
for annot_file in annot_file_list:
|
52 |
+
seq = annot_file.split('/')[-1].split('.')[0]
|
53 |
+
|
54 |
+
data = pickle.load(open(annot_file, 'rb'), encoding='latin1')
|
55 |
+
|
56 |
+
num_people = len(data['poses'])
|
57 |
+
num_frames = len(data['img_frame_ids'])
|
58 |
+
assert (data['poses2d'][0].shape[0] == num_frames)
|
59 |
+
|
60 |
+
K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float()
|
61 |
+
|
62 |
+
for p_id in range(num_people):
|
63 |
+
|
64 |
+
logger.info(f'==> {seq} {p_id}')
|
65 |
+
gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]]
|
66 |
+
|
67 |
+
# ======== Add TCMR data ======== #
|
68 |
+
vid_name = f'{seq}_{p_id}'
|
69 |
+
tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v]
|
70 |
+
frame_ids = tcmr_data['frame_id'][tcmr_ids]
|
71 |
+
|
72 |
+
pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids]
|
73 |
+
shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1)
|
74 |
+
pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() # Camera coordinate
|
75 |
+
cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float()
|
76 |
+
|
77 |
+
# ======== Get detection results ======== #
|
78 |
+
fname = f'{seq}_{p_id}.npy'
|
79 |
+
pred_kp2d = torch.from_numpy(
|
80 |
+
np.load(osp.join(detection_results_dir, fname))
|
81 |
+
).float()[frame_ids]
|
82 |
+
# ======== Get detection results ======== #
|
83 |
+
|
84 |
+
img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg')))
|
85 |
+
img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids]
|
86 |
+
img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2]
|
87 |
+
vid_idxs = fname.split('.')[0]
|
88 |
+
|
89 |
+
# ======== Append data ======== #
|
90 |
+
dataset['gender'].append(gender)
|
91 |
+
dataset['vid'].append(vid_idxs)
|
92 |
+
dataset['pose'].append(pose)
|
93 |
+
dataset['betas'].append(shape)
|
94 |
+
dataset['cam_poses'].append(cam_poses)
|
95 |
+
dataset['frame_id'].append(torch.from_numpy(frame_ids))
|
96 |
+
dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float())
|
97 |
+
dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float())
|
98 |
+
dataset['kp2d'].append(pred_kp2d)
|
99 |
+
|
100 |
+
# Flipped data
|
101 |
+
dataset['flipped_bbox'].append(
|
102 |
+
torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float()
|
103 |
+
)
|
104 |
+
dataset['flipped_kp2d'].append(
|
105 |
+
torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float()
|
106 |
+
)
|
107 |
+
# ======== Append data ======== #
|
108 |
+
|
109 |
+
# ======== Extract features ======== #
|
110 |
+
patch_list = []
|
111 |
+
bboxes = dataset['bbox'][-1].clone().numpy()
|
112 |
+
bar = Bar(f'Load images', fill='#', max=len(img_paths))
|
113 |
+
|
114 |
+
for img_path, bbox in zip(img_paths, bboxes):
|
115 |
+
img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
116 |
+
norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
|
117 |
+
patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
|
118 |
+
bar.next()
|
119 |
+
|
120 |
+
patch_list = torch.split(torch.cat(patch_list), batch_size)
|
121 |
+
features, flipped_features = [], []
|
122 |
+
for i, patch in enumerate(patch_list):
|
123 |
+
feature = extractor.model(patch.cuda(), encode=True)
|
124 |
+
features.append(feature.cpu())
|
125 |
+
|
126 |
+
flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True)
|
127 |
+
flipped_features.append(flipped_feature.cpu())
|
128 |
+
|
129 |
+
if i == 0:
|
130 |
+
init_patch = patch[[0]].clone()
|
131 |
+
|
132 |
+
features = torch.cat(features)
|
133 |
+
flipped_features = torch.cat(flipped_features)
|
134 |
+
dataset['features'].append(features)
|
135 |
+
dataset['flipped_features'].append(flipped_features)
|
136 |
+
# ======== Extract features ======== #
|
137 |
+
|
138 |
+
# Pad 1 frame
|
139 |
+
for key, val in dataset.items():
|
140 |
+
if isinstance(val[-1], torch.Tensor):
|
141 |
+
dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0)
|
142 |
+
|
143 |
+
# Initial predictions
|
144 |
+
bbox = torch.from_numpy(bboxes[:1].copy()).float().cuda()
|
145 |
+
bbox_center = bbox[:, :2].clone()
|
146 |
+
bbox_scale = bbox[:, 2].clone() / 200
|
147 |
+
kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(),
|
148 |
+
'img_h': torch.tensor(res_h).repeat(1).float().cuda(),
|
149 |
+
'bbox_center': bbox_center, 'bbox_scale': bbox_scale}
|
150 |
+
|
151 |
+
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs)
|
152 |
+
pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(),
|
153 |
+
body_pose=pred_pose.cpu(),
|
154 |
+
betas=pred_shape.cpu(),
|
155 |
+
pose2rot=False)
|
156 |
+
init_kp3d = pred_output.joints
|
157 |
+
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
|
158 |
+
|
159 |
+
dataset['init_kp3d'].append(init_kp3d)
|
160 |
+
dataset['init_pose'].append(init_pose.cpu())
|
161 |
+
|
162 |
+
# Flipped initial predictions
|
163 |
+
bbox_center[:, 0] = res_w - bbox_center[:, 0]
|
164 |
+
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs)
|
165 |
+
pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(),
|
166 |
+
body_pose=pred_pose.cpu(),
|
167 |
+
betas=pred_shape.cpu(),
|
168 |
+
pose2rot=False)
|
169 |
+
init_kp3d = pred_output.joints
|
170 |
+
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
|
171 |
+
|
172 |
+
dataset['flipped_init_kp3d'].append(init_kp3d)
|
173 |
+
dataset['flipped_init_pose'].append(init_pose.cpu())
|
174 |
+
|
175 |
+
joblib.dump(dataset, save_pth)
|
176 |
+
logger.info(f'\n ==> Done !')
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
parser = argparse.ArgumentParser()
|
181 |
+
parser.add_argument('-s', '--split', type=str, choices=['val', 'test'], help='Data split')
|
182 |
+
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
|
183 |
+
args = parser.parse_args()
|
184 |
+
|
185 |
+
preprocess(args.split, args.batch_size)
|
lib/data_utils/threedpw_train_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
from glob import glob
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import pickle
|
12 |
+
import joblib
|
13 |
+
import argparse
|
14 |
+
import numpy as np
|
15 |
+
from loguru import logger
|
16 |
+
from progress.bar import Bar
|
17 |
+
|
18 |
+
from configs import constants as _C
|
19 |
+
from lib.models.smpl import SMPL
|
20 |
+
from lib.models.preproc.extractor import FeatureExtractor
|
21 |
+
from lib.models.preproc.backbone.utils import process_image
|
22 |
+
|
23 |
+
dataset = defaultdict(list)
|
24 |
+
detection_results_dir = 'dataset/detection_results/3DPW'
|
25 |
+
tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_train_db.pt'
|
26 |
+
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def preprocess(batch_size):
|
30 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
31 |
+
save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_train_vit.pth') # Use ViT feature extractor
|
32 |
+
extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
|
33 |
+
|
34 |
+
tcmr_data = joblib.load(tcmr_annot_pth)
|
35 |
+
|
36 |
+
annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True)
|
37 |
+
idxs = idxs.tolist()
|
38 |
+
annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)]
|
39 |
+
annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', 'train', annot_file[:-2] + '.pkl') for annot_file in annot_file_list]
|
40 |
+
annot_file_list = list(dict.fromkeys(annot_file_list))
|
41 |
+
|
42 |
+
vid_idx = 0
|
43 |
+
for annot_file in annot_file_list:
|
44 |
+
seq = annot_file.split('/')[-1].split('.')[0]
|
45 |
+
|
46 |
+
data = pickle.load(open(annot_file, 'rb'), encoding='latin1')
|
47 |
+
|
48 |
+
num_people = len(data['poses'])
|
49 |
+
num_frames = len(data['img_frame_ids'])
|
50 |
+
assert (data['poses2d'][0].shape[0] == num_frames)
|
51 |
+
|
52 |
+
K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float()
|
53 |
+
|
54 |
+
for p_id in range(num_people):
|
55 |
+
|
56 |
+
logger.info(f'==> {seq} {p_id}')
|
57 |
+
gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]]
|
58 |
+
smpl_gender = SMPL(model_path=_C.BMODEL.FLDR, gender=gender)
|
59 |
+
|
60 |
+
# ======== Add TCMR data ======== #
|
61 |
+
vid_name = f'{seq}_{p_id}'
|
62 |
+
tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v]
|
63 |
+
frame_ids = tcmr_data['frame_id'][tcmr_ids]
|
64 |
+
|
65 |
+
pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids]
|
66 |
+
shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1)
|
67 |
+
trans = torch.from_numpy(data['trans'][p_id]).float()[frame_ids]
|
68 |
+
cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float()
|
69 |
+
|
70 |
+
# ======== Align the mesh params ======== #
|
71 |
+
Rc = cam_poses[:, :3, :3]
|
72 |
+
Tc = cam_poses[:, :3, 3]
|
73 |
+
org_output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3], transl=trans)
|
74 |
+
org_v0 = (org_output.vertices + org_output.offset.unsqueeze(1)).mean(1)
|
75 |
+
pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float()
|
76 |
+
|
77 |
+
output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3])
|
78 |
+
v0 = (output.vertices + output.offset.unsqueeze(1)).mean(1)
|
79 |
+
trans = (Rc @ org_v0.reshape(-1, 3, 1)).reshape(-1, 3) + Tc - v0
|
80 |
+
j3d = output.joints + (output.offset + trans).unsqueeze(1)
|
81 |
+
j2d = torch.div(j3d, j3d[..., 2:])
|
82 |
+
kp2d = torch.matmul(K, j2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
|
83 |
+
# ======== Align the mesh params ======== #
|
84 |
+
|
85 |
+
# ======== Get detection results ======== #
|
86 |
+
fname = f'{seq}_{p_id}.npy'
|
87 |
+
pred_kp2d = torch.from_numpy(
|
88 |
+
np.load(osp.join(detection_results_dir, fname))
|
89 |
+
).float()[frame_ids]
|
90 |
+
# ======== Get detection results ======== #
|
91 |
+
|
92 |
+
img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg')))
|
93 |
+
img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids]
|
94 |
+
img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2]
|
95 |
+
vid_idxs = torch.from_numpy(np.array([vid_idx] * len(img_paths)).astype(int))
|
96 |
+
vid_idx += 1
|
97 |
+
|
98 |
+
# ======== Append data ======== #
|
99 |
+
dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float())
|
100 |
+
dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float())
|
101 |
+
dataset['vid'].append(vid_idxs)
|
102 |
+
dataset['pose'].append(pose)
|
103 |
+
dataset['betas'].append(shape)
|
104 |
+
dataset['transl'].append(trans)
|
105 |
+
dataset['kp2d'].append(pred_kp2d)
|
106 |
+
dataset['joints3D'].append(j3d)
|
107 |
+
dataset['joints2D'].append(kp2d)
|
108 |
+
dataset['frame_id'].append(torch.from_numpy(frame_ids))
|
109 |
+
dataset['cam_poses'].append(cam_poses)
|
110 |
+
dataset['gender'].append(torch.tensor([['male','female'].index(gender)]).repeat(len(frame_ids)))
|
111 |
+
# ======== Append data ======== #
|
112 |
+
|
113 |
+
# ======== Extract features ======== #
|
114 |
+
patch_list = []
|
115 |
+
bboxes = dataset['bbox'][-1].clone().numpy()
|
116 |
+
bar = Bar(f'Load images', fill='#', max=len(img_paths))
|
117 |
+
|
118 |
+
for img_path, bbox in zip(img_paths, bboxes):
|
119 |
+
img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
120 |
+
norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
|
121 |
+
patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
|
122 |
+
bar.next()
|
123 |
+
|
124 |
+
patch_list = torch.split(torch.cat(patch_list), batch_size)
|
125 |
+
features = []
|
126 |
+
for i, patch in enumerate(patch_list):
|
127 |
+
pred = extractor.model(patch.cuda(), encode=True)
|
128 |
+
features.append(pred.cpu())
|
129 |
+
|
130 |
+
features = torch.cat(features)
|
131 |
+
dataset['features'].append(features)
|
132 |
+
# ======== Extract features ======== #
|
133 |
+
|
134 |
+
for key in dataset.keys():
|
135 |
+
dataset[key] = torch.cat(dataset[key])
|
136 |
+
|
137 |
+
joblib.dump(dataset, save_pth)
|
138 |
+
logger.info(f'\n ==> Done !')
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
parser = argparse.ArgumentParser()
|
143 |
+
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
|
144 |
+
args = parser.parse_args()
|
145 |
+
|
146 |
+
preprocess(args.batch_size)
|
lib/eval/eval_utils.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py
|
2 |
+
# Adhere to their licence to use these functions
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
+
|
9 |
+
|
10 |
+
def compute_accel(joints):
|
11 |
+
"""
|
12 |
+
Computes acceleration of 3D joints.
|
13 |
+
Args:
|
14 |
+
joints (Nx25x3).
|
15 |
+
Returns:
|
16 |
+
Accelerations (N-2).
|
17 |
+
"""
|
18 |
+
velocities = joints[1:] - joints[:-1]
|
19 |
+
acceleration = velocities[1:] - velocities[:-1]
|
20 |
+
acceleration_normed = np.linalg.norm(acceleration, axis=2)
|
21 |
+
return np.mean(acceleration_normed, axis=1)
|
22 |
+
|
23 |
+
|
24 |
+
def compute_error_accel(joints_gt, joints_pred, vis=None):
|
25 |
+
"""
|
26 |
+
Computes acceleration error:
|
27 |
+
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
|
28 |
+
Note that for each frame that is not visible, three entries in the
|
29 |
+
acceleration error should be zero'd out.
|
30 |
+
Args:
|
31 |
+
joints_gt (Nx14x3).
|
32 |
+
joints_pred (Nx14x3).
|
33 |
+
vis (N).
|
34 |
+
Returns:
|
35 |
+
error_accel (N-2).
|
36 |
+
"""
|
37 |
+
# (N-2)x14x3
|
38 |
+
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
|
39 |
+
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]
|
40 |
+
|
41 |
+
normed = np.linalg.norm(accel_pred - accel_gt, axis=2)
|
42 |
+
|
43 |
+
if vis is None:
|
44 |
+
new_vis = np.ones(len(normed), dtype=bool)
|
45 |
+
else:
|
46 |
+
invis = np.logical_not(vis)
|
47 |
+
invis1 = np.roll(invis, -1)
|
48 |
+
invis2 = np.roll(invis, -2)
|
49 |
+
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
|
50 |
+
new_vis = np.logical_not(new_invis)
|
51 |
+
|
52 |
+
return np.mean(normed[new_vis], axis=1)
|
53 |
+
|
54 |
+
|
55 |
+
def compute_error_verts(pred_verts, target_verts=None, target_theta=None):
|
56 |
+
"""
|
57 |
+
Computes MPJPE over 6890 surface vertices.
|
58 |
+
Args:
|
59 |
+
verts_gt (Nx6890x3).
|
60 |
+
verts_pred (Nx6890x3).
|
61 |
+
Returns:
|
62 |
+
error_verts (N).
|
63 |
+
"""
|
64 |
+
|
65 |
+
if target_verts is None:
|
66 |
+
from lib.models.smpl import SMPL_MODEL_DIR
|
67 |
+
from lib.models.smpl import SMPL
|
68 |
+
device = 'cpu'
|
69 |
+
smpl = SMPL(
|
70 |
+
SMPL_MODEL_DIR,
|
71 |
+
batch_size=1, # target_theta.shape[0],
|
72 |
+
).to(device)
|
73 |
+
|
74 |
+
betas = torch.from_numpy(target_theta[:,75:]).to(device)
|
75 |
+
pose = torch.from_numpy(target_theta[:,3:75]).to(device)
|
76 |
+
|
77 |
+
target_verts = []
|
78 |
+
b_ = torch.split(betas, 5000)
|
79 |
+
p_ = torch.split(pose, 5000)
|
80 |
+
|
81 |
+
for b,p in zip(b_,p_):
|
82 |
+
output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True)
|
83 |
+
target_verts.append(output.vertices.detach().cpu().numpy())
|
84 |
+
|
85 |
+
target_verts = np.concatenate(target_verts, axis=0)
|
86 |
+
|
87 |
+
assert len(pred_verts) == len(target_verts)
|
88 |
+
error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2))
|
89 |
+
return np.mean(error_per_vert, axis=1)
|
90 |
+
|
91 |
+
|
92 |
+
def compute_similarity_transform(S1, S2):
|
93 |
+
'''
|
94 |
+
Computes a similarity transform (sR, t) that takes
|
95 |
+
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
96 |
+
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
97 |
+
i.e. solves the orthogonal Procrutes problem.
|
98 |
+
'''
|
99 |
+
transposed = False
|
100 |
+
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
101 |
+
S1 = S1.T
|
102 |
+
S2 = S2.T
|
103 |
+
transposed = True
|
104 |
+
assert(S2.shape[1] == S1.shape[1])
|
105 |
+
|
106 |
+
# 1. Remove mean.
|
107 |
+
mu1 = S1.mean(axis=1, keepdims=True)
|
108 |
+
mu2 = S2.mean(axis=1, keepdims=True)
|
109 |
+
X1 = S1 - mu1
|
110 |
+
X2 = S2 - mu2
|
111 |
+
|
112 |
+
# 2. Compute variance of X1 used for scale.
|
113 |
+
var1 = np.sum(X1**2)
|
114 |
+
|
115 |
+
# 3. The outer product of X1 and X2.
|
116 |
+
K = X1.dot(X2.T)
|
117 |
+
|
118 |
+
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
119 |
+
# singular vectors of K.
|
120 |
+
U, s, Vh = np.linalg.svd(K)
|
121 |
+
V = Vh.T
|
122 |
+
# Construct Z that fixes the orientation of R to get det(R)=1.
|
123 |
+
Z = np.eye(U.shape[0])
|
124 |
+
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
125 |
+
# Construct R.
|
126 |
+
R = V.dot(Z.dot(U.T))
|
127 |
+
|
128 |
+
# 5. Recover scale.
|
129 |
+
scale = np.trace(R.dot(K)) / var1
|
130 |
+
|
131 |
+
# 6. Recover translation.
|
132 |
+
t = mu2 - scale*(R.dot(mu1))
|
133 |
+
|
134 |
+
# 7. Error:
|
135 |
+
S1_hat = scale*R.dot(S1) + t
|
136 |
+
|
137 |
+
if transposed:
|
138 |
+
S1_hat = S1_hat.T
|
139 |
+
|
140 |
+
return S1_hat
|
141 |
+
|
142 |
+
|
143 |
+
def compute_similarity_transform_torch(S1, S2):
|
144 |
+
'''
|
145 |
+
Computes a similarity transform (sR, t) that takes
|
146 |
+
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
147 |
+
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
148 |
+
i.e. solves the orthogonal Procrutes problem.
|
149 |
+
'''
|
150 |
+
transposed = False
|
151 |
+
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
152 |
+
S1 = S1.T
|
153 |
+
S2 = S2.T
|
154 |
+
transposed = True
|
155 |
+
assert (S2.shape[1] == S1.shape[1])
|
156 |
+
|
157 |
+
# 1. Remove mean.
|
158 |
+
mu1 = S1.mean(axis=1, keepdims=True)
|
159 |
+
mu2 = S2.mean(axis=1, keepdims=True)
|
160 |
+
X1 = S1 - mu1
|
161 |
+
X2 = S2 - mu2
|
162 |
+
|
163 |
+
# print('X1', X1.shape)
|
164 |
+
|
165 |
+
# 2. Compute variance of X1 used for scale.
|
166 |
+
var1 = torch.sum(X1 ** 2)
|
167 |
+
|
168 |
+
# print('var', var1.shape)
|
169 |
+
|
170 |
+
# 3. The outer product of X1 and X2.
|
171 |
+
K = X1.mm(X2.T)
|
172 |
+
|
173 |
+
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
174 |
+
# singular vectors of K.
|
175 |
+
U, s, V = torch.svd(K)
|
176 |
+
# V = Vh.T
|
177 |
+
# Construct Z that fixes the orientation of R to get det(R)=1.
|
178 |
+
Z = torch.eye(U.shape[0], device=S1.device)
|
179 |
+
Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
|
180 |
+
# Construct R.
|
181 |
+
R = V.mm(Z.mm(U.T))
|
182 |
+
|
183 |
+
# print('R', X1.shape)
|
184 |
+
|
185 |
+
# 5. Recover scale.
|
186 |
+
scale = torch.trace(R.mm(K)) / var1
|
187 |
+
# print(R.shape, mu1.shape)
|
188 |
+
# 6. Recover translation.
|
189 |
+
t = mu2 - scale * (R.mm(mu1))
|
190 |
+
# print(t.shape)
|
191 |
+
|
192 |
+
# 7. Error:
|
193 |
+
S1_hat = scale * R.mm(S1) + t
|
194 |
+
|
195 |
+
if transposed:
|
196 |
+
S1_hat = S1_hat.T
|
197 |
+
|
198 |
+
return S1_hat
|
199 |
+
|
200 |
+
|
201 |
+
def batch_compute_similarity_transform_torch(S1, S2):
|
202 |
+
'''
|
203 |
+
Computes a similarity transform (sR, t) that takes
|
204 |
+
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
205 |
+
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
206 |
+
i.e. solves the orthogonal Procrutes problem.
|
207 |
+
'''
|
208 |
+
transposed = False
|
209 |
+
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
210 |
+
S1 = S1.permute(0,2,1)
|
211 |
+
S2 = S2.permute(0,2,1)
|
212 |
+
transposed = True
|
213 |
+
assert(S2.shape[1] == S1.shape[1])
|
214 |
+
|
215 |
+
# 1. Remove mean.
|
216 |
+
mu1 = S1.mean(axis=-1, keepdims=True)
|
217 |
+
mu2 = S2.mean(axis=-1, keepdims=True)
|
218 |
+
|
219 |
+
X1 = S1 - mu1
|
220 |
+
X2 = S2 - mu2
|
221 |
+
|
222 |
+
# 2. Compute variance of X1 used for scale.
|
223 |
+
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
|
224 |
+
|
225 |
+
# 3. The outer product of X1 and X2.
|
226 |
+
K = X1.bmm(X2.permute(0,2,1))
|
227 |
+
|
228 |
+
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
229 |
+
# singular vectors of K.
|
230 |
+
U, s, V = torch.svd(K)
|
231 |
+
|
232 |
+
# Construct Z that fixes the orientation of R to get det(R)=1.
|
233 |
+
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
|
234 |
+
Z = Z.repeat(U.shape[0],1,1)
|
235 |
+
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))
|
236 |
+
|
237 |
+
# Construct R.
|
238 |
+
R = V.bmm(Z.bmm(U.permute(0,2,1)))
|
239 |
+
|
240 |
+
# 5. Recover scale.
|
241 |
+
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
|
242 |
+
|
243 |
+
# 6. Recover translation.
|
244 |
+
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
|
245 |
+
|
246 |
+
# 7. Error:
|
247 |
+
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
|
248 |
+
|
249 |
+
if transposed:
|
250 |
+
S1_hat = S1_hat.permute(0,2,1)
|
251 |
+
|
252 |
+
return S1_hat
|
253 |
+
|
254 |
+
|
255 |
+
def align_by_pelvis(joints):
|
256 |
+
"""
|
257 |
+
Assumes joints is 14 x 3 in LSP order.
|
258 |
+
Then hips are: [3, 2]
|
259 |
+
Takes mid point of these points, then subtracts it.
|
260 |
+
"""
|
261 |
+
|
262 |
+
left_id = 2
|
263 |
+
right_id = 3
|
264 |
+
|
265 |
+
pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0
|
266 |
+
return joints - np.expand_dims(pelvis, axis=0)
|
267 |
+
|
268 |
+
|
269 |
+
def compute_errors(gt3ds, preds):
|
270 |
+
"""
|
271 |
+
Gets MPJPE after pelvis alignment + MPJPE after Procrustes.
|
272 |
+
Evaluates on the 14 common joints.
|
273 |
+
Inputs:
|
274 |
+
- gt3ds: N x 14 x 3
|
275 |
+
- preds: N x 14 x 3
|
276 |
+
"""
|
277 |
+
errors, errors_pa = [], []
|
278 |
+
for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)):
|
279 |
+
gt3d = gt3d.reshape(-1, 3)
|
280 |
+
# Root align.
|
281 |
+
gt3d = align_by_pelvis(gt3d)
|
282 |
+
pred3d = align_by_pelvis(pred)
|
283 |
+
|
284 |
+
joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1))
|
285 |
+
errors.append(np.mean(joint_error))
|
286 |
+
|
287 |
+
# Get PA error.
|
288 |
+
pred3d_sym = compute_similarity_transform(pred3d, gt3d)
|
289 |
+
pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1))
|
290 |
+
errors_pa.append(np.mean(pa_error))
|
291 |
+
|
292 |
+
return errors, errors_pa
|
293 |
+
|
294 |
+
|
295 |
+
def batch_align_by_pelvis(data_list, pelvis_idxs):
|
296 |
+
"""
|
297 |
+
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts].
|
298 |
+
Each data is in shape of (frames, num_points, 3)
|
299 |
+
Pelvis is notated as one / two joints indices.
|
300 |
+
Align all data to the corresponding pelvis location.
|
301 |
+
"""
|
302 |
+
|
303 |
+
pred_j3d, target_j3d, pred_verts, target_verts = data_list
|
304 |
+
|
305 |
+
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
|
306 |
+
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
|
307 |
+
|
308 |
+
# Align to the pelvis
|
309 |
+
pred_j3d = pred_j3d - pred_pelvis
|
310 |
+
target_j3d = target_j3d - target_pelvis
|
311 |
+
pred_verts = pred_verts - pred_pelvis
|
312 |
+
target_verts = target_verts - target_pelvis
|
313 |
+
|
314 |
+
return (pred_j3d, target_j3d, pred_verts, target_verts)
|
315 |
+
|
316 |
+
def compute_jpe(S1, S2):
|
317 |
+
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy()
|
318 |
+
|
319 |
+
|
320 |
+
# The functions below are borrowed from SLAHMR official implementation.
|
321 |
+
# Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py
|
322 |
+
def global_align_joints(gt_joints, pred_joints):
|
323 |
+
"""
|
324 |
+
:param gt_joints (T, J, 3)
|
325 |
+
:param pred_joints (T, J, 3)
|
326 |
+
"""
|
327 |
+
s_glob, R_glob, t_glob = align_pcl(
|
328 |
+
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
|
329 |
+
)
|
330 |
+
pred_glob = (
|
331 |
+
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
|
332 |
+
)
|
333 |
+
return pred_glob
|
334 |
+
|
335 |
+
|
336 |
+
def first_align_joints(gt_joints, pred_joints):
|
337 |
+
"""
|
338 |
+
align the first two frames
|
339 |
+
:param gt_joints (T, J, 3)
|
340 |
+
:param pred_joints (T, J, 3)
|
341 |
+
"""
|
342 |
+
# (1, 1), (1, 3, 3), (1, 3)
|
343 |
+
s_first, R_first, t_first = align_pcl(
|
344 |
+
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
|
345 |
+
)
|
346 |
+
pred_first = (
|
347 |
+
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
|
348 |
+
)
|
349 |
+
return pred_first
|
350 |
+
|
351 |
+
|
352 |
+
def local_align_joints(gt_joints, pred_joints):
|
353 |
+
"""
|
354 |
+
:param gt_joints (T, J, 3)
|
355 |
+
:param pred_joints (T, J, 3)
|
356 |
+
"""
|
357 |
+
s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints)
|
358 |
+
pred_loc = (
|
359 |
+
s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints)
|
360 |
+
+ t_loc[:, None]
|
361 |
+
)
|
362 |
+
return pred_loc
|
363 |
+
|
364 |
+
|
365 |
+
def align_pcl(Y, X, weight=None, fixed_scale=False):
|
366 |
+
"""align similarity transform to align X with Y using umeyama method
|
367 |
+
X' = s * R * X + t is aligned with Y
|
368 |
+
:param Y (*, N, 3) first trajectory
|
369 |
+
:param X (*, N, 3) second trajectory
|
370 |
+
:param weight (*, N, 1) optional weight of valid correspondences
|
371 |
+
:returns s (*, 1), R (*, 3, 3), t (*, 3)
|
372 |
+
"""
|
373 |
+
*dims, N, _ = Y.shape
|
374 |
+
N = torch.ones(*dims, 1, 1) * N
|
375 |
+
|
376 |
+
if weight is not None:
|
377 |
+
Y = Y * weight
|
378 |
+
X = X * weight
|
379 |
+
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
|
380 |
+
|
381 |
+
# subtract mean
|
382 |
+
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
|
383 |
+
mx = X.sum(dim=-2) / N[..., 0]
|
384 |
+
y0 = Y - my[..., None, :] # (*, N, 3)
|
385 |
+
x0 = X - mx[..., None, :]
|
386 |
+
|
387 |
+
if weight is not None:
|
388 |
+
y0 = y0 * weight
|
389 |
+
x0 = x0 * weight
|
390 |
+
|
391 |
+
# correlation
|
392 |
+
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
|
393 |
+
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
|
394 |
+
|
395 |
+
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
|
396 |
+
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
|
397 |
+
S[neg, 2, 2] = -1
|
398 |
+
|
399 |
+
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
|
400 |
+
|
401 |
+
D = torch.diag_embed(D) # (*, 3, 3)
|
402 |
+
if fixed_scale:
|
403 |
+
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
|
404 |
+
else:
|
405 |
+
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
|
406 |
+
s = (
|
407 |
+
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
|
408 |
+
dim=-1, keepdim=True
|
409 |
+
)
|
410 |
+
/ var[..., 0]
|
411 |
+
) # (*, 1)
|
412 |
+
|
413 |
+
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
|
414 |
+
|
415 |
+
return s, R, t
|
416 |
+
|
417 |
+
|
418 |
+
def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2):
|
419 |
+
"""compute foot sliding error
|
420 |
+
The foot ground contact label is computed by the threshold of 1 cm/frame
|
421 |
+
Args:
|
422 |
+
target_output (SMPL ModelOutput).
|
423 |
+
pred_output (SMPL ModelOutput).
|
424 |
+
masks (N).
|
425 |
+
Returns:
|
426 |
+
error (N frames in contact).
|
427 |
+
"""
|
428 |
+
|
429 |
+
# Foot vertices idxs
|
430 |
+
foot_idxs = [3216, 3387, 6617, 6787]
|
431 |
+
|
432 |
+
# Compute contact label
|
433 |
+
foot_loc = target_output.vertices[masks][:, foot_idxs]
|
434 |
+
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1)
|
435 |
+
contact = foot_disp[:] < thr
|
436 |
+
|
437 |
+
pred_feet_loc = pred_output.vertices[:, foot_idxs]
|
438 |
+
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1)
|
439 |
+
|
440 |
+
error = pred_disp[contact]
|
441 |
+
|
442 |
+
return error.cpu().numpy()
|
443 |
+
|
444 |
+
|
445 |
+
def compute_jitter(pred_output, fps=30):
|
446 |
+
"""compute jitter of the motion
|
447 |
+
Args:
|
448 |
+
pred_output (SMPL ModelOutput).
|
449 |
+
fps (float).
|
450 |
+
Returns:
|
451 |
+
jitter (N-3).
|
452 |
+
"""
|
453 |
+
|
454 |
+
pred3d = pred_output.joints[:, :24]
|
455 |
+
|
456 |
+
pred_jitter = torch.norm(
|
457 |
+
(pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3),
|
458 |
+
dim=2,
|
459 |
+
).mean(dim=-1)
|
460 |
+
|
461 |
+
return pred_jitter.cpu().numpy() / 10.0
|
462 |
+
|
463 |
+
|
464 |
+
def compute_rte(target_trans, pred_trans):
|
465 |
+
# Compute the global alignment
|
466 |
+
_, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True)
|
467 |
+
pred_trans_hat = (
|
468 |
+
torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :]
|
469 |
+
)[0]
|
470 |
+
|
471 |
+
# Compute the entire displacement of ground truth trajectory
|
472 |
+
disps, disp = [], 0
|
473 |
+
for p1, p2 in zip(target_trans, target_trans[1:]):
|
474 |
+
delta = (p2 - p1).norm(2, dim=-1)
|
475 |
+
disp += delta
|
476 |
+
disps.append(disp)
|
477 |
+
|
478 |
+
# Compute absolute root-translation-error (RTE)
|
479 |
+
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1)
|
480 |
+
|
481 |
+
# Normalize it to the displacement
|
482 |
+
return (rte / disp).numpy()
|
lib/eval/evaluate_3dpw.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import os.path as osp
|
4 |
+
from glob import glob
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import imageio
|
9 |
+
import numpy as np
|
10 |
+
from smplx import SMPL
|
11 |
+
from loguru import logger
|
12 |
+
from progress.bar import Bar
|
13 |
+
|
14 |
+
from configs import constants as _C
|
15 |
+
from configs.config import parse_args
|
16 |
+
from lib.data.dataloader import setup_eval_dataloader
|
17 |
+
from lib.models import build_network, build_body_model
|
18 |
+
from lib.eval.eval_utils import (
|
19 |
+
compute_error_accel,
|
20 |
+
batch_align_by_pelvis,
|
21 |
+
batch_compute_similarity_transform_torch,
|
22 |
+
)
|
23 |
+
from lib.utils import transforms
|
24 |
+
from lib.utils.utils import prepare_output_dir
|
25 |
+
from lib.utils.utils import prepare_batch
|
26 |
+
from lib.utils.imutils import avg_preds
|
27 |
+
|
28 |
+
try:
|
29 |
+
from lib.vis.renderer import Renderer
|
30 |
+
_render = True
|
31 |
+
except:
|
32 |
+
print("PyTorch3D is not properly installed! Cannot render the SMPL mesh")
|
33 |
+
_render = False
|
34 |
+
|
35 |
+
|
36 |
+
m2mm = 1e3
|
37 |
+
@torch.no_grad()
|
38 |
+
def main(cfg, args):
|
39 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
40 |
+
torch.backends.cudnn.allow_tf32 = False
|
41 |
+
|
42 |
+
logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
|
43 |
+
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
|
44 |
+
|
45 |
+
# ========= Dataloaders ========= #
|
46 |
+
eval_loader = setup_eval_dataloader(cfg, '3dpw', 'test', cfg.MODEL.BACKBONE)
|
47 |
+
logger.info(f'Dataset loaded')
|
48 |
+
|
49 |
+
# ========= Load WHAM ========= #
|
50 |
+
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
|
51 |
+
smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
|
52 |
+
network = build_network(cfg, smpl)
|
53 |
+
network.eval()
|
54 |
+
|
55 |
+
# Build SMPL models with each gender
|
56 |
+
smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']}
|
57 |
+
|
58 |
+
# Load vertices -> joints regression matrix to evaluate
|
59 |
+
J_regressor_eval = torch.from_numpy(
|
60 |
+
np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
|
61 |
+
)[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(cfg.DEVICE)
|
62 |
+
pelvis_idxs = [2, 3]
|
63 |
+
|
64 |
+
accumulator = defaultdict(list)
|
65 |
+
bar = Bar('Inference', fill='#', max=len(eval_loader))
|
66 |
+
with torch.no_grad():
|
67 |
+
for i in range(len(eval_loader)):
|
68 |
+
# Original batch
|
69 |
+
batch = eval_loader.dataset.load_data(i, False)
|
70 |
+
x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
|
71 |
+
|
72 |
+
if cfg.FLIP_EVAL:
|
73 |
+
flipped_batch = eval_loader.dataset.load_data(i, True)
|
74 |
+
f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
|
75 |
+
|
76 |
+
# Forward pass with flipped input
|
77 |
+
flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
|
78 |
+
|
79 |
+
# Forward pass with normal input
|
80 |
+
pred = network(x, inits, features, **kwargs)
|
81 |
+
|
82 |
+
if cfg.FLIP_EVAL:
|
83 |
+
# Merge two predictions
|
84 |
+
flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
|
85 |
+
pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
|
86 |
+
flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
|
87 |
+
avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
|
88 |
+
avg_pose = avg_pose.reshape(-1, 144)
|
89 |
+
|
90 |
+
# Refine trajectory with merged prediction
|
91 |
+
network.pred_pose = avg_pose.view_as(network.pred_pose)
|
92 |
+
network.pred_shape = avg_shape.view_as(network.pred_shape)
|
93 |
+
pred = network.forward_smpl(**kwargs)
|
94 |
+
|
95 |
+
# <======= Build predicted SMPL
|
96 |
+
pred_output = smpl['neutral'](body_pose=pred['poses_body'],
|
97 |
+
global_orient=pred['poses_root_cam'],
|
98 |
+
betas=pred['betas'].squeeze(0),
|
99 |
+
pose2rot=False)
|
100 |
+
pred_verts = pred_output.vertices.cpu()
|
101 |
+
pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu()
|
102 |
+
# =======>
|
103 |
+
|
104 |
+
# <======= Build groundtruth SMPL
|
105 |
+
target_output = smpl[batch['gender']](
|
106 |
+
body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]),
|
107 |
+
global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]),
|
108 |
+
betas=gt['betas'][0],
|
109 |
+
pose2rot=False)
|
110 |
+
target_verts = target_output.vertices.cpu()
|
111 |
+
target_j3d = torch.matmul(J_regressor_eval, target_output.vertices).cpu()
|
112 |
+
# =======>
|
113 |
+
|
114 |
+
# <======= Compute performance of the current sequence
|
115 |
+
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
|
116 |
+
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs
|
117 |
+
)
|
118 |
+
S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
|
119 |
+
pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
120 |
+
mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
121 |
+
pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
122 |
+
accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1]
|
123 |
+
accel = accel * (30 ** 2) # per frame^s to per s^2
|
124 |
+
# =======>
|
125 |
+
|
126 |
+
summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
|
127 |
+
bar.suffix = summary_string
|
128 |
+
bar.next()
|
129 |
+
|
130 |
+
# <======= Accumulate the results over entire sequences
|
131 |
+
accumulator['pa_mpjpe'].append(pa_mpjpe)
|
132 |
+
accumulator['mpjpe'].append(mpjpe)
|
133 |
+
accumulator['pve'].append(pve)
|
134 |
+
accumulator['accel'].append(accel)
|
135 |
+
# =======>
|
136 |
+
|
137 |
+
# <======= (Optional) Render the prediction
|
138 |
+
if not (_render and args.render):
|
139 |
+
# Skip if PyTorch3D is not installed or rendering argument is not parsed.
|
140 |
+
continue
|
141 |
+
|
142 |
+
# Save path
|
143 |
+
viz_pth = osp.join('output', 'visualization')
|
144 |
+
os.makedirs(viz_pth, exist_ok=True)
|
145 |
+
|
146 |
+
# Build Renderer
|
147 |
+
width, height = batch['cam_intrinsics'][0][0, :2, -1].numpy() * 2
|
148 |
+
focal_length = batch['cam_intrinsics'][0][0, 0, 0].item()
|
149 |
+
renderer = Renderer(width, height, focal_length, cfg.DEVICE, smpl['neutral'].faces)
|
150 |
+
|
151 |
+
# Get images and writer
|
152 |
+
frame_list = batch['frame_id'][0].numpy()
|
153 |
+
imname_list = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', batch['vid'][:-2], '*.jpg')))
|
154 |
+
writer = imageio.get_writer(osp.join(viz_pth, batch['vid'] + '.mp4'),
|
155 |
+
mode='I', format='FFMPEG', fps=30, macro_block_size=1)
|
156 |
+
|
157 |
+
# Skip the invalid frames
|
158 |
+
for i, frame in enumerate(frame_list):
|
159 |
+
image = imageio.imread(imname_list[frame])
|
160 |
+
|
161 |
+
# NOTE: pred['verts'] is different from pred_verts as we substracted offset from SMPL mesh.
|
162 |
+
# Check line 121 in lib/models/smpl.py
|
163 |
+
vertices = pred['verts_cam'][i] + pred['trans_cam'][[i]]
|
164 |
+
image = renderer.render_mesh(vertices, image)
|
165 |
+
writer.append_data(image)
|
166 |
+
writer.close()
|
167 |
+
# =======>
|
168 |
+
|
169 |
+
for k, v in accumulator.items():
|
170 |
+
accumulator[k] = np.concatenate(v).mean()
|
171 |
+
|
172 |
+
print('')
|
173 |
+
log_str = 'Evaluation on 3DPW, '
|
174 |
+
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
|
175 |
+
logger.info(log_str)
|
176 |
+
|
177 |
+
if __name__ == '__main__':
|
178 |
+
cfg, cfg_file, args = parse_args(test=True)
|
179 |
+
cfg = prepare_output_dir(cfg, cfg_file)
|
180 |
+
|
181 |
+
main(cfg, args)
|
lib/eval/evaluate_emdb.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import os.path as osp
|
4 |
+
from glob import glob
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
import numpy as np
|
10 |
+
from smplx import SMPL
|
11 |
+
from loguru import logger
|
12 |
+
from progress.bar import Bar
|
13 |
+
|
14 |
+
from configs import constants as _C
|
15 |
+
from configs.config import parse_args
|
16 |
+
from lib.data.dataloader import setup_eval_dataloader
|
17 |
+
from lib.models import build_network, build_body_model
|
18 |
+
from lib.eval.eval_utils import (
|
19 |
+
compute_jpe,
|
20 |
+
compute_rte,
|
21 |
+
compute_jitter,
|
22 |
+
compute_error_accel,
|
23 |
+
compute_foot_sliding,
|
24 |
+
batch_align_by_pelvis,
|
25 |
+
first_align_joints,
|
26 |
+
global_align_joints,
|
27 |
+
compute_rte,
|
28 |
+
compute_jitter,
|
29 |
+
compute_foot_sliding
|
30 |
+
batch_compute_similarity_transform_torch,
|
31 |
+
)
|
32 |
+
from lib.utils import transforms
|
33 |
+
from lib.utils.utils import prepare_output_dir
|
34 |
+
from lib.utils.utils import prepare_batch
|
35 |
+
from lib.utils.imutils import avg_preds
|
36 |
+
|
37 |
+
"""
|
38 |
+
This is a tentative script to evaluate WHAM on EMDB dataset.
|
39 |
+
Current implementation requires EMDB dataset downloaded at ./datasets/EMDB/
|
40 |
+
"""
|
41 |
+
|
42 |
+
m2mm = 1e3
|
43 |
+
@torch.no_grad()
|
44 |
+
def main(cfg, args):
|
45 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
46 |
+
torch.backends.cudnn.allow_tf32 = False
|
47 |
+
|
48 |
+
logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
|
49 |
+
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
|
50 |
+
|
51 |
+
# ========= Dataloaders ========= #
|
52 |
+
eval_loader = setup_eval_dataloader(cfg, 'emdb', args.eval_split, cfg.MODEL.BACKBONE)
|
53 |
+
logger.info(f'Dataset loaded')
|
54 |
+
|
55 |
+
# ========= Load WHAM ========= #
|
56 |
+
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
|
57 |
+
smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
|
58 |
+
network = build_network(cfg, smpl)
|
59 |
+
network.eval()
|
60 |
+
|
61 |
+
# Build SMPL models with each gender
|
62 |
+
smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']}
|
63 |
+
|
64 |
+
# Load vertices -> joints regression matrix to evaluate
|
65 |
+
pelvis_idxs = [1, 2]
|
66 |
+
|
67 |
+
# WHAM uses Y-down coordinate system, while EMDB dataset uses Y-up one.
|
68 |
+
yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(cfg.DEVICE)
|
69 |
+
|
70 |
+
# To torch tensor function
|
71 |
+
tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE)
|
72 |
+
accumulator = defaultdict(list)
|
73 |
+
bar = Bar('Inference', fill='#', max=len(eval_loader))
|
74 |
+
with torch.no_grad():
|
75 |
+
for i in range(len(eval_loader)):
|
76 |
+
# Original batch
|
77 |
+
batch = eval_loader.dataset.load_data(i, False)
|
78 |
+
x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2')
|
79 |
+
|
80 |
+
# Align with groundtruth data to the first frame
|
81 |
+
cam2yup = batch['R'][0][:1].to(cfg.DEVICE)
|
82 |
+
cam2ydown = cam2yup @ yup2ydown
|
83 |
+
cam2root = transforms.rotation_6d_to_matrix(inits[1][:, 0, 0])
|
84 |
+
ydown2root = cam2ydown.mT @ cam2root
|
85 |
+
ydown2root = transforms.matrix_to_rotation_6d(ydown2root)
|
86 |
+
kwargs['init_root'][:, 0] = ydown2root
|
87 |
+
|
88 |
+
if cfg.FLIP_EVAL:
|
89 |
+
flipped_batch = eval_loader.dataset.load_data(i, True)
|
90 |
+
f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2')
|
91 |
+
|
92 |
+
# Forward pass with flipped input
|
93 |
+
flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
|
94 |
+
|
95 |
+
# Forward pass with normal input
|
96 |
+
pred = network(x, inits, features, **kwargs)
|
97 |
+
|
98 |
+
if cfg.FLIP_EVAL:
|
99 |
+
# Merge two predictions
|
100 |
+
flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
|
101 |
+
pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
|
102 |
+
flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
|
103 |
+
avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
|
104 |
+
avg_pose = avg_pose.reshape(-1, 144)
|
105 |
+
avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2
|
106 |
+
|
107 |
+
# Refine trajectory with merged prediction
|
108 |
+
network.pred_pose = avg_pose.view_as(network.pred_pose)
|
109 |
+
network.pred_shape = avg_shape.view_as(network.pred_shape)
|
110 |
+
network.pred_contact = avg_contact.view_as(network.pred_contact)
|
111 |
+
output = network.forward_smpl(**kwargs)
|
112 |
+
pred = network.refine_trajectory(output, return_y_up=True, **kwargs)
|
113 |
+
|
114 |
+
# <======= Prepare groundtruth data
|
115 |
+
subj, seq = batch['vid'][:2], batch['vid'][3:]
|
116 |
+
annot_pth = glob(osp.join(_C.PATHS.EMDB_PTH, subj, seq, '*_data.pkl'))[0]
|
117 |
+
annot = pickle.load(open(annot_pth, 'rb'))
|
118 |
+
|
119 |
+
masks = annot['good_frames_mask']
|
120 |
+
gender = annot['gender']
|
121 |
+
poses_body = annot["smpl"]["poses_body"]
|
122 |
+
poses_root = annot["smpl"]["poses_root"]
|
123 |
+
betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0)
|
124 |
+
trans = annot["smpl"]["trans"]
|
125 |
+
extrinsics = annot["camera"]["extrinsics"]
|
126 |
+
|
127 |
+
# # Map to camear coordinate
|
128 |
+
poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root)))
|
129 |
+
|
130 |
+
# Groundtruth global motion
|
131 |
+
target_glob = smpl[gender](body_pose=tt(poses_body), global_orient=tt(poses_root), betas=tt(betas), transl=tt(trans))
|
132 |
+
target_j3d_glob = target_glob.joints[:, :24][masks]
|
133 |
+
|
134 |
+
# Groundtruth local motion
|
135 |
+
target_cam = smpl[gender](body_pose=tt(poses_body), global_orient=poses_root_cam, betas=tt(betas))
|
136 |
+
target_verts_cam = target_cam.vertices[masks]
|
137 |
+
target_j3d_cam = target_cam.joints[:, :24][masks]
|
138 |
+
# =======>
|
139 |
+
|
140 |
+
# Convert WHAM global orient to Y-up coordinate
|
141 |
+
poses_root = pred['poses_root_world'].squeeze(0)
|
142 |
+
pred_trans = pred['trans_world'].squeeze(0)
|
143 |
+
poses_root = yup2ydown.mT @ poses_root
|
144 |
+
pred_trans = (yup2ydown.mT @ pred_trans.unsqueeze(-1)).squeeze(-1)
|
145 |
+
|
146 |
+
# <======= Build predicted motion
|
147 |
+
# Predicted global motion
|
148 |
+
pred_glob = smpl['neutral'](body_pose=pred['poses_body'], global_orient=poses_root.unsqueeze(1), betas=pred['betas'].squeeze(0), transl=pred_trans, pose2rot=False)
|
149 |
+
pred_j3d_glob = pred_glob.joints[:, :24]
|
150 |
+
|
151 |
+
# Predicted local motion
|
152 |
+
pred_cam = smpl['neutral'](body_pose=pred['poses_body'], global_orient=pred['poses_root_cam'], betas=pred['betas'].squeeze(0), pose2rot=False)
|
153 |
+
pred_verts_cam = pred_cam.vertices
|
154 |
+
pred_j3d_cam = pred_cam.joints[:, :24]
|
155 |
+
# =======>
|
156 |
+
|
157 |
+
# <======= Evaluation on the local motion
|
158 |
+
pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam = batch_align_by_pelvis(
|
159 |
+
[pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam], pelvis_idxs
|
160 |
+
)
|
161 |
+
S1_hat = batch_compute_similarity_transform_torch(pred_j3d_cam, target_j3d_cam)
|
162 |
+
pa_mpjpe = torch.sqrt(((S1_hat - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
|
163 |
+
mpjpe = torch.sqrt(((pred_j3d_cam - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
|
164 |
+
pve = torch.sqrt(((pred_verts_cam - target_verts_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
|
165 |
+
accel = compute_error_accel(joints_pred=pred_j3d_cam.cpu(), joints_gt=target_j3d_cam.cpu())[1:-1]
|
166 |
+
accel = accel * (30 ** 2) # per frame^s to per s^2
|
167 |
+
|
168 |
+
summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
|
169 |
+
bar.suffix = summary_string
|
170 |
+
bar.next()
|
171 |
+
# =======>
|
172 |
+
|
173 |
+
# <======= Evaluation on the global motion
|
174 |
+
chunk_length = 100
|
175 |
+
w_mpjpe, wa_mpjpe = [], []
|
176 |
+
for start in range(0, masks.sum(), chunk_length):
|
177 |
+
end = min(masks.sum(), start + chunk_length)
|
178 |
+
|
179 |
+
target_j3d = target_j3d_glob[start:end].clone().cpu()
|
180 |
+
pred_j3d = pred_j3d_glob[start:end].clone().cpu()
|
181 |
+
|
182 |
+
w_j3d = first_align_joints(target_j3d, pred_j3d)
|
183 |
+
wa_j3d = global_align_joints(target_j3d, pred_j3d)
|
184 |
+
|
185 |
+
w_jpe = compute_jpe(target_j3d, w_j3d)
|
186 |
+
wa_jpe = compute_jpe(target_j3d, wa_j3d)
|
187 |
+
w_mpjpe.append(w_jpe)
|
188 |
+
wa_mpjpe.append(wa_jpe)
|
189 |
+
|
190 |
+
w_mpjpe = np.concatenate(w_mpjpe) * m2mm
|
191 |
+
wa_mpjpe = np.concatenate(wa_mpjpe) * m2mm
|
192 |
+
|
193 |
+
# Additional metrics
|
194 |
+
rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2
|
195 |
+
jitter = compute_jitter(pred_glob, fps=30)
|
196 |
+
foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm
|
197 |
+
# =======>
|
198 |
+
|
199 |
+
# Additional metrics
|
200 |
+
rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2
|
201 |
+
jitter = compute_jitter(pred_glob, fps=30)
|
202 |
+
foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm
|
203 |
+
|
204 |
+
# <======= Accumulate the results over entire sequences
|
205 |
+
accumulator['pa_mpjpe'].append(pa_mpjpe)
|
206 |
+
accumulator['mpjpe'].append(mpjpe)
|
207 |
+
accumulator['pve'].append(pve)
|
208 |
+
accumulator['accel'].append(accel)
|
209 |
+
accumulator['wa_mpjpe'].append(wa_mpjpe)
|
210 |
+
accumulator['w_mpjpe'].append(w_mpjpe)
|
211 |
+
accumulator['RTE'].append(rte)
|
212 |
+
accumulator['jitter'].append(jitter)
|
213 |
+
accumulator['FS'].append(foot_sliding)
|
214 |
+
# =======>
|
215 |
+
|
216 |
+
for k, v in accumulator.items():
|
217 |
+
accumulator[k] = np.concatenate(v).mean()
|
218 |
+
|
219 |
+
print('')
|
220 |
+
log_str = f'Evaluation on EMDB {args.eval_split}, '
|
221 |
+
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
|
222 |
+
logger.info(log_str)
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
cfg, cfg_file, args = parse_args(test=True)
|
226 |
+
cfg = prepare_output_dir(cfg, cfg_file)
|
227 |
+
|
228 |
+
main(cfg, args)
|
lib/eval/evaluate_rich.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
from collections import defaultdict
|
4 |
+
from time import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import joblib
|
8 |
+
import numpy as np
|
9 |
+
from loguru import logger
|
10 |
+
from smplx import SMPL, SMPLX
|
11 |
+
from progress.bar import Bar
|
12 |
+
|
13 |
+
from configs import constants as _C
|
14 |
+
from configs.config import parse_args
|
15 |
+
from lib.data.dataloader import setup_eval_dataloader
|
16 |
+
from lib.models import build_network, build_body_model
|
17 |
+
from lib.eval.eval_utils import (
|
18 |
+
compute_error_accel,
|
19 |
+
batch_align_by_pelvis,
|
20 |
+
batch_compute_similarity_transform_torch,
|
21 |
+
)
|
22 |
+
from lib.utils import transforms
|
23 |
+
from lib.utils.utils import prepare_output_dir
|
24 |
+
from lib.utils.utils import prepare_batch
|
25 |
+
from lib.utils.imutils import avg_preds
|
26 |
+
|
27 |
+
m2mm = 1e3
|
28 |
+
smplx2smpl = torch.from_numpy(joblib.load(_C.BMODEL.SMPLX2SMPL)['matrix']).unsqueeze(0).float().cuda()
|
29 |
+
@torch.no_grad()
|
30 |
+
def main(cfg, args):
|
31 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
32 |
+
torch.backends.cudnn.allow_tf32 = False
|
33 |
+
|
34 |
+
logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
|
35 |
+
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
|
36 |
+
|
37 |
+
# ========= Dataloaders ========= #
|
38 |
+
eval_loader = setup_eval_dataloader(cfg, 'rich', 'test', cfg.MODEL.BACKBONE)
|
39 |
+
logger.info(f'Dataset loaded')
|
40 |
+
|
41 |
+
# ========= Load WHAM ========= #
|
42 |
+
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
|
43 |
+
smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
|
44 |
+
network = build_network(cfg, smpl)
|
45 |
+
network.eval()
|
46 |
+
|
47 |
+
# Build neutral SMPL model for WHAM and gendered SMPLX models for the groundtruth data
|
48 |
+
smpl = SMPL(_C.BMODEL.FLDR, gender='neutral').to(cfg.DEVICE)
|
49 |
+
|
50 |
+
# Load vertices -> joints regression matrix to evaluate
|
51 |
+
J_regressor_eval = smpl.J_regressor.clone().unsqueeze(0)
|
52 |
+
pelvis_idxs = [1, 2]
|
53 |
+
|
54 |
+
accumulator = defaultdict(list)
|
55 |
+
bar = Bar('Inference', fill='#', max=len(eval_loader))
|
56 |
+
with torch.no_grad():
|
57 |
+
for i in range(len(eval_loader)):
|
58 |
+
time_dict = {}
|
59 |
+
_t = time()
|
60 |
+
|
61 |
+
# Original batch
|
62 |
+
batch = eval_loader.dataset.load_data(i, False)
|
63 |
+
x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
|
64 |
+
|
65 |
+
# <======= Inference
|
66 |
+
if cfg.FLIP_EVAL:
|
67 |
+
flipped_batch = eval_loader.dataset.load_data(i, True)
|
68 |
+
f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
|
69 |
+
|
70 |
+
# Forward pass with flipped input
|
71 |
+
flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
|
72 |
+
time_dict['inference_flipped'] = time() - _t; _t = time()
|
73 |
+
|
74 |
+
# Forward pass with normal input
|
75 |
+
pred = network(x, inits, features, **kwargs)
|
76 |
+
time_dict['inference'] = time() - _t; _t = time()
|
77 |
+
|
78 |
+
if cfg.FLIP_EVAL:
|
79 |
+
# Merge two predictions
|
80 |
+
flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
|
81 |
+
pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
|
82 |
+
flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
|
83 |
+
avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
|
84 |
+
avg_pose = avg_pose.reshape(-1, 144)
|
85 |
+
|
86 |
+
# Refine trajectory with merged prediction
|
87 |
+
network.pred_pose = avg_pose.view_as(network.pred_pose)
|
88 |
+
network.pred_shape = avg_shape.view_as(network.pred_shape)
|
89 |
+
pred = network.forward_smpl(**kwargs)
|
90 |
+
time_dict['averaging'] = time() - _t; _t = time()
|
91 |
+
# =======>
|
92 |
+
|
93 |
+
# <======= Build predicted SMPL
|
94 |
+
pred_output = smpl(body_pose=pred['poses_body'],
|
95 |
+
global_orient=pred['poses_root_cam'],
|
96 |
+
betas=pred['betas'].squeeze(0),
|
97 |
+
pose2rot=False)
|
98 |
+
pred_verts = pred_output.vertices.cpu()
|
99 |
+
pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu()
|
100 |
+
time_dict['building prediction'] = time() - _t; _t = time()
|
101 |
+
# =======>
|
102 |
+
|
103 |
+
# <======= Build groundtruth SMPL (from SMPLX)
|
104 |
+
smplx = SMPLX(_C.BMODEL.FLDR.replace('smpl', 'smplx'),
|
105 |
+
gender=batch['gender'],
|
106 |
+
batch_size=len(pred_verts)
|
107 |
+
).to(cfg.DEVICE)
|
108 |
+
gt_pose = transforms.matrix_to_axis_angle(transforms.rotation_6d_to_matrix(gt['pose'][0]))
|
109 |
+
target_output = smplx(
|
110 |
+
body_pose=gt_pose[:, 1:-2].reshape(-1, 63),
|
111 |
+
global_orient=gt_pose[:, 0],
|
112 |
+
betas=gt['betas'][0])
|
113 |
+
target_verts = torch.matmul(smplx2smpl, target_output.vertices.cuda()).cpu()
|
114 |
+
target_j3d = torch.matmul(J_regressor_eval, target_verts.to(cfg.DEVICE)).cpu()
|
115 |
+
time_dict['building target'] = time() - _t; _t = time()
|
116 |
+
# =======>
|
117 |
+
|
118 |
+
# <======= Compute performance of the current sequence
|
119 |
+
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
|
120 |
+
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs
|
121 |
+
)
|
122 |
+
S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
|
123 |
+
pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
124 |
+
mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
125 |
+
pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
|
126 |
+
accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1]
|
127 |
+
accel = accel * (30 ** 2) # per frame^s to per s^2
|
128 |
+
time_dict['evaluating'] = time() - _t; _t = time()
|
129 |
+
# =======>
|
130 |
+
|
131 |
+
# summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
|
132 |
+
summary_string = f'{batch["vid"]} | ' + ' '.join([f'{k}: {v:.1f} s' for k, v in time_dict.items()])
|
133 |
+
bar.suffix = summary_string
|
134 |
+
bar.next()
|
135 |
+
|
136 |
+
# <======= Accumulate the results over entire sequences
|
137 |
+
accumulator['pa_mpjpe'].append(pa_mpjpe)
|
138 |
+
accumulator['mpjpe'].append(mpjpe)
|
139 |
+
accumulator['pve'].append(pve)
|
140 |
+
accumulator['accel'].append(accel)
|
141 |
+
|
142 |
+
# =======>
|
143 |
+
|
144 |
+
for k, v in accumulator.items():
|
145 |
+
accumulator[k] = np.concatenate(v).mean()
|
146 |
+
|
147 |
+
print('')
|
148 |
+
log_str = 'Evaluation on RICH, '
|
149 |
+
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
|
150 |
+
logger.info(log_str)
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
cfg, cfg_file, args = parse_args(test=True)
|
154 |
+
cfg = prepare_output_dir(cfg, cfg_file)
|
155 |
+
|
156 |
+
main(cfg, args)
|
lib/models/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from configs import constants as _C
|
7 |
+
from .smpl import SMPL
|
8 |
+
|
9 |
+
|
10 |
+
def build_body_model(device, batch_size=1, gender='neutral', **kwargs):
|
11 |
+
sys.stdout = open(os.devnull, 'w')
|
12 |
+
body_model = SMPL(
|
13 |
+
model_path=_C.BMODEL.FLDR,
|
14 |
+
gender=gender,
|
15 |
+
batch_size=batch_size,
|
16 |
+
create_transl=False).to(device)
|
17 |
+
sys.stdout = sys.__stdout__
|
18 |
+
return body_model
|
19 |
+
|
20 |
+
|
21 |
+
def build_network(cfg, smpl):
|
22 |
+
from .wham import Network
|
23 |
+
|
24 |
+
with open(cfg.MODEL_CONFIG, 'r') as f:
|
25 |
+
model_config = yaml.safe_load(f)
|
26 |
+
model_config.update({'d_feat': _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]})
|
27 |
+
|
28 |
+
network = Network(smpl, **model_config).to(cfg.DEVICE)
|
29 |
+
|
30 |
+
# Load Checkpoint
|
31 |
+
if os.path.isfile(cfg.TRAIN.CHECKPOINT):
|
32 |
+
checkpoint = torch.load(cfg.TRAIN.CHECKPOINT)
|
33 |
+
ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
|
34 |
+
model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
|
35 |
+
network.load_state_dict(model_state_dict, strict=False)
|
36 |
+
logger.info(f"=> loaded checkpoint '{cfg.TRAIN.CHECKPOINT}' ")
|
37 |
+
else:
|
38 |
+
logger.info(f"=> Warning! no checkpoint found at '{cfg.TRAIN.CHECKPOINT}'.")
|
39 |
+
|
40 |
+
return network
|
lib/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.72 kB). View file
|
|
lib/models/__pycache__/smpl.cpython-39.pyc
ADDED
Binary file (7 kB). View file
|
|
lib/models/__pycache__/wham.cpython-39.pyc
ADDED
Binary file (4.95 kB). View file
|
|
lib/models/layers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modules import MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator
|
2 |
+
from .utils import rollout_global_motion, compute_camera_pose, reset_root_velocity, compute_camera_motion
|
lib/models/layers/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (482 Bytes). View file
|
|