Techt3o commited on
Commit
f561f8b
·
verified ·
1 Parent(s): 57ae837

e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. examples/test16.mov +3 -0
  3. examples/test17.mov +3 -0
  4. examples/test18.mov +3 -0
  5. examples/test19.mov +3 -0
  6. fetch_demo_data.sh +50 -0
  7. lib/core/loss.py +438 -0
  8. lib/core/trainer.py +341 -0
  9. lib/data/__init__.py +0 -0
  10. lib/data/__pycache__/__init__.cpython-39.pyc +0 -0
  11. lib/data/__pycache__/_dataset.cpython-39.pyc +0 -0
  12. lib/data/_dataset.py +77 -0
  13. lib/data/dataloader.py +46 -0
  14. lib/data/datasets/__init__.py +3 -0
  15. lib/data/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  16. lib/data/datasets/__pycache__/amass.cpython-39.pyc +0 -0
  17. lib/data/datasets/__pycache__/bedlam.cpython-39.pyc +0 -0
  18. lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc +0 -0
  19. lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc +0 -0
  20. lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc +0 -0
  21. lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc +0 -0
  22. lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc +0 -0
  23. lib/data/datasets/__pycache__/videos.cpython-39.pyc +0 -0
  24. lib/data/datasets/amass.py +173 -0
  25. lib/data/datasets/bedlam.py +165 -0
  26. lib/data/datasets/dataset2d.py +140 -0
  27. lib/data/datasets/dataset3d.py +172 -0
  28. lib/data/datasets/dataset_custom.py +115 -0
  29. lib/data/datasets/dataset_eval.py +113 -0
  30. lib/data/datasets/mixed_dataset.py +61 -0
  31. lib/data/datasets/videos.py +105 -0
  32. lib/data/utils/__pycache__/augmentor.cpython-39.pyc +0 -0
  33. lib/data/utils/__pycache__/normalizer.cpython-39.pyc +0 -0
  34. lib/data/utils/augmentor.py +292 -0
  35. lib/data/utils/normalizer.py +105 -0
  36. lib/data_utils/amass_utils.py +107 -0
  37. lib/data_utils/emdb_eval_utils.py +189 -0
  38. lib/data_utils/rich_eval_utils.py +69 -0
  39. lib/data_utils/threedpw_eval_utils.py +185 -0
  40. lib/data_utils/threedpw_train_utils.py +146 -0
  41. lib/eval/eval_utils.py +482 -0
  42. lib/eval/evaluate_3dpw.py +181 -0
  43. lib/eval/evaluate_emdb.py +228 -0
  44. lib/eval/evaluate_rich.py +156 -0
  45. lib/models/__init__.py +40 -0
  46. lib/models/__pycache__/__init__.cpython-39.pyc +0 -0
  47. lib/models/__pycache__/smpl.cpython-39.pyc +0 -0
  48. lib/models/__pycache__/wham.cpython-39.pyc +0 -0
  49. lib/models/layers/__init__.py +2 -0
  50. lib/models/layers/__pycache__/__init__.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -37,3 +37,7 @@ examples/drone_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
  examples/IMG_9730.mov filter=lfs diff=lfs merge=lfs -text
38
  examples/IMG_9731.mov filter=lfs diff=lfs merge=lfs -text
39
  examples/IMG_9732.mov filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
37
  examples/IMG_9730.mov filter=lfs diff=lfs merge=lfs -text
38
  examples/IMG_9731.mov filter=lfs diff=lfs merge=lfs -text
39
  examples/IMG_9732.mov filter=lfs diff=lfs merge=lfs -text
40
+ examples/test16.mov filter=lfs diff=lfs merge=lfs -text
41
+ examples/test17.mov filter=lfs diff=lfs merge=lfs -text
42
+ examples/test18.mov filter=lfs diff=lfs merge=lfs -text
43
+ examples/test19.mov filter=lfs diff=lfs merge=lfs -text
examples/test16.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f068400bf962e732e5517af45397694f84fae0a6592085b9dd3781fdbacaa550
3
+ size 1567779
examples/test17.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce06d8885332fd0b770273010dbd4da20a0867a386dc55925f85198651651253
3
+ size 2299497
examples/test18.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66fc6eb20e1c8525070c8004bed621e0acc2712accace1dbf1eb72fced62bb14
3
+ size 2033756
examples/test19.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:878219571dbf0e8ff56f4ba4bf325f90f46a730b57a35a2df91f4f509af616d8
3
+ size 1940593
fetch_demo_data.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
3
+
4
+ # SMPL Neutral model
5
+ echo -e "\nYou need to register at https://smplify.is.tue.mpg.de"
6
+ read -p "Username (SMPLify):" username
7
+ read -p "Password (SMPLify):" password
8
+ username=$(urle $username)
9
+ password=$(urle $password)
10
+
11
+ mkdir -p dataset/body_models/smpl
12
+ wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&resume=1&sfile=mpips_smplify_public_v2.zip' -O './dataset/body_models/smplify.zip' --no-check-certificate --continue
13
+ unzip dataset/body_models/smplify.zip -d dataset/body_models/smplify
14
+ mv dataset/body_models/smplify/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_NEUTRAL.pkl
15
+ rm -rf dataset/body_models/smplify
16
+ rm -rf dataset/body_models/smplify.zip
17
+
18
+ # SMPL Male and Female model
19
+ echo -e "\nYou need to register at https://smpl.is.tue.mpg.de"
20
+ read -p "Username (SMPL):" username
21
+ read -p "Password (SMPL):" password
22
+ username=$(urle $username)
23
+ password=$(urle $password)
24
+
25
+ wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip' -O './dataset/body_models/smpl.zip' --no-check-certificate --continue
26
+ unzip dataset/body_models/smpl.zip -d dataset/body_models/smpl
27
+ mv dataset/body_models/smpl/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_FEMALE.pkl
28
+ mv dataset/body_models/smpl/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_MALE.pkl
29
+ rm -rf dataset/body_models/smpl/smpl
30
+ rm -rf dataset/body_models/smpl.zip
31
+
32
+ # Auxiliary SMPL-related data
33
+ wget "https://drive.google.com/uc?id=1pbmzRbWGgae6noDIyQOnohzaVnX_csUZ&export=download&confirm=t" -O 'dataset/body_models.tar.gz'
34
+ tar -xvf dataset/body_models.tar.gz -C dataset/
35
+ rm -rf dataset/body_models.tar.gz
36
+
37
+ # Checkpoints
38
+ mkdir checkpoints
39
+ gdown "https://drive.google.com/uc?id=1i7kt9RlCCCNEW2aYaDWVr-G778JkLNcB&export=download&confirm=t" -O 'checkpoints/wham_vit_w_3dpw.pth.tar'
40
+ gdown "https://drive.google.com/uc?id=19qkI-a6xuwob9_RFNSPWf1yWErwVVlks&export=download&confirm=t" -O 'checkpoints/wham_vit_bedlam_w_3dpw.pth.tar'
41
+ gdown "https://drive.google.com/uc?id=1J6l8teyZrL0zFzHhzkC7efRhU0ZJ5G9Y&export=download&confirm=t" -O 'checkpoints/hmr2a.ckpt'
42
+ gdown "https://drive.google.com/uc?id=1kXTV4EYb-BI3H7J-bkR3Bc4gT9zfnHGT&export=download&confirm=t" -O 'checkpoints/dpvo.pth'
43
+ gdown "https://drive.google.com/uc?id=1zJ0KP23tXD42D47cw1Gs7zE2BA_V_ERo&export=download&confirm=t" -O 'checkpoints/yolov8x.pt'
44
+ gdown "https://drive.google.com/uc?id=1xyF7F3I7lWtdq82xmEPVQ5zl4HaasBso&export=download&confirm=t" -O 'checkpoints/vitpose-h-multi-coco.pth'
45
+
46
+ # Demo videos
47
+ gdown "https://drive.google.com/uc?id=1KjfODCcOUm_xIMLLR54IcjJtf816Dkc7&export=download&confirm=t" -O 'examples.tar.gz'
48
+ tar -xvf examples.tar.gz
49
+ rm -rf examples.tar.gz
50
+
lib/core/loss.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ from configs import constants as _C
10
+ from lib.utils import transforms
11
+ from lib.utils.kp_utils import root_centering
12
+
13
+ class WHAMLoss(nn.Module):
14
+ def __init__(
15
+ self,
16
+ cfg=None,
17
+ device=None,
18
+ ):
19
+ super(WHAMLoss, self).__init__()
20
+
21
+ self.cfg = cfg
22
+ self.n_joints = _C.KEYPOINTS.NUM_JOINTS
23
+ self.criterion = nn.MSELoss()
24
+ self.criterion_noreduce = nn.MSELoss(reduction='none')
25
+
26
+ self.pose_loss_weight = cfg.LOSS.POSE_LOSS_WEIGHT
27
+ self.shape_loss_weight = cfg.LOSS.SHAPE_LOSS_WEIGHT
28
+ self.keypoint_2d_loss_weight = cfg.LOSS.JOINT2D_LOSS_WEIGHT
29
+ self.keypoint_3d_loss_weight = cfg.LOSS.JOINT3D_LOSS_WEIGHT
30
+ self.cascaded_loss_weight = cfg.LOSS.CASCADED_LOSS_WEIGHT
31
+ self.vertices_loss_weight = cfg.LOSS.VERTS3D_LOSS_WEIGHT
32
+ self.contact_loss_weight = cfg.LOSS.CONTACT_LOSS_WEIGHT
33
+ self.root_vel_loss_weight = cfg.LOSS.ROOT_VEL_LOSS_WEIGHT
34
+ self.root_pose_loss_weight = cfg.LOSS.ROOT_POSE_LOSS_WEIGHT
35
+ self.sliding_loss_weight = cfg.LOSS.SLIDING_LOSS_WEIGHT
36
+ self.camera_loss_weight = cfg.LOSS.CAMERA_LOSS_WEIGHT
37
+ self.loss_weight = cfg.LOSS.LOSS_WEIGHT
38
+
39
+ kp_weights = [
40
+ 0.5, 0.5, 0.5, 0.5, 0.5, # Face
41
+ 1.5, 1.5, 4, 4, 4, 4, # Arms
42
+ 1.5, 1.5, 4, 4, 4, 4, # Legs
43
+ 4, 4, 1.5, 1.5, 4, 4, # Legs
44
+ 4, 4, 1.5, 1.5, 4, 4, # Arms
45
+ 0.5, 0.5 # Head
46
+ ]
47
+
48
+ theta_weights = [
49
+ 0.1, 1.0, 1.0, 1.0, 1.0, # pelvis, lhip, rhip, spine1, lknee
50
+ 1.0, 1.0, 1.0, 1.0, 1.0, # rknn, spine2, lankle, rankle, spin3
51
+ 0.1, 0.1, # Foot
52
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # neck, lisldr, risldr, head, losldr, rosldr,
53
+ 1.0, 1.0, 1.0, 1.0, # lelbow, relbow, lwrist, rwrist
54
+ 0.1, 0.1, # Hand
55
+ ]
56
+ self.theta_weights = torch.tensor([[theta_weights]]).float().to(device)
57
+ self.theta_weights /= self.theta_weights.mean()
58
+ self.kp_weights = torch.tensor([kp_weights]).float().to(device)
59
+
60
+ self.epoch = -1
61
+ self.step()
62
+
63
+ def step(self):
64
+ self.epoch += 1
65
+ self.skip_camera_loss = self.epoch < self.cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH
66
+
67
+ def forward(self, pred, gt):
68
+
69
+ loss = 0.0
70
+ b, f = gt['kp3d'].shape[:2]
71
+
72
+ # <======= Predictions and Groundtruths
73
+ pred_betas = pred['betas']
74
+ pred_pose = pred['pose'].reshape(b, f, -1, 6)
75
+ pred_kp3d_nn = pred['kp3d_nn']
76
+ pred_kp3d_smpl = root_centering(pred['kp3d'].reshape(b, f, -1, 3))
77
+ pred_full_kp2d = pred['full_kp2d']
78
+ pred_weak_kp2d = pred['weak_kp2d']
79
+ pred_contact = pred['contact']
80
+ pred_vel_root = pred['vel_root']
81
+ pred_pose_root = pred['poses_root_r6d'][:, 1:]
82
+ pred_vel_root_ref = pred['vel_root_refined']
83
+ pred_pose_root_ref = pred['poses_root_r6d_refined'][:, 1:]
84
+ pred_cam_r = transforms.matrix_to_rotation_6d(pred['R'])
85
+
86
+ gt_betas = gt['betas']
87
+ gt_pose = gt['pose']
88
+ gt_kp3d = root_centering(gt['kp3d'])
89
+ gt_full_kp2d = gt['full_kp2d']
90
+ gt_weak_kp2d = gt['weak_kp2d']
91
+ gt_contact = gt['contact']
92
+ gt_vel_root = gt['vel_root']
93
+ gt_pose_root = gt['pose_root'][:, 1:]
94
+ gt_cam_angvel = gt['cam_angvel']
95
+ gt_cam_r = transforms.matrix_to_rotation_6d(gt['R'][:, 1:])
96
+ bbox = gt['bbox']
97
+ # =======>
98
+
99
+ loss_keypoints_full = full_projected_keypoint_loss(
100
+ pred_full_kp2d,
101
+ gt_full_kp2d,
102
+ bbox,
103
+ self.kp_weights,
104
+ criterion=self.criterion_noreduce,
105
+ )
106
+
107
+ loss_keypoints_weak = weak_projected_keypoint_loss(
108
+ pred_weak_kp2d,
109
+ gt_weak_kp2d,
110
+ self.kp_weights,
111
+ criterion=self.criterion_noreduce
112
+ )
113
+
114
+ # Compute 3D keypoint loss
115
+ loss_keypoints_3d_nn = keypoint_3d_loss(
116
+ pred_kp3d_nn,
117
+ gt_kp3d[:, :, :self.n_joints],
118
+ self.kp_weights[:, :self.n_joints],
119
+ criterion=self.criterion_noreduce,
120
+ )
121
+
122
+ loss_keypoints_3d_smpl = keypoint_3d_loss(
123
+ pred_kp3d_smpl,
124
+ gt_kp3d,
125
+ self.kp_weights,
126
+ criterion=self.criterion_noreduce,
127
+ )
128
+
129
+ loss_cascaded = keypoint_3d_loss(
130
+ pred_kp3d_nn,
131
+ torch.cat((pred_kp3d_smpl[:, :, :self.n_joints], gt_kp3d[:, :, :self.n_joints, -1:]), dim=-1),
132
+ self.kp_weights[:, :self.n_joints] * 0.5,
133
+ criterion=self.criterion_noreduce,
134
+ )
135
+
136
+ loss_vertices = vertices_loss(
137
+ pred['verts_cam'],
138
+ gt['verts'],
139
+ gt['has_verts'],
140
+ criterion=self.criterion_noreduce,
141
+ )
142
+
143
+ # Compute loss on SMPL parameters
144
+ smpl_mask = gt['has_smpl']
145
+ loss_regr_pose, loss_regr_betas = smpl_losses(
146
+ pred_pose,
147
+ pred_betas,
148
+ gt_pose,
149
+ gt_betas,
150
+ self.theta_weights,
151
+ smpl_mask,
152
+ criterion=self.criterion_noreduce
153
+ )
154
+
155
+ # Compute loss on foot contact
156
+ loss_contact = contact_loss(
157
+ pred_contact,
158
+ gt_contact,
159
+ self.criterion_noreduce
160
+ )
161
+
162
+ # Compute loss on root velocity and angular velocity
163
+ loss_vel_root, loss_pose_root = root_loss(
164
+ pred_vel_root,
165
+ pred_pose_root,
166
+ gt_vel_root,
167
+ gt_pose_root,
168
+ gt_contact,
169
+ self.criterion_noreduce
170
+ )
171
+
172
+ # Root loss after trajectory refinement
173
+ loss_vel_root_ref, loss_pose_root_ref = root_loss(
174
+ pred_vel_root_ref,
175
+ pred_pose_root_ref,
176
+ gt_vel_root,
177
+ gt_pose_root,
178
+ gt_contact,
179
+ self.criterion_noreduce
180
+ )
181
+
182
+ # Camera prediction loss
183
+ loss_camera = camera_loss(
184
+ pred_cam_r,
185
+ gt_cam_r,
186
+ gt_cam_angvel[:, 1:],
187
+ gt['has_traj'],
188
+ self.criterion_noreduce,
189
+ self.skip_camera_loss
190
+ )
191
+
192
+ # Foot sliding loss
193
+ loss_sliding = sliding_loss(
194
+ pred['feet'],
195
+ gt_contact,
196
+ )
197
+
198
+ # Foot sliding loss
199
+ loss_sliding_ref = sliding_loss(
200
+ pred['feet_refined'],
201
+ gt_contact,
202
+ )
203
+
204
+ loss_keypoints = loss_keypoints_full + loss_keypoints_weak
205
+ loss_keypoints *= self.keypoint_2d_loss_weight
206
+ loss_keypoints_3d_smpl *= self.keypoint_3d_loss_weight
207
+ loss_keypoints_3d_nn *= self.keypoint_3d_loss_weight
208
+ loss_cascaded *= self.cascaded_loss_weight
209
+ loss_vertices *= self.vertices_loss_weight
210
+ loss_contact *= self.contact_loss_weight
211
+ loss_root = loss_vel_root * self.root_vel_loss_weight + loss_pose_root * self.root_pose_loss_weight
212
+ loss_root_ref = loss_vel_root_ref * self.root_vel_loss_weight + loss_pose_root_ref * self.root_pose_loss_weight
213
+
214
+ loss_regr_pose *= self.pose_loss_weight
215
+ loss_regr_betas *= self.shape_loss_weight
216
+
217
+ loss_sliding *= self.sliding_loss_weight
218
+ loss_camera *= self.camera_loss_weight
219
+ loss_sliding_ref *= self.sliding_loss_weight
220
+
221
+ loss_dict = {
222
+ 'pose': loss_regr_pose * self.loss_weight,
223
+ 'betas': loss_regr_betas * self.loss_weight,
224
+ '2d': loss_keypoints * self.loss_weight,
225
+ '3d': loss_keypoints_3d_smpl * self.loss_weight,
226
+ '3d_nn': loss_keypoints_3d_nn * self.loss_weight,
227
+ 'casc': loss_cascaded * self.loss_weight,
228
+ 'v3d': loss_vertices * self.loss_weight,
229
+ 'contact': loss_contact * self.loss_weight,
230
+ 'root': loss_root * self.loss_weight,
231
+ 'root_ref': loss_root_ref * self.loss_weight,
232
+ 'sliding': loss_sliding * self.loss_weight,
233
+ 'camera': loss_camera * self.loss_weight,
234
+ 'sliding_ref': loss_sliding_ref * self.loss_weight,
235
+ }
236
+
237
+ loss = sum(loss for loss in loss_dict.values())
238
+
239
+ return loss, loss_dict
240
+
241
+
242
+ def root_loss(
243
+ pred_vel_root,
244
+ pred_pose_root,
245
+ gt_vel_root,
246
+ gt_pose_root,
247
+ stationary,
248
+ criterion
249
+ ):
250
+
251
+ mask_r = (gt_pose_root != 0.0).all(dim=-1).all(dim=-1)
252
+ mask_v = (gt_vel_root != 0.0).all(dim=-1).all(dim=-1)
253
+ mask_s = (stationary != -1).any(dim=1).any(dim=1)
254
+ mask_v = mask_v * mask_s
255
+
256
+ if mask_r.any():
257
+ loss_r = criterion(pred_pose_root, gt_pose_root)[mask_r].mean()
258
+ else:
259
+ loss_r = torch.FloatTensor(1).fill_(0.).to(gt_pose_root.device)[0]
260
+
261
+ if mask_v.any():
262
+ loss_v = 0
263
+ T = gt_vel_root.shape[0]
264
+ ws_list = [1, 3, 9, 27]
265
+ for ws in ws_list:
266
+ tmp_v = 0
267
+ for m in range(T//ws):
268
+ cumulative_v = torch.sum(pred_vel_root[:, m:(m+1)*ws] - gt_vel_root[:, m:(m+1)*ws], dim=1)
269
+ tmp_v += torch.norm(cumulative_v, dim=-1)
270
+ loss_v += tmp_v
271
+ loss_v = loss_v[mask_v].mean()
272
+ else:
273
+ loss_v = torch.FloatTensor(1).fill_(0.).to(gt_vel_root.device)[0]
274
+
275
+ return loss_v, loss_r
276
+
277
+
278
+ def contact_loss(
279
+ pred_stationary,
280
+ gt_stationary,
281
+ criterion,
282
+ ):
283
+
284
+ mask = gt_stationary != -1
285
+ if mask.any():
286
+ loss = criterion(pred_stationary, gt_stationary)[mask].mean()
287
+ else:
288
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_stationary.device)[0]
289
+ return loss
290
+
291
+
292
+
293
+ def full_projected_keypoint_loss(
294
+ pred_keypoints_2d,
295
+ gt_keypoints_2d,
296
+ bbox,
297
+ weight,
298
+ criterion,
299
+ ):
300
+
301
+ scale = bbox[..., 2:] * 200.
302
+ conf = gt_keypoints_2d[..., -1]
303
+
304
+ if (conf > 0).any():
305
+ loss = torch.mean(
306
+ weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1)
307
+ ) / scale, dim=1).mean() * conf.mean()
308
+ else:
309
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0]
310
+ return loss
311
+
312
+
313
+ def weak_projected_keypoint_loss(
314
+ pred_keypoints_2d,
315
+ gt_keypoints_2d,
316
+ weight,
317
+ criterion,
318
+ ):
319
+
320
+ conf = gt_keypoints_2d[..., -1]
321
+ if (conf > 0).any():
322
+ loss = torch.mean(
323
+ weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1)
324
+ ), dim=1).mean() * conf.mean() * 5
325
+ else:
326
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0]
327
+ return loss
328
+
329
+
330
+ def keypoint_3d_loss(
331
+ pred_keypoints_3d,
332
+ gt_keypoints_3d,
333
+ weight,
334
+ criterion,
335
+ ):
336
+
337
+ conf = gt_keypoints_3d[..., -1]
338
+ if (conf > 0).any():
339
+ if weight.shape[-2] > 17:
340
+ pred_keypoints_3d[..., -14:] = pred_keypoints_3d[..., -14:] - pred_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True)
341
+ gt_keypoints_3d[..., -14:] = gt_keypoints_3d[..., -14:] - gt_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True)
342
+
343
+ loss = torch.mean(
344
+ weight * (conf * torch.norm(pred_keypoints_3d - gt_keypoints_3d[..., :3], dim=-1)
345
+ ), dim=1).mean() * conf.mean()
346
+ else:
347
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_3d.device)[0]
348
+ return loss
349
+
350
+
351
+ def vertices_loss(
352
+ pred_verts,
353
+ gt_verts,
354
+ mask,
355
+ criterion,
356
+ ):
357
+
358
+ if mask.sum() > 0:
359
+ # Align
360
+ pred_verts = pred_verts.view_as(gt_verts)
361
+ pred_verts = pred_verts - pred_verts.mean(-2, True)
362
+ gt_verts = gt_verts - gt_verts.mean(-2, True)
363
+
364
+ # loss = criterion(pred_verts, gt_verts).mean() * mask.float().mean()
365
+ # loss = torch.mean(
366
+ # `(torch.norm(pred_verts - gt_verts, dim=-1)[mask]`
367
+ # ), dim=1).mean() * mask.float().mean()
368
+ loss = torch.mean(
369
+ (torch.norm(pred_verts - gt_verts, p=1, dim=-1)[mask]
370
+ ), dim=1).mean() * mask.float().mean()
371
+ else:
372
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_verts.device)[0]
373
+ return loss
374
+
375
+
376
+ def smpl_losses(
377
+ pred_pose,
378
+ pred_betas,
379
+ gt_pose,
380
+ gt_betas,
381
+ weight,
382
+ mask,
383
+ criterion,
384
+ ):
385
+
386
+ if mask.any().item():
387
+ loss_regr_pose = torch.mean(
388
+ weight * torch.square(pred_pose - gt_pose)[mask].mean(-1)
389
+ ) * mask.float().mean()
390
+ loss_regr_betas = F.mse_loss(pred_betas, gt_betas, reduction='none')[mask].mean() * mask.float().mean()
391
+ else:
392
+ loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0]
393
+ loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0]
394
+
395
+ return loss_regr_pose, loss_regr_betas
396
+
397
+
398
+ def camera_loss(
399
+ pred_cam_r,
400
+ gt_cam_r,
401
+ cam_angvel,
402
+ mask,
403
+ criterion,
404
+ skip
405
+ ):
406
+ # mask = (gt_cam_r != 0.0).all(dim=-1).all(dim=-1)
407
+
408
+ if mask.any() and not skip:
409
+ # Camera pose loss in 6D representation
410
+ loss_r = criterion(pred_cam_r, gt_cam_r)[mask].mean()
411
+
412
+ # Reconstruct camera angular velocity and compute reconstruction loss
413
+ pred_R = transforms.rotation_6d_to_matrix(pred_cam_r)
414
+ cam_angvel_from_R = transforms.matrix_to_rotation_6d(pred_R[:, :-1] @ pred_R[:, 1:].transpose(-1, -2))
415
+ cam_angvel_from_R = (cam_angvel_from_R - torch.tensor([[[1, 0, 0, 0, 1, 0]]]).to(cam_angvel)) * 30
416
+ loss_a = criterion(cam_angvel, cam_angvel_from_R)[mask].mean()
417
+
418
+ loss = loss_r + loss_a
419
+ else:
420
+ loss = torch.FloatTensor(1).fill_(0.).to(gt_cam_r.device)[0]
421
+
422
+ return loss
423
+
424
+
425
+ def sliding_loss(
426
+ foot_position,
427
+ contact_prob,
428
+ ):
429
+ """ Compute foot skate loss when foot is assumed to be on contact with ground
430
+
431
+ foot_position: 3D foot (heel and toe) position, torch.Tensor (B, F, 4, 3)
432
+ contact_prob: contact probability of foot (heel and toe), torch.Tensor (B, F, 4)
433
+ """
434
+
435
+ contact_mask = (contact_prob > 0.5).detach().float()
436
+ foot_velocity = foot_position[:, 1:] - foot_position[:, :-1]
437
+ loss = (torch.norm(foot_velocity, dim=-1) * contact_mask[:, 1:]).mean()
438
+ return loss
lib/core/trainer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ import time
18
+ import torch
19
+ import shutil
20
+ import logging
21
+ import numpy as np
22
+ import os.path as osp
23
+ from progress.bar import Bar
24
+
25
+ from configs import constants as _C
26
+ from lib.utils import transforms
27
+ from lib.utils.utils import AverageMeter, prepare_batch
28
+ from lib.eval.eval_utils import (
29
+ compute_accel,
30
+ compute_error_accel,
31
+ batch_align_by_pelvis,
32
+ batch_compute_similarity_transform_torch,
33
+ )
34
+ from lib.models import build_body_model
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ class Trainer():
39
+ def __init__(self,
40
+ data_loaders,
41
+ network,
42
+ optimizer,
43
+ criterion=None,
44
+ train_stage='syn',
45
+ start_epoch=0,
46
+ checkpoint=None,
47
+ end_epoch=999,
48
+ lr_scheduler=None,
49
+ device=None,
50
+ writer=None,
51
+ debug=False,
52
+ resume=False,
53
+ logdir='output',
54
+ performance_type='min',
55
+ summary_iter=1,
56
+ ):
57
+
58
+ self.train_loader, self.valid_loader = data_loaders
59
+
60
+ # Model and optimizer
61
+ self.network = network
62
+ self.optimizer = optimizer
63
+
64
+ # Training parameters
65
+ self.train_stage = train_stage
66
+ self.start_epoch = start_epoch
67
+ self.end_epoch = end_epoch
68
+ self.criterion = criterion
69
+ self.lr_scheduler = lr_scheduler
70
+ self.device = device
71
+ self.writer = writer
72
+ self.debug = debug
73
+ self.resume = resume
74
+ self.logdir = logdir
75
+ self.summary_iter = summary_iter
76
+
77
+ self.performance_type = performance_type
78
+ self.train_global_step = 0
79
+ self.valid_global_step = 0
80
+ self.epoch = 0
81
+ self.best_performance = float('inf') if performance_type == 'min' else -float('inf')
82
+ self.summary_loss_keys = ['pose']
83
+
84
+ self.evaluation_accumulators = dict.fromkeys(
85
+ ['pred_j3d', 'target_j3d', 'pve'])# 'pred_verts', 'target_verts'])
86
+
87
+ self.J_regressor_eval = torch.from_numpy(
88
+ np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
89
+ )[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(device)
90
+
91
+ if self.writer is None:
92
+ from torch.utils.tensorboard import SummaryWriter
93
+ self.writer = SummaryWriter(log_dir=self.logdir)
94
+
95
+ if self.device is None:
96
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
97
+
98
+ if checkpoint is not None:
99
+ self.load_pretrained(checkpoint)
100
+
101
+ def train(self, ):
102
+ # Single epoch training routine
103
+
104
+ losses = AverageMeter()
105
+ kp_2d_loss = AverageMeter()
106
+ kp_3d_loss = AverageMeter()
107
+
108
+ timer = {
109
+ 'data': 0,
110
+ 'forward': 0,
111
+ 'loss': 0,
112
+ 'backward': 0,
113
+ 'batch': 0,
114
+ }
115
+ self.network.train()
116
+ start = time.time()
117
+ summary_string = ''
118
+
119
+ bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=len(self.train_loader))
120
+ for i, batch in enumerate(self.train_loader):
121
+
122
+ # <======= Feedforward
123
+ x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
124
+ timer['data'] = time.time() - start
125
+ start = time.time()
126
+ pred = self.network(x, inits, features, **kwargs)
127
+ timer['forward'] = time.time() - start
128
+ start = time.time()
129
+ # =======>
130
+
131
+ # <======= Backprop
132
+ loss, loss_dict = self.criterion(pred, gt)
133
+ timer['loss'] = time.time() - start
134
+ start = time.time()
135
+
136
+ # Clip gradients
137
+ self.optimizer.zero_grad()
138
+ loss.backward()
139
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
140
+ self.optimizer.step()
141
+ # =======>
142
+
143
+ # <======= Log training info
144
+ total_loss = loss
145
+ losses.update(total_loss.item(), x.size(0))
146
+ kp_2d_loss.update(loss_dict['2d'].item(), x.size(0))
147
+ kp_3d_loss.update(loss_dict['3d'].item(), x.size(0))
148
+
149
+ timer['backward'] = time.time() - start
150
+ timer['batch'] = timer['data'] + timer['forward'] + timer['loss'] + timer['backward']
151
+ start = time.time()
152
+
153
+ summary_string = f'({i + 1}/{len(self.train_loader)}) | Total: {bar.elapsed_td} ' \
154
+ f'| loss: {losses.avg:.2f} | 2d: {kp_2d_loss.avg:.2f} ' \
155
+ f'| 3d: {kp_3d_loss.avg:.2f} '
156
+
157
+ for k, v in loss_dict.items():
158
+ if k in self.summary_loss_keys:
159
+ summary_string += f' | {k}: {v:.2f}'
160
+ if (i + 1) % self.summary_iter == 0:
161
+ self.writer.add_scalar('train_loss/'+k, v, global_step=self.train_global_step)
162
+
163
+ if (i + 1) % self.summary_iter == 0:
164
+ self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step)
165
+
166
+ self.train_global_step += 1
167
+ bar.suffix = summary_string
168
+ bar.next(1)
169
+
170
+ if torch.isnan(total_loss):
171
+ exit('Nan value in loss, exiting!...')
172
+ # =======>
173
+
174
+ logger.info(summary_string)
175
+ bar.finish()
176
+
177
+ def validate(self, ):
178
+ self.network.eval()
179
+
180
+ start = time.time()
181
+ summary_string = ''
182
+ bar = Bar('Validation', fill='#', max=len(self.valid_loader))
183
+
184
+ if self.evaluation_accumulators is not None:
185
+ for k,v in self.evaluation_accumulators.items():
186
+ self.evaluation_accumulators[k] = []
187
+
188
+ with torch.no_grad():
189
+ for i, batch in enumerate(self.valid_loader):
190
+ x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
191
+
192
+ # <======= Feedforward
193
+ pred = self.network(x, inits, features, **kwargs)
194
+
195
+ # 3DPW dataset has groundtruth vertices
196
+ # NOTE: Following SPIN, we compute PVE against ground truth from Gendered SMPL mesh
197
+ smpl = build_body_model(self.device, batch_size=len(pred['verts_cam']), gender=batch['gender'][0])
198
+ gt_output = smpl.get_output(
199
+ body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]),
200
+ global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]),
201
+ betas=gt['betas'][0],
202
+ pose2rot=False
203
+ )
204
+
205
+ pred_j3d = torch.matmul(self.J_regressor_eval, pred['verts_cam']).cpu()
206
+ target_j3d = torch.matmul(self.J_regressor_eval, gt_output.vertices).cpu()
207
+ pred_verts = pred['verts_cam'].cpu()
208
+ target_verts = gt_output.vertices.cpu()
209
+
210
+ pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
211
+ [pred_j3d, target_j3d, pred_verts, target_verts], [2, 3]
212
+ )
213
+
214
+ self.evaluation_accumulators['pred_j3d'].append(pred_j3d.numpy())
215
+ self.evaluation_accumulators['target_j3d'].append(target_j3d.numpy())
216
+ pve = np.sqrt(np.sum((target_verts.numpy() - pred_verts.numpy()) ** 2, axis=-1)).mean(-1) * 1e3
217
+ self.evaluation_accumulators['pve'].append(pve[:, None])
218
+ # =======>
219
+
220
+ batch_time = time.time() - start
221
+
222
+ summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \
223
+ f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}'
224
+
225
+ self.valid_global_step += 1
226
+ bar.suffix = summary_string
227
+ bar.next()
228
+
229
+ logger.info(summary_string)
230
+
231
+ bar.finish()
232
+
233
+ def evaluate(self, ):
234
+ for k, v in self.evaluation_accumulators.items():
235
+ self.evaluation_accumulators[k] = np.vstack(v)
236
+
237
+ pred_j3ds = self.evaluation_accumulators['pred_j3d']
238
+ target_j3ds = self.evaluation_accumulators['target_j3d']
239
+
240
+ pred_j3ds = torch.from_numpy(pred_j3ds).float()
241
+ target_j3ds = torch.from_numpy(target_j3ds).float()
242
+
243
+ print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...')
244
+ errors = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
245
+ S1_hat = batch_compute_similarity_transform_torch(pred_j3ds, target_j3ds)
246
+ errors_pa = torch.sqrt(((S1_hat - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
247
+
248
+ m2mm = 1000
249
+ accel = np.mean(compute_accel(pred_j3ds)) * m2mm
250
+ accel_err = np.mean(compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm
251
+ mpjpe = np.mean(errors) * m2mm
252
+ pa_mpjpe = np.mean(errors_pa) * m2mm
253
+
254
+ eval_dict = {
255
+ 'mpjpe': mpjpe,
256
+ 'pa-mpjpe': pa_mpjpe,
257
+ 'accel': accel,
258
+ 'accel_err': accel_err
259
+ }
260
+
261
+ if 'pred_verts' in self.evaluation_accumulators.keys():
262
+ eval_dict.update({'pve': self.evaluation_accumulators['pve'].mean()})
263
+
264
+ log_str = f'Epoch {self.epoch}, '
265
+ log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in eval_dict.items()])
266
+ logger.info(log_str)
267
+
268
+ for k,v in eval_dict.items():
269
+ self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch)
270
+
271
+ # return (mpjpe + pa_mpjpe) / 2.
272
+ return pa_mpjpe
273
+
274
+ def save_model(self, performance, epoch):
275
+ save_dict = {
276
+ 'epoch': epoch,
277
+ 'model': self.network.state_dict(),
278
+ 'performance': performance,
279
+ 'optimizer': self.optimizer.state_dict(),
280
+ }
281
+
282
+ filename = osp.join(self.logdir, 'checkpoint.pth.tar')
283
+ torch.save(save_dict, filename)
284
+
285
+ if self.performance_type == 'min':
286
+ is_best = performance < self.best_performance
287
+ else:
288
+ is_best = performance > self.best_performance
289
+
290
+ if is_best:
291
+ logger.info('Best performance achived, saving it!')
292
+ self.best_performance = performance
293
+ shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar'))
294
+
295
+ with open(osp.join(self.logdir, 'best.txt'), 'w') as f:
296
+ f.write(str(float(performance)))
297
+
298
+ def fit(self):
299
+ for epoch in range(self.start_epoch, self.end_epoch):
300
+ self.epoch = epoch
301
+ self.train()
302
+ self.validate()
303
+ performance = self.evaluate()
304
+
305
+ self.criterion.step()
306
+ if self.lr_scheduler is not None:
307
+ self.lr_scheduler.step()
308
+
309
+ # log the learning rate
310
+ for param_group in self.optimizer.param_groups[:2]:
311
+ print(f'Learning rate {param_group["lr"]}')
312
+ self.writer.add_scalar('lr', param_group['lr'], global_step=self.epoch)
313
+
314
+ logger.info(f'Epoch {epoch+1} performance: {performance:.4f}')
315
+
316
+ self.save_model(performance, epoch)
317
+ self.train_loader.dataset.prepare_video_batch()
318
+
319
+ self.writer.close()
320
+
321
+ def load_pretrained(self, model_path):
322
+ if osp.isfile(model_path):
323
+ checkpoint = torch.load(model_path)
324
+
325
+ # network
326
+ ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
327
+ ignore_keys2 = [k for k in checkpoint['model'].keys() if 'integrator' in k]
328
+ ignore_keys.extend(ignore_keys2)
329
+ model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
330
+ model_state_dict = {k: v for k, v in model_state_dict.items() if k in self.network.state_dict().keys()}
331
+ self.network.load_state_dict(model_state_dict, strict=False)
332
+
333
+ if self.resume:
334
+ self.start_epoch = checkpoint['epoch']
335
+ self.best_performance = checkpoint['performance']
336
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
337
+
338
+ logger.info(f"=> loaded checkpoint '{model_path}' "
339
+ f"(epoch {self.start_epoch}, performance {self.best_performance})")
340
+ else:
341
+ logger.info(f"=> no checkpoint found at '{model_path}'")
lib/data/__init__.py ADDED
File without changes
lib/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (184 Bytes). View file
 
lib/data/__pycache__/_dataset.cpython-39.pyc ADDED
Binary file (3.18 kB). View file
 
lib/data/_dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import numpy as np
7
+ from skimage.util.shape import view_as_windows
8
+
9
+ from configs import constants as _C
10
+ from .utils.normalizer import Normalizer
11
+ from ..utils.imutils import transform
12
+
13
+ class BaseDataset(torch.utils.data.Dataset):
14
+ def __init__(self, cfg, training=True):
15
+ super(BaseDataset, self).__init__()
16
+ self.epoch = 0
17
+ self.training = training
18
+ self.n_joints = _C.KEYPOINTS.NUM_JOINTS
19
+ self.n_frames = cfg.DATASET.SEQLEN + 1
20
+ self.keypoints_normalizer = Normalizer(cfg)
21
+
22
+ def prepare_video_batch(self):
23
+ r = self.epoch % 4
24
+
25
+ self.video_indices = []
26
+ vid_name = self.labels['vid']
27
+ if isinstance(vid_name, torch.Tensor): vid_name = vid_name.numpy()
28
+ video_names_unique, group = np.unique(
29
+ vid_name, return_index=True)
30
+ perm = np.argsort(group)
31
+ group_perm = group[perm]
32
+ indices = np.split(
33
+ np.arange(0, self.labels['vid'].shape[0]), group_perm[1:]
34
+ )
35
+ for idx in range(len(video_names_unique)):
36
+ indexes = indices[idx]
37
+ if indexes.shape[0] < self.n_frames: continue
38
+ chunks = view_as_windows(
39
+ indexes, (self.n_frames), step=self.n_frames // 4
40
+ )
41
+ start_finish = chunks[r::4, (0, -1)].tolist()
42
+ self.video_indices += start_finish
43
+
44
+ self.epoch += 1
45
+
46
+ def __len__(self):
47
+ if self.training:
48
+ return len(self.video_indices)
49
+ else:
50
+ return len(self.labels['kp2d'])
51
+
52
+ def __getitem__(self, index):
53
+ return self.get_single_sequence(index)
54
+
55
+ def get_single_sequence(self, index):
56
+ NotImplementedError('get_single_sequence is not implemented')
57
+
58
+ def get_naive_intrinsics(self, res):
59
+ # Assume 45 degree FOV
60
+ img_w, img_h = res
61
+ self.focal_length = (img_w * img_w + img_h * img_h) ** 0.5
62
+ self.cam_intrinsics = torch.eye(3).repeat(1, 1, 1).float()
63
+ self.cam_intrinsics[:, 0, 0] = self.focal_length
64
+ self.cam_intrinsics[:, 1, 1] = self.focal_length
65
+ self.cam_intrinsics[:, 0, 2] = img_w/2.
66
+ self.cam_intrinsics[:, 1, 2] = img_h/2.
67
+
68
+ def j2d_processing(self, kp, bbox):
69
+ center = bbox[..., :2]
70
+ scale = bbox[..., -1:]
71
+ nparts = kp.shape[0]
72
+ for i in range(nparts):
73
+ kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale,
74
+ [224, 224])
75
+ kp[:, :2] = 2. * kp[:, :2] / 224 - 1.
76
+ kp = kp.astype('float32')
77
+ return kp
lib/data/dataloader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+
7
+ from .datasets import EvalDataset, DataFactory
8
+ from ..utils.data_utils import make_collate_fn
9
+
10
+
11
+ def setup_eval_dataloader(cfg, data, split='test', backbone=None):
12
+ if backbone is None:
13
+ backbone = cfg.MODEL.BACKBONE
14
+
15
+ dataset = EvalDataset(cfg, data, split, backbone)
16
+ dloader = torch.utils.data.DataLoader(
17
+ dataset,
18
+ batch_size=1,
19
+ num_workers=0,
20
+ shuffle=False,
21
+ pin_memory=True,
22
+ collate_fn=make_collate_fn()
23
+ )
24
+ return dloader
25
+
26
+
27
+ def setup_train_dataloader(cfg, ):
28
+ n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS
29
+
30
+ train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE)
31
+ dloader = torch.utils.data.DataLoader(
32
+ train_dataset,
33
+ batch_size=cfg.TRAIN.BATCH_SIZE,
34
+ num_workers=n_workers,
35
+ shuffle=True,
36
+ pin_memory=True,
37
+ collate_fn=make_collate_fn()
38
+ )
39
+ return dloader
40
+
41
+
42
+ def setup_dloaders(cfg, dset='3dpw', split='val'):
43
+ test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE)
44
+ train_dloader = setup_train_dataloader(cfg)
45
+
46
+ return train_dloader, test_dloader
lib/data/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dataset_eval import EvalDataset
2
+ from .dataset_custom import CustomDataset
3
+ from .mixed_dataset import DataFactory
lib/data/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (342 Bytes). View file
 
lib/data/datasets/__pycache__/amass.cpython-39.pyc ADDED
Binary file (5.66 kB). View file
 
lib/data/datasets/__pycache__/bedlam.cpython-39.pyc ADDED
Binary file (5.39 kB). View file
 
lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc ADDED
Binary file (4.14 kB). View file
 
lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc ADDED
Binary file (5.04 kB). View file
 
lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc ADDED
Binary file (3.57 kB). View file
 
lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc ADDED
Binary file (3.93 kB). View file
 
lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc ADDED
Binary file (2.99 kB). View file
 
lib/data/datasets/__pycache__/videos.cpython-39.pyc ADDED
Binary file (4.13 kB). View file
 
lib/data/datasets/amass.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import joblib
7
+ from lib.utils import transforms
8
+
9
+ from configs import constants as _C
10
+
11
+ from ..utils.augmentor import *
12
+ from .._dataset import BaseDataset
13
+ from ...models import build_body_model
14
+ from ...utils import data_utils as d_utils
15
+ from ...utils.kp_utils import root_centering
16
+
17
+
18
+
19
+ def compute_contact_label(feet, thr=1e-2, alpha=5):
20
+ vel = torch.zeros_like(feet[..., 0])
21
+ label = torch.zeros_like(feet[..., 0])
22
+
23
+ vel[1:-1] = (feet[2:] - feet[:-2]).norm(dim=-1) / 2.0
24
+ vel[0] = vel[1].clone()
25
+ vel[-1] = vel[-2].clone()
26
+
27
+ label = 1 / (1 + torch.exp(alpha * (thr ** -1) * (vel - thr)))
28
+ return label
29
+
30
+
31
+ class AMASSDataset(BaseDataset):
32
+ def __init__(self, cfg):
33
+ label_pth = _C.PATHS.AMASS_LABEL
34
+ super(AMASSDataset, self).__init__(cfg, training=True)
35
+
36
+ self.supervise_pose = cfg.TRAIN.STAGE == 'stage1'
37
+ self.labels = joblib.load(label_pth)
38
+ self.SequenceAugmentor = SequenceAugmentor(cfg.DATASET.SEQLEN + 1)
39
+
40
+ # Load augmentators
41
+ self.VideoAugmentor = VideoAugmentor(cfg)
42
+ self.SMPLAugmentor = SMPLAugmentor(cfg)
43
+ self.d_img_feature = _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]
44
+
45
+ self.n_frames = int(cfg.DATASET.SEQLEN * self.SequenceAugmentor.l_factor) + 1
46
+ self.smpl = build_body_model('cpu', self.n_frames)
47
+ self.prepare_video_batch()
48
+
49
+ # Naive assumption of image intrinsics
50
+ self.img_w, self.img_h = 1000, 1000
51
+ self.get_naive_intrinsics((self.img_w, self.img_h))
52
+
53
+ self.CameraAugmentor = CameraAugmentor(cfg.DATASET.SEQLEN + 1, self.img_w, self.img_h, self.focal_length)
54
+
55
+
56
+ @property
57
+ def __name__(self, ):
58
+ return 'AMASS'
59
+
60
+ def get_input(self, target):
61
+ gt_kp3d = target['kp3d']
62
+ inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
63
+ kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
64
+ mask = self.VideoAugmentor.get_mask()
65
+ kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
66
+
67
+ target['bbox'] = bbox[1:]
68
+ target['kp2d'] = kp2d
69
+ target['mask'] = mask[1:]
70
+ target['features'] = torch.zeros((self.SMPLAugmentor.n_frames, self.d_img_feature)).float()
71
+ return target
72
+
73
+ def get_groundtruth(self, target):
74
+ # GT 1. Joints
75
+ gt_kp3d = target['kp3d']
76
+ gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
77
+ target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1]) * float(self.supervise_pose)), dim=-1)
78
+ target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1]) * float(self.supervise_pose)), dim=-1)[1:]
79
+ target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
80
+ target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
81
+ target['verts'] = torch.zeros((self.SMPLAugmentor.n_frames, 6890, 3)).float()
82
+
83
+ # GT 2. Root pose
84
+ vel_world = (target['transl'][1:] - target['transl'][:-1])
85
+ pose_root = target['pose_root'].clone()
86
+ vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
87
+ target['vel_root'] = vel_root.clone()
88
+ target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
89
+ target['init_root'] = target['pose_root'][:1].clone()
90
+
91
+ # GT 3. Foot contact
92
+ contact = compute_contact_label(target['feet'])
93
+ if 'tread' in target['vid']:
94
+ target['contact'] = torch.ones_like(contact) * (-1)
95
+ else:
96
+ target['contact'] = contact
97
+
98
+ return target
99
+
100
+ def forward_smpl(self, target):
101
+ output = self.smpl.get_output(
102
+ body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
103
+ global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
104
+ betas=target['betas'],
105
+ pose2rot=False)
106
+
107
+ target['transl'] = target['transl'] - output.offset
108
+ target['transl'] = target['transl'] - target['transl'][0]
109
+ target['kp3d'] = output.joints
110
+ target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
111
+
112
+ return target
113
+
114
+ def augment_data(self, target):
115
+ # Augmentation 1. SMPL params augmentation
116
+ target = self.SMPLAugmentor(target)
117
+
118
+ # Augmentation 2. Sequence speed augmentation
119
+ target = self.SequenceAugmentor(target)
120
+
121
+ # Get world-coordinate SMPL
122
+ target = self.forward_smpl(target)
123
+
124
+ # Augmentation 3. Virtual camera generation
125
+ target = self.CameraAugmentor(target)
126
+
127
+ return target
128
+
129
+ def load_amass(self, index, target):
130
+ start_index, end_index = self.video_indices[index]
131
+
132
+ # Load AMASS labels
133
+ pose = torch.from_numpy(self.labels['pose'][start_index:end_index+1].copy())
134
+ pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
135
+ transl = torch.from_numpy(self.labels['transl'][start_index:end_index+1].copy())
136
+ betas = torch.from_numpy(self.labels['betas'][start_index:end_index+1].copy())
137
+
138
+ # Stack GT
139
+ target.update({'vid': self.labels['vid'][start_index],
140
+ 'pose': pose,
141
+ 'transl': transl,
142
+ 'betas': betas})
143
+
144
+ return target
145
+
146
+ def get_single_sequence(self, index):
147
+ target = {'res': torch.tensor([self.img_w, self.img_h]).float(),
148
+ 'cam_intrinsics': self.cam_intrinsics.clone(),
149
+ 'has_full_screen': torch.tensor(True),
150
+ 'has_smpl': torch.tensor(self.supervise_pose),
151
+ 'has_traj': torch.tensor(True),
152
+ 'has_verts': torch.tensor(False),}
153
+
154
+ target = self.load_amass(index, target)
155
+ target = self.augment_data(target)
156
+ target = self.get_groundtruth(target)
157
+ target = self.get_input(target)
158
+
159
+ target = d_utils.prepare_keypoints_data(target)
160
+ target = d_utils.prepare_smpl_data(target)
161
+
162
+ return target
163
+
164
+
165
+ def perspective_projection(points, cam_intrinsics, rotation=None, translation=None):
166
+ K = cam_intrinsics
167
+ if rotation is not None:
168
+ points = torch.matmul(rotation, points.transpose(1, 2)).transpose(1, 2)
169
+ if translation is not None:
170
+ points = points + translation.unsqueeze(1)
171
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
172
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
173
+ return projected_points[:, :, :-1]
lib/data/datasets/bedlam.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import joblib
7
+ from lib.utils import transforms
8
+
9
+ from configs import constants as _C
10
+
11
+ from .amass import compute_contact_label, perspective_projection
12
+ from ..utils.augmentor import *
13
+ from .._dataset import BaseDataset
14
+ from ...models import build_body_model
15
+ from ...utils import data_utils as d_utils
16
+ from ...utils.kp_utils import root_centering
17
+
18
+ class BEDLAMDataset(BaseDataset):
19
+ def __init__(self, cfg):
20
+ label_pth = _C.PATHS.BEDLAM_LABEL.replace('backbone', cfg.MODEL.BACKBONE)
21
+ super(BEDLAMDataset, self).__init__(cfg, training=True)
22
+
23
+ self.labels = joblib.load(label_pth)
24
+
25
+ self.VideoAugmentor = VideoAugmentor(cfg)
26
+ self.SMPLAugmentor = SMPLAugmentor(cfg, False)
27
+
28
+ self.smpl = build_body_model('cpu', self.n_frames)
29
+ self.prepare_video_batch()
30
+
31
+ @property
32
+ def __name__(self, ):
33
+ return 'BEDLAM'
34
+
35
+ def get_inputs(self, index, target, vis_thr=0.6):
36
+ start_index, end_index = self.video_indices[index]
37
+
38
+ bbox = self.labels['bbox'][start_index:end_index+1].clone()
39
+ bbox[:, 2] = bbox[:, 2] / 200
40
+
41
+ gt_kp3d = target['kp3d']
42
+ inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
43
+ # kp2d = perspective_projection(inpt_kp3d, target['K'])
44
+ kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
45
+ mask = self.VideoAugmentor.get_mask()
46
+ # kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
47
+ kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
48
+
49
+ target['bbox'] = bbox[1:]
50
+ target['kp2d'] = kp2d
51
+ target['mask'] = mask[1:]
52
+
53
+ # Image features
54
+ target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
55
+
56
+ return target
57
+
58
+ def get_groundtruth(self, index, target):
59
+ start_index, end_index = self.video_indices[index]
60
+
61
+ # GT 1. Joints
62
+ gt_kp3d = target['kp3d']
63
+ # gt_kp2d = perspective_projection(gt_kp3d, target['K'])
64
+ gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
65
+ target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1])), dim=-1)
66
+ # target['full_kp2d'] = torch.cat((gt_kp2d, torch.zeros_like(gt_kp2d[..., :1])), dim=-1)[1:]
67
+ target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1])), dim=-1)[1:]
68
+ target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
69
+ target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
70
+
71
+ # GT 2. Root pose
72
+ w_transl = self.labels['w_trans'][start_index:end_index+1]
73
+ pose_root = transforms.axis_angle_to_matrix(self.labels['root'][start_index:end_index+1])
74
+ vel_world = (w_transl[1:] - w_transl[:-1])
75
+ vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
76
+ target['vel_root'] = vel_root.clone()
77
+ target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
78
+ target['init_root'] = target['pose_root'][:1].clone()
79
+
80
+ return target
81
+
82
+ def forward_smpl(self, target):
83
+ output = self.smpl.get_output(
84
+ body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
85
+ global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
86
+ betas=target['betas'],
87
+ transl=target['transl'],
88
+ pose2rot=False)
89
+
90
+ target['kp3d'] = output.joints + output.offset.unsqueeze(1)
91
+ target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
92
+ target['verts'] = output.vertices[1:, ].clone()
93
+
94
+ return target
95
+
96
+ def augment_data(self, target):
97
+ # Augmentation 1. SMPL params augmentation
98
+ target = self.SMPLAugmentor(target)
99
+
100
+ # Get world-coordinate SMPL
101
+ target = self.forward_smpl(target)
102
+
103
+ return target
104
+
105
+ def load_camera(self, index, target):
106
+ start_index, end_index = self.video_indices[index]
107
+
108
+ # Get camera info
109
+ extrinsics = self.labels['extrinsics'][start_index:end_index+1].clone()
110
+ R = extrinsics[:, :3, :3]
111
+ T = extrinsics[:, :3, -1]
112
+ K = self.labels['intrinsics'][start_index:end_index+1].clone()
113
+ width, height = K[0, 0, 2] * 2, K[0, 1, 2] * 2
114
+ target['R'] = R
115
+ target['res'] = torch.tensor([width, height]).float()
116
+
117
+ # Compute angular velocity
118
+ cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
119
+ cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
120
+ target['cam_angvel'] = cam_angvel * 3e1 # BEDLAM is 30-fps
121
+
122
+ target['K'] = K # Use GT camera intrinsics for projecting keypoints
123
+ self.get_naive_intrinsics(target['res'])
124
+ target['cam_intrinsics'] = self.cam_intrinsics
125
+
126
+ return target
127
+
128
+ def load_params(self, index, target):
129
+ start_index, end_index = self.video_indices[index]
130
+
131
+ # Load AMASS labels
132
+ pose = self.labels['pose'][start_index:end_index+1].clone()
133
+ pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
134
+ transl = self.labels['c_trans'][start_index:end_index+1].clone()
135
+ betas = self.labels['betas'][start_index:end_index+1, :10].clone()
136
+
137
+ # Stack GT
138
+ target.update({'vid': self.labels['vid'][start_index].clone(),
139
+ 'pose': pose,
140
+ 'transl': transl,
141
+ 'betas': betas})
142
+
143
+ return target
144
+
145
+
146
+ def get_single_sequence(self, index):
147
+ target = {'has_full_screen': torch.tensor(True),
148
+ 'has_smpl': torch.tensor(True),
149
+ 'has_traj': torch.tensor(False),
150
+ 'has_verts': torch.tensor(True),
151
+
152
+ # Null contact label
153
+ 'contact': torch.ones((self.n_frames - 1, 4)) * (-1),
154
+ }
155
+
156
+ target = self.load_params(index, target)
157
+ target = self.load_camera(index, target)
158
+ target = self.augment_data(target)
159
+ target = self.get_groundtruth(index, target)
160
+ target = self.get_inputs(index, target)
161
+
162
+ target = d_utils.prepare_keypoints_data(target)
163
+ target = d_utils.prepare_smpl_data(target)
164
+
165
+ return target
lib/data/datasets/dataset2d.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import joblib
7
+
8
+ from .._dataset import BaseDataset
9
+ from ..utils.augmentor import *
10
+ from ...utils import data_utils as d_utils
11
+ from ...utils import transforms
12
+ from ...models import build_body_model
13
+ from ...utils.kp_utils import convert_kps, root_centering
14
+
15
+
16
+ class Dataset2D(BaseDataset):
17
+ def __init__(self, cfg, fname, training):
18
+ super(Dataset2D, self).__init__(cfg, training)
19
+
20
+ self.epoch = 0
21
+ self.n_frames = cfg.DATASET.SEQLEN + 1
22
+ self.labels = joblib.load(fname)
23
+
24
+ if self.training:
25
+ self.prepare_video_batch()
26
+
27
+ self.smpl = build_body_model('cpu', self.n_frames)
28
+ self.SMPLAugmentor = SMPLAugmentor(cfg, False)
29
+
30
+ def __getitem__(self, index):
31
+ return self.get_single_sequence(index)
32
+
33
+ def get_inputs(self, index, target, vis_thr=0.6):
34
+ start_index, end_index = self.video_indices[index]
35
+
36
+ # 2D keypoints detection
37
+ kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone()
38
+ kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], target['cam_intrinsics'], 224, 224, target['bbox'])
39
+ target['bbox'] = bbox[1:]
40
+ target['kp2d'] = kp2d
41
+
42
+ # Detection mask
43
+ target['mask'] = ~self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone().bool()
44
+
45
+ # Image features
46
+ target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
47
+
48
+ return target
49
+
50
+ def get_labels(self, index, target):
51
+ start_index, end_index = self.video_indices[index]
52
+
53
+ # SMPL parameters
54
+ # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input.
55
+ # We do not supervise the network on SMPL parameters.
56
+ target['pose'] = transforms.axis_angle_to_matrix(
57
+ self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3))
58
+ target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t
59
+
60
+ # Apply SMPL augmentor (y-axis rotation and initial frame noise)
61
+ target = self.SMPLAugmentor(target)
62
+
63
+ # 2D keypoints
64
+ kp2d = self.labels['kp2d'][start_index:end_index+1].clone().float()[..., :2]
65
+ gt_kp2d = torch.zeros((self.n_frames - 1, 31, 2))
66
+ gt_kp2d[:, :17] = kp2d[1:].clone()
67
+
68
+ # Set 0 confidence to the masked keypoints
69
+ mask = torch.zeros((self.n_frames - 1, 31))
70
+ mask[:, :17] = self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone()
71
+ mask = torch.logical_and(gt_kp2d.mean(-1) != 0, mask)
72
+ gt_kp2d = torch.cat((gt_kp2d, mask.float().unsqueeze(-1)), dim=-1)
73
+
74
+ _gt_kp2d = gt_kp2d.clone()
75
+ for idx in range(len(_gt_kp2d)):
76
+ _gt_kp2d[idx][..., :2] = torch.from_numpy(
77
+ self.j2d_processing(gt_kp2d[idx][..., :2].numpy().copy(),
78
+ target['bbox'][idx].numpy().copy()))
79
+
80
+ target['weak_kp2d'] = _gt_kp2d.clone()
81
+ target['full_kp2d'] = torch.zeros_like(gt_kp2d)
82
+ target['kp3d'] = torch.zeros((kp2d.shape[0], 31, 4))
83
+
84
+ # No SMPL vertices available
85
+ target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float()
86
+ return target
87
+
88
+ def get_init_frame(self, target):
89
+ # Prepare initial frame
90
+ output = self.smpl.get_output(
91
+ body_pose=target['init_pose'][:, 1:],
92
+ global_orient=target['init_pose'][:, :1],
93
+ betas=target['betas'][:1],
94
+ pose2rot=False
95
+ )
96
+ target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1)
97
+
98
+ return target
99
+
100
+ def get_single_sequence(self, index):
101
+ # Camera parameters
102
+ res = (224.0, 224.0)
103
+ bbox = torch.tensor([112.0, 112.0, 1.12])
104
+ res = torch.tensor(res)
105
+ self.get_naive_intrinsics(res)
106
+ bbox = bbox.repeat(self.n_frames, 1)
107
+
108
+ # Universal target
109
+ target = {'has_full_screen': torch.tensor(False),
110
+ 'has_smpl': torch.tensor(self.has_smpl),
111
+ 'has_traj': torch.tensor(self.has_traj),
112
+ 'has_verts': torch.tensor(False),
113
+ 'transl': torch.zeros((self.n_frames, 3)),
114
+
115
+ # Camera parameters and bbox
116
+ 'res': res,
117
+ 'cam_intrinsics': self.cam_intrinsics,
118
+ 'bbox': bbox,
119
+
120
+ # Null camera motion
121
+ 'R': torch.eye(3).repeat(self.n_frames, 1, 1),
122
+ 'cam_angvel': torch.zeros((self.n_frames - 1, 6)),
123
+
124
+ # Null root orientation and velocity
125
+ 'pose_root': torch.zeros((self.n_frames, 6)),
126
+ 'vel_root': torch.zeros((self.n_frames - 1, 3)),
127
+ 'init_root': torch.zeros((1, 6)),
128
+
129
+ # Null contact label
130
+ 'contact': torch.ones((self.n_frames - 1, 4)) * (-1)
131
+ }
132
+
133
+ self.get_inputs(index, target)
134
+ self.get_labels(index, target)
135
+ self.get_init_frame(target)
136
+
137
+ target = d_utils.prepare_keypoints_data(target)
138
+ target = d_utils.prepare_smpl_data(target)
139
+
140
+ return target
lib/data/datasets/dataset3d.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import joblib
7
+ import numpy as np
8
+
9
+ from .._dataset import BaseDataset
10
+ from ..utils.augmentor import *
11
+ from ...utils import data_utils as d_utils
12
+ from ...utils import transforms
13
+ from ...models import build_body_model
14
+ from ...utils.kp_utils import convert_kps, root_centering
15
+
16
+
17
+ class Dataset3D(BaseDataset):
18
+ def __init__(self, cfg, fname, training):
19
+ super(Dataset3D, self).__init__(cfg, training)
20
+
21
+ self.epoch = 0
22
+ self.labels = joblib.load(fname)
23
+ self.n_frames = cfg.DATASET.SEQLEN + 1
24
+
25
+ if self.training:
26
+ self.prepare_video_batch()
27
+
28
+ self.smpl = build_body_model('cpu', self.n_frames)
29
+ self.SMPLAugmentor = SMPLAugmentor(cfg, False)
30
+ self.VideoAugmentor = VideoAugmentor(cfg)
31
+
32
+ def __getitem__(self, index):
33
+ return self.get_single_sequence(index)
34
+
35
+ def get_inputs(self, index, target, vis_thr=0.6):
36
+ start_index, end_index = self.video_indices[index]
37
+
38
+ # 2D keypoints detection
39
+ kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone()
40
+ bbox = self.labels['bbox'][start_index:end_index+1][..., [0, 1, -1]].clone()
41
+ bbox[:, 2] = bbox[:, 2] / 200
42
+ kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
43
+
44
+ target['bbox'] = bbox[1:]
45
+ target['kp2d'] = kp2d
46
+ target['mask'] = self.labels['kp2d'][start_index+1:end_index+1][..., -1] < vis_thr
47
+
48
+ # Image features
49
+ target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
50
+
51
+ return target
52
+
53
+ def get_labels(self, index, target):
54
+ start_index, end_index = self.video_indices[index]
55
+
56
+ # SMPL parameters
57
+ # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input.
58
+ # We do not supervise the network on SMPL parameters.
59
+ target['pose'] = transforms.axis_angle_to_matrix(
60
+ self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3))
61
+ target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t
62
+
63
+ # Apply SMPL augmentor (y-axis rotation and initial frame noise)
64
+ target = self.SMPLAugmentor(target)
65
+
66
+ # 3D and 2D keypoints
67
+ if self.__name__ == 'ThreeDPW': # 3DPW has SMPL labels
68
+ gt_kp3d = self.labels['joints3D'][start_index:end_index+1].clone()
69
+ gt_kp2d = self.labels['joints2D'][start_index+1:end_index+1, ..., :2].clone()
70
+ gt_kp3d = root_centering(gt_kp3d.clone())
71
+
72
+ else: # Human36m and MPII do not have SMPL labels
73
+ gt_kp3d = torch.zeros((self.n_frames, self.n_joints + 14, 3))
74
+ gt_kp3d[:, self.n_joints:] = convert_kps(self.labels['joints3D'][start_index:end_index+1], 'spin', 'common')
75
+ gt_kp2d = torch.zeros((self.n_frames - 1, self.n_joints + 14, 2))
76
+ gt_kp2d[:, self.n_joints:] = convert_kps(self.labels['joints2D'][start_index+1:end_index+1, ..., :2], 'spin', 'common')
77
+
78
+ conf = self.mask.repeat(self.n_frames, 1).unsqueeze(-1)
79
+ gt_kp2d = torch.cat((gt_kp2d, conf[1:]), dim=-1)
80
+ gt_kp3d = torch.cat((gt_kp3d, conf), dim=-1)
81
+ target['kp3d'] = gt_kp3d
82
+ target['full_kp2d'] = gt_kp2d
83
+ target['weak_kp2d'] = torch.zeros_like(gt_kp2d)
84
+
85
+ if self.__name__ != 'ThreeDPW': # 3DPW does not contain world-coordinate motion
86
+ # Foot ground contact labels for Human36M and MPII3D
87
+ target['contact'] = self.labels['stationaries'][start_index+1:end_index+1].clone()
88
+ else:
89
+ # No foot ground contact label available for 3DPW
90
+ target['contact'] = torch.ones((self.n_frames - 1, 4)) * (-1)
91
+
92
+ if self.has_verts:
93
+ # SMPL vertices available for 3DPW
94
+ with torch.no_grad():
95
+ start_index, end_index = self.video_indices[index]
96
+ gender = self.labels['gender'][start_index].item()
97
+ output = self.smpl_gender[gender](
98
+ body_pose=target['pose'][1:, 1:],
99
+ global_orient=target['pose'][1:, :1],
100
+ betas=target['betas'][1:],
101
+ pose2rot=False,
102
+ )
103
+ target['verts'] = output.vertices.clone()
104
+ else:
105
+ # No SMPL vertices available
106
+ target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float()
107
+
108
+ return target
109
+
110
+ def get_init_frame(self, target):
111
+ # Prepare initial frame
112
+ output = self.smpl.get_output(
113
+ body_pose=target['init_pose'][:, 1:],
114
+ global_orient=target['init_pose'][:, :1],
115
+ betas=target['betas'][:1],
116
+ pose2rot=False
117
+ )
118
+ target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1)
119
+
120
+ return target
121
+
122
+ def get_camera_info(self, index, target):
123
+ start_index, end_index = self.video_indices[index]
124
+
125
+ # Intrinsics
126
+ target['res'] = self.labels['res'][start_index:end_index+1][0].clone()
127
+ self.get_naive_intrinsics(target['res'])
128
+ target['cam_intrinsics'] = self.cam_intrinsics.clone()
129
+
130
+ # Extrinsics pose
131
+ R = self.labels['cam_poses'][start_index:end_index+1, :3, :3].clone().float()
132
+ yaw = transforms.axis_angle_to_matrix(torch.tensor([[0, 2 * np.pi * np.random.uniform(), 0]])).float()
133
+ if self.__name__ == 'Human36M':
134
+ # Map Z-up to Y-down coordinate
135
+ zup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[-np.pi/2, 0, 0]])).float()
136
+ zup2ydown = torch.matmul(yaw, zup2ydown)
137
+ R = torch.matmul(R, zup2ydown)
138
+ elif self.__name__ == 'MPII3D':
139
+ # Map Y-up to Y-down coordinate
140
+ yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float()
141
+ yup2ydown = torch.matmul(yaw, yup2ydown)
142
+ R = torch.matmul(R, yup2ydown)
143
+
144
+ return target
145
+
146
+ def get_single_sequence(self, index):
147
+ # Universal target
148
+ target = {'has_full_screen': torch.tensor(True),
149
+ 'has_smpl': torch.tensor(self.has_smpl),
150
+ 'has_traj': torch.tensor(self.has_traj),
151
+ 'has_verts': torch.tensor(self.has_verts),
152
+ 'transl': torch.zeros((self.n_frames, 3)),
153
+
154
+ # Null camera motion
155
+ 'R': torch.eye(3).repeat(self.n_frames, 1, 1),
156
+ 'cam_angvel': torch.zeros((self.n_frames - 1, 6)),
157
+
158
+ # Null root orientation and velocity
159
+ 'pose_root': torch.zeros((self.n_frames, 6)),
160
+ 'vel_root': torch.zeros((self.n_frames - 1, 3)),
161
+ 'init_root': torch.zeros((1, 6)),
162
+ }
163
+
164
+ self.get_camera_info(index, target)
165
+ self.get_inputs(index, target)
166
+ self.get_labels(index, target)
167
+ self.get_init_frame(target)
168
+
169
+ target = d_utils.prepare_keypoints_data(target)
170
+ target = d_utils.prepare_smpl_data(target)
171
+
172
+ return target
lib/data/datasets/dataset_custom.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+
7
+ from ..utils.normalizer import Normalizer
8
+ from ...models import build_body_model
9
+ from ...utils import transforms
10
+ from ...utils.kp_utils import root_centering
11
+ from ...utils.imutils import compute_cam_intrinsics
12
+
13
+ KEYPOINTS_THR = 0.3
14
+
15
+ def convert_dpvo_to_cam_angvel(traj, fps):
16
+ """Function to convert DPVO trajectory output to camera angular velocity"""
17
+
18
+ # 0 ~ 3: translation, 3 ~ 7: Quaternion
19
+ quat = traj[:, 3:]
20
+
21
+ # Convert (x,y,z,q) to (q,x,y,z)
22
+ quat = quat[:, [3, 0, 1, 2]]
23
+
24
+ # Quat is camera to world transformation. Convert it to world to camera
25
+ world2cam = transforms.quaternion_to_matrix(torch.from_numpy(quat)).float()
26
+ R = world2cam.mT
27
+
28
+ # Compute the rotational changes over time.
29
+ cam_angvel = transforms.matrix_to_axis_angle(R[:-1] @ R[1:].transpose(-1, -2))
30
+
31
+ # Convert matrix to 6D representation
32
+ cam_angvel = transforms.matrix_to_rotation_6d(transforms.axis_angle_to_matrix(cam_angvel))
33
+
34
+ # Normalize 6D angular velocity
35
+ cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
36
+ cam_angvel = cam_angvel * fps
37
+ cam_angvel = torch.cat((cam_angvel, cam_angvel[:1]), dim=0)
38
+ return cam_angvel
39
+
40
+
41
+ class CustomDataset(torch.utils.data.Dataset):
42
+ def __init__(self, cfg, tracking_results, slam_results, width, height, fps):
43
+
44
+ self.tracking_results = tracking_results
45
+ self.slam_results = slam_results
46
+ self.width = width
47
+ self.height = height
48
+ self.fps = fps
49
+ self.res = torch.tensor([width, height]).float()
50
+ self.intrinsics = compute_cam_intrinsics(self.res)
51
+
52
+ self.device = cfg.DEVICE.lower()
53
+
54
+ self.smpl = build_body_model('cpu')
55
+ self.keypoints_normalizer = Normalizer(cfg)
56
+
57
+ self._to = lambda x: x.unsqueeze(0).to(self.device)
58
+
59
+ def __len__(self):
60
+ return len(self.tracking_results.keys())
61
+
62
+ def load_data(self, index, flip=False):
63
+ if flip:
64
+ self.prefix = 'flipped_'
65
+ else:
66
+ self.prefix = ''
67
+
68
+ return self.__getitem__(index)
69
+
70
+ def __getitem__(self, _index):
71
+ if _index >= len(self): return
72
+
73
+ index = sorted(list(self.tracking_results.keys()))[_index]
74
+
75
+ # Process 2D keypoints
76
+ kp2d = torch.from_numpy(self.tracking_results[index][self.prefix + 'keypoints']).float()
77
+ mask = kp2d[..., -1] < KEYPOINTS_THR
78
+ bbox = torch.from_numpy(self.tracking_results[index][self.prefix + 'bbox']).float()
79
+
80
+ norm_kp2d, _ = self.keypoints_normalizer(
81
+ kp2d[..., :-1].clone(), self.res, self.intrinsics, 224, 224, bbox
82
+ )
83
+
84
+ # Process image features
85
+ features = self.tracking_results[index][self.prefix + 'features']
86
+
87
+ # Process initial pose
88
+ init_output = self.smpl.get_output(
89
+ global_orient=self.tracking_results[index][self.prefix + 'init_global_orient'],
90
+ body_pose=self.tracking_results[index][self.prefix + 'init_body_pose'],
91
+ betas=self.tracking_results[index][self.prefix + 'init_betas'],
92
+ pose2rot=False,
93
+ return_full_pose=True
94
+ )
95
+ init_kp3d = root_centering(init_output.joints[:, :17], 'coco')
96
+ init_kp = torch.cat((init_kp3d.reshape(1, -1), norm_kp2d[0].clone().reshape(1, -1)), dim=-1)
97
+ init_smpl = transforms.matrix_to_rotation_6d(init_output.full_pose)
98
+ init_root = transforms.matrix_to_rotation_6d(init_output.global_orient)
99
+
100
+ # Process SLAM results
101
+ cam_angvel = convert_dpvo_to_cam_angvel(self.slam_results, self.fps)
102
+
103
+ return (
104
+ index, # subject id
105
+ self._to(norm_kp2d), # 2d keypoints
106
+ (self._to(init_kp), self._to(init_smpl)), # initial pose
107
+ self._to(features), # image features
108
+ self._to(mask), # keypoints mask
109
+ init_root.to(self.device), # initial root orientation
110
+ self._to(cam_angvel), # camera angular velocity
111
+ self.tracking_results[index]['frame_id'], # frame indices
112
+ {'cam_intrinsics': self._to(self.intrinsics), # other keyword arguments
113
+ 'bbox': self._to(bbox),
114
+ 'res': self._to(self.res)},
115
+ )
lib/data/datasets/dataset_eval.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import torch
7
+ import joblib
8
+
9
+ from configs import constants as _C
10
+ from .._dataset import BaseDataset
11
+ from ...utils import transforms
12
+ from ...utils import data_utils as d_utils
13
+ from ...utils.kp_utils import root_centering
14
+
15
+ FPS = 30
16
+ class EvalDataset(BaseDataset):
17
+ def __init__(self, cfg, data, split, backbone):
18
+ super(EvalDataset, self).__init__(cfg, False)
19
+
20
+ self.prefix = ''
21
+ self.data = data
22
+ parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth')
23
+ self.labels = joblib.load(parsed_data_path)
24
+
25
+ def load_data(self, index, flip=False):
26
+ if flip:
27
+ self.prefix = 'flipped_'
28
+ else:
29
+ self.prefix = ''
30
+
31
+ target = self.__getitem__(index)
32
+ for key, val in target.items():
33
+ if isinstance(val, torch.Tensor):
34
+ target[key] = val.unsqueeze(0)
35
+ return target
36
+
37
+ def __getitem__(self, index):
38
+ target = {}
39
+ target = self.get_data(index)
40
+ target = d_utils.prepare_keypoints_data(target)
41
+ target = d_utils.prepare_smpl_data(target)
42
+
43
+ return target
44
+
45
+ def __len__(self):
46
+ return len(self.labels['kp2d'])
47
+
48
+ def prepare_labels(self, index, target):
49
+ # Ground truth SMPL parameters
50
+ target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3))
51
+ target['betas'] = self.labels['betas'][index]
52
+ target['gender'] = self.labels['gender'][index]
53
+
54
+ # Sequence information
55
+ target['res'] = self.labels['res'][index][0]
56
+ target['vid'] = self.labels['vid'][index]
57
+ target['frame_id'] = self.labels['frame_id'][index][1:]
58
+
59
+ # Camera information
60
+ self.get_naive_intrinsics(target['res'])
61
+ target['cam_intrinsics'] = self.cam_intrinsics
62
+ R = self.labels['cam_poses'][index][:, :3, :3].clone()
63
+ if 'emdb' in self.data.lower():
64
+ # Use groundtruth camera angular velocity.
65
+ # Can be updated with SLAM results if you have it.
66
+ cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
67
+ cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS
68
+ target['R'] = R
69
+ else:
70
+ cam_angvel = torch.zeros((len(target['pose']) - 1, 6))
71
+ target['cam_angvel'] = cam_angvel
72
+ return target
73
+
74
+ def prepare_inputs(self, index, target):
75
+ for key in ['features', 'bbox']:
76
+ data = self.labels[self.prefix + key][index][1:]
77
+ target[key] = data
78
+
79
+ bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float()
80
+ bbox[:, 2] = bbox[:, 2] / 200
81
+
82
+ # Normalize keypoints
83
+ kp2d, bbox = self.keypoints_normalizer(
84
+ self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(),
85
+ target['res'], target['cam_intrinsics'], 224, 224, bbox)
86
+ target['kp2d'] = kp2d
87
+ target['bbox'] = bbox[1:]
88
+
89
+ # Masking out low confident keypoints
90
+ mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3
91
+ target['input_kp2d'] = self.labels['kp2d'][index][1:]
92
+ target['input_kp2d'][mask[1:]] *= 0
93
+ target['mask'] = mask[1:]
94
+
95
+ return target
96
+
97
+ def prepare_initialization(self, index, target):
98
+ # Initial frame per-frame estimation
99
+ target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1)
100
+ target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu()
101
+ pose_root = target['pose'][:, 0].clone()
102
+ target['init_root'] = transforms.matrix_to_rotation_6d(pose_root)
103
+
104
+ return target
105
+
106
+ def get_data(self, index):
107
+ target = {}
108
+
109
+ target = self.prepare_labels(index, target)
110
+ target = self.prepare_inputs(index, target)
111
+ target = self.prepare_initialization(index, target)
112
+
113
+ return target
lib/data/datasets/mixed_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from .amass import AMASSDataset
9
+ from .videos import Human36M, ThreeDPW, MPII3D, InstaVariety
10
+ from .bedlam import BEDLAMDataset
11
+ from lib.utils.data_utils import make_collate_fn
12
+
13
+
14
+ class DataFactory(torch.utils.data.Dataset):
15
+ def __init__(self, cfg, train_stage='syn'):
16
+ super(DataFactory, self).__init__()
17
+
18
+ if train_stage == 'stage1':
19
+ self.datasets = [AMASSDataset(cfg)]
20
+ self.dataset_names = ['AMASS']
21
+ elif train_stage == 'stage2':
22
+ self.datasets = [
23
+ AMASSDataset(cfg), ThreeDPW(cfg),
24
+ Human36M(cfg), MPII3D(cfg), InstaVariety(cfg)
25
+ ]
26
+ self.dataset_names = ['AMASS', '3DPW', 'Human36M', 'MPII3D', 'Insta']
27
+
28
+ if len(cfg.DATASET.RATIO) == 6: # Use BEDLAM
29
+ self.datasets.append(BEDLAMDataset(cfg))
30
+ self.dataset_names.append('BEDLAM')
31
+
32
+ self._set_partition(cfg.DATASET.RATIO)
33
+ self.lengths = [len(ds) for ds in self.datasets]
34
+
35
+ @property
36
+ def __name__(self, ):
37
+ return 'MixedData'
38
+
39
+ def prepare_video_batch(self):
40
+ [ds.prepare_video_batch() for ds in self.datasets]
41
+ self.lengths = [len(ds) for ds in self.datasets]
42
+
43
+ def _set_partition(self, partition):
44
+ self.partition = partition
45
+ self.ratio = partition
46
+ self.partition = np.array(self.partition).cumsum()
47
+ self.partition /= self.partition[-1]
48
+
49
+ def __len__(self):
50
+ return int(np.array([l for l, r in zip(self.lengths, self.ratio) if r > 0]).mean())
51
+
52
+ def __getitem__(self, index):
53
+ # Get the dataset to sample from
54
+ p = np.random.rand()
55
+ for i in range(len(self.datasets)):
56
+ if p <= self.partition[i]:
57
+ if len(self.datasets) == 1:
58
+ return self.datasets[i][index % self.lengths[i]]
59
+ else:
60
+ d_index = np.random.randint(0, self.lengths[i])
61
+ return self.datasets[i][d_index]
lib/data/datasets/videos.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import torch
7
+
8
+ from configs import constants as _C
9
+ from .dataset3d import Dataset3D
10
+ from .dataset2d import Dataset2D
11
+ from ...utils.kp_utils import convert_kps
12
+ from smplx import SMPL
13
+
14
+
15
+ class Human36M(Dataset3D):
16
+ def __init__(self, cfg, dset='train'):
17
+ parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth')
18
+ parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
19
+ super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train')
20
+
21
+ self.has_3d = True
22
+ self.has_traj = True
23
+ self.has_smpl = False
24
+ self.has_verts = False
25
+
26
+ # Among 31 joints format, 14 common joints are avaialable
27
+ self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
28
+ self.mask[-14:] = 1
29
+
30
+ @property
31
+ def __name__(self, ):
32
+ return 'Human36M'
33
+
34
+ def compute_3d_keypoints(self, index):
35
+ return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
36
+ )[:, _C.KEYPOINTS.H36M_TO_J14].float()
37
+
38
+ class MPII3D(Dataset3D):
39
+ def __init__(self, cfg, dset='train'):
40
+ parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth')
41
+ parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
42
+ super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train')
43
+
44
+ self.has_3d = True
45
+ self.has_traj = True
46
+ self.has_smpl = False
47
+ self.has_verts = False
48
+
49
+ # Among 31 joints format, 14 common joints are avaialable
50
+ self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
51
+ self.mask[-14:] = 1
52
+
53
+ @property
54
+ def __name__(self, ):
55
+ return 'MPII3D'
56
+
57
+ def compute_3d_keypoints(self, index):
58
+ return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
59
+ )[:, _C.KEYPOINTS.H36M_TO_J17].float()
60
+
61
+ class ThreeDPW(Dataset3D):
62
+ def __init__(self, cfg, dset='train'):
63
+ parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth')
64
+ parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
65
+ super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train')
66
+
67
+ self.has_3d = True
68
+ self.has_traj = False
69
+ self.has_smpl = True
70
+ self.has_verts = True # In testing
71
+
72
+ # Among 31 joints format, 14 common joints are avaialable
73
+ self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
74
+ self.mask[:-14] = 1
75
+
76
+ self.smpl_gender = {
77
+ 0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10),
78
+ 1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10)
79
+ }
80
+
81
+ @property
82
+ def __name__(self, ):
83
+ return 'ThreeDPW'
84
+
85
+ def compute_3d_keypoints(self, index):
86
+ return self.labels['joints3D'][index]
87
+
88
+
89
+ class InstaVariety(Dataset2D):
90
+ def __init__(self, cfg, dset='train'):
91
+ parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth')
92
+ parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
93
+ super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train')
94
+
95
+ self.has_3d = False
96
+ self.has_traj = False
97
+ self.has_smpl = False
98
+
99
+ # Among 31 joints format, 17 coco joints are avaialable
100
+ self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
101
+ self.mask[:17] = 1
102
+
103
+ @property
104
+ def __name__(self, ):
105
+ return 'InstaVariety'
lib/data/utils/__pycache__/augmentor.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
lib/data/utils/__pycache__/normalizer.cpython-39.pyc ADDED
Binary file (4.13 kB). View file
 
lib/data/utils/augmentor.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ from configs import constants as _C
6
+
7
+ import torch
8
+ import numpy as np
9
+ from torch.nn import functional as F
10
+
11
+ from ...utils import transforms
12
+
13
+ __all__ = ['VideoAugmentor', 'SMPLAugmentor', 'SequenceAugmentor', 'CameraAugmentor']
14
+
15
+
16
+ num_joints = _C.KEYPOINTS.NUM_JOINTS
17
+ class VideoAugmentor():
18
+ def __init__(self, cfg, train=True):
19
+ self.train = train
20
+ self.l = cfg.DATASET.SEQLEN + 1
21
+ self.aug_dict = torch.load(_C.KEYPOINTS.COCO_AUG_DICT)
22
+
23
+ def get_jitter(self, ):
24
+ """Guassian jitter modeling."""
25
+ jittering_noise = torch.normal(
26
+ mean=torch.zeros((self.l, num_joints, 3)),
27
+ std=self.aug_dict['jittering'].reshape(1, num_joints, 1).expand(self.l, -1, 3)
28
+ ) * _C.KEYPOINTS.S_JITTERING
29
+ return jittering_noise
30
+
31
+ def get_lfhp(self, ):
32
+ """Low-frequency high-peak noise modeling."""
33
+ def get_peak_noise_mask():
34
+ peak_noise_mask = torch.rand(self.l, num_joints).float() * self.aug_dict['pmask'].squeeze(0)
35
+ peak_noise_mask = peak_noise_mask < _C.KEYPOINTS.S_PEAK_MASK
36
+ return peak_noise_mask
37
+
38
+ peak_noise_mask = get_peak_noise_mask()
39
+ peak_noise = peak_noise_mask.float().unsqueeze(-1).repeat(1, 1, 3)
40
+ peak_noise = peak_noise * torch.randn(3) * self.aug_dict['peak'].reshape(1, -1, 1) * _C.KEYPOINTS.S_PEAK
41
+ return peak_noise
42
+
43
+ def get_bias(self, ):
44
+ """Bias noise modeling."""
45
+ bias_noise = torch.normal(
46
+ mean=torch.zeros((num_joints, 3)), std=self.aug_dict['bias'].reshape(num_joints, 1)
47
+ ).unsqueeze(0) * _C.KEYPOINTS.S_BIAS
48
+ return bias_noise
49
+
50
+ def get_mask(self, scale=None):
51
+ """Mask modeling."""
52
+
53
+ if scale is None:
54
+ scale = _C.KEYPOINTS.S_MASK
55
+ # Per-frame and joint
56
+ mask = torch.rand(self.l, num_joints) < scale
57
+ visible = (~mask).clone()
58
+ for child in range(num_joints):
59
+ parent = _C.KEYPOINTS.TREE[child]
60
+ if parent == -1: continue
61
+ if isinstance(parent, list):
62
+ visible[:, child] *= (visible[:, parent[0]] * visible[:, parent[1]])
63
+ else:
64
+ visible[:, child] *= visible[:, parent]
65
+ mask = (~visible).clone()
66
+
67
+ return mask
68
+
69
+ def __call__(self, keypoints):
70
+ keypoints += self.get_bias() + self.get_jitter() + self.get_lfhp()
71
+ return keypoints
72
+
73
+
74
+ class SMPLAugmentor():
75
+ noise_scale = 1e-2
76
+
77
+ def __init__(self, cfg, augment=True):
78
+ self.n_frames = cfg.DATASET.SEQLEN
79
+ self.augment = augment
80
+
81
+ def __call__(self, target):
82
+ if not self.augment:
83
+ # Only add initial frame augmentation
84
+ if not 'init_pose' in target:
85
+ target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation()
86
+ return target
87
+
88
+ n_frames = target['pose'].shape[0]
89
+
90
+ # Global rotation
91
+ rmat = self.get_global_augmentation()
92
+ target['pose'][:, 0] = rmat @ target['pose'][:, 0]
93
+ target['transl'] = (rmat.squeeze() @ target['transl'].T).T
94
+
95
+ # Shape
96
+ shape_noise = self.get_shape_augmentation(n_frames)
97
+ target['betas'] = target['betas'] + shape_noise
98
+
99
+ # Initial frames mis-prediction
100
+ target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation()
101
+
102
+ return target
103
+
104
+ def get_global_augmentation(self, ):
105
+ """Global coordinate augmentation. Random rotation around y-axis"""
106
+
107
+ angle_y = torch.rand(1) * 2 * np.pi * float(self.augment)
108
+ aa = torch.tensor([0.0, angle_y, 0.0]).float().unsqueeze(0)
109
+ rmat = transforms.axis_angle_to_matrix(aa)
110
+
111
+ return rmat
112
+
113
+ def get_shape_augmentation(self, n_frames):
114
+ """Shape noise modeling."""
115
+
116
+ shape_noise = torch.normal(
117
+ mean=torch.zeros((1, 10)),
118
+ std=torch.ones((1, 10)) * 0.1 * float(self.augment)).expand(n_frames, 10)
119
+
120
+ return shape_noise
121
+
122
+ def get_initial_pose_augmentation(self, ):
123
+ """Initial frame pose noise modeling. Random rotation around all joints."""
124
+
125
+ euler = torch.normal(
126
+ mean=torch.zeros((24, 3)),
127
+ std=torch.ones((24, 3))
128
+ ) * self.noise_scale #* float(self.augment)
129
+ rmat = transforms.axis_angle_to_matrix(euler)
130
+
131
+ return rmat.unsqueeze(0)
132
+
133
+
134
+ class SequenceAugmentor:
135
+ """Augment the play speed of the motion sequence"""
136
+ l_factor = 1.5
137
+ def __init__(self, l_default):
138
+ self.l_default = l_default
139
+
140
+ def __call__(self, target):
141
+ l = torch.randint(low=int(self.l_default / self.l_factor), high=int(self.l_default * self.l_factor), size=(1, ))
142
+
143
+ pose = transforms.matrix_to_rotation_6d(target['pose'])
144
+ resampled_pose = F.interpolate(
145
+ pose[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True
146
+ ).permute(2, 0, 1)
147
+ resampled_pose = transforms.rotation_6d_to_matrix(resampled_pose)
148
+
149
+ transl = target['transl'].unsqueeze(1)
150
+ resampled_transl = F.interpolate(
151
+ transl[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True
152
+ ).squeeze(0).T
153
+
154
+ target['pose'] = resampled_pose
155
+ target['transl'] = resampled_transl
156
+ target['betas'] = target['betas'][:self.l_default]
157
+
158
+ return target
159
+
160
+
161
+ class CameraAugmentor:
162
+ rx_factor = np.pi/8
163
+ ry_factor = np.pi/4
164
+ rz_factor = np.pi/8
165
+
166
+ pitch_std = np.pi/8
167
+ pitch_mean = np.pi/36
168
+ roll_std = np.pi/24
169
+ t_factor = 1
170
+
171
+ tz_scale = 10
172
+ tz_min = 2
173
+
174
+ motion_prob = 0.75
175
+ interp_noise = 0.2
176
+
177
+ def __init__(self, l, w, h, f):
178
+ self.l = l
179
+ self.w = w
180
+ self.h = h
181
+ self.f = f
182
+ self.fov_tol = 1.2 * (0.5 ** 0.5)
183
+
184
+ def __call__(self, target):
185
+
186
+ R, T = self.create_camera(target)
187
+
188
+ if np.random.rand() < self.motion_prob:
189
+ R = self.create_rotation_move(R)
190
+ T = self.create_translation_move(T)
191
+
192
+ return self.apply(target, R, T)
193
+
194
+ def create_camera(self, target):
195
+ """Create the initial frame camera pose"""
196
+ yaw = np.random.rand() * 2 * np.pi
197
+ pitch = np.random.normal(scale=self.pitch_std) + self.pitch_mean
198
+ roll = np.random.normal(scale=self.roll_std)
199
+
200
+ yaw_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, yaw, 0]]).float())
201
+ pitch_rm = transforms.axis_angle_to_matrix(torch.tensor([[pitch, 0, 0]]).float())
202
+ roll_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, 0, roll]]).float())
203
+ R = (roll_rm @ pitch_rm @ yaw_rm)
204
+
205
+ # Place people in the scene
206
+ tz = np.random.rand() * self.tz_scale + self.tz_min
207
+ max_d = self.w * tz / self.f / 2
208
+ tx = np.random.normal(scale=0.25) * max_d
209
+ ty = np.random.normal(scale=0.25) * max_d
210
+ dist = torch.tensor([tx, ty, tz]).float()
211
+ T = dist - torch.matmul(R, target['transl'][0])
212
+
213
+ return R.repeat(self.l, 1, 1), T.repeat(self.l, 1)
214
+
215
+ def create_rotation_move(self, R):
216
+ """Create rotational move for the camera"""
217
+
218
+ # Create final camera pose
219
+ rx = np.random.normal(scale=self.rx_factor)
220
+ ry = np.random.normal(scale=self.ry_factor)
221
+ rz = np.random.normal(scale=self.rz_factor)
222
+ Rf = R[0] @ transforms.axis_angle_to_matrix(torch.tensor([rx, ry, rz]).float())
223
+
224
+ # Inbetweening two poses
225
+ Rs = torch.stack((R[0], Rf))
226
+ rs = transforms.matrix_to_rotation_6d(Rs).numpy()
227
+ rs_move = self.noisy_interpolation(rs)
228
+ R_move = transforms.rotation_6d_to_matrix(torch.from_numpy(rs_move).float())
229
+ return R_move
230
+
231
+ def create_translation_move(self, T):
232
+ """Create translational move for the camera"""
233
+
234
+ # Create final camera position
235
+ tx = np.random.normal(scale=self.t_factor)
236
+ ty = np.random.normal(scale=self.t_factor)
237
+ tz = np.random.normal(scale=self.t_factor)
238
+ Ts = np.array([[0, 0, 0], [tx, ty, tz]])
239
+
240
+ T_move = self.noisy_interpolation(Ts)
241
+ T_move = torch.from_numpy(T_move).float()
242
+ return T_move + T
243
+
244
+ def noisy_interpolation(self, data):
245
+ """Non-linear interpolation with noise"""
246
+
247
+ dim = data.shape[-1]
248
+ output = np.zeros((self.l, dim))
249
+
250
+ linspace = np.stack([np.linspace(0, 1, self.l) for _ in range(dim)])
251
+ noise = (linspace[0, 1] - linspace[0, 0]) * self.interp_noise
252
+ space_noise = np.stack([np.random.uniform(-noise, noise, self.l - 2) for _ in range(dim)])
253
+
254
+ linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise
255
+ for i in range(dim):
256
+ output[:, i] = np.interp(linspace[i], np.array([0., 1.,]), data[:, i])
257
+ return output
258
+
259
+ def apply(self, target, R, T):
260
+ target['R'] = R
261
+ target['T'] = T
262
+
263
+ # Recompute the translation
264
+ transl_cam = torch.matmul(R, target['transl'].unsqueeze(-1)).squeeze(-1)
265
+ transl_cam = transl_cam + T
266
+ if transl_cam[..., 2].min() < 0.5: # If the person is too close to the camera
267
+ transl_cam[..., 2] = transl_cam[..., 2] + (1.0 - transl_cam[..., 2].min())
268
+
269
+ # If the subject is away from the field of view, put the camera behind
270
+ fov = torch.div(transl_cam[..., :2], transl_cam[..., 2:]).abs()
271
+ if fov.max() > self.fov_tol:
272
+ t_max = transl_cam[fov.max(1)[0].max(0)[1].item()]
273
+ z_trg = t_max[:2].abs().max(0)[0] / self.fov_tol
274
+ pad = z_trg - t_max[2]
275
+ transl_cam[..., 2] = transl_cam[..., 2] + pad
276
+
277
+ target['transl_cam'] = transl_cam
278
+
279
+ # Transform world coordinate to camera coordinate
280
+ target['pose_root'] = target['pose'][:, 0].clone()
281
+ target['pose'][:, 0] = R @ target['pose'][:, 0] # pose
282
+ target['init_pose'][:, 0] = R[:1] @ target['init_pose'][:, 0] # init pose
283
+
284
+ # Compute angular velocity
285
+ cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
286
+ cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
287
+ target['cam_angvel'] = cam_angvel * 3e1 # assume 30-fps
288
+
289
+ if 'kp3d' in target:
290
+ target['kp3d'] = torch.matmul(R, target['kp3d'].transpose(1, 2)).transpose(1, 2) + target['transl_cam'].unsqueeze(1)
291
+
292
+ return target
lib/data/utils/normalizer.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from ...utils.imutils import transform_keypoints
5
+
6
+ class Normalizer:
7
+ def __init__(self, cfg):
8
+ pass
9
+
10
+ def __call__(self, kp_2d, res, cam_intrinsics, patch_width=224, patch_height=224, bbox=None, mask=None):
11
+ if bbox is None:
12
+ bbox = compute_bbox_from_keypoints(kp_2d, do_augment=True, mask=mask)
13
+
14
+ out_kp_2d = self.bbox_normalization(kp_2d, bbox, res, patch_width, patch_height)
15
+ return out_kp_2d, bbox
16
+
17
+ def bbox_normalization(self, kp_2d, bbox, res, patch_width, patch_height):
18
+ to_torch = False
19
+ if isinstance(kp_2d, torch.Tensor):
20
+ to_torch = True
21
+ kp_2d = kp_2d.numpy()
22
+ bbox = bbox.numpy()
23
+
24
+ out_kp_2d = np.zeros_like(kp_2d)
25
+ for idx in range(len(out_kp_2d)):
26
+ out_kp_2d[idx] = transform_keypoints(kp_2d[idx], bbox[idx][:3], patch_width, patch_height)[0]
27
+ out_kp_2d[idx] = normalize_keypoints_to_patch(out_kp_2d[idx], patch_width)
28
+
29
+ if to_torch:
30
+ out_kp_2d = torch.from_numpy(out_kp_2d)
31
+ bbox = torch.from_numpy(bbox)
32
+
33
+ centers = normalize_keypoints_to_image(bbox[:, :2].unsqueeze(1), res).squeeze(1)
34
+ scale = bbox[:, 2:] * 200 / res.max()
35
+ location = torch.cat((centers, scale), dim=-1)
36
+
37
+ out_kp_2d = out_kp_2d.reshape(out_kp_2d.shape[0], -1)
38
+ out_kp_2d = torch.cat((out_kp_2d, location), dim=-1)
39
+ return out_kp_2d
40
+
41
+
42
+ def normalize_keypoints_to_patch(kp_2d, crop_size=224, inv=False):
43
+ # Normalize keypoints between -1, 1
44
+ if not inv:
45
+ ratio = 1.0 / crop_size
46
+ kp_2d = 2.0 * kp_2d * ratio - 1.0
47
+ else:
48
+ ratio = 1.0 / crop_size
49
+ kp_2d = (kp_2d + 1.0)/(2*ratio)
50
+
51
+ return kp_2d
52
+
53
+
54
+ def normalize_keypoints_to_image(x, res):
55
+ res = res.to(x.device)
56
+ scale = res.max(-1)[0].reshape(-1)
57
+ mean = torch.stack([res[..., 0] / scale, res[..., 1] / scale], dim=-1).to(x.device)
58
+ x = (2 * x / scale.reshape(*[1 for i in range(len(x.shape[1:]))]) - \
59
+ mean.reshape(*[1 for i in range(len(x.shape[1:-1]))], -1))
60
+ return x
61
+
62
+
63
+ def compute_bbox_from_keypoints(X, do_augment=False, mask=None):
64
+ def smooth_bbox(bb):
65
+ # Smooth bounding box detection
66
+ import scipy.signal as signal
67
+ smoothed = np.array([signal.medfilt(param, int(30 / 2)) for param in bb])
68
+ return smoothed
69
+
70
+ def do_augmentation(scale_factor=0.2, trans_factor=0.05):
71
+ _scaleFactor = np.random.uniform(1.0 - scale_factor, 1.2 + scale_factor)
72
+ _trans_x = np.random.uniform(-trans_factor, trans_factor)
73
+ _trans_y = np.random.uniform(-trans_factor, trans_factor)
74
+
75
+ return _scaleFactor, _trans_x, _trans_y
76
+
77
+ if do_augment:
78
+ scaleFactor, trans_x, trans_y = do_augmentation()
79
+ else:
80
+ scaleFactor, trans_x, trans_y = 1.2, 0.0, 0.0
81
+
82
+ if mask is None:
83
+ bbox = [X[:, :, 0].min(-1)[0], X[:, :, 1].min(-1)[0],
84
+ X[:, :, 0].max(-1)[0], X[:, :, 1].max(-1)[0]]
85
+ else:
86
+ bbox = []
87
+ for x, _mask in zip(X, mask):
88
+ if _mask.sum() > 10:
89
+ _mask[:] = False
90
+ _bbox = [x[~_mask, 0].min(-1)[0], x[~_mask, 1].min(-1)[0],
91
+ x[~_mask, 0].max(-1)[0], x[~_mask, 1].max(-1)[0]]
92
+ bbox.append(_bbox)
93
+ bbox = torch.tensor(bbox).T
94
+
95
+ cx, cy = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2]
96
+ bbox_w = bbox[2] - bbox[0]
97
+ bbox_h = bbox[3] - bbox[1]
98
+ bbox_size = torch.stack((bbox_w, bbox_h)).max(0)[0]
99
+ scale = bbox_size * scaleFactor
100
+ bbox = torch.stack((cx + trans_x * scale, cy + trans_y * scale, scale / 200))
101
+
102
+ if do_augment:
103
+ bbox = torch.from_numpy(smooth_bbox(bbox.numpy()))
104
+
105
+ return bbox.T
lib/data_utils/amass_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import os.path as osp
7
+ from collections import defaultdict
8
+
9
+ import torch
10
+ import joblib
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from smplx import SMPL
14
+
15
+ from configs import constants as _C
16
+ from lib.utils.data_utils import map_dmpl_to_smpl, transform_global_coordinate
17
+
18
+
19
+ @torch.no_grad()
20
+ def process_amass():
21
+ target_fps = 30
22
+
23
+ _, seqs, _ = next(os.walk(_C.PATHS.AMASS_PTH))
24
+
25
+ zup2ydown = torch.Tensor(
26
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
27
+ ).unsqueeze(0).float()
28
+
29
+ smpl_dict = {'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'),
30
+ 'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'),
31
+ 'neutral': SMPL(model_path=_C.BMODEL.FLDR)}
32
+ processed_data = defaultdict(list)
33
+
34
+ for seq in (seq_bar := tqdm(sorted(seqs), leave=True)):
35
+ seq_bar.set_description(f'Dataset: {seq}')
36
+ seq_fldr = osp.join(_C.PATHS.AMASS_PTH, seq)
37
+ _, subjs, _ = next(os.walk(seq_fldr))
38
+
39
+ for subj in (subj_bar := tqdm(sorted(subjs), leave=False)):
40
+ subj_bar.set_description(f'Subject: {subj}')
41
+ subj_fldr = osp.join(seq_fldr, subj)
42
+ acts = [x for x in os.listdir(subj_fldr) if x.endswith('.npz')]
43
+
44
+ for act in (act_bar := tqdm(sorted(acts), leave=False)):
45
+ act_bar.set_description(f'Action: {act}')
46
+
47
+ # Load data
48
+ fname = osp.join(subj_fldr, act)
49
+ if fname.endswith('shape.npz') or fname.endswith('stagei.npz'):
50
+ # Skip shape and stagei files
51
+ continue
52
+ data = dict(np.load(fname, allow_pickle=True))
53
+
54
+ # Resample data to target_fps
55
+ key = [k for k in data.keys() if 'mocap_frame' in k][0]
56
+ mocap_framerate = data[key]
57
+ retain_freq = int(mocap_framerate / target_fps + 0.5)
58
+ num_frames = len(data['poses'][::retain_freq])
59
+
60
+ # Skip if the sequence is too short
61
+ if num_frames < 25: continue
62
+
63
+ # Get SMPL groundtruth from MoSh fitting
64
+ pose = map_dmpl_to_smpl(torch.from_numpy(data['poses'][::retain_freq]).float())
65
+ transl = torch.from_numpy(data['trans'][::retain_freq]).float()
66
+ betas = torch.from_numpy(
67
+ np.repeat(data['betas'][:10][np.newaxis], pose.shape[0], axis=0)).float()
68
+
69
+ # Convert Z-up coordinate to Y-down
70
+ pose, transl = transform_global_coordinate(pose, zup2ydown, transl)
71
+ pose = pose.reshape(-1, 72)
72
+
73
+ # Create SMPL mesh
74
+ gender = str(data['gender'])
75
+ if not gender in ['male', 'female', 'neutral']:
76
+ if 'female' in gender: gender = 'female'
77
+ elif 'neutral' in gender: gender = 'neutral'
78
+ elif 'male' in gender: gender = 'male'
79
+
80
+ output = smpl_dict[gender](body_pose=pose[:, 3:],
81
+ global_orient=pose[:, :3],
82
+ betas=betas,
83
+ transl=transl)
84
+ vertices = output.vertices
85
+
86
+ # Assume motion starts with 0-height
87
+ init_height = vertices[0].max(0)[0][1]
88
+ transl[:, 1] = transl[:, 1] + init_height
89
+ vertices[:, :, 1] = vertices[:, :, 1] - init_height
90
+
91
+ # Append data
92
+ processed_data['pose'].append(pose.numpy())
93
+ processed_data['betas'].append(betas.numpy())
94
+ processed_data['transl'].append(transl.numpy())
95
+ processed_data['vid'].append(np.array([f'{seq}_{subj}_{act}'] * pose.shape[0]))
96
+
97
+ for key, val in processed_data.items():
98
+ processed_data[key] = np.concatenate(val)
99
+
100
+ joblib.dump(processed_data, _C.PATHS.AMASS_LABEL)
101
+ print('\nDone!')
102
+
103
+ if __name__ == '__main__':
104
+ out_path = '/'.join(_C.PATHS.AMASS_LABEL.split('/')[:-1])
105
+ os.makedirs(out_path, exist_ok=True)
106
+
107
+ process_amass()
lib/data_utils/emdb_eval_utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import os.path as osp
7
+ from glob import glob
8
+ from collections import defaultdict
9
+
10
+ import cv2
11
+ import torch
12
+ import pickle
13
+ import joblib
14
+ import argparse
15
+ import numpy as np
16
+ from loguru import logger
17
+ from progress.bar import Bar
18
+
19
+ from configs import constants as _C
20
+ from lib.models.smpl import SMPL
21
+ from lib.models.preproc.extractor import FeatureExtractor
22
+ from lib.models.preproc.backbone.utils import process_image
23
+ from lib.utils import transforms
24
+ from lib.utils.imutils import (
25
+ flip_kp, flip_bbox
26
+ )
27
+
28
+ dataset = defaultdict(list)
29
+ detection_results_dir = 'dataset/detection_results/EMDB'
30
+
31
+ def is_dset(emdb_pkl_file, dset):
32
+ target_dset = 'emdb' + dset
33
+ with open(emdb_pkl_file, "rb") as f:
34
+ data = pickle.load(f)
35
+ return data[target_dset]
36
+
37
+ @torch.no_grad()
38
+ def preprocess(dset, batch_size):
39
+
40
+ tt = lambda x: torch.from_numpy(x).float()
41
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
42
+ save_pth = osp.join(_C.PATHS.PARSED_DATA, f'emdb_{dset}_vit.pth') # Use ViT feature extractor
43
+ extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
44
+
45
+ all_emdb_pkl_files = sorted(glob(os.path.join(_C.PATHS.EMDB_PTH, "*/*/*_data.pkl")))
46
+ emdb_sequence_roots = []
47
+ both = []
48
+ for emdb_pkl_file in all_emdb_pkl_files:
49
+ if is_dset(emdb_pkl_file, dset):
50
+ emdb_sequence_roots.append(os.path.dirname(emdb_pkl_file))
51
+
52
+ smpl = {
53
+ 'neutral': SMPL(model_path=_C.BMODEL.FLDR),
54
+ 'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'),
55
+ 'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'),
56
+ }
57
+
58
+ for sequence in emdb_sequence_roots:
59
+ subj, seq = sequence.split('/')[-2:]
60
+ annot_pth = glob(osp.join(sequence, '*_data.pkl'))[0]
61
+ annot = pickle.load(open(annot_pth, 'rb'))
62
+
63
+ # Get ground truth data
64
+ gender = annot['gender']
65
+ masks = annot['good_frames_mask']
66
+ poses_body = annot["smpl"]["poses_body"]
67
+ poses_root = annot["smpl"]["poses_root"]
68
+ betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0)
69
+ extrinsics = annot["camera"]["extrinsics"]
70
+ width, height = annot['camera']['width'], annot['camera']['height']
71
+ xyxys = annot['bboxes']['bboxes']
72
+
73
+ # Map to camear coordinate
74
+ poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root)))
75
+ poses = np.concatenate([poses_root_cam, poses_body], axis=-1)
76
+
77
+ pred_kp2d = np.load(osp.join(detection_results_dir, f'{subj}_{seq}.npy'))
78
+
79
+ # ======== Extract features ======== #
80
+ imname_list = sorted(glob(osp.join(sequence, 'images/*')))
81
+ bboxes, frame_ids, patch_list, features, flipped_features = [], [], [], [], []
82
+ bar = Bar(f'Load images', fill='#', max=len(imname_list))
83
+ for idx, (imname, xyxy, mask) in enumerate(zip(imname_list, xyxys, masks)):
84
+ if not mask: continue
85
+
86
+ # ========= Load image ========= #
87
+ img_rgb = cv2.cvtColor(cv2.imread(imname), cv2.COLOR_BGR2RGB)
88
+
89
+ # ========= Load bbox ========= #
90
+ x1, y1, x2, y2 = xyxy
91
+ bbox = np.array([(x1 + x2)/2., (y1 + y2)/2., max(x2 - x1, y2 - y1) / 1.1])
92
+
93
+ # ========= Process image ========= #
94
+ norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
95
+
96
+ patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
97
+ bboxes.append(bbox)
98
+ frame_ids.append(idx)
99
+ bar.next()
100
+
101
+ patch_list = torch.split(torch.cat(patch_list), batch_size)
102
+ bboxes = torch.from_numpy(np.stack(bboxes)).float()
103
+ for i, patch in enumerate(patch_list):
104
+ bbox = bboxes[i*batch_size:min((i+1)*batch_size, len(frame_ids))].float().cuda()
105
+ bbox_center = bbox[:, :2]
106
+ bbox_scale = bbox[:, 2] / 200
107
+
108
+ feature = extractor.model(patch.cuda(), encode=True)
109
+ features.append(feature.cpu())
110
+
111
+ flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True)
112
+ flipped_features.append(flipped_feature.cpu())
113
+
114
+ if i == 0:
115
+ init_patch = patch[[0]].clone()
116
+
117
+ features = torch.cat(features)
118
+ flipped_features = torch.cat(flipped_features)
119
+ res_h, res_w = img_rgb.shape[:2]
120
+
121
+ # ======== Append data ======== #
122
+ dataset['gender'].append(gender)
123
+ dataset['bbox'].append(bboxes)
124
+ dataset['res'].append(torch.tensor([[width, height]]).repeat(len(frame_ids), 1).float())
125
+ dataset['vid'].append(f'{subj}_{seq}')
126
+ dataset['pose'].append(tt(poses)[frame_ids])
127
+ dataset['betas'].append(tt(betas)[frame_ids])
128
+ dataset['kp2d'].append(tt(pred_kp2d)[frame_ids])
129
+ dataset['frame_id'].append(torch.from_numpy(np.array(frame_ids)))
130
+ dataset['cam_poses'].append(tt(extrinsics)[frame_ids])
131
+ dataset['features'].append(features)
132
+ dataset['flipped_features'].append(flipped_features)
133
+
134
+ # Flipped data
135
+ dataset['flipped_bbox'].append(
136
+ torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float()
137
+ )
138
+ dataset['flipped_kp2d'].append(
139
+ torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float()
140
+ )
141
+ # ======== Append data ======== #
142
+
143
+ # Pad 1 frame
144
+ for key, val in dataset.items():
145
+ if isinstance(val[-1], torch.Tensor):
146
+ dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0)
147
+
148
+ # Initial predictions
149
+ bbox = bboxes[:1].clone().cuda()
150
+ bbox_center = bbox[:, :2].clone()
151
+ bbox_scale = bbox[:, 2].clone() / 200
152
+ kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(),
153
+ 'img_h': torch.tensor(res_h).repeat(1).float().cuda(),
154
+ 'bbox_center': bbox_center, 'bbox_scale': bbox_scale}
155
+
156
+ pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs)
157
+ pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(),
158
+ body_pose=pred_pose.cpu(),
159
+ betas=pred_shape.cpu(),
160
+ pose2rot=False)
161
+ init_kp3d = pred_output.joints
162
+ init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
163
+
164
+ dataset['init_kp3d'].append(init_kp3d)
165
+ dataset['init_pose'].append(init_pose.cpu())
166
+
167
+ # Flipped initial predictions
168
+ bbox_center[:, 0] = res_w - bbox_center[:, 0]
169
+ pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs)
170
+ pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(),
171
+ body_pose=pred_pose.cpu(),
172
+ betas=pred_shape.cpu(),
173
+ pose2rot=False)
174
+ init_kp3d = pred_output.joints
175
+ init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
176
+
177
+ dataset['flipped_init_kp3d'].append(init_kp3d)
178
+ dataset['flipped_init_pose'].append(init_pose.cpu())
179
+
180
+ joblib.dump(dataset, save_pth)
181
+ logger.info(f'==> Done !')
182
+
183
+ if __name__ == '__main__':
184
+ parser = argparse.ArgumentParser()
185
+ parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split')
186
+ parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
187
+ args = parser.parse_args()
188
+
189
+ preprocess(args.split, args.batch_size)
lib/data_utils/rich_eval_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import os.path as osp
7
+ from glob import glob
8
+ from collections import defaultdict
9
+
10
+ import cv2
11
+ import torch
12
+ import pickle
13
+ import joblib
14
+ import argparse
15
+ import numpy as np
16
+ from loguru import logger
17
+ from progress.bar import Bar
18
+
19
+ from configs import constants as _C
20
+ from lib.models.smpl import SMPL
21
+ from lib.models.preproc.extractor import FeatureExtractor
22
+ from lib.models.preproc.backbone.utils import process_image
23
+ from lib.utils import transforms
24
+ from lib.utils.imutils import (
25
+ flip_kp, flip_bbox
26
+ )
27
+
28
+ dataset = defaultdict(list)
29
+ detection_results_dir = 'dataset/detection_results/RICH'
30
+
31
+ def extract_cam_param_xml(xml_path='', dtype=torch.float32):
32
+
33
+ import xml.etree.ElementTree as ET
34
+ tree = ET.parse(xml_path)
35
+
36
+ extrinsics_mat = [float(s) for s in tree.find('./CameraMatrix/data').text.split()]
37
+ intrinsics_mat = [float(s) for s in tree.find('./Intrinsics/data').text.split()]
38
+ # distortion_vec = [float(s) for s in tree.find('./Distortion/data').text.split()]
39
+
40
+ focal_length_x = intrinsics_mat[0]
41
+ focal_length_y = intrinsics_mat[4]
42
+ center = torch.tensor([[intrinsics_mat[2], intrinsics_mat[5]]], dtype=dtype)
43
+
44
+ rotation = torch.tensor([[extrinsics_mat[0], extrinsics_mat[1], extrinsics_mat[2]],
45
+ [extrinsics_mat[4], extrinsics_mat[5], extrinsics_mat[6]],
46
+ [extrinsics_mat[8], extrinsics_mat[9], extrinsics_mat[10]]], dtype=dtype)
47
+
48
+ translation = torch.tensor([[extrinsics_mat[3], extrinsics_mat[7], extrinsics_mat[11]]], dtype=dtype)
49
+
50
+ # t = -Rc --> c = -R^Tt
51
+ cam_center = [ -extrinsics_mat[0]*extrinsics_mat[3] - extrinsics_mat[4]*extrinsics_mat[7] - extrinsics_mat[8]*extrinsics_mat[11],
52
+ -extrinsics_mat[1]*extrinsics_mat[3] - extrinsics_mat[5]*extrinsics_mat[7] - extrinsics_mat[9]*extrinsics_mat[11],
53
+ -extrinsics_mat[2]*extrinsics_mat[3] - extrinsics_mat[6]*extrinsics_mat[7] - extrinsics_mat[10]*extrinsics_mat[11]]
54
+
55
+ cam_center = torch.tensor([cam_center], dtype=dtype)
56
+
57
+ return focal_length_x, focal_length_y, center, rotation, translation, cam_center
58
+
59
+ @torch.no_grad()
60
+ def preprocess(dset, batch_size):
61
+ import pdb; pdb.set_trace()
62
+
63
+ if __name__ == '__main__':
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split')
66
+ parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
67
+ args = parser.parse_args()
68
+
69
+ preprocess(args.split, args.batch_size)
lib/data_utils/threedpw_eval_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os.path as osp
6
+ from glob import glob
7
+ from collections import defaultdict
8
+
9
+ import cv2
10
+ import torch
11
+ import pickle
12
+ import joblib
13
+ import argparse
14
+ import numpy as np
15
+ from loguru import logger
16
+ from progress.bar import Bar
17
+
18
+ from configs import constants as _C
19
+ from lib.models.smpl import SMPL
20
+ from lib.models.preproc.extractor import FeatureExtractor
21
+ from lib.models.preproc.backbone.utils import process_image
22
+ from lib.utils import transforms
23
+ from lib.utils.imutils import (
24
+ flip_kp, flip_bbox
25
+ )
26
+
27
+
28
+ dataset = defaultdict(list)
29
+ detection_results_dir = 'dataset/detection_results/3DPW'
30
+ tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_dset_db.pt'
31
+
32
+ @torch.no_grad()
33
+ def preprocess(dset, batch_size):
34
+
35
+ if dset == 'val': _dset = 'validation'
36
+ else: _dset = dset
37
+
38
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
39
+ save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_{dset}_vit.pth') # Use ViT feature extractor
40
+ extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
41
+
42
+ tcmr_data = joblib.load(tcmr_annot_pth.replace('dset', dset))
43
+ smpl_neutral = SMPL(model_path=_C.BMODEL.FLDR)
44
+
45
+ annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True)
46
+ idxs = idxs.tolist()
47
+ annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)]
48
+ annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', _dset, annot_file[:-2] + '.pkl') for annot_file in annot_file_list]
49
+ annot_file_list = list(dict.fromkeys(annot_file_list))
50
+
51
+ for annot_file in annot_file_list:
52
+ seq = annot_file.split('/')[-1].split('.')[0]
53
+
54
+ data = pickle.load(open(annot_file, 'rb'), encoding='latin1')
55
+
56
+ num_people = len(data['poses'])
57
+ num_frames = len(data['img_frame_ids'])
58
+ assert (data['poses2d'][0].shape[0] == num_frames)
59
+
60
+ K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float()
61
+
62
+ for p_id in range(num_people):
63
+
64
+ logger.info(f'==> {seq} {p_id}')
65
+ gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]]
66
+
67
+ # ======== Add TCMR data ======== #
68
+ vid_name = f'{seq}_{p_id}'
69
+ tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v]
70
+ frame_ids = tcmr_data['frame_id'][tcmr_ids]
71
+
72
+ pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids]
73
+ shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1)
74
+ pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() # Camera coordinate
75
+ cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float()
76
+
77
+ # ======== Get detection results ======== #
78
+ fname = f'{seq}_{p_id}.npy'
79
+ pred_kp2d = torch.from_numpy(
80
+ np.load(osp.join(detection_results_dir, fname))
81
+ ).float()[frame_ids]
82
+ # ======== Get detection results ======== #
83
+
84
+ img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg')))
85
+ img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids]
86
+ img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2]
87
+ vid_idxs = fname.split('.')[0]
88
+
89
+ # ======== Append data ======== #
90
+ dataset['gender'].append(gender)
91
+ dataset['vid'].append(vid_idxs)
92
+ dataset['pose'].append(pose)
93
+ dataset['betas'].append(shape)
94
+ dataset['cam_poses'].append(cam_poses)
95
+ dataset['frame_id'].append(torch.from_numpy(frame_ids))
96
+ dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float())
97
+ dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float())
98
+ dataset['kp2d'].append(pred_kp2d)
99
+
100
+ # Flipped data
101
+ dataset['flipped_bbox'].append(
102
+ torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float()
103
+ )
104
+ dataset['flipped_kp2d'].append(
105
+ torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float()
106
+ )
107
+ # ======== Append data ======== #
108
+
109
+ # ======== Extract features ======== #
110
+ patch_list = []
111
+ bboxes = dataset['bbox'][-1].clone().numpy()
112
+ bar = Bar(f'Load images', fill='#', max=len(img_paths))
113
+
114
+ for img_path, bbox in zip(img_paths, bboxes):
115
+ img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
116
+ norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
117
+ patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
118
+ bar.next()
119
+
120
+ patch_list = torch.split(torch.cat(patch_list), batch_size)
121
+ features, flipped_features = [], []
122
+ for i, patch in enumerate(patch_list):
123
+ feature = extractor.model(patch.cuda(), encode=True)
124
+ features.append(feature.cpu())
125
+
126
+ flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True)
127
+ flipped_features.append(flipped_feature.cpu())
128
+
129
+ if i == 0:
130
+ init_patch = patch[[0]].clone()
131
+
132
+ features = torch.cat(features)
133
+ flipped_features = torch.cat(flipped_features)
134
+ dataset['features'].append(features)
135
+ dataset['flipped_features'].append(flipped_features)
136
+ # ======== Extract features ======== #
137
+
138
+ # Pad 1 frame
139
+ for key, val in dataset.items():
140
+ if isinstance(val[-1], torch.Tensor):
141
+ dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0)
142
+
143
+ # Initial predictions
144
+ bbox = torch.from_numpy(bboxes[:1].copy()).float().cuda()
145
+ bbox_center = bbox[:, :2].clone()
146
+ bbox_scale = bbox[:, 2].clone() / 200
147
+ kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(),
148
+ 'img_h': torch.tensor(res_h).repeat(1).float().cuda(),
149
+ 'bbox_center': bbox_center, 'bbox_scale': bbox_scale}
150
+
151
+ pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs)
152
+ pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(),
153
+ body_pose=pred_pose.cpu(),
154
+ betas=pred_shape.cpu(),
155
+ pose2rot=False)
156
+ init_kp3d = pred_output.joints
157
+ init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
158
+
159
+ dataset['init_kp3d'].append(init_kp3d)
160
+ dataset['init_pose'].append(init_pose.cpu())
161
+
162
+ # Flipped initial predictions
163
+ bbox_center[:, 0] = res_w - bbox_center[:, 0]
164
+ pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs)
165
+ pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(),
166
+ body_pose=pred_pose.cpu(),
167
+ betas=pred_shape.cpu(),
168
+ pose2rot=False)
169
+ init_kp3d = pred_output.joints
170
+ init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1))
171
+
172
+ dataset['flipped_init_kp3d'].append(init_kp3d)
173
+ dataset['flipped_init_pose'].append(init_pose.cpu())
174
+
175
+ joblib.dump(dataset, save_pth)
176
+ logger.info(f'\n ==> Done !')
177
+
178
+
179
+ if __name__ == '__main__':
180
+ parser = argparse.ArgumentParser()
181
+ parser.add_argument('-s', '--split', type=str, choices=['val', 'test'], help='Data split')
182
+ parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
183
+ args = parser.parse_args()
184
+
185
+ preprocess(args.split, args.batch_size)
lib/data_utils/threedpw_train_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os.path as osp
6
+ from glob import glob
7
+ from collections import defaultdict
8
+
9
+ import cv2
10
+ import torch
11
+ import pickle
12
+ import joblib
13
+ import argparse
14
+ import numpy as np
15
+ from loguru import logger
16
+ from progress.bar import Bar
17
+
18
+ from configs import constants as _C
19
+ from lib.models.smpl import SMPL
20
+ from lib.models.preproc.extractor import FeatureExtractor
21
+ from lib.models.preproc.backbone.utils import process_image
22
+
23
+ dataset = defaultdict(list)
24
+ detection_results_dir = 'dataset/detection_results/3DPW'
25
+ tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_train_db.pt'
26
+
27
+
28
+ @torch.no_grad()
29
+ def preprocess(batch_size):
30
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
31
+ save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_train_vit.pth') # Use ViT feature extractor
32
+ extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size)
33
+
34
+ tcmr_data = joblib.load(tcmr_annot_pth)
35
+
36
+ annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True)
37
+ idxs = idxs.tolist()
38
+ annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)]
39
+ annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', 'train', annot_file[:-2] + '.pkl') for annot_file in annot_file_list]
40
+ annot_file_list = list(dict.fromkeys(annot_file_list))
41
+
42
+ vid_idx = 0
43
+ for annot_file in annot_file_list:
44
+ seq = annot_file.split('/')[-1].split('.')[0]
45
+
46
+ data = pickle.load(open(annot_file, 'rb'), encoding='latin1')
47
+
48
+ num_people = len(data['poses'])
49
+ num_frames = len(data['img_frame_ids'])
50
+ assert (data['poses2d'][0].shape[0] == num_frames)
51
+
52
+ K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float()
53
+
54
+ for p_id in range(num_people):
55
+
56
+ logger.info(f'==> {seq} {p_id}')
57
+ gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]]
58
+ smpl_gender = SMPL(model_path=_C.BMODEL.FLDR, gender=gender)
59
+
60
+ # ======== Add TCMR data ======== #
61
+ vid_name = f'{seq}_{p_id}'
62
+ tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v]
63
+ frame_ids = tcmr_data['frame_id'][tcmr_ids]
64
+
65
+ pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids]
66
+ shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1)
67
+ trans = torch.from_numpy(data['trans'][p_id]).float()[frame_ids]
68
+ cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float()
69
+
70
+ # ======== Align the mesh params ======== #
71
+ Rc = cam_poses[:, :3, :3]
72
+ Tc = cam_poses[:, :3, 3]
73
+ org_output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3], transl=trans)
74
+ org_v0 = (org_output.vertices + org_output.offset.unsqueeze(1)).mean(1)
75
+ pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float()
76
+
77
+ output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3])
78
+ v0 = (output.vertices + output.offset.unsqueeze(1)).mean(1)
79
+ trans = (Rc @ org_v0.reshape(-1, 3, 1)).reshape(-1, 3) + Tc - v0
80
+ j3d = output.joints + (output.offset + trans).unsqueeze(1)
81
+ j2d = torch.div(j3d, j3d[..., 2:])
82
+ kp2d = torch.matmul(K, j2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
83
+ # ======== Align the mesh params ======== #
84
+
85
+ # ======== Get detection results ======== #
86
+ fname = f'{seq}_{p_id}.npy'
87
+ pred_kp2d = torch.from_numpy(
88
+ np.load(osp.join(detection_results_dir, fname))
89
+ ).float()[frame_ids]
90
+ # ======== Get detection results ======== #
91
+
92
+ img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg')))
93
+ img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids]
94
+ img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2]
95
+ vid_idxs = torch.from_numpy(np.array([vid_idx] * len(img_paths)).astype(int))
96
+ vid_idx += 1
97
+
98
+ # ======== Append data ======== #
99
+ dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float())
100
+ dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float())
101
+ dataset['vid'].append(vid_idxs)
102
+ dataset['pose'].append(pose)
103
+ dataset['betas'].append(shape)
104
+ dataset['transl'].append(trans)
105
+ dataset['kp2d'].append(pred_kp2d)
106
+ dataset['joints3D'].append(j3d)
107
+ dataset['joints2D'].append(kp2d)
108
+ dataset['frame_id'].append(torch.from_numpy(frame_ids))
109
+ dataset['cam_poses'].append(cam_poses)
110
+ dataset['gender'].append(torch.tensor([['male','female'].index(gender)]).repeat(len(frame_ids)))
111
+ # ======== Append data ======== #
112
+
113
+ # ======== Extract features ======== #
114
+ patch_list = []
115
+ bboxes = dataset['bbox'][-1].clone().numpy()
116
+ bar = Bar(f'Load images', fill='#', max=len(img_paths))
117
+
118
+ for img_path, bbox in zip(img_paths, bboxes):
119
+ img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
120
+ norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256)
121
+ patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float())
122
+ bar.next()
123
+
124
+ patch_list = torch.split(torch.cat(patch_list), batch_size)
125
+ features = []
126
+ for i, patch in enumerate(patch_list):
127
+ pred = extractor.model(patch.cuda(), encode=True)
128
+ features.append(pred.cpu())
129
+
130
+ features = torch.cat(features)
131
+ dataset['features'].append(features)
132
+ # ======== Extract features ======== #
133
+
134
+ for key in dataset.keys():
135
+ dataset[key] = torch.cat(dataset[key])
136
+
137
+ joblib.dump(dataset, save_pth)
138
+ logger.info(f'\n ==> Done !')
139
+
140
+
141
+ if __name__ == '__main__':
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split')
144
+ args = parser.parse_args()
145
+
146
+ preprocess(args.batch_size)
lib/eval/eval_utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py
2
+ # Adhere to their licence to use these functions
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import numpy as np
7
+ from matplotlib import pyplot as plt
8
+
9
+
10
+ def compute_accel(joints):
11
+ """
12
+ Computes acceleration of 3D joints.
13
+ Args:
14
+ joints (Nx25x3).
15
+ Returns:
16
+ Accelerations (N-2).
17
+ """
18
+ velocities = joints[1:] - joints[:-1]
19
+ acceleration = velocities[1:] - velocities[:-1]
20
+ acceleration_normed = np.linalg.norm(acceleration, axis=2)
21
+ return np.mean(acceleration_normed, axis=1)
22
+
23
+
24
+ def compute_error_accel(joints_gt, joints_pred, vis=None):
25
+ """
26
+ Computes acceleration error:
27
+ 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
28
+ Note that for each frame that is not visible, three entries in the
29
+ acceleration error should be zero'd out.
30
+ Args:
31
+ joints_gt (Nx14x3).
32
+ joints_pred (Nx14x3).
33
+ vis (N).
34
+ Returns:
35
+ error_accel (N-2).
36
+ """
37
+ # (N-2)x14x3
38
+ accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
39
+ accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]
40
+
41
+ normed = np.linalg.norm(accel_pred - accel_gt, axis=2)
42
+
43
+ if vis is None:
44
+ new_vis = np.ones(len(normed), dtype=bool)
45
+ else:
46
+ invis = np.logical_not(vis)
47
+ invis1 = np.roll(invis, -1)
48
+ invis2 = np.roll(invis, -2)
49
+ new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
50
+ new_vis = np.logical_not(new_invis)
51
+
52
+ return np.mean(normed[new_vis], axis=1)
53
+
54
+
55
+ def compute_error_verts(pred_verts, target_verts=None, target_theta=None):
56
+ """
57
+ Computes MPJPE over 6890 surface vertices.
58
+ Args:
59
+ verts_gt (Nx6890x3).
60
+ verts_pred (Nx6890x3).
61
+ Returns:
62
+ error_verts (N).
63
+ """
64
+
65
+ if target_verts is None:
66
+ from lib.models.smpl import SMPL_MODEL_DIR
67
+ from lib.models.smpl import SMPL
68
+ device = 'cpu'
69
+ smpl = SMPL(
70
+ SMPL_MODEL_DIR,
71
+ batch_size=1, # target_theta.shape[0],
72
+ ).to(device)
73
+
74
+ betas = torch.from_numpy(target_theta[:,75:]).to(device)
75
+ pose = torch.from_numpy(target_theta[:,3:75]).to(device)
76
+
77
+ target_verts = []
78
+ b_ = torch.split(betas, 5000)
79
+ p_ = torch.split(pose, 5000)
80
+
81
+ for b,p in zip(b_,p_):
82
+ output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True)
83
+ target_verts.append(output.vertices.detach().cpu().numpy())
84
+
85
+ target_verts = np.concatenate(target_verts, axis=0)
86
+
87
+ assert len(pred_verts) == len(target_verts)
88
+ error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2))
89
+ return np.mean(error_per_vert, axis=1)
90
+
91
+
92
+ def compute_similarity_transform(S1, S2):
93
+ '''
94
+ Computes a similarity transform (sR, t) that takes
95
+ a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
96
+ where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
97
+ i.e. solves the orthogonal Procrutes problem.
98
+ '''
99
+ transposed = False
100
+ if S1.shape[0] != 3 and S1.shape[0] != 2:
101
+ S1 = S1.T
102
+ S2 = S2.T
103
+ transposed = True
104
+ assert(S2.shape[1] == S1.shape[1])
105
+
106
+ # 1. Remove mean.
107
+ mu1 = S1.mean(axis=1, keepdims=True)
108
+ mu2 = S2.mean(axis=1, keepdims=True)
109
+ X1 = S1 - mu1
110
+ X2 = S2 - mu2
111
+
112
+ # 2. Compute variance of X1 used for scale.
113
+ var1 = np.sum(X1**2)
114
+
115
+ # 3. The outer product of X1 and X2.
116
+ K = X1.dot(X2.T)
117
+
118
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
119
+ # singular vectors of K.
120
+ U, s, Vh = np.linalg.svd(K)
121
+ V = Vh.T
122
+ # Construct Z that fixes the orientation of R to get det(R)=1.
123
+ Z = np.eye(U.shape[0])
124
+ Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
125
+ # Construct R.
126
+ R = V.dot(Z.dot(U.T))
127
+
128
+ # 5. Recover scale.
129
+ scale = np.trace(R.dot(K)) / var1
130
+
131
+ # 6. Recover translation.
132
+ t = mu2 - scale*(R.dot(mu1))
133
+
134
+ # 7. Error:
135
+ S1_hat = scale*R.dot(S1) + t
136
+
137
+ if transposed:
138
+ S1_hat = S1_hat.T
139
+
140
+ return S1_hat
141
+
142
+
143
+ def compute_similarity_transform_torch(S1, S2):
144
+ '''
145
+ Computes a similarity transform (sR, t) that takes
146
+ a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
147
+ where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
148
+ i.e. solves the orthogonal Procrutes problem.
149
+ '''
150
+ transposed = False
151
+ if S1.shape[0] != 3 and S1.shape[0] != 2:
152
+ S1 = S1.T
153
+ S2 = S2.T
154
+ transposed = True
155
+ assert (S2.shape[1] == S1.shape[1])
156
+
157
+ # 1. Remove mean.
158
+ mu1 = S1.mean(axis=1, keepdims=True)
159
+ mu2 = S2.mean(axis=1, keepdims=True)
160
+ X1 = S1 - mu1
161
+ X2 = S2 - mu2
162
+
163
+ # print('X1', X1.shape)
164
+
165
+ # 2. Compute variance of X1 used for scale.
166
+ var1 = torch.sum(X1 ** 2)
167
+
168
+ # print('var', var1.shape)
169
+
170
+ # 3. The outer product of X1 and X2.
171
+ K = X1.mm(X2.T)
172
+
173
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
174
+ # singular vectors of K.
175
+ U, s, V = torch.svd(K)
176
+ # V = Vh.T
177
+ # Construct Z that fixes the orientation of R to get det(R)=1.
178
+ Z = torch.eye(U.shape[0], device=S1.device)
179
+ Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
180
+ # Construct R.
181
+ R = V.mm(Z.mm(U.T))
182
+
183
+ # print('R', X1.shape)
184
+
185
+ # 5. Recover scale.
186
+ scale = torch.trace(R.mm(K)) / var1
187
+ # print(R.shape, mu1.shape)
188
+ # 6. Recover translation.
189
+ t = mu2 - scale * (R.mm(mu1))
190
+ # print(t.shape)
191
+
192
+ # 7. Error:
193
+ S1_hat = scale * R.mm(S1) + t
194
+
195
+ if transposed:
196
+ S1_hat = S1_hat.T
197
+
198
+ return S1_hat
199
+
200
+
201
+ def batch_compute_similarity_transform_torch(S1, S2):
202
+ '''
203
+ Computes a similarity transform (sR, t) that takes
204
+ a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
205
+ where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
206
+ i.e. solves the orthogonal Procrutes problem.
207
+ '''
208
+ transposed = False
209
+ if S1.shape[0] != 3 and S1.shape[0] != 2:
210
+ S1 = S1.permute(0,2,1)
211
+ S2 = S2.permute(0,2,1)
212
+ transposed = True
213
+ assert(S2.shape[1] == S1.shape[1])
214
+
215
+ # 1. Remove mean.
216
+ mu1 = S1.mean(axis=-1, keepdims=True)
217
+ mu2 = S2.mean(axis=-1, keepdims=True)
218
+
219
+ X1 = S1 - mu1
220
+ X2 = S2 - mu2
221
+
222
+ # 2. Compute variance of X1 used for scale.
223
+ var1 = torch.sum(X1**2, dim=1).sum(dim=1)
224
+
225
+ # 3. The outer product of X1 and X2.
226
+ K = X1.bmm(X2.permute(0,2,1))
227
+
228
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
229
+ # singular vectors of K.
230
+ U, s, V = torch.svd(K)
231
+
232
+ # Construct Z that fixes the orientation of R to get det(R)=1.
233
+ Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
234
+ Z = Z.repeat(U.shape[0],1,1)
235
+ Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))
236
+
237
+ # Construct R.
238
+ R = V.bmm(Z.bmm(U.permute(0,2,1)))
239
+
240
+ # 5. Recover scale.
241
+ scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
242
+
243
+ # 6. Recover translation.
244
+ t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
245
+
246
+ # 7. Error:
247
+ S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
248
+
249
+ if transposed:
250
+ S1_hat = S1_hat.permute(0,2,1)
251
+
252
+ return S1_hat
253
+
254
+
255
+ def align_by_pelvis(joints):
256
+ """
257
+ Assumes joints is 14 x 3 in LSP order.
258
+ Then hips are: [3, 2]
259
+ Takes mid point of these points, then subtracts it.
260
+ """
261
+
262
+ left_id = 2
263
+ right_id = 3
264
+
265
+ pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0
266
+ return joints - np.expand_dims(pelvis, axis=0)
267
+
268
+
269
+ def compute_errors(gt3ds, preds):
270
+ """
271
+ Gets MPJPE after pelvis alignment + MPJPE after Procrustes.
272
+ Evaluates on the 14 common joints.
273
+ Inputs:
274
+ - gt3ds: N x 14 x 3
275
+ - preds: N x 14 x 3
276
+ """
277
+ errors, errors_pa = [], []
278
+ for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)):
279
+ gt3d = gt3d.reshape(-1, 3)
280
+ # Root align.
281
+ gt3d = align_by_pelvis(gt3d)
282
+ pred3d = align_by_pelvis(pred)
283
+
284
+ joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1))
285
+ errors.append(np.mean(joint_error))
286
+
287
+ # Get PA error.
288
+ pred3d_sym = compute_similarity_transform(pred3d, gt3d)
289
+ pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1))
290
+ errors_pa.append(np.mean(pa_error))
291
+
292
+ return errors, errors_pa
293
+
294
+
295
+ def batch_align_by_pelvis(data_list, pelvis_idxs):
296
+ """
297
+ Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts].
298
+ Each data is in shape of (frames, num_points, 3)
299
+ Pelvis is notated as one / two joints indices.
300
+ Align all data to the corresponding pelvis location.
301
+ """
302
+
303
+ pred_j3d, target_j3d, pred_verts, target_verts = data_list
304
+
305
+ pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
306
+ target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
307
+
308
+ # Align to the pelvis
309
+ pred_j3d = pred_j3d - pred_pelvis
310
+ target_j3d = target_j3d - target_pelvis
311
+ pred_verts = pred_verts - pred_pelvis
312
+ target_verts = target_verts - target_pelvis
313
+
314
+ return (pred_j3d, target_j3d, pred_verts, target_verts)
315
+
316
+ def compute_jpe(S1, S2):
317
+ return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy()
318
+
319
+
320
+ # The functions below are borrowed from SLAHMR official implementation.
321
+ # Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py
322
+ def global_align_joints(gt_joints, pred_joints):
323
+ """
324
+ :param gt_joints (T, J, 3)
325
+ :param pred_joints (T, J, 3)
326
+ """
327
+ s_glob, R_glob, t_glob = align_pcl(
328
+ gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
329
+ )
330
+ pred_glob = (
331
+ s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
332
+ )
333
+ return pred_glob
334
+
335
+
336
+ def first_align_joints(gt_joints, pred_joints):
337
+ """
338
+ align the first two frames
339
+ :param gt_joints (T, J, 3)
340
+ :param pred_joints (T, J, 3)
341
+ """
342
+ # (1, 1), (1, 3, 3), (1, 3)
343
+ s_first, R_first, t_first = align_pcl(
344
+ gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
345
+ )
346
+ pred_first = (
347
+ s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
348
+ )
349
+ return pred_first
350
+
351
+
352
+ def local_align_joints(gt_joints, pred_joints):
353
+ """
354
+ :param gt_joints (T, J, 3)
355
+ :param pred_joints (T, J, 3)
356
+ """
357
+ s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints)
358
+ pred_loc = (
359
+ s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints)
360
+ + t_loc[:, None]
361
+ )
362
+ return pred_loc
363
+
364
+
365
+ def align_pcl(Y, X, weight=None, fixed_scale=False):
366
+ """align similarity transform to align X with Y using umeyama method
367
+ X' = s * R * X + t is aligned with Y
368
+ :param Y (*, N, 3) first trajectory
369
+ :param X (*, N, 3) second trajectory
370
+ :param weight (*, N, 1) optional weight of valid correspondences
371
+ :returns s (*, 1), R (*, 3, 3), t (*, 3)
372
+ """
373
+ *dims, N, _ = Y.shape
374
+ N = torch.ones(*dims, 1, 1) * N
375
+
376
+ if weight is not None:
377
+ Y = Y * weight
378
+ X = X * weight
379
+ N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
380
+
381
+ # subtract mean
382
+ my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
383
+ mx = X.sum(dim=-2) / N[..., 0]
384
+ y0 = Y - my[..., None, :] # (*, N, 3)
385
+ x0 = X - mx[..., None, :]
386
+
387
+ if weight is not None:
388
+ y0 = y0 * weight
389
+ x0 = x0 * weight
390
+
391
+ # correlation
392
+ C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
393
+ U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
394
+
395
+ S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
396
+ neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
397
+ S[neg, 2, 2] = -1
398
+
399
+ R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
400
+
401
+ D = torch.diag_embed(D) # (*, 3, 3)
402
+ if fixed_scale:
403
+ s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
404
+ else:
405
+ var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
406
+ s = (
407
+ torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
408
+ dim=-1, keepdim=True
409
+ )
410
+ / var[..., 0]
411
+ ) # (*, 1)
412
+
413
+ t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
414
+
415
+ return s, R, t
416
+
417
+
418
+ def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2):
419
+ """compute foot sliding error
420
+ The foot ground contact label is computed by the threshold of 1 cm/frame
421
+ Args:
422
+ target_output (SMPL ModelOutput).
423
+ pred_output (SMPL ModelOutput).
424
+ masks (N).
425
+ Returns:
426
+ error (N frames in contact).
427
+ """
428
+
429
+ # Foot vertices idxs
430
+ foot_idxs = [3216, 3387, 6617, 6787]
431
+
432
+ # Compute contact label
433
+ foot_loc = target_output.vertices[masks][:, foot_idxs]
434
+ foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1)
435
+ contact = foot_disp[:] < thr
436
+
437
+ pred_feet_loc = pred_output.vertices[:, foot_idxs]
438
+ pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1)
439
+
440
+ error = pred_disp[contact]
441
+
442
+ return error.cpu().numpy()
443
+
444
+
445
+ def compute_jitter(pred_output, fps=30):
446
+ """compute jitter of the motion
447
+ Args:
448
+ pred_output (SMPL ModelOutput).
449
+ fps (float).
450
+ Returns:
451
+ jitter (N-3).
452
+ """
453
+
454
+ pred3d = pred_output.joints[:, :24]
455
+
456
+ pred_jitter = torch.norm(
457
+ (pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3),
458
+ dim=2,
459
+ ).mean(dim=-1)
460
+
461
+ return pred_jitter.cpu().numpy() / 10.0
462
+
463
+
464
+ def compute_rte(target_trans, pred_trans):
465
+ # Compute the global alignment
466
+ _, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True)
467
+ pred_trans_hat = (
468
+ torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :]
469
+ )[0]
470
+
471
+ # Compute the entire displacement of ground truth trajectory
472
+ disps, disp = [], 0
473
+ for p1, p2 in zip(target_trans, target_trans[1:]):
474
+ delta = (p2 - p1).norm(2, dim=-1)
475
+ disp += delta
476
+ disps.append(disp)
477
+
478
+ # Compute absolute root-translation-error (RTE)
479
+ rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1)
480
+
481
+ # Normalize it to the displacement
482
+ return (rte / disp).numpy()
lib/eval/evaluate_3dpw.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import os.path as osp
4
+ from glob import glob
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+ import imageio
9
+ import numpy as np
10
+ from smplx import SMPL
11
+ from loguru import logger
12
+ from progress.bar import Bar
13
+
14
+ from configs import constants as _C
15
+ from configs.config import parse_args
16
+ from lib.data.dataloader import setup_eval_dataloader
17
+ from lib.models import build_network, build_body_model
18
+ from lib.eval.eval_utils import (
19
+ compute_error_accel,
20
+ batch_align_by_pelvis,
21
+ batch_compute_similarity_transform_torch,
22
+ )
23
+ from lib.utils import transforms
24
+ from lib.utils.utils import prepare_output_dir
25
+ from lib.utils.utils import prepare_batch
26
+ from lib.utils.imutils import avg_preds
27
+
28
+ try:
29
+ from lib.vis.renderer import Renderer
30
+ _render = True
31
+ except:
32
+ print("PyTorch3D is not properly installed! Cannot render the SMPL mesh")
33
+ _render = False
34
+
35
+
36
+ m2mm = 1e3
37
+ @torch.no_grad()
38
+ def main(cfg, args):
39
+ torch.backends.cuda.matmul.allow_tf32 = False
40
+ torch.backends.cudnn.allow_tf32 = False
41
+
42
+ logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
43
+ logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
44
+
45
+ # ========= Dataloaders ========= #
46
+ eval_loader = setup_eval_dataloader(cfg, '3dpw', 'test', cfg.MODEL.BACKBONE)
47
+ logger.info(f'Dataset loaded')
48
+
49
+ # ========= Load WHAM ========= #
50
+ smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
51
+ smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
52
+ network = build_network(cfg, smpl)
53
+ network.eval()
54
+
55
+ # Build SMPL models with each gender
56
+ smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']}
57
+
58
+ # Load vertices -> joints regression matrix to evaluate
59
+ J_regressor_eval = torch.from_numpy(
60
+ np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
61
+ )[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(cfg.DEVICE)
62
+ pelvis_idxs = [2, 3]
63
+
64
+ accumulator = defaultdict(list)
65
+ bar = Bar('Inference', fill='#', max=len(eval_loader))
66
+ with torch.no_grad():
67
+ for i in range(len(eval_loader)):
68
+ # Original batch
69
+ batch = eval_loader.dataset.load_data(i, False)
70
+ x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
71
+
72
+ if cfg.FLIP_EVAL:
73
+ flipped_batch = eval_loader.dataset.load_data(i, True)
74
+ f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
75
+
76
+ # Forward pass with flipped input
77
+ flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
78
+
79
+ # Forward pass with normal input
80
+ pred = network(x, inits, features, **kwargs)
81
+
82
+ if cfg.FLIP_EVAL:
83
+ # Merge two predictions
84
+ flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
85
+ pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
86
+ flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
87
+ avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
88
+ avg_pose = avg_pose.reshape(-1, 144)
89
+
90
+ # Refine trajectory with merged prediction
91
+ network.pred_pose = avg_pose.view_as(network.pred_pose)
92
+ network.pred_shape = avg_shape.view_as(network.pred_shape)
93
+ pred = network.forward_smpl(**kwargs)
94
+
95
+ # <======= Build predicted SMPL
96
+ pred_output = smpl['neutral'](body_pose=pred['poses_body'],
97
+ global_orient=pred['poses_root_cam'],
98
+ betas=pred['betas'].squeeze(0),
99
+ pose2rot=False)
100
+ pred_verts = pred_output.vertices.cpu()
101
+ pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu()
102
+ # =======>
103
+
104
+ # <======= Build groundtruth SMPL
105
+ target_output = smpl[batch['gender']](
106
+ body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]),
107
+ global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]),
108
+ betas=gt['betas'][0],
109
+ pose2rot=False)
110
+ target_verts = target_output.vertices.cpu()
111
+ target_j3d = torch.matmul(J_regressor_eval, target_output.vertices).cpu()
112
+ # =======>
113
+
114
+ # <======= Compute performance of the current sequence
115
+ pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
116
+ [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs
117
+ )
118
+ S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
119
+ pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
120
+ mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
121
+ pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
122
+ accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1]
123
+ accel = accel * (30 ** 2) # per frame^s to per s^2
124
+ # =======>
125
+
126
+ summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
127
+ bar.suffix = summary_string
128
+ bar.next()
129
+
130
+ # <======= Accumulate the results over entire sequences
131
+ accumulator['pa_mpjpe'].append(pa_mpjpe)
132
+ accumulator['mpjpe'].append(mpjpe)
133
+ accumulator['pve'].append(pve)
134
+ accumulator['accel'].append(accel)
135
+ # =======>
136
+
137
+ # <======= (Optional) Render the prediction
138
+ if not (_render and args.render):
139
+ # Skip if PyTorch3D is not installed or rendering argument is not parsed.
140
+ continue
141
+
142
+ # Save path
143
+ viz_pth = osp.join('output', 'visualization')
144
+ os.makedirs(viz_pth, exist_ok=True)
145
+
146
+ # Build Renderer
147
+ width, height = batch['cam_intrinsics'][0][0, :2, -1].numpy() * 2
148
+ focal_length = batch['cam_intrinsics'][0][0, 0, 0].item()
149
+ renderer = Renderer(width, height, focal_length, cfg.DEVICE, smpl['neutral'].faces)
150
+
151
+ # Get images and writer
152
+ frame_list = batch['frame_id'][0].numpy()
153
+ imname_list = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', batch['vid'][:-2], '*.jpg')))
154
+ writer = imageio.get_writer(osp.join(viz_pth, batch['vid'] + '.mp4'),
155
+ mode='I', format='FFMPEG', fps=30, macro_block_size=1)
156
+
157
+ # Skip the invalid frames
158
+ for i, frame in enumerate(frame_list):
159
+ image = imageio.imread(imname_list[frame])
160
+
161
+ # NOTE: pred['verts'] is different from pred_verts as we substracted offset from SMPL mesh.
162
+ # Check line 121 in lib/models/smpl.py
163
+ vertices = pred['verts_cam'][i] + pred['trans_cam'][[i]]
164
+ image = renderer.render_mesh(vertices, image)
165
+ writer.append_data(image)
166
+ writer.close()
167
+ # =======>
168
+
169
+ for k, v in accumulator.items():
170
+ accumulator[k] = np.concatenate(v).mean()
171
+
172
+ print('')
173
+ log_str = 'Evaluation on 3DPW, '
174
+ log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
175
+ logger.info(log_str)
176
+
177
+ if __name__ == '__main__':
178
+ cfg, cfg_file, args = parse_args(test=True)
179
+ cfg = prepare_output_dir(cfg, cfg_file)
180
+
181
+ main(cfg, args)
lib/eval/evaluate_emdb.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import os.path as osp
4
+ from glob import glob
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+ import pickle
9
+ import numpy as np
10
+ from smplx import SMPL
11
+ from loguru import logger
12
+ from progress.bar import Bar
13
+
14
+ from configs import constants as _C
15
+ from configs.config import parse_args
16
+ from lib.data.dataloader import setup_eval_dataloader
17
+ from lib.models import build_network, build_body_model
18
+ from lib.eval.eval_utils import (
19
+ compute_jpe,
20
+ compute_rte,
21
+ compute_jitter,
22
+ compute_error_accel,
23
+ compute_foot_sliding,
24
+ batch_align_by_pelvis,
25
+ first_align_joints,
26
+ global_align_joints,
27
+ compute_rte,
28
+ compute_jitter,
29
+ compute_foot_sliding
30
+ batch_compute_similarity_transform_torch,
31
+ )
32
+ from lib.utils import transforms
33
+ from lib.utils.utils import prepare_output_dir
34
+ from lib.utils.utils import prepare_batch
35
+ from lib.utils.imutils import avg_preds
36
+
37
+ """
38
+ This is a tentative script to evaluate WHAM on EMDB dataset.
39
+ Current implementation requires EMDB dataset downloaded at ./datasets/EMDB/
40
+ """
41
+
42
+ m2mm = 1e3
43
+ @torch.no_grad()
44
+ def main(cfg, args):
45
+ torch.backends.cuda.matmul.allow_tf32 = False
46
+ torch.backends.cudnn.allow_tf32 = False
47
+
48
+ logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
49
+ logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
50
+
51
+ # ========= Dataloaders ========= #
52
+ eval_loader = setup_eval_dataloader(cfg, 'emdb', args.eval_split, cfg.MODEL.BACKBONE)
53
+ logger.info(f'Dataset loaded')
54
+
55
+ # ========= Load WHAM ========= #
56
+ smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
57
+ smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
58
+ network = build_network(cfg, smpl)
59
+ network.eval()
60
+
61
+ # Build SMPL models with each gender
62
+ smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']}
63
+
64
+ # Load vertices -> joints regression matrix to evaluate
65
+ pelvis_idxs = [1, 2]
66
+
67
+ # WHAM uses Y-down coordinate system, while EMDB dataset uses Y-up one.
68
+ yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(cfg.DEVICE)
69
+
70
+ # To torch tensor function
71
+ tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE)
72
+ accumulator = defaultdict(list)
73
+ bar = Bar('Inference', fill='#', max=len(eval_loader))
74
+ with torch.no_grad():
75
+ for i in range(len(eval_loader)):
76
+ # Original batch
77
+ batch = eval_loader.dataset.load_data(i, False)
78
+ x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2')
79
+
80
+ # Align with groundtruth data to the first frame
81
+ cam2yup = batch['R'][0][:1].to(cfg.DEVICE)
82
+ cam2ydown = cam2yup @ yup2ydown
83
+ cam2root = transforms.rotation_6d_to_matrix(inits[1][:, 0, 0])
84
+ ydown2root = cam2ydown.mT @ cam2root
85
+ ydown2root = transforms.matrix_to_rotation_6d(ydown2root)
86
+ kwargs['init_root'][:, 0] = ydown2root
87
+
88
+ if cfg.FLIP_EVAL:
89
+ flipped_batch = eval_loader.dataset.load_data(i, True)
90
+ f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2')
91
+
92
+ # Forward pass with flipped input
93
+ flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
94
+
95
+ # Forward pass with normal input
96
+ pred = network(x, inits, features, **kwargs)
97
+
98
+ if cfg.FLIP_EVAL:
99
+ # Merge two predictions
100
+ flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
101
+ pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
102
+ flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
103
+ avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
104
+ avg_pose = avg_pose.reshape(-1, 144)
105
+ avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2
106
+
107
+ # Refine trajectory with merged prediction
108
+ network.pred_pose = avg_pose.view_as(network.pred_pose)
109
+ network.pred_shape = avg_shape.view_as(network.pred_shape)
110
+ network.pred_contact = avg_contact.view_as(network.pred_contact)
111
+ output = network.forward_smpl(**kwargs)
112
+ pred = network.refine_trajectory(output, return_y_up=True, **kwargs)
113
+
114
+ # <======= Prepare groundtruth data
115
+ subj, seq = batch['vid'][:2], batch['vid'][3:]
116
+ annot_pth = glob(osp.join(_C.PATHS.EMDB_PTH, subj, seq, '*_data.pkl'))[0]
117
+ annot = pickle.load(open(annot_pth, 'rb'))
118
+
119
+ masks = annot['good_frames_mask']
120
+ gender = annot['gender']
121
+ poses_body = annot["smpl"]["poses_body"]
122
+ poses_root = annot["smpl"]["poses_root"]
123
+ betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0)
124
+ trans = annot["smpl"]["trans"]
125
+ extrinsics = annot["camera"]["extrinsics"]
126
+
127
+ # # Map to camear coordinate
128
+ poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root)))
129
+
130
+ # Groundtruth global motion
131
+ target_glob = smpl[gender](body_pose=tt(poses_body), global_orient=tt(poses_root), betas=tt(betas), transl=tt(trans))
132
+ target_j3d_glob = target_glob.joints[:, :24][masks]
133
+
134
+ # Groundtruth local motion
135
+ target_cam = smpl[gender](body_pose=tt(poses_body), global_orient=poses_root_cam, betas=tt(betas))
136
+ target_verts_cam = target_cam.vertices[masks]
137
+ target_j3d_cam = target_cam.joints[:, :24][masks]
138
+ # =======>
139
+
140
+ # Convert WHAM global orient to Y-up coordinate
141
+ poses_root = pred['poses_root_world'].squeeze(0)
142
+ pred_trans = pred['trans_world'].squeeze(0)
143
+ poses_root = yup2ydown.mT @ poses_root
144
+ pred_trans = (yup2ydown.mT @ pred_trans.unsqueeze(-1)).squeeze(-1)
145
+
146
+ # <======= Build predicted motion
147
+ # Predicted global motion
148
+ pred_glob = smpl['neutral'](body_pose=pred['poses_body'], global_orient=poses_root.unsqueeze(1), betas=pred['betas'].squeeze(0), transl=pred_trans, pose2rot=False)
149
+ pred_j3d_glob = pred_glob.joints[:, :24]
150
+
151
+ # Predicted local motion
152
+ pred_cam = smpl['neutral'](body_pose=pred['poses_body'], global_orient=pred['poses_root_cam'], betas=pred['betas'].squeeze(0), pose2rot=False)
153
+ pred_verts_cam = pred_cam.vertices
154
+ pred_j3d_cam = pred_cam.joints[:, :24]
155
+ # =======>
156
+
157
+ # <======= Evaluation on the local motion
158
+ pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam = batch_align_by_pelvis(
159
+ [pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam], pelvis_idxs
160
+ )
161
+ S1_hat = batch_compute_similarity_transform_torch(pred_j3d_cam, target_j3d_cam)
162
+ pa_mpjpe = torch.sqrt(((S1_hat - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
163
+ mpjpe = torch.sqrt(((pred_j3d_cam - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
164
+ pve = torch.sqrt(((pred_verts_cam - target_verts_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm
165
+ accel = compute_error_accel(joints_pred=pred_j3d_cam.cpu(), joints_gt=target_j3d_cam.cpu())[1:-1]
166
+ accel = accel * (30 ** 2) # per frame^s to per s^2
167
+
168
+ summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
169
+ bar.suffix = summary_string
170
+ bar.next()
171
+ # =======>
172
+
173
+ # <======= Evaluation on the global motion
174
+ chunk_length = 100
175
+ w_mpjpe, wa_mpjpe = [], []
176
+ for start in range(0, masks.sum(), chunk_length):
177
+ end = min(masks.sum(), start + chunk_length)
178
+
179
+ target_j3d = target_j3d_glob[start:end].clone().cpu()
180
+ pred_j3d = pred_j3d_glob[start:end].clone().cpu()
181
+
182
+ w_j3d = first_align_joints(target_j3d, pred_j3d)
183
+ wa_j3d = global_align_joints(target_j3d, pred_j3d)
184
+
185
+ w_jpe = compute_jpe(target_j3d, w_j3d)
186
+ wa_jpe = compute_jpe(target_j3d, wa_j3d)
187
+ w_mpjpe.append(w_jpe)
188
+ wa_mpjpe.append(wa_jpe)
189
+
190
+ w_mpjpe = np.concatenate(w_mpjpe) * m2mm
191
+ wa_mpjpe = np.concatenate(wa_mpjpe) * m2mm
192
+
193
+ # Additional metrics
194
+ rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2
195
+ jitter = compute_jitter(pred_glob, fps=30)
196
+ foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm
197
+ # =======>
198
+
199
+ # Additional metrics
200
+ rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2
201
+ jitter = compute_jitter(pred_glob, fps=30)
202
+ foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm
203
+
204
+ # <======= Accumulate the results over entire sequences
205
+ accumulator['pa_mpjpe'].append(pa_mpjpe)
206
+ accumulator['mpjpe'].append(mpjpe)
207
+ accumulator['pve'].append(pve)
208
+ accumulator['accel'].append(accel)
209
+ accumulator['wa_mpjpe'].append(wa_mpjpe)
210
+ accumulator['w_mpjpe'].append(w_mpjpe)
211
+ accumulator['RTE'].append(rte)
212
+ accumulator['jitter'].append(jitter)
213
+ accumulator['FS'].append(foot_sliding)
214
+ # =======>
215
+
216
+ for k, v in accumulator.items():
217
+ accumulator[k] = np.concatenate(v).mean()
218
+
219
+ print('')
220
+ log_str = f'Evaluation on EMDB {args.eval_split}, '
221
+ log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
222
+ logger.info(log_str)
223
+
224
+ if __name__ == '__main__':
225
+ cfg, cfg_file, args = parse_args(test=True)
226
+ cfg = prepare_output_dir(cfg, cfg_file)
227
+
228
+ main(cfg, args)
lib/eval/evaluate_rich.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from collections import defaultdict
4
+ from time import time
5
+
6
+ import torch
7
+ import joblib
8
+ import numpy as np
9
+ from loguru import logger
10
+ from smplx import SMPL, SMPLX
11
+ from progress.bar import Bar
12
+
13
+ from configs import constants as _C
14
+ from configs.config import parse_args
15
+ from lib.data.dataloader import setup_eval_dataloader
16
+ from lib.models import build_network, build_body_model
17
+ from lib.eval.eval_utils import (
18
+ compute_error_accel,
19
+ batch_align_by_pelvis,
20
+ batch_compute_similarity_transform_torch,
21
+ )
22
+ from lib.utils import transforms
23
+ from lib.utils.utils import prepare_output_dir
24
+ from lib.utils.utils import prepare_batch
25
+ from lib.utils.imutils import avg_preds
26
+
27
+ m2mm = 1e3
28
+ smplx2smpl = torch.from_numpy(joblib.load(_C.BMODEL.SMPLX2SMPL)['matrix']).unsqueeze(0).float().cuda()
29
+ @torch.no_grad()
30
+ def main(cfg, args):
31
+ torch.backends.cuda.matmul.allow_tf32 = False
32
+ torch.backends.cudnn.allow_tf32 = False
33
+
34
+ logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
35
+ logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
36
+
37
+ # ========= Dataloaders ========= #
38
+ eval_loader = setup_eval_dataloader(cfg, 'rich', 'test', cfg.MODEL.BACKBONE)
39
+ logger.info(f'Dataset loaded')
40
+
41
+ # ========= Load WHAM ========= #
42
+ smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
43
+ smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
44
+ network = build_network(cfg, smpl)
45
+ network.eval()
46
+
47
+ # Build neutral SMPL model for WHAM and gendered SMPLX models for the groundtruth data
48
+ smpl = SMPL(_C.BMODEL.FLDR, gender='neutral').to(cfg.DEVICE)
49
+
50
+ # Load vertices -> joints regression matrix to evaluate
51
+ J_regressor_eval = smpl.J_regressor.clone().unsqueeze(0)
52
+ pelvis_idxs = [1, 2]
53
+
54
+ accumulator = defaultdict(list)
55
+ bar = Bar('Inference', fill='#', max=len(eval_loader))
56
+ with torch.no_grad():
57
+ for i in range(len(eval_loader)):
58
+ time_dict = {}
59
+ _t = time()
60
+
61
+ # Original batch
62
+ batch = eval_loader.dataset.load_data(i, False)
63
+ x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
64
+
65
+ # <======= Inference
66
+ if cfg.FLIP_EVAL:
67
+ flipped_batch = eval_loader.dataset.load_data(i, True)
68
+ f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2')
69
+
70
+ # Forward pass with flipped input
71
+ flipped_pred = network(f_x, f_inits, f_features, **f_kwargs)
72
+ time_dict['inference_flipped'] = time() - _t; _t = time()
73
+
74
+ # Forward pass with normal input
75
+ pred = network(x, inits, features, **kwargs)
76
+ time_dict['inference'] = time() - _t; _t = time()
77
+
78
+ if cfg.FLIP_EVAL:
79
+ # Merge two predictions
80
+ flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0)
81
+ pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0)
82
+ flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6)
83
+ avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape)
84
+ avg_pose = avg_pose.reshape(-1, 144)
85
+
86
+ # Refine trajectory with merged prediction
87
+ network.pred_pose = avg_pose.view_as(network.pred_pose)
88
+ network.pred_shape = avg_shape.view_as(network.pred_shape)
89
+ pred = network.forward_smpl(**kwargs)
90
+ time_dict['averaging'] = time() - _t; _t = time()
91
+ # =======>
92
+
93
+ # <======= Build predicted SMPL
94
+ pred_output = smpl(body_pose=pred['poses_body'],
95
+ global_orient=pred['poses_root_cam'],
96
+ betas=pred['betas'].squeeze(0),
97
+ pose2rot=False)
98
+ pred_verts = pred_output.vertices.cpu()
99
+ pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu()
100
+ time_dict['building prediction'] = time() - _t; _t = time()
101
+ # =======>
102
+
103
+ # <======= Build groundtruth SMPL (from SMPLX)
104
+ smplx = SMPLX(_C.BMODEL.FLDR.replace('smpl', 'smplx'),
105
+ gender=batch['gender'],
106
+ batch_size=len(pred_verts)
107
+ ).to(cfg.DEVICE)
108
+ gt_pose = transforms.matrix_to_axis_angle(transforms.rotation_6d_to_matrix(gt['pose'][0]))
109
+ target_output = smplx(
110
+ body_pose=gt_pose[:, 1:-2].reshape(-1, 63),
111
+ global_orient=gt_pose[:, 0],
112
+ betas=gt['betas'][0])
113
+ target_verts = torch.matmul(smplx2smpl, target_output.vertices.cuda()).cpu()
114
+ target_j3d = torch.matmul(J_regressor_eval, target_verts.to(cfg.DEVICE)).cpu()
115
+ time_dict['building target'] = time() - _t; _t = time()
116
+ # =======>
117
+
118
+ # <======= Compute performance of the current sequence
119
+ pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
120
+ [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs
121
+ )
122
+ S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
123
+ pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
124
+ mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
125
+ pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm
126
+ accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1]
127
+ accel = accel * (30 ** 2) # per frame^s to per s^2
128
+ time_dict['evaluating'] = time() - _t; _t = time()
129
+ # =======>
130
+
131
+ # summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}'
132
+ summary_string = f'{batch["vid"]} | ' + ' '.join([f'{k}: {v:.1f} s' for k, v in time_dict.items()])
133
+ bar.suffix = summary_string
134
+ bar.next()
135
+
136
+ # <======= Accumulate the results over entire sequences
137
+ accumulator['pa_mpjpe'].append(pa_mpjpe)
138
+ accumulator['mpjpe'].append(mpjpe)
139
+ accumulator['pve'].append(pve)
140
+ accumulator['accel'].append(accel)
141
+
142
+ # =======>
143
+
144
+ for k, v in accumulator.items():
145
+ accumulator[k] = np.concatenate(v).mean()
146
+
147
+ print('')
148
+ log_str = 'Evaluation on RICH, '
149
+ log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()])
150
+ logger.info(log_str)
151
+
152
+ if __name__ == '__main__':
153
+ cfg, cfg_file, args = parse_args(test=True)
154
+ cfg = prepare_output_dir(cfg, cfg_file)
155
+
156
+ main(cfg, args)
lib/models/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import yaml
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from configs import constants as _C
7
+ from .smpl import SMPL
8
+
9
+
10
+ def build_body_model(device, batch_size=1, gender='neutral', **kwargs):
11
+ sys.stdout = open(os.devnull, 'w')
12
+ body_model = SMPL(
13
+ model_path=_C.BMODEL.FLDR,
14
+ gender=gender,
15
+ batch_size=batch_size,
16
+ create_transl=False).to(device)
17
+ sys.stdout = sys.__stdout__
18
+ return body_model
19
+
20
+
21
+ def build_network(cfg, smpl):
22
+ from .wham import Network
23
+
24
+ with open(cfg.MODEL_CONFIG, 'r') as f:
25
+ model_config = yaml.safe_load(f)
26
+ model_config.update({'d_feat': _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]})
27
+
28
+ network = Network(smpl, **model_config).to(cfg.DEVICE)
29
+
30
+ # Load Checkpoint
31
+ if os.path.isfile(cfg.TRAIN.CHECKPOINT):
32
+ checkpoint = torch.load(cfg.TRAIN.CHECKPOINT)
33
+ ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
34
+ model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
35
+ network.load_state_dict(model_state_dict, strict=False)
36
+ logger.info(f"=> loaded checkpoint '{cfg.TRAIN.CHECKPOINT}' ")
37
+ else:
38
+ logger.info(f"=> Warning! no checkpoint found at '{cfg.TRAIN.CHECKPOINT}'.")
39
+
40
+ return network
lib/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.72 kB). View file
 
lib/models/__pycache__/smpl.cpython-39.pyc ADDED
Binary file (7 kB). View file
 
lib/models/__pycache__/wham.cpython-39.pyc ADDED
Binary file (4.95 kB). View file
 
lib/models/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modules import MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator
2
+ from .utils import rollout_global_motion, compute_camera_pose, reset_root_velocity, compute_camera_motion
lib/models/layers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (482 Bytes). View file