Techt3o commited on
Commit
c87d1bc
·
verified ·
1 Parent(s): f561f8b

ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe

Browse files
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +120 -11
  3. lib/models/layers/__pycache__/utils.cpython-39.pyc +0 -0
  4. lib/models/layers/modules.py +262 -0
  5. lib/models/layers/utils.py +52 -0
  6. lib/models/preproc/__pycache__/detector.cpython-39.pyc +0 -0
  7. lib/models/preproc/__pycache__/extractor.cpython-39.pyc +0 -0
  8. lib/models/preproc/__pycache__/slam.cpython-39.pyc +0 -0
  9. lib/models/preproc/backbone/__pycache__/hmr2.cpython-39.pyc +0 -0
  10. lib/models/preproc/backbone/__pycache__/pose_transformer.cpython-39.pyc +0 -0
  11. lib/models/preproc/backbone/__pycache__/smpl_head.cpython-39.pyc +0 -0
  12. lib/models/preproc/backbone/__pycache__/t_cond_mlp.cpython-39.pyc +0 -0
  13. lib/models/preproc/backbone/__pycache__/utils.cpython-39.pyc +0 -0
  14. lib/models/preproc/backbone/__pycache__/vit.cpython-39.pyc +0 -0
  15. lib/models/preproc/backbone/hmr2.py +77 -0
  16. lib/models/preproc/backbone/pose_transformer.py +357 -0
  17. lib/models/preproc/backbone/smpl_head.py +128 -0
  18. lib/models/preproc/backbone/t_cond_mlp.py +198 -0
  19. lib/models/preproc/backbone/utils.py +115 -0
  20. lib/models/preproc/backbone/vit.py +348 -0
  21. lib/models/preproc/detector.py +146 -0
  22. lib/models/preproc/extractor.py +112 -0
  23. lib/models/preproc/slam.py +70 -0
  24. lib/models/smpl.py +264 -0
  25. lib/models/smplify/__init__.py +1 -0
  26. lib/models/smplify/__pycache__/__init__.cpython-39.pyc +0 -0
  27. lib/models/smplify/__pycache__/losses.cpython-39.pyc +0 -0
  28. lib/models/smplify/__pycache__/smplify.cpython-39.pyc +0 -0
  29. lib/models/smplify/losses.py +87 -0
  30. lib/models/smplify/smplify.py +83 -0
  31. lib/models/wham.py +210 -0
  32. lib/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  33. lib/utils/__pycache__/imutils.cpython-39.pyc +0 -0
  34. lib/utils/__pycache__/kp_utils.cpython-39.pyc +0 -0
  35. lib/utils/__pycache__/transforms.cpython-39.pyc +0 -0
  36. lib/utils/data_utils.py +113 -0
  37. lib/utils/imutils.py +363 -0
  38. lib/utils/kp_utils.py +761 -0
  39. lib/utils/transforms.py +828 -0
  40. lib/utils/utils.py +265 -0
  41. lib/vis/__pycache__/renderer.cpython-39.pyc +0 -0
  42. lib/vis/__pycache__/run_vis.cpython-39.pyc +0 -0
  43. lib/vis/__pycache__/tools.cpython-39.pyc +0 -0
  44. lib/vis/renderer.py +313 -0
  45. lib/vis/run_vis.py +92 -0
  46. lib/vis/tools.py +822 -0
  47. output/demo/test19/output.mp4 +0 -0
  48. output/demo/test19/slam_results.pth +3 -0
  49. output/demo/test19/tracking_results.pth +3 -0
  50. output/demo/test19/wham_output.pkl +3 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Soyong Shin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,120 @@
1
- ---
2
- title: Motionbert Meta Sapiens
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- short_description: Sapiens
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion
2
+
3
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/2312.07531) <a href="https://wham.is.tue.mpg.de/"><img alt="Project" src="https://img.shields.io/badge/-Project%20Page-lightgrey?logo=Google%20Chrome&color=informational&logoColor=white"></a> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ysUtGSwidTQIdBQRhq0hj63KbseFujkn?usp=sharing)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wham-reconstructing-world-grounded-humans/3d-human-pose-estimation-on-3dpw)](https://paperswithcode.com/sota/3d-human-pose-estimation-on-3dpw?p=wham-reconstructing-world-grounded-humans) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wham-reconstructing-world-grounded-humans/3d-human-pose-estimation-on-emdb)](https://paperswithcode.com/sota/3d-human-pose-estimation-on-emdb?p=wham-reconstructing-world-grounded-humans)
5
+
6
+
7
+ https://github.com/yohanshin/WHAM/assets/46889727/da4602b4-0597-4e64-8da4-ab06931b23ee
8
+
9
+
10
+ ## Introduction
11
+ This repository is the official [Pytorch](https://pytorch.org/) implementation of [WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion](https://arxiv.org/abs/2312.07531). For more information, please visit our [project page](https://wham.is.tue.mpg.de/).
12
+
13
+
14
+ ## Installation
15
+ Please see [Installation](docs/INSTALL.md) for details.
16
+
17
+
18
+ ## Quick Demo
19
+
20
+ ### [<img src="https://i.imgur.com/QCojoJk.png" width="30"> Google Colab for WHAM demo is now available](https://colab.research.google.com/drive/1ysUtGSwidTQIdBQRhq0hj63KbseFujkn?usp=sharing)
21
+
22
+ ### Registration
23
+
24
+ To download SMPL body models (Neutral, Female, and Male), you need to register for [SMPL](https://smpl.is.tue.mpg.de/) and [SMPLify](https://smplify.is.tue.mpg.de/). The username and password for both homepages will be used while fetching the demo data.
25
+
26
+ Next, run the following script to fetch demo data. This script will download all the required dependencies including trained models and demo videos.
27
+
28
+ ```bash
29
+ bash fetch_demo_data.sh
30
+ ```
31
+
32
+ You can try with one examplar video:
33
+ ```
34
+ python demo.py --video examples/IMG_9732.mov --visualize
35
+ ```
36
+
37
+ We assume camera focal length following [CLIFF](https://github.com/haofanwang/CLIFF). You can specify known camera intrinsics [fx fy cx cy] for SLAM as the demo example below:
38
+ ```
39
+ python demo.py --video examples/drone_video.mp4 --calib examples/drone_calib.txt --visualize
40
+ ```
41
+
42
+ You can skip SLAM if you only want to get camera-coordinate motion. You can run as:
43
+ ```
44
+ python demo.py --video examples/IMG_9732.mov --visualize --estimate_local_only
45
+ ```
46
+
47
+ You can further refine the results of WHAM using Temporal SMPLify as a post processing. This will allow better 2D alignment as well as 3D accuracy. All you need to do is add `--run_smplify` flag when running demo.
48
+
49
+ ## Docker
50
+
51
+ Please refer to [Docker](docs/DOCKER.md) for details.
52
+
53
+ ## Python API
54
+
55
+ Please refer to [API](docs/API.md) for details.
56
+
57
+ ## Dataset
58
+ Please see [Dataset](docs/DATASET.md) for details.
59
+
60
+ ## Evaluation
61
+ ```bash
62
+ # Evaluate on 3DPW dataset
63
+ python -m lib.eval.evaluate_3dpw --cfg configs/yamls/demo.yaml TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar
64
+
65
+ # Evaluate on RICH dataset
66
+ python -m lib.eval.evaluate_rich --cfg configs/yamls/demo.yaml TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar
67
+
68
+ # Evaluate on EMDB dataset (also computes W-MPJPE and WA-MPJPE)
69
+ python -m lib.eval.evaluate_emdb --cfg configs/yamls/demo.yaml --eval-split 1 TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar # EMDB 1
70
+
71
+ python -m lib.eval.evaluate_emdb --cfg configs/yamls/demo.yaml --eval-split 2 TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar # EMDB 2
72
+ ```
73
+
74
+ ## Training
75
+ WHAM training involves into two different stages; (1) 2D to SMPL lifting through AMASS dataset and (2) finetuning with feature integration using the video datasets. Please see [Dataset](docs/DATASET.md) for preprocessing the training datasets.
76
+
77
+ ### Stage 1.
78
+ ```bash
79
+ python train.py --cfg configs/yamls/stage1.yaml
80
+ ```
81
+
82
+ ### Stage 2.
83
+ Training stage 2 requires pretrained results from the stage 1. You can use your pretrained results, or download the weight from [Google Drive](https://drive.google.com/file/d/1Erjkho7O0bnZFawarntICRUCroaKabRE/view?usp=sharing) save as `checkpoints/wham_stage1.tar.pth`.
84
+ ```bash
85
+ python train.py --cfg configs/yamls/stage2.yaml TRAIN.CHECKPOINT <PATH-TO-STAGE1-RESULTS>
86
+ ```
87
+
88
+ ### Train with BEDLAM
89
+ TBD
90
+
91
+ ## Acknowledgement
92
+ We would like to sincerely appreciate Hongwei Yi and Silvia Zuffi for the discussion and proofreading. Part of this work was done when Soyong Shin was an intern at the Max Planck Institute for Intelligence System.
93
+
94
+ The base implementation is largely borrowed from [VIBE](https://github.com/mkocabas/VIBE) and [TCMR](https://github.com/hongsukchoi/TCMR_RELEASE). We use [ViTPose](https://github.com/ViTAE-Transformer/ViTPose) for 2D keypoints detection and [DPVO](https://github.com/princeton-vl/DPVO), [DROID-SLAM](https://github.com/princeton-vl/DROID-SLAM) for extracting camera motion. Please visit their official websites for more details.
95
+
96
+ ## TODO
97
+
98
+ - [ ] Data preprocessing
99
+
100
+ - [x] Training implementation
101
+
102
+ - [x] Colab demo release
103
+
104
+ - [x] Demo for custom videos
105
+
106
+ ## Citation
107
+ ```
108
+ @InProceedings{shin2023wham,
109
+ title={WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion},
110
+ author={Shin, Soyong and Kim, Juyong and Halilaj, Eni and Black, Michael J.},
111
+ booktitle={Computer Vision and Pattern Recognition (CVPR)},
112
+ year={2024}
113
+ }
114
+ ```
115
+
116
+ ## License
117
+ Please see [License](./LICENSE) for details.
118
+
119
+ ## Contact
120
+ Please contact [email protected] for any questions related to this work.
lib/models/layers/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
lib/models/layers/modules.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch import nn
8
+ from configs import constants as _C
9
+ from .utils import rollout_global_motion
10
+ from lib.utils.transforms import axis_angle_to_matrix
11
+
12
+
13
+ class Regressor(nn.Module):
14
+ def __init__(self, in_dim, hid_dim, out_dims, init_dim, layer='LSTM', n_layers=2, n_iters=1):
15
+ super().__init__()
16
+ self.n_outs = len(out_dims)
17
+
18
+ self.rnn = getattr(nn, layer.upper())(
19
+ in_dim + init_dim, hid_dim, n_layers,
20
+ bidirectional=False, batch_first=True, dropout=0.3)
21
+
22
+ for i, out_dim in enumerate(out_dims):
23
+ setattr(self, 'declayer%d'%i, nn.Linear(hid_dim, out_dim))
24
+ nn.init.xavier_uniform_(getattr(self, 'declayer%d'%i).weight, gain=0.01)
25
+
26
+ def forward(self, x, inits, h0):
27
+ xc = torch.cat([x, *inits], dim=-1)
28
+ xc, h0 = self.rnn(xc, h0)
29
+
30
+ preds = []
31
+ for j in range(self.n_outs):
32
+ out = getattr(self, 'declayer%d'%j)(xc)
33
+ preds.append(out)
34
+
35
+ return preds, xc, h0
36
+
37
+
38
+ class NeuralInitialization(nn.Module):
39
+ def __init__(self, in_dim, hid_dim, layer, n_layers):
40
+ super().__init__()
41
+
42
+ out_dim = hid_dim
43
+ self.n_layers = n_layers
44
+ self.num_inits = int(layer.upper() == 'LSTM') + 1
45
+ out_dim *= self.num_inits * n_layers
46
+
47
+ self.linear1 = nn.Linear(in_dim, hid_dim)
48
+ self.linear2 = nn.Linear(hid_dim, hid_dim * self.n_layers)
49
+ self.linear3 = nn.Linear(hid_dim * self.n_layers, out_dim)
50
+ self.relu1 = nn.ReLU()
51
+ self.relu2 = nn.ReLU()
52
+
53
+ def forward(self, x):
54
+ b = x.shape[0]
55
+
56
+ out = self.linear3(self.relu2(self.linear2(self.relu1(self.linear1(x)))))
57
+ out = out.view(b, self.num_inits, self.n_layers, -1).permute(1, 2, 0, 3).contiguous()
58
+
59
+ if self.num_inits == 2:
60
+ return tuple([_ for _ in out])
61
+ return out[0]
62
+
63
+
64
+ class Integrator(nn.Module):
65
+ def __init__(self, in_channel, out_channel, hid_channel=1024):
66
+ super().__init__()
67
+
68
+ self.layer1 = nn.Linear(in_channel, hid_channel)
69
+ self.relu1 = nn.ReLU()
70
+ self.dr1 = nn.Dropout(0.1)
71
+
72
+ self.layer2 = nn.Linear(hid_channel, hid_channel)
73
+ self.relu2 = nn.ReLU()
74
+ self.dr2 = nn.Dropout(0.1)
75
+
76
+ self.layer3 = nn.Linear(hid_channel, out_channel)
77
+
78
+
79
+ def forward(self, x, feat):
80
+ res = x
81
+ mask = (feat != 0).all(dim=-1).all(dim=-1)
82
+
83
+ out = torch.cat((x, feat), dim=-1)
84
+ out = self.layer1(out)
85
+ out = self.relu1(out)
86
+ out = self.dr1(out)
87
+
88
+ out = self.layer2(out)
89
+ out = self.relu2(out)
90
+ out = self.dr2(out)
91
+
92
+ out = self.layer3(out)
93
+ out[mask] = out[mask] + res[mask]
94
+
95
+ return out
96
+
97
+
98
+ class MotionEncoder(nn.Module):
99
+ def __init__(self,
100
+ in_dim,
101
+ d_embed,
102
+ pose_dr,
103
+ rnn_type,
104
+ n_layers,
105
+ n_joints):
106
+ super().__init__()
107
+
108
+ self.n_joints = n_joints
109
+
110
+ self.embed_layer = nn.Linear(in_dim, d_embed)
111
+ self.pos_drop = nn.Dropout(pose_dr)
112
+
113
+ # Keypoints initializer
114
+ self.neural_init = NeuralInitialization(n_joints * 3 + in_dim, d_embed, rnn_type, n_layers)
115
+
116
+ # 3d keypoints regressor
117
+ self.regressor = Regressor(
118
+ d_embed, d_embed, [n_joints * 3], n_joints * 3, rnn_type, n_layers)
119
+
120
+ def forward(self, x, init):
121
+ """ Forward pass of motion encoder.
122
+ """
123
+
124
+ self.b, self.f = x.shape[:2]
125
+ x = self.embed_layer(x.reshape(self.b, self.f, -1))
126
+ x = self.pos_drop(x)
127
+
128
+ h0 = self.neural_init(init)
129
+ pred_list = [init[..., :self.n_joints * 3]]
130
+ motion_context_list = []
131
+
132
+ for i in range(self.f):
133
+ (pred_kp3d, ), motion_context, h0 = self.regressor(x[:, [i]], pred_list[-1:], h0)
134
+ motion_context_list.append(motion_context)
135
+ pred_list.append(pred_kp3d)
136
+
137
+ pred_kp3d = torch.cat(pred_list[1:], dim=1).view(self.b, self.f, -1, 3)
138
+ motion_context = torch.cat(motion_context_list, dim=1)
139
+
140
+ # Merge 3D keypoints with motion context
141
+ motion_context = torch.cat((motion_context, pred_kp3d.reshape(self.b, self.f, -1)), dim=-1)
142
+ return pred_kp3d, motion_context
143
+
144
+
145
+ class TrajectoryDecoder(nn.Module):
146
+ def __init__(self,
147
+ d_embed,
148
+ rnn_type,
149
+ n_layers):
150
+ super().__init__()
151
+
152
+ # Trajectory regressor
153
+ self.regressor = Regressor(
154
+ d_embed, d_embed, [3, 6], 12, rnn_type, n_layers, )
155
+
156
+ def forward(self, x, root, cam_a, h0=None):
157
+ """ Forward pass of trajectory decoder.
158
+ """
159
+
160
+ b, f = x.shape[:2]
161
+ pred_root_list, pred_vel_list = [root[:, :1]], []
162
+
163
+ for i in range(f):
164
+ # Global coordinate estimation
165
+ (pred_rootv, pred_rootr), _, h0 = self.regressor(
166
+ x[:, [i]], [pred_root_list[-1], cam_a[:, [i]]], h0)
167
+
168
+ pred_root_list.append(pred_rootr)
169
+ pred_vel_list.append(pred_rootv)
170
+
171
+ pred_root = torch.cat(pred_root_list, dim=1).view(b, f + 1, -1)
172
+ pred_vel = torch.cat(pred_vel_list, dim=1).view(b, f, -1)
173
+
174
+ return pred_root, pred_vel
175
+
176
+
177
+ class MotionDecoder(nn.Module):
178
+ def __init__(self,
179
+ d_embed,
180
+ rnn_type,
181
+ n_layers):
182
+ super().__init__()
183
+
184
+ self.n_pose = 24
185
+
186
+ # SMPL pose initialization
187
+ self.neural_init = NeuralInitialization(len(_C.BMODEL.MAIN_JOINTS) * 6, d_embed, rnn_type, n_layers)
188
+
189
+ # 3d keypoints regressor
190
+ self.regressor = Regressor(
191
+ d_embed, d_embed, [self.n_pose * 6, 10, 3, 4], self.n_pose * 6, rnn_type, n_layers)
192
+
193
+ def forward(self, x, init):
194
+ """ Forward pass of motion decoder.
195
+ """
196
+ b, f = x.shape[:2]
197
+
198
+ h0 = self.neural_init(init[:, :, _C.BMODEL.MAIN_JOINTS].reshape(b, 1, -1))
199
+
200
+ # Recursive prediction of SMPL parameters
201
+ pred_pose_list = [init.reshape(b, 1, -1)]
202
+ pred_shape_list, pred_cam_list, pred_contact_list = [], [], []
203
+
204
+ for i in range(f):
205
+ # Camera coordinate estimation
206
+ (pred_pose, pred_shape, pred_cam, pred_contact), _, h0 = self.regressor(x[:, [i]], pred_pose_list[-1:], h0)
207
+ pred_pose_list.append(pred_pose)
208
+ pred_shape_list.append(pred_shape)
209
+ pred_cam_list.append(pred_cam)
210
+ pred_contact_list.append(pred_contact)
211
+
212
+ pred_pose = torch.cat(pred_pose_list[1:], dim=1).view(b, f, -1)
213
+ pred_shape = torch.cat(pred_shape_list, dim=1).view(b, f, -1)
214
+ pred_cam = torch.cat(pred_cam_list, dim=1).view(b, f, -1)
215
+ pred_contact = torch.cat(pred_contact_list, dim=1).view(b, f, -1)
216
+
217
+ return pred_pose, pred_shape, pred_cam, pred_contact
218
+
219
+
220
+ class TrajectoryRefiner(nn.Module):
221
+ def __init__(self,
222
+ d_embed,
223
+ d_hidden,
224
+ rnn_type,
225
+ n_layers):
226
+ super().__init__()
227
+
228
+ d_input = d_embed + 12
229
+ self.refiner = Regressor(
230
+ d_input, d_hidden, [6, 3], 9, rnn_type, n_layers)
231
+
232
+ def forward(self, context, pred_vel, output, cam_angvel, return_y_up):
233
+ b, f = context.shape[:2]
234
+
235
+ # Register values
236
+ pred_root = output['poses_root_r6d'].clone().detach()
237
+ feet = output['feet'].clone().detach()
238
+ contact = output['contact'].clone().detach()
239
+
240
+ feet_vel = torch.cat((torch.zeros_like(feet[:, :1]), feet[:, 1:] - feet[:, :-1]), dim=1) * 30 # Normalize to 30 times
241
+ feet = (feet_vel * contact.unsqueeze(-1)).reshape(b, f, -1) # Velocity input
242
+ inpt_feat = torch.cat([context, feet], dim=-1)
243
+
244
+ (delta_root, delta_vel), _, _ = self.refiner(inpt_feat, [pred_root[:, 1:], pred_vel], h0=None)
245
+ pred_root[:, 1:] = pred_root[:, 1:] + delta_root
246
+ pred_vel = pred_vel + delta_vel
247
+
248
+ # root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
249
+
250
+ # if return_y_up:
251
+ # yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
252
+ # root_world = yup2ydown.mT @ root_world
253
+ # trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
254
+
255
+ output.update({
256
+ 'poses_root_r6d_refined': pred_root,
257
+ 'vel_root_refined': pred_vel,
258
+ # 'poses_root_world': root_world,
259
+ # 'trans_world': trans_world,
260
+ })
261
+
262
+ return output
lib/models/layers/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.utils import transforms
3
+
4
+
5
+
6
+ def rollout_global_motion(root_r, root_v, init_trans=None):
7
+ b, f = root_v.shape[:2]
8
+ root = transforms.rotation_6d_to_matrix(root_r[:])
9
+ vel_world = (root[:, :-1] @ root_v.unsqueeze(-1)).squeeze(-1)
10
+ trans = torch.cumsum(vel_world, dim=1)
11
+
12
+ if init_trans is not None: trans = trans + init_trans
13
+ return root[:, 1:], trans
14
+
15
+ def compute_camera_motion(output, root_c_d6d, root_w, trans, pred_cam):
16
+ root_c = transforms.rotation_6d_to_matrix(root_c_d6d) # Root orient in cam coord
17
+ cam_R = root_c @ root_w.mT
18
+ pelvis_cam = output.full_cam.view_as(pred_cam)
19
+ pelvis_world = (cam_R.mT @ pelvis_cam.unsqueeze(-1)).squeeze(-1)
20
+ cam_T_world = pelvis_world - trans
21
+ cam_T = (cam_R @ cam_T_world.unsqueeze(-1)).squeeze(-1)
22
+
23
+ return cam_R, cam_T
24
+
25
+ def compute_camera_pose(root_c_d6d, root_w):
26
+ root_c = transforms.rotation_6d_to_matrix(root_c_d6d) # Root orient in cam coord
27
+ cam_R = root_c @ root_w.mT
28
+ return cam_R
29
+
30
+
31
+ def reset_root_velocity(smpl, output, stationary, pred_ori, pred_vel, thr=0.7):
32
+ b, f = pred_vel.shape[:2]
33
+
34
+ stationary_mask = (stationary.clone().detach() > thr).unsqueeze(-1).float()
35
+ poses_root = transforms.rotation_6d_to_matrix(pred_ori.clone().detach())
36
+ vel_world = (poses_root[:, 1:] @ pred_vel.clone().detach().unsqueeze(-1)).squeeze(-1)
37
+
38
+ output = smpl.get_output(body_pose=output.body_pose.clone().detach(),
39
+ global_orient=poses_root[:, 1:].reshape(-1, 1, 3, 3),
40
+ betas=output.betas.clone().detach(),
41
+ pose2rot=False)
42
+ feet = output.feet.reshape(b, f, 4, 3)
43
+ feet_vel = feet[:, 1:] - feet[:, :-1] + vel_world[:, 1:].unsqueeze(-2)
44
+ feet_vel = torch.cat((torch.zeros_like(feet_vel[:, :1]), feet_vel), dim=1)
45
+
46
+ stationary_vel = feet_vel * stationary_mask
47
+ del_vel = stationary_vel.sum(dim=2) / ((stationary_vel != 0).sum(dim=2) + 1e-4)
48
+ vel_world_update = vel_world - del_vel
49
+
50
+ vel_root = (poses_root[:, 1:].mT @ vel_world_update.unsqueeze(-1)).squeeze(-1)
51
+
52
+ return vel_root
lib/models/preproc/__pycache__/detector.cpython-39.pyc ADDED
Binary file (4.77 kB). View file
 
lib/models/preproc/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.48 kB). View file
 
lib/models/preproc/__pycache__/slam.cpython-39.pyc ADDED
Binary file (2.6 kB). View file
 
lib/models/preproc/backbone/__pycache__/hmr2.cpython-39.pyc ADDED
Binary file (2.43 kB). View file
 
lib/models/preproc/backbone/__pycache__/pose_transformer.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
lib/models/preproc/backbone/__pycache__/smpl_head.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
lib/models/preproc/backbone/__pycache__/t_cond_mlp.cpython-39.pyc ADDED
Binary file (6.04 kB). View file
 
lib/models/preproc/backbone/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.59 kB). View file
 
lib/models/preproc/backbone/__pycache__/vit.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
lib/models/preproc/backbone/hmr2.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ # import pytorch_lightning as pl
7
+
8
+ from yacs.config import CfgNode
9
+ from .vit import vit
10
+ from .smpl_head import SMPLTransformerDecoderHead
11
+
12
+ # class HMR2(pl.LightningModule):
13
+ class HMR2(nn.Module):
14
+
15
+ def __init__(self):
16
+ """
17
+ Setup HMR2 model
18
+ Args:
19
+ cfg (CfgNode): Config file as a yacs CfgNode
20
+ """
21
+ super().__init__()
22
+
23
+ # Create backbone feature extractor
24
+ self.backbone = vit()
25
+
26
+ # Create SMPL head
27
+ self.smpl_head = SMPLTransformerDecoderHead()
28
+
29
+
30
+ def decode(self, x):
31
+
32
+ batch_size = x.shape[0]
33
+ pred_smpl_params, pred_cam, _ = self.smpl_head(x)
34
+
35
+ # Compute model vertices, joints and the projected joints
36
+ pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
37
+ pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
38
+ pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
39
+ return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam
40
+
41
+ def forward(self, x, encode=False, **kwargs):
42
+ """
43
+ Run a forward step of the network
44
+ Args:
45
+ batch (Dict): Dictionary containing batch data
46
+ train (bool): Flag indicating whether it is training or validation mode
47
+ Returns:
48
+ Dict: Dictionary containing the regression output
49
+ """
50
+
51
+ # Use RGB image as input
52
+ batch_size = x.shape[0]
53
+
54
+ # Compute conditioning features using the backbone
55
+ # if using ViT backbone, we need to use a different aspect ratio
56
+ conditioning_feats = self.backbone(x[:,:,:,32:-32])
57
+ if encode:
58
+ conditioning_feats = einops.rearrange(conditioning_feats, 'b c h w -> b (h w) c')
59
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
60
+ token_out = self.smpl_head.transformer(token, context=conditioning_feats)
61
+ return token_out.squeeze(1)
62
+
63
+ pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats)
64
+
65
+ # Compute model vertices, joints and the projected joints
66
+ pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
67
+ pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
68
+ pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
69
+ return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam
70
+
71
+
72
+ def hmr2(checkpoint_pth):
73
+ model = HMR2()
74
+ if os.path.exists(checkpoint_pth):
75
+ model.load_state_dict(torch.load(checkpoint_pth, map_location='cpu')['state_dict'], strict=False)
76
+ print(f'Load backbone weight: {checkpoint_pth}')
77
+ return model
lib/models/preproc/backbone/pose_transformer.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from .t_cond_mlp import (
10
+ AdaptiveLayerNorm1D,
11
+ FrequencyEmbedder,
12
+ normalization_layer,
13
+ )
14
+ # from .vit import Attention, FeedForward
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if isfunction(d) else d
25
+
26
+
27
+ class PreNorm(nn.Module):
28
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
29
+ super().__init__()
30
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
31
+ self.fn = fn
32
+
33
+ def forward(self, x: torch.Tensor, *args, **kwargs):
34
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
35
+ return self.fn(self.norm(x, *args), **kwargs)
36
+ else:
37
+ return self.fn(self.norm(x), **kwargs)
38
+
39
+
40
+ class FeedForward(nn.Module):
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(dim, hidden_dim),
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim, dim),
48
+ nn.Dropout(dropout),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ project_out = not (heads == 1 and dim_head == dim)
60
+
61
+ self.heads = heads
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.attend = nn.Softmax(dim=-1)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
68
+
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x):
76
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
77
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
78
+
79
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
80
+
81
+ attn = self.attend(dots)
82
+ attn = self.dropout(attn)
83
+
84
+ out = torch.matmul(attn, v)
85
+ out = rearrange(out, "b h n d -> b n (h d)")
86
+ return self.to_out(out)
87
+
88
+
89
+ class CrossAttention(nn.Module):
90
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
91
+ super().__init__()
92
+ inner_dim = dim_head * heads
93
+ project_out = not (heads == 1 and dim_head == dim)
94
+
95
+ self.heads = heads
96
+ self.scale = dim_head**-0.5
97
+
98
+ self.attend = nn.Softmax(dim=-1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ context_dim = default(context_dim, dim)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
103
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
104
+
105
+ self.to_out = (
106
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
107
+ if project_out
108
+ else nn.Identity()
109
+ )
110
+
111
+ def forward(self, x, context=None):
112
+ context = default(context, x)
113
+ k, v = self.to_kv(context).chunk(2, dim=-1)
114
+ q = self.to_q(x)
115
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
116
+
117
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
118
+
119
+ attn = self.attend(dots)
120
+ attn = self.dropout(attn)
121
+
122
+ out = torch.matmul(attn, v)
123
+ out = rearrange(out, "b h n d -> b n (h d)")
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class TransformerCrossAttn(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ depth: int,
165
+ heads: int,
166
+ dim_head: int,
167
+ mlp_dim: int,
168
+ dropout: float = 0.0,
169
+ norm: str = "layer",
170
+ norm_cond_dim: int = -1,
171
+ context_dim: Optional[int] = None,
172
+ ):
173
+ super().__init__()
174
+ self.layers = nn.ModuleList([])
175
+ for _ in range(depth):
176
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
177
+ ca = CrossAttention(
178
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
179
+ )
180
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
181
+ self.layers.append(
182
+ nn.ModuleList(
183
+ [
184
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
185
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
186
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
187
+ ]
188
+ )
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
192
+ if context_list is None:
193
+ context_list = [context] * len(self.layers)
194
+ if len(context_list) != len(self.layers):
195
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
196
+
197
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
198
+ x = self_attn(x, *args) + x
199
+ x = cross_attn(x, *args, context=context_list[i]) + x
200
+ x = ff(x, *args) + x
201
+ return x
202
+
203
+
204
+ class DropTokenDropout(nn.Module):
205
+ def __init__(self, p: float = 0.1):
206
+ super().__init__()
207
+ if p < 0 or p > 1:
208
+ raise ValueError(
209
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
210
+ )
211
+ self.p = p
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ # x: (batch_size, seq_len, dim)
215
+ if self.training and self.p > 0:
216
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
217
+ # TODO: permutation idx for each batch using torch.argsort
218
+ if zero_mask.any():
219
+ x = x[:, ~zero_mask, :]
220
+ return x
221
+
222
+
223
+ class ZeroTokenDropout(nn.Module):
224
+ def __init__(self, p: float = 0.1):
225
+ super().__init__()
226
+ if p < 0 or p > 1:
227
+ raise ValueError(
228
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
229
+ )
230
+ self.p = p
231
+
232
+ def forward(self, x: torch.Tensor):
233
+ # x: (batch_size, seq_len, dim)
234
+ if self.training and self.p > 0:
235
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
236
+ # Zero-out the masked tokens
237
+ x[zero_mask, :] = 0
238
+ return x
239
+
240
+
241
+ class TransformerEncoder(nn.Module):
242
+ def __init__(
243
+ self,
244
+ num_tokens: int,
245
+ token_dim: int,
246
+ dim: int,
247
+ depth: int,
248
+ heads: int,
249
+ mlp_dim: int,
250
+ dim_head: int = 64,
251
+ dropout: float = 0.0,
252
+ emb_dropout: float = 0.0,
253
+ emb_dropout_type: str = "drop",
254
+ emb_dropout_loc: str = "token",
255
+ norm: str = "layer",
256
+ norm_cond_dim: int = -1,
257
+ token_pe_numfreq: int = -1,
258
+ ):
259
+ super().__init__()
260
+ if token_pe_numfreq > 0:
261
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
262
+ self.to_token_embedding = nn.Sequential(
263
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
264
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
265
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
266
+ nn.Linear(token_dim_new, dim),
267
+ )
268
+ else:
269
+ self.to_token_embedding = nn.Linear(token_dim, dim)
270
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
271
+ if emb_dropout_type == "drop":
272
+ self.dropout = DropTokenDropout(emb_dropout)
273
+ elif emb_dropout_type == "zero":
274
+ self.dropout = ZeroTokenDropout(emb_dropout)
275
+ else:
276
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
277
+ self.emb_dropout_loc = emb_dropout_loc
278
+
279
+ self.transformer = Transformer(
280
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
281
+ )
282
+
283
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
284
+ x = inp
285
+
286
+ if self.emb_dropout_loc == "input":
287
+ x = self.dropout(x)
288
+ x = self.to_token_embedding(x)
289
+
290
+ if self.emb_dropout_loc == "token":
291
+ x = self.dropout(x)
292
+ b, n, _ = x.shape
293
+ x += self.pos_embedding[:, :n]
294
+
295
+ if self.emb_dropout_loc == "token_afterpos":
296
+ x = self.dropout(x)
297
+ x = self.transformer(x, *args)
298
+ return x
299
+
300
+
301
+ class TransformerDecoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ num_tokens: int,
305
+ token_dim: int,
306
+ dim: int,
307
+ depth: int,
308
+ heads: int,
309
+ mlp_dim: int,
310
+ dim_head: int = 64,
311
+ dropout: float = 0.0,
312
+ emb_dropout: float = 0.0,
313
+ emb_dropout_type: str = 'drop',
314
+ norm: str = "layer",
315
+ norm_cond_dim: int = -1,
316
+ context_dim: Optional[int] = None,
317
+ skip_token_embedding: bool = False,
318
+ ):
319
+ super().__init__()
320
+ if not skip_token_embedding:
321
+ self.to_token_embedding = nn.Linear(token_dim, dim)
322
+ else:
323
+ self.to_token_embedding = nn.Identity()
324
+ if token_dim != dim:
325
+ raise ValueError(
326
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
327
+ )
328
+
329
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
330
+ if emb_dropout_type == "drop":
331
+ self.dropout = DropTokenDropout(emb_dropout)
332
+ elif emb_dropout_type == "zero":
333
+ self.dropout = ZeroTokenDropout(emb_dropout)
334
+ elif emb_dropout_type == "normal":
335
+ self.dropout = nn.Dropout(emb_dropout)
336
+
337
+ self.transformer = TransformerCrossAttn(
338
+ dim,
339
+ depth,
340
+ heads,
341
+ dim_head,
342
+ mlp_dim,
343
+ dropout,
344
+ norm=norm,
345
+ norm_cond_dim=norm_cond_dim,
346
+ context_dim=context_dim,
347
+ )
348
+
349
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
350
+ x = self.to_token_embedding(inp)
351
+ b, n, _ = x.shape
352
+
353
+ x = self.dropout(x)
354
+ x += self.pos_embedding[:, :n]
355
+
356
+ x = self.transformer(x, *args, context=context, context_list=context_list)
357
+ return x
lib/models/preproc/backbone/smpl_head.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import einops
6
+
7
+ from configs import constants as _C
8
+ from lib.utils.transforms import axis_angle_to_matrix
9
+ from .pose_transformer import TransformerDecoder
10
+
11
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Convert 6D rotation representation to 3x3 rotation matrix.
14
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
15
+ Args:
16
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
17
+ Returns:
18
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
19
+ """
20
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
21
+ a1 = x[:, :, 0]
22
+ a2 = x[:, :, 1]
23
+ b1 = F.normalize(a1)
24
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
25
+ b3 = torch.cross(b1, b2)
26
+ return torch.stack((b1, b2, b3), dim=-1)
27
+
28
+ def build_smpl_head(cfg):
29
+ smpl_head_type = 'transformer_decoder'
30
+ if smpl_head_type == 'transformer_decoder':
31
+ return SMPLTransformerDecoderHead(cfg)
32
+ else:
33
+ raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type))
34
+
35
+ class SMPLTransformerDecoderHead(nn.Module):
36
+ """ Cross-attention based SMPL Transformer decoder
37
+ """
38
+
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.joint_rep_type = '6d'
42
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
43
+ npose = self.joint_rep_dim * 24
44
+ self.npose = npose
45
+ self.input_is_mean_shape = False
46
+ transformer_args = dict(
47
+ num_tokens=1,
48
+ token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
49
+ dim=1024,
50
+ )
51
+ transformer_args_from_cfg = dict(
52
+ depth=6, heads=8, mlp_dim=1024, dim_head=64, dropout=0.0, emb_dropout=0.0, norm='layer', context_dim=1280
53
+ )
54
+ transformer_args = (transformer_args | transformer_args_from_cfg)
55
+ self.transformer = TransformerDecoder(
56
+ **transformer_args
57
+ )
58
+ dim=transformer_args['dim']
59
+ self.decpose = nn.Linear(dim, npose)
60
+ self.decshape = nn.Linear(dim, 10)
61
+ self.deccam = nn.Linear(dim, 3)
62
+
63
+ mean_params = np.load(_C.BMODEL.MEAN_PARAMS)
64
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
65
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
66
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
67
+ self.register_buffer('init_body_pose', init_body_pose)
68
+ self.register_buffer('init_betas', init_betas)
69
+ self.register_buffer('init_cam', init_cam)
70
+
71
+ def forward(self, x, **kwargs):
72
+
73
+ batch_size = x.shape[0]
74
+ # vit pretrained backbone is channel-first. Change to token-first
75
+
76
+ init_body_pose = self.init_body_pose.expand(batch_size, -1)
77
+ init_betas = self.init_betas.expand(batch_size, -1)
78
+ init_cam = self.init_cam.expand(batch_size, -1)
79
+
80
+ # TODO: Convert init_body_pose to aa rep if needed
81
+ if self.joint_rep_type == 'aa':
82
+ raise NotImplementedError
83
+
84
+ pred_body_pose = init_body_pose
85
+ pred_betas = init_betas
86
+ pred_cam = init_cam
87
+ pred_body_pose_list = []
88
+ pred_betas_list = []
89
+ pred_cam_list = []
90
+
91
+ # Input token to transformer is zero token
92
+ if len(x.shape) > 2:
93
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
94
+ if self.input_is_mean_shape:
95
+ token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:,None,:]
96
+ else:
97
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
98
+
99
+ # Pass through transformer
100
+ token_out = self.transformer(token, context=x)
101
+ token_out = token_out.squeeze(1) # (B, C)
102
+ else:
103
+ token_out = x
104
+
105
+ # Readout from token_out
106
+ pred_body_pose = self.decpose(token_out) + pred_body_pose
107
+ pred_betas = self.decshape(token_out) + pred_betas
108
+ pred_cam = self.deccam(token_out) + pred_cam
109
+ pred_body_pose_list.append(pred_body_pose)
110
+ pred_betas_list.append(pred_betas)
111
+ pred_cam_list.append(pred_cam)
112
+
113
+ # Convert self.joint_rep_type -> rotmat
114
+ joint_conversion_fn = {
115
+ '6d': rot6d_to_rotmat,
116
+ 'aa': lambda x: axis_angle_to_matrix(x.view(-1, 3).contiguous())
117
+ }[self.joint_rep_type]
118
+
119
+ pred_smpl_params_list = {}
120
+ pred_smpl_params_list['body_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0)
121
+ pred_smpl_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
122
+ pred_smpl_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
123
+ pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, 24, 3, 3)
124
+
125
+ pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
126
+ 'body_pose': pred_body_pose[:, 1:],
127
+ 'betas': pred_betas}
128
+ return pred_smpl_params, pred_cam, pred_smpl_params_list
lib/models/preproc/backbone/t_cond_mlp.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class AdaptiveLayerNorm1D(torch.nn.Module):
8
+ def __init__(self, data_dim: int, norm_cond_dim: int):
9
+ super().__init__()
10
+ if data_dim <= 0:
11
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
12
+ if norm_cond_dim <= 0:
13
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
14
+ self.norm = torch.nn.LayerNorm(
15
+ data_dim
16
+ ) # TODO: Check if elementwise_affine=True is correct
17
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
18
+ torch.nn.init.zeros_(self.linear.weight)
19
+ torch.nn.init.zeros_(self.linear.bias)
20
+
21
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
22
+ # x: (batch, ..., data_dim)
23
+ # t: (batch, norm_cond_dim)
24
+ # return: (batch, data_dim)
25
+ x = self.norm(x)
26
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
27
+
28
+ # Add singleton dimensions to alpha and beta
29
+ if x.dim() > 2:
30
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
31
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
32
+
33
+ return x * (1 + alpha) + beta
34
+
35
+
36
+ class SequentialCond(torch.nn.Sequential):
37
+ def forward(self, input, *args, **kwargs):
38
+ for module in self:
39
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
40
+ # print(f'Passing on args to {module}', [a.shape for a in args])
41
+ input = module(input, *args, **kwargs)
42
+ else:
43
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
44
+ input = module(input)
45
+ return input
46
+
47
+
48
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
49
+ if norm == "batch":
50
+ return torch.nn.BatchNorm1d(dim)
51
+ elif norm == "layer":
52
+ return torch.nn.LayerNorm(dim)
53
+ elif norm == "ada":
54
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
55
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
56
+ elif norm is None:
57
+ return torch.nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown norm: {norm}")
60
+
61
+
62
+ def linear_norm_activ_dropout(
63
+ input_dim: int,
64
+ output_dim: int,
65
+ activation: torch.nn.Module = torch.nn.ReLU(),
66
+ bias: bool = True,
67
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
68
+ dropout: float = 0.0,
69
+ norm_cond_dim: int = -1,
70
+ ) -> SequentialCond:
71
+ layers = []
72
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
73
+ if norm is not None:
74
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
75
+ layers.append(copy.deepcopy(activation))
76
+ if dropout > 0.0:
77
+ layers.append(torch.nn.Dropout(dropout))
78
+ return SequentialCond(*layers)
79
+
80
+
81
+ def create_simple_mlp(
82
+ input_dim: int,
83
+ hidden_dims: List[int],
84
+ output_dim: int,
85
+ activation: torch.nn.Module = torch.nn.ReLU(),
86
+ bias: bool = True,
87
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
88
+ dropout: float = 0.0,
89
+ norm_cond_dim: int = -1,
90
+ ) -> SequentialCond:
91
+ layers = []
92
+ prev_dim = input_dim
93
+ for hidden_dim in hidden_dims:
94
+ layers.extend(
95
+ linear_norm_activ_dropout(
96
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
97
+ )
98
+ )
99
+ prev_dim = hidden_dim
100
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
101
+ return SequentialCond(*layers)
102
+
103
+
104
+ class ResidualMLPBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ input_dim: int,
108
+ hidden_dim: int,
109
+ num_hidden_layers: int,
110
+ output_dim: int,
111
+ activation: torch.nn.Module = torch.nn.ReLU(),
112
+ bias: bool = True,
113
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
114
+ dropout: float = 0.0,
115
+ norm_cond_dim: int = -1,
116
+ ):
117
+ super().__init__()
118
+ if not (input_dim == output_dim == hidden_dim):
119
+ raise NotImplementedError(
120
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
121
+ )
122
+
123
+ layers = []
124
+ prev_dim = input_dim
125
+ for i in range(num_hidden_layers):
126
+ layers.append(
127
+ linear_norm_activ_dropout(
128
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
129
+ )
130
+ )
131
+ prev_dim = hidden_dim
132
+ self.model = SequentialCond(*layers)
133
+ self.skip = torch.nn.Identity()
134
+
135
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
136
+ return x + self.model(x, *args, **kwargs)
137
+
138
+
139
+ class ResidualMLP(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ input_dim: int,
143
+ hidden_dim: int,
144
+ num_hidden_layers: int,
145
+ output_dim: int,
146
+ activation: torch.nn.Module = torch.nn.ReLU(),
147
+ bias: bool = True,
148
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
149
+ dropout: float = 0.0,
150
+ num_blocks: int = 1,
151
+ norm_cond_dim: int = -1,
152
+ ):
153
+ super().__init__()
154
+ self.input_dim = input_dim
155
+ self.model = SequentialCond(
156
+ linear_norm_activ_dropout(
157
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
158
+ ),
159
+ *[
160
+ ResidualMLPBlock(
161
+ hidden_dim,
162
+ hidden_dim,
163
+ num_hidden_layers,
164
+ hidden_dim,
165
+ activation,
166
+ bias,
167
+ norm,
168
+ dropout,
169
+ norm_cond_dim,
170
+ )
171
+ for _ in range(num_blocks)
172
+ ],
173
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
174
+ )
175
+
176
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
177
+ return self.model(x, *args, **kwargs)
178
+
179
+
180
+ class FrequencyEmbedder(torch.nn.Module):
181
+ def __init__(self, num_frequencies, max_freq_log2):
182
+ super().__init__()
183
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
184
+ self.register_buffer("frequencies", frequencies)
185
+
186
+ def forward(self, x):
187
+ # x should be of size (N,) or (N, D)
188
+ N = x.size(0)
189
+ if x.dim() == 1: # (N,)
190
+ x = x.unsqueeze(1) # (N, D) where D=1
191
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
192
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
193
+ s = torch.sin(scaled)
194
+ c = torch.cos(scaled)
195
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
196
+ N, -1
197
+ ) # (N, D * 2 * num_frequencies + D)
198
+ return embedded
lib/models/preproc/backbone/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os
6
+ import os.path as osp
7
+ from collections import OrderedDict
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from skimage.filters import gaussian
12
+
13
+
14
+ def get_transform(center, scale, res, rot=0):
15
+ """Generate transformation matrix."""
16
+ # res: (height, width), (rows, cols)
17
+ crop_aspect_ratio = res[0] / float(res[1])
18
+ h = 200 * scale
19
+ w = h / crop_aspect_ratio
20
+ t = np.zeros((3, 3))
21
+ t[0, 0] = float(res[1]) / w
22
+ t[1, 1] = float(res[0]) / h
23
+ t[0, 2] = res[1] * (-float(center[0]) / w + .5)
24
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
25
+ t[2, 2] = 1
26
+ if not rot == 0:
27
+ rot = -rot # To match direction of rotation from cropping
28
+ rot_mat = np.zeros((3, 3))
29
+ rot_rad = rot * np.pi / 180
30
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
31
+ rot_mat[0, :2] = [cs, -sn]
32
+ rot_mat[1, :2] = [sn, cs]
33
+ rot_mat[2, 2] = 1
34
+ # Need to rotate around center
35
+ t_mat = np.eye(3)
36
+ t_mat[0, 2] = -res[1] / 2
37
+ t_mat[1, 2] = -res[0] / 2
38
+ t_inv = t_mat.copy()
39
+ t_inv[:2, 2] *= -1
40
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
41
+ return t
42
+
43
+
44
+ def transform(pt, center, scale, res, invert=0, rot=0):
45
+ """Transform pixel location to different reference."""
46
+ t = get_transform(center, scale, res, rot=rot)
47
+ if invert:
48
+ t = np.linalg.inv(t)
49
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
50
+ new_pt = np.dot(t, new_pt)
51
+ return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
52
+
53
+
54
+ def crop(img, center, scale, res):
55
+ """
56
+ Crop image according to the supplied bounding box.
57
+ res: [rows, cols]
58
+ """
59
+ # Upper left point
60
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
61
+ # Bottom right point
62
+ br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
63
+
64
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
65
+ if len(img.shape) > 2:
66
+ new_shape += [img.shape[2]]
67
+ new_img = np.zeros(new_shape, dtype=np.float32)
68
+
69
+ # Range to fill new array
70
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
71
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
72
+ # Range to sample from original image
73
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
74
+ old_y = max(0, ul[1]), min(len(img), br[1])
75
+ try:
76
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
77
+ except Exception as e:
78
+ print(e)
79
+
80
+ new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
81
+
82
+ return new_img, ul, br
83
+
84
+
85
+
86
+ def process_image(orig_img_rgb, center, scale, crop_height=256, crop_width=192, blur=False, do_crop=True):
87
+ """
88
+ Read image, do preprocessing and possibly crop it according to the bounding box.
89
+ If there are bounding box annotations, use them to crop the image.
90
+ If no bounding box is specified but openpose detections are available, use them to get the bounding box.
91
+ """
92
+
93
+ if blur:
94
+ # Blur image to avoid aliasing artifacts
95
+ downsampling_factor = ((scale * 200 * 1.0) / crop_height)
96
+ downsampling_factor = downsampling_factor / 2.0
97
+ if downsampling_factor > 1.1:
98
+ orig_img_rgb = gaussian(orig_img_rgb, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
99
+
100
+ IMG_NORM_MEAN = [0.485, 0.456, 0.406]
101
+ IMG_NORM_STD = [0.229, 0.224, 0.225]
102
+
103
+ if do_crop:
104
+ img, ul, br = crop(orig_img_rgb, center, scale, (crop_height, crop_width))
105
+ else:
106
+ img = orig_img_rgb.copy()
107
+ crop_img = img.copy()
108
+
109
+ img = img / 255.
110
+ mean = np.array(IMG_NORM_MEAN, dtype=np.float32)
111
+ std = np.array(IMG_NORM_STD, dtype=np.float32)
112
+ norm_img = (img - mean) / std
113
+ norm_img = np.transpose(norm_img, (2, 0, 1))
114
+
115
+ return norm_img, crop_img
lib/models/preproc/backbone/vit.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
+
12
+ def vit():
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ )
25
+
26
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
27
+ """
28
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
29
+ dimension for the original embeddings.
30
+ Args:
31
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
32
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
33
+ hw (Tuple): size of input image tokens.
34
+
35
+ Returns:
36
+ Absolute positional embeddings after processing with shape (1, H, W, C)
37
+ """
38
+ cls_token = None
39
+ B, L, C = abs_pos.shape
40
+ if has_cls_token:
41
+ cls_token = abs_pos[:, 0:1]
42
+ abs_pos = abs_pos[:, 1:]
43
+
44
+ if ori_h != h or ori_w != w:
45
+ new_abs_pos = F.interpolate(
46
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
47
+ size=(h, w),
48
+ mode="bicubic",
49
+ align_corners=False,
50
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
51
+
52
+ else:
53
+ new_abs_pos = abs_pos
54
+
55
+ if cls_token is not None:
56
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
57
+ return new_abs_pos
58
+
59
+ class DropPath(nn.Module):
60
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+ """
62
+ def __init__(self, drop_prob=None):
63
+ super(DropPath, self).__init__()
64
+ self.drop_prob = drop_prob
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training)
68
+
69
+ def extra_repr(self):
70
+ return 'p={}'.format(self.drop_prob)
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
74
+ super().__init__()
75
+ out_features = out_features or in_features
76
+ hidden_features = hidden_features or in_features
77
+ self.fc1 = nn.Linear(in_features, hidden_features)
78
+ self.act = act_layer()
79
+ self.fc2 = nn.Linear(hidden_features, out_features)
80
+ self.drop = nn.Dropout(drop)
81
+
82
+ def forward(self, x):
83
+ x = self.fc1(x)
84
+ x = self.act(x)
85
+ x = self.fc2(x)
86
+ x = self.drop(x)
87
+ return x
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
92
+ proj_drop=0., attn_head_dim=None,):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.dim = dim
97
+
98
+ if attn_head_dim is not None:
99
+ head_dim = attn_head_dim
100
+ all_head_dim = head_dim * self.num_heads
101
+
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+
104
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
105
+
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(all_head_dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape
112
+ qkv = self.qkv(x)
113
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
114
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115
+
116
+ q = q * self.scale
117
+ attn = (q @ k.transpose(-2, -1))
118
+
119
+ attn = attn.softmax(dim=-1)
120
+ attn = self.attn_drop(attn)
121
+
122
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
123
+ x = self.proj(x)
124
+ x = self.proj_drop(x)
125
+
126
+ return x
127
+
128
+ class Block(nn.Module):
129
+
130
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
131
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
132
+ norm_layer=nn.LayerNorm, attn_head_dim=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
139
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
140
+ )
141
+
142
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
143
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144
+ self.norm2 = norm_layer(dim)
145
+ mlp_hidden_dim = int(dim * mlp_ratio)
146
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147
+
148
+ def forward(self, x):
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ return x
152
+
153
+
154
+ class PatchEmbed(nn.Module):
155
+ """ Image to Patch Embedding
156
+ """
157
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
158
+ super().__init__()
159
+ img_size = to_2tuple(img_size)
160
+ patch_size = to_2tuple(patch_size)
161
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
162
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
163
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
164
+ self.img_size = img_size
165
+ self.patch_size = patch_size
166
+ self.num_patches = num_patches
167
+
168
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
169
+
170
+ def forward(self, x, **kwargs):
171
+ B, C, H, W = x.shape
172
+ x = self.proj(x)
173
+ Hp, Wp = x.shape[2], x.shape[3]
174
+
175
+ x = x.flatten(2).transpose(1, 2)
176
+ return x, (Hp, Wp)
177
+
178
+
179
+ class HybridEmbed(nn.Module):
180
+ """ CNN Feature Map Embedding
181
+ Extract feature map from CNN, flatten, project to embedding dim.
182
+ """
183
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
184
+ super().__init__()
185
+ assert isinstance(backbone, nn.Module)
186
+ img_size = to_2tuple(img_size)
187
+ self.img_size = img_size
188
+ self.backbone = backbone
189
+ if feature_size is None:
190
+ with torch.no_grad():
191
+ training = backbone.training
192
+ if training:
193
+ backbone.eval()
194
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
195
+ feature_size = o.shape[-2:]
196
+ feature_dim = o.shape[1]
197
+ backbone.train(training)
198
+ else:
199
+ feature_size = to_2tuple(feature_size)
200
+ feature_dim = self.backbone.feature_info.channels()[-1]
201
+ self.num_patches = feature_size[0] * feature_size[1]
202
+ self.proj = nn.Linear(feature_dim, embed_dim)
203
+
204
+ def forward(self, x):
205
+ x = self.backbone(x)[-1]
206
+ x = x.flatten(2).transpose(1, 2)
207
+ x = self.proj(x)
208
+ return x
209
+
210
+
211
+ class ViT(nn.Module):
212
+
213
+ def __init__(self,
214
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
215
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
216
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
217
+ frozen_stages=-1, ratio=1, last_norm=True,
218
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
219
+ ):
220
+ # Protect mutable default arguments
221
+ super(ViT, self).__init__()
222
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
223
+ self.num_classes = num_classes
224
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
225
+ self.frozen_stages = frozen_stages
226
+ self.use_checkpoint = use_checkpoint
227
+ self.patch_padding = patch_padding
228
+ self.freeze_attn = freeze_attn
229
+ self.freeze_ffn = freeze_ffn
230
+ self.depth = depth
231
+
232
+ if hybrid_backbone is not None:
233
+ self.patch_embed = HybridEmbed(
234
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
235
+ else:
236
+ self.patch_embed = PatchEmbed(
237
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
238
+ num_patches = self.patch_embed.num_patches
239
+
240
+ # since the pretraining model has class token
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
242
+
243
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
244
+
245
+ self.blocks = nn.ModuleList([
246
+ Block(
247
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
249
+ )
250
+ for i in range(depth)])
251
+
252
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
253
+
254
+ if self.pos_embed is not None:
255
+ trunc_normal_(self.pos_embed, std=.02)
256
+
257
+ self._freeze_stages()
258
+
259
+ def _freeze_stages(self):
260
+ """Freeze parameters."""
261
+ if self.frozen_stages >= 0:
262
+ self.patch_embed.eval()
263
+ for param in self.patch_embed.parameters():
264
+ param.requires_grad = False
265
+
266
+ for i in range(1, self.frozen_stages + 1):
267
+ m = self.blocks[i]
268
+ m.eval()
269
+ for param in m.parameters():
270
+ param.requires_grad = False
271
+
272
+ if self.freeze_attn:
273
+ for i in range(0, self.depth):
274
+ m = self.blocks[i]
275
+ m.attn.eval()
276
+ m.norm1.eval()
277
+ for param in m.attn.parameters():
278
+ param.requires_grad = False
279
+ for param in m.norm1.parameters():
280
+ param.requires_grad = False
281
+
282
+ if self.freeze_ffn:
283
+ self.pos_embed.requires_grad = False
284
+ self.patch_embed.eval()
285
+ for param in self.patch_embed.parameters():
286
+ param.requires_grad = False
287
+ for i in range(0, self.depth):
288
+ m = self.blocks[i]
289
+ m.mlp.eval()
290
+ m.norm2.eval()
291
+ for param in m.mlp.parameters():
292
+ param.requires_grad = False
293
+ for param in m.norm2.parameters():
294
+ param.requires_grad = False
295
+
296
+ def init_weights(self):
297
+ """Initialize the weights in backbone.
298
+ Args:
299
+ pretrained (str, optional): Path to pre-trained weights.
300
+ Defaults to None.
301
+ """
302
+ def _init_weights(m):
303
+ if isinstance(m, nn.Linear):
304
+ trunc_normal_(m.weight, std=.02)
305
+ if isinstance(m, nn.Linear) and m.bias is not None:
306
+ nn.init.constant_(m.bias, 0)
307
+ elif isinstance(m, nn.LayerNorm):
308
+ nn.init.constant_(m.bias, 0)
309
+ nn.init.constant_(m.weight, 1.0)
310
+
311
+ self.apply(_init_weights)
312
+
313
+ def get_num_layers(self):
314
+ return len(self.blocks)
315
+
316
+ @torch.jit.ignore
317
+ def no_weight_decay(self):
318
+ return {'pos_embed', 'cls_token'}
319
+
320
+ def forward_features(self, x):
321
+ B, C, H, W = x.shape
322
+ x, (Hp, Wp) = self.patch_embed(x)
323
+
324
+ if self.pos_embed is not None:
325
+ # fit for multiple GPU training
326
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
327
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
328
+
329
+ for blk in self.blocks:
330
+ if self.use_checkpoint:
331
+ x = checkpoint.checkpoint(blk, x)
332
+ else:
333
+ x = blk(x)
334
+
335
+ x = self.last_norm(x)
336
+
337
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
338
+
339
+ return xp
340
+
341
+ def forward(self, x):
342
+ x = self.forward_features(x)
343
+ return x
344
+
345
+ def train(self, mode=True):
346
+ """Convert the model into training mode."""
347
+ super().train(mode)
348
+ self._freeze_stages()
lib/models/preproc/detector.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import os.path as osp
5
+ from collections import defaultdict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import scipy.signal as signal
11
+ from progress.bar import Bar
12
+
13
+ from ultralytics import YOLO
14
+ from mmpose.apis import (
15
+ inference_top_down_pose_model,
16
+ init_pose_model,
17
+ get_track_id,
18
+ vis_pose_result,
19
+ )
20
+
21
+ ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
22
+ VIT_DIR = osp.join(ROOT_DIR, "third-party/ViTPose")
23
+
24
+ VIS_THRESH = 0.3
25
+ BBOX_CONF = 0.5
26
+ TRACKING_THR = 0.1
27
+ MINIMUM_FRMAES = 30
28
+ MINIMUM_JOINTS = 6
29
+
30
+ class DetectionModel(object):
31
+ def __init__(self, device):
32
+
33
+ # ViTPose
34
+ pose_model_cfg = osp.join(VIT_DIR, 'configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_huge_coco_256x192.py')
35
+ pose_model_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'vitpose-h-multi-coco.pth')
36
+ self.pose_model = init_pose_model(pose_model_cfg, pose_model_ckpt, device=device.lower())
37
+
38
+ # YOLO
39
+ bbox_model_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'yolov8x.pt')
40
+ self.bbox_model = YOLO(bbox_model_ckpt)
41
+
42
+ self.device = device
43
+ self.initialize_tracking()
44
+
45
+ def initialize_tracking(self, ):
46
+ self.next_id = 0
47
+ self.frame_id = 0
48
+ self.pose_results_last = []
49
+ self.tracking_results = {
50
+ 'id': [],
51
+ 'frame_id': [],
52
+ 'bbox': [],
53
+ 'keypoints': []
54
+ }
55
+
56
+ def xyxy_to_cxcys(self, bbox, s_factor=1.05):
57
+ cx, cy = bbox[[0, 2]].mean(), bbox[[1, 3]].mean()
58
+ scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200 * s_factor
59
+ return np.array([[cx, cy, scale]])
60
+
61
+ def compute_bboxes_from_keypoints(self, s_factor=1.2):
62
+ X = self.tracking_results['keypoints'].copy()
63
+ mask = X[..., -1] > VIS_THRESH
64
+
65
+ bbox = np.zeros((len(X), 3))
66
+ for i, (kp, m) in enumerate(zip(X, mask)):
67
+ bb = [kp[m, 0].min(), kp[m, 1].min(),
68
+ kp[m, 0].max(), kp[m, 1].max()]
69
+ cx, cy = [(bb[2]+bb[0])/2, (bb[3]+bb[1])/2]
70
+ bb_w = bb[2] - bb[0]
71
+ bb_h = bb[3] - bb[1]
72
+ s = np.stack((bb_w, bb_h)).max()
73
+ bb = np.array((cx, cy, s))
74
+ bbox[i] = bb
75
+
76
+ bbox[:, 2] = bbox[:, 2] * s_factor / 200.0
77
+ self.tracking_results['bbox'] = bbox
78
+
79
+ def track(self, img, fps, length):
80
+
81
+ # bbox detection
82
+ bboxes = self.bbox_model.predict(
83
+ img, device=self.device, classes=0, conf=BBOX_CONF, save=False, verbose=False
84
+ )[0].boxes.xyxy.detach().cpu().numpy()
85
+ bboxes = [{'bbox': bbox} for bbox in bboxes]
86
+
87
+ # keypoints detection
88
+ pose_results, returned_outputs = inference_top_down_pose_model(
89
+ self.pose_model,
90
+ img,
91
+ person_results=bboxes,
92
+ format='xyxy',
93
+ return_heatmap=False,
94
+ outputs=None)
95
+
96
+ # person identification
97
+ pose_results, self.next_id = get_track_id(
98
+ pose_results,
99
+ self.pose_results_last,
100
+ self.next_id,
101
+ use_oks=False,
102
+ tracking_thr=TRACKING_THR,
103
+ use_one_euro=True,
104
+ fps=fps)
105
+
106
+ for pose_result in pose_results:
107
+ n_valid = (pose_result['keypoints'][:, -1] > VIS_THRESH).sum()
108
+ if n_valid < MINIMUM_JOINTS: continue
109
+
110
+ _id = pose_result['track_id']
111
+ xyxy = pose_result['bbox']
112
+ bbox = self.xyxy_to_cxcys(xyxy)
113
+
114
+ self.tracking_results['id'].append(_id)
115
+ self.tracking_results['frame_id'].append(self.frame_id)
116
+ self.tracking_results['bbox'].append(bbox)
117
+ self.tracking_results['keypoints'].append(pose_result['keypoints'])
118
+
119
+ self.frame_id += 1
120
+ self.pose_results_last = pose_results
121
+
122
+ def process(self, fps):
123
+ for key in ['id', 'frame_id', 'keypoints']:
124
+ self.tracking_results[key] = np.array(self.tracking_results[key])
125
+ self.compute_bboxes_from_keypoints()
126
+
127
+ output = defaultdict(lambda: defaultdict(list))
128
+ ids = np.unique(self.tracking_results['id'])
129
+ for _id in ids:
130
+ idxs = np.where(self.tracking_results['id'] == _id)[0]
131
+ for key, val in self.tracking_results.items():
132
+ if key == 'id': continue
133
+ output[_id][key] = val[idxs]
134
+
135
+ # Smooth bounding box detection
136
+ ids = list(output.keys())
137
+ for _id in ids:
138
+ if len(output[_id]['bbox']) < MINIMUM_FRMAES:
139
+ del output[_id]
140
+ continue
141
+
142
+ kernel = int(int(fps/2) / 2) * 2 + 1
143
+ smoothed_bbox = np.array([signal.medfilt(param, kernel) for param in output[_id]['bbox'].T]).T
144
+ output[_id]['bbox'] = smoothed_bbox
145
+
146
+ return output
lib/models/preproc/extractor.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import os.path as osp
5
+ from collections import defaultdict
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import scipy.signal as signal
11
+ from progress.bar import Bar
12
+ from scipy.ndimage.filters import gaussian_filter1d
13
+
14
+ from configs import constants as _C
15
+ from .backbone.hmr2 import hmr2
16
+ from .backbone.utils import process_image
17
+ from ...utils.imutils import flip_kp, flip_bbox
18
+
19
+ ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
20
+
21
+ class FeatureExtractor(object):
22
+ def __init__(self, device, flip_eval=False, max_batch_size=64):
23
+
24
+ self.device = device
25
+ self.flip_eval = flip_eval
26
+ self.max_batch_size = max_batch_size
27
+
28
+ ckpt = osp.join(ROOT_DIR, 'checkpoints', 'hmr2a.ckpt')
29
+ self.model = hmr2(ckpt).to(device).eval()
30
+
31
+ def run(self, video, tracking_results, patch_h=256, patch_w=256):
32
+
33
+ if osp.isfile(video):
34
+ cap = cv2.VideoCapture(video)
35
+ is_video = True
36
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
37
+ width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
38
+ else: # Image list
39
+ cap = video
40
+ is_video = False
41
+ length = len(video)
42
+ height, width = cv2.imread(video[0]).shape[:2]
43
+
44
+ frame_id = 0
45
+ bar = Bar('Feature extraction ...', fill='#', max=length)
46
+ while True:
47
+ if is_video:
48
+ flag, img = cap.read()
49
+ if not flag:
50
+ break
51
+ else:
52
+ if frame_id >= len(cap):
53
+ break
54
+ img = cv2.imread(cap[frame_id])
55
+
56
+ for _id, val in tracking_results.items():
57
+ if not frame_id in val['frame_id']: continue
58
+
59
+ frame_id2 = np.where(val['frame_id'] == frame_id)[0][0]
60
+ bbox = val['bbox'][frame_id2]
61
+ cx, cy, scale = bbox
62
+
63
+ norm_img, crop_img = process_image(img[..., ::-1], [cx, cy], scale, patch_h, patch_w)
64
+ norm_img = torch.from_numpy(norm_img).unsqueeze(0).to(self.device)
65
+ feature = self.model(norm_img, encode=True)
66
+ tracking_results[_id]['features'].append(feature.cpu())
67
+
68
+ if frame_id2 == 0: # First frame of this subject
69
+ tracking_results = self.predict_init(norm_img, tracking_results, _id, flip_eval=False)
70
+
71
+ if self.flip_eval:
72
+ flipped_bbox = flip_bbox(bbox, width, height)
73
+ tracking_results[_id]['flipped_bbox'].append(flipped_bbox)
74
+
75
+ keypoints = val['keypoints'][frame_id2]
76
+ flipped_keypoints = flip_kp(keypoints, width)
77
+ tracking_results[_id]['flipped_keypoints'].append(flipped_keypoints)
78
+
79
+ flipped_features = self.model(torch.flip(norm_img, (3, )), encode=True)
80
+ tracking_results[_id]['flipped_features'].append(flipped_features.cpu())
81
+
82
+ if frame_id2 == 0:
83
+ tracking_results = self.predict_init(torch.flip(norm_img, (3, )), tracking_results, _id, flip_eval=True)
84
+
85
+ bar.next()
86
+ frame_id += 1
87
+
88
+ return self.process(tracking_results)
89
+
90
+ def predict_init(self, norm_img, tracking_results, _id, flip_eval=False):
91
+ prefix = 'flipped_' if flip_eval else ''
92
+
93
+ pred_global_orient, pred_body_pose, pred_betas, _ = self.model(norm_img, encode=False)
94
+ tracking_results[_id][prefix + 'init_global_orient'] = pred_global_orient.cpu()
95
+ tracking_results[_id][prefix + 'init_body_pose'] = pred_body_pose.cpu()
96
+ tracking_results[_id][prefix + 'init_betas'] = pred_betas.cpu()
97
+ return tracking_results
98
+
99
+ def process(self, tracking_results):
100
+ output = defaultdict(dict)
101
+
102
+ for _id, results in tracking_results.items():
103
+
104
+ for key, val in results.items():
105
+ if isinstance(val, list):
106
+ if isinstance(val[0], torch.Tensor):
107
+ val = torch.cat(val)
108
+ elif isinstance(val[0], np.ndarray):
109
+ val = np.array(val)
110
+ output[_id][key] = val
111
+
112
+ return output
lib/models/preproc/slam.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import glob
4
+ import os.path as osp
5
+ import os
6
+ import time
7
+ import torch
8
+ from pathlib import Path
9
+ from multiprocessing import Process, Queue
10
+
11
+ from dpvo.utils import Timer
12
+ from dpvo.dpvo import DPVO
13
+ from dpvo.config import cfg
14
+ from dpvo.stream import image_stream, video_stream
15
+
16
+ ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
17
+ DPVO_DIR = osp.join(ROOT_DIR, "third-party/DPVO")
18
+
19
+
20
+ class SLAMModel(object):
21
+ def __init__(self, video, output_pth, width, height, calib=None, stride=1, skip=0, buffer=2048):
22
+
23
+ if calib == None or not osp.exists(calib):
24
+ calib = osp.join(output_pth, 'calib.txt')
25
+ if not osp.exists(calib):
26
+ self.estimate_intrinsics(width, height, calib)
27
+
28
+ self.dpvo_cfg = osp.join(DPVO_DIR, 'config/default.yaml')
29
+ self.dpvo_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'dpvo.pth')
30
+
31
+ self.buffer = buffer
32
+ self.times = []
33
+ self.slam = None
34
+ self.queue = Queue(maxsize=8)
35
+ self.reader = Process(target=video_stream, args=(self.queue, video, calib, stride, skip))
36
+ self.reader.start()
37
+
38
+ def estimate_intrinsics(self, width, height, calib):
39
+ focal_length = (height ** 2 + width ** 2) ** 0.5
40
+ center_x = width / 2
41
+ center_y = height / 2
42
+
43
+ with open(calib, 'w') as fopen:
44
+ line = f'{focal_length} {focal_length} {center_x} {center_y}'
45
+ fopen.write(line)
46
+
47
+ def track(self, ):
48
+ (t, image, intrinsics) = self.queue.get()
49
+
50
+ if t < 0: return
51
+
52
+ image = torch.from_numpy(image).permute(2,0,1).cuda()
53
+ intrinsics = torch.from_numpy(intrinsics).cuda()
54
+
55
+ if self.slam is None:
56
+ cfg.merge_from_file(self.dpvo_cfg)
57
+ cfg.BUFFER_SIZE = self.buffer
58
+ self.slam = DPVO(cfg, self.dpvo_ckpt, ht=image.shape[1], wd=image.shape[2], viz=False)
59
+
60
+ with Timer("SLAM", enabled=False):
61
+ t = time.time()
62
+ self.slam(t, image, intrinsics)
63
+ self.times.append(time.time() - t)
64
+
65
+ def process(self, ):
66
+ for _ in range(12):
67
+ self.slam.update()
68
+
69
+ self.reader.join()
70
+ return self.slam.terminate()[0]
lib/models/smpl.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import os, sys
6
+
7
+ import torch
8
+ import numpy as np
9
+ from lib.utils import transforms
10
+
11
+ from smplx import SMPL as _SMPL
12
+ from smplx.utils import SMPLOutput as ModelOutput
13
+ from smplx.lbs import vertices2joints
14
+
15
+ from configs import constants as _C
16
+
17
+ class SMPL(_SMPL):
18
+ """ Extension of the official SMPL implementation to support more joints """
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ sys.stdout = open(os.devnull, 'w')
22
+ super(SMPL, self).__init__(*args, **kwargs)
23
+ sys.stdout = sys.__stdout__
24
+
25
+ J_regressor_wham = np.load(_C.BMODEL.JOINTS_REGRESSOR_WHAM)
26
+ J_regressor_eval = np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
27
+ self.register_buffer('J_regressor_wham', torch.tensor(
28
+ J_regressor_wham, dtype=torch.float32))
29
+ self.register_buffer('J_regressor_eval', torch.tensor(
30
+ J_regressor_eval, dtype=torch.float32))
31
+ self.register_buffer('J_regressor_feet', torch.from_numpy(
32
+ np.load(_C.BMODEL.JOINTS_REGRESSOR_FEET)
33
+ ).float())
34
+
35
+ def get_local_pose_from_reduced_global_pose(self, reduced_pose):
36
+ full_pose = torch.eye(
37
+ 3, device=reduced_pose.device
38
+ )[(None, ) * 2].repeat(reduced_pose.shape[0], 24, 1, 1)
39
+ full_pose[:, _C.BMODEL.MAIN_JOINTS] = reduced_pose
40
+ return full_pose
41
+
42
+ def forward(self,
43
+ pred_rot6d,
44
+ betas,
45
+ cam=None,
46
+ cam_intrinsics=None,
47
+ bbox=None,
48
+ res=None,
49
+ return_full_pose=False,
50
+ **kwargs):
51
+
52
+ rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
53
+ ).reshape(-1, 24, 3, 3)
54
+
55
+ output = self.get_output(body_pose=rotmat[:, 1:],
56
+ global_orient=rotmat[:, :1],
57
+ betas=betas.view(-1, 10),
58
+ pose2rot=False,
59
+ return_full_pose=return_full_pose)
60
+
61
+ if cam is not None:
62
+ joints3d = output.joints.reshape(*cam.shape[:2], -1, 3)
63
+
64
+ # Weak perspective projection (for InstaVariety)
65
+ weak_cam = convert_weak_perspective_to_perspective(cam)
66
+
67
+ weak_joints2d = weak_perspective_projection(
68
+ joints3d,
69
+ rotation=torch.eye(3, device=cam.device).unsqueeze(0).unsqueeze(0).expand(*cam.shape[:2], -1, -1),
70
+ translation=weak_cam,
71
+ focal_length=5000.,
72
+ camera_center=torch.zeros(*cam.shape[:2], 2, device=cam.device)
73
+ )
74
+ output.weak_joints2d = weak_joints2d
75
+
76
+ # Full perspective projection
77
+ full_cam = convert_pare_to_full_img_cam(
78
+ cam,
79
+ bbox[:, :, 2] * 200.,
80
+ bbox[:, :, :2],
81
+ res[:, 0].unsqueeze(-1),
82
+ res[:, 1].unsqueeze(-1),
83
+ focal_length=cam_intrinsics[:, :, 0, 0]
84
+ )
85
+
86
+ full_joints2d = full_perspective_projection(
87
+ joints3d,
88
+ translation=full_cam,
89
+ cam_intrinsics=cam_intrinsics,
90
+ )
91
+ output.full_joints2d = full_joints2d
92
+ output.full_cam = full_cam.reshape(-1, 3)
93
+
94
+ return output
95
+
96
+ def forward_nd(self,
97
+ pred_rot6d,
98
+ root,
99
+ betas,
100
+ return_full_pose=False):
101
+
102
+ rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
103
+ ).reshape(-1, 24, 3, 3)
104
+
105
+ output = self.get_output(body_pose=rotmat[:, 1:],
106
+ global_orient=root.reshape(-1, 1, 3, 3),
107
+ betas=betas.view(-1, 10),
108
+ pose2rot=False,
109
+ return_full_pose=return_full_pose)
110
+
111
+ return output
112
+
113
+ def get_output(self, *args, **kwargs):
114
+ kwargs['get_skin'] = True
115
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
116
+ joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
117
+ feet = vertices2joints(self.J_regressor_feet, smpl_output.vertices)
118
+
119
+ offset = joints[..., [11, 12], :].mean(-2)
120
+ if 'transl' in kwargs:
121
+ offset = offset - kwargs['transl']
122
+ vertices = smpl_output.vertices - offset.unsqueeze(-2)
123
+ joints = joints - offset.unsqueeze(-2)
124
+ feet = feet - offset.unsqueeze(-2)
125
+
126
+ output = ModelOutput(vertices=vertices,
127
+ global_orient=smpl_output.global_orient,
128
+ body_pose=smpl_output.body_pose,
129
+ joints=joints,
130
+ betas=smpl_output.betas,
131
+ full_pose=smpl_output.full_pose)
132
+ output.feet = feet
133
+ output.offset = offset
134
+ return output
135
+
136
+ def get_offset(self, *args, **kwargs):
137
+ kwargs['get_skin'] = True
138
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
139
+ joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
140
+
141
+ offset = joints[..., [11, 12], :].mean(-2)
142
+ return offset
143
+
144
+ def get_faces(self):
145
+ return np.array(self.faces)
146
+
147
+
148
+ def convert_weak_perspective_to_perspective(
149
+ weak_perspective_camera,
150
+ focal_length=5000.,
151
+ img_res=224,
152
+ ):
153
+
154
+ perspective_camera = torch.stack(
155
+ [
156
+ weak_perspective_camera[..., 1],
157
+ weak_perspective_camera[..., 2],
158
+ 2 * focal_length / (img_res * weak_perspective_camera[..., 0] + 1e-9)
159
+ ],
160
+ dim=-1
161
+ )
162
+ return perspective_camera
163
+
164
+
165
+ def weak_perspective_projection(
166
+ points,
167
+ rotation,
168
+ translation,
169
+ focal_length,
170
+ camera_center,
171
+ img_res=224,
172
+ normalize_joints2d=True,
173
+ ):
174
+ """
175
+ This function computes the perspective projection of a set of points.
176
+ Input:
177
+ points (b, f, N, 3): 3D points
178
+ rotation (b, f, 3, 3): Camera rotation
179
+ translation (b, f, 3): Camera translation
180
+ focal_length (b, f,) or scalar: Focal length
181
+ camera_center (b, f, 2): Camera center
182
+ """
183
+
184
+ K = torch.zeros([*points.shape[:2], 3, 3], device=points.device)
185
+ K[:,:,0,0] = focal_length
186
+ K[:,:,1,1] = focal_length
187
+ K[:,:,2,2] = 1.
188
+ K[:,:,:-1, -1] = camera_center
189
+
190
+ # Transform points
191
+ points = torch.einsum('bfij,bfkj->bfki', rotation, points)
192
+ points = points + translation.unsqueeze(-2)
193
+
194
+ # Apply perspective distortion
195
+ projected_points = points / points[...,-1].unsqueeze(-1)
196
+
197
+ # Apply camera intrinsics
198
+ projected_points = torch.einsum('bfij,bfkj->bfki', K, projected_points)
199
+
200
+ if normalize_joints2d:
201
+ projected_points = projected_points / (img_res / 2.)
202
+
203
+ return projected_points[..., :-1]
204
+
205
+
206
+ def full_perspective_projection(
207
+ points,
208
+ cam_intrinsics,
209
+ rotation=None,
210
+ translation=None,
211
+ ):
212
+
213
+ K = cam_intrinsics
214
+
215
+ if rotation is not None:
216
+ points = (rotation @ points.transpose(-1, -2)).transpose(-1, -2)
217
+ if translation is not None:
218
+ points = points + translation.unsqueeze(-2)
219
+ projected_points = points / points[..., -1].unsqueeze(-1)
220
+ projected_points = (K @ projected_points.transpose(-1, -2)).transpose(-1, -2)
221
+ return projected_points[..., :-1]
222
+
223
+
224
+ def convert_pare_to_full_img_cam(
225
+ pare_cam,
226
+ bbox_height,
227
+ bbox_center,
228
+ img_w,
229
+ img_h,
230
+ focal_length,
231
+ crop_res=224
232
+ ):
233
+
234
+ s, tx, ty = pare_cam[..., 0], pare_cam[..., 1], pare_cam[..., 2]
235
+ res = crop_res
236
+ r = bbox_height / res
237
+ tz = 2 * focal_length / (r * res * s)
238
+
239
+ cx = 2 * (bbox_center[..., 0] - (img_w / 2.)) / (s * bbox_height)
240
+ cy = 2 * (bbox_center[..., 1] - (img_h / 2.)) / (s * bbox_height)
241
+
242
+ cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
243
+ return cam_t
244
+
245
+
246
+ def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length):
247
+ """
248
+ convert the camera parameters from the crop camera to the full camera
249
+ :param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
250
+ :param center: shape=(N, 2) bbox coordinates (c_x, c_y)
251
+ :param scale: shape=(N) square bbox resolution (b / 200)
252
+ :param full_img_shape: shape=(N, 2) original image height and width
253
+ :param focal_length: shape=(N,)
254
+ :return:
255
+ """
256
+ img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
257
+ cx, cy, b = center[:, 0], center[:, 1], scale * 200
258
+ w_2, h_2 = img_w / 2., img_h / 2.
259
+ bs = b * crop_cam[:, 0] + 1e-9
260
+ tz = 2 * focal_length / bs
261
+ tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1]
262
+ ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2]
263
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
264
+ return full_cam
lib/models/smplify/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smplify import TemporalSMPLify
lib/models/smplify/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (241 Bytes). View file
 
lib/models/smplify/__pycache__/losses.cpython-39.pyc ADDED
Binary file (2.63 kB). View file
 
lib/models/smplify/__pycache__/smplify.cpython-39.pyc ADDED
Binary file (2.09 kB). View file
 
lib/models/smplify/losses.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def gmof(x, sigma):
4
+ """
5
+ Geman-McClure error function
6
+ """
7
+ x_squared = x ** 2
8
+ sigma_squared = sigma ** 2
9
+ return (sigma_squared * x_squared) / (sigma_squared + x_squared)
10
+
11
+
12
+ def compute_jitter(x):
13
+ """
14
+ Compute jitter for the input tensor
15
+ """
16
+ return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1)
17
+
18
+
19
+ class SMPLifyLoss(torch.nn.Module):
20
+ def __init__(self,
21
+ res,
22
+ cam_intrinsics,
23
+ init_pose,
24
+ device,
25
+ **kwargs
26
+ ):
27
+
28
+ super().__init__()
29
+
30
+ self.res = res
31
+ self.cam_intrinsics = cam_intrinsics
32
+ self.init_pose = torch.from_numpy(init_pose).float().to(device)
33
+
34
+ def forward(self, output, params, input_keypoints, bbox,
35
+ reprojection_weight=100., regularize_weight=60.0,
36
+ consistency_weight=10.0, sprior_weight=0.04,
37
+ smooth_weight=20.0, sigma=100):
38
+
39
+ pose, shape, cam = params
40
+ scale = bbox[..., 2:].unsqueeze(-1) * 200.
41
+
42
+ # Loss 1. Data term
43
+ pred_keypoints = output.full_joints2d[..., :17, :]
44
+ joints_conf = input_keypoints[..., -1:]
45
+ reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma)
46
+ reprojection_error = ((reprojection_error * joints_conf) / scale).mean()
47
+
48
+ # Loss 2. Regularization term
49
+ regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean()
50
+
51
+ # Loss 3. Shape prior and consistency error
52
+ consistency_error = shape.std(dim=1).mean()
53
+ sprior_error = torch.linalg.norm(shape, dim=-1).mean()
54
+ shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error
55
+
56
+ # Loss 4. Smooth loss
57
+ pose_diff = compute_jitter(pose).mean()
58
+ cam_diff = compute_jitter(cam).mean()
59
+ smooth_error = pose_diff + cam_diff
60
+
61
+ # Sum up losses
62
+ loss = {
63
+ 'reprojection': reprojection_weight * reprojection_error,
64
+ 'regularize': regularize_weight * regularize_error,
65
+ 'shape': shape_error,
66
+ 'smooth': smooth_weight * smooth_error
67
+ }
68
+
69
+ return loss
70
+
71
+ def create_closure(self,
72
+ optimizer,
73
+ smpl,
74
+ params,
75
+ bbox,
76
+ input_keypoints):
77
+
78
+ def closure():
79
+ optimizer.zero_grad()
80
+ output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res)
81
+
82
+ loss_dict = self.forward(output, params, input_keypoints, bbox)
83
+ loss = sum(loss_dict.values())
84
+ loss.backward()
85
+ return loss
86
+
87
+ return closure
lib/models/smplify/smplify.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from lib.models import build_body_model
6
+ from .losses import SMPLifyLoss
7
+
8
+ class TemporalSMPLify():
9
+
10
+ def __init__(self,
11
+ smpl=None,
12
+ lr=1e-2,
13
+ num_iters=5,
14
+ num_steps=10,
15
+ img_w=None,
16
+ img_h=None,
17
+ device=None
18
+ ):
19
+
20
+ self.smpl = smpl
21
+ self.lr = lr
22
+ self.num_iters = num_iters
23
+ self.num_steps = num_steps
24
+ self.img_w = img_w
25
+ self.img_h = img_h
26
+ self.device = device
27
+
28
+ def fit(self, init_pred, keypoints, bbox, **kwargs):
29
+
30
+ def to_params(param):
31
+ return torch.from_numpy(param).float().to(self.device).requires_grad_(True)
32
+
33
+ pose = init_pred['pose'].detach().cpu().numpy()
34
+ betas = init_pred['betas'].detach().cpu().numpy()
35
+ cam = init_pred['cam'].detach().cpu().numpy()
36
+ keypoints = torch.from_numpy(keypoints).float().unsqueeze(0).to(self.device)
37
+
38
+ BN = pose.shape[1]
39
+ lr = self.lr
40
+
41
+ # Stage 1. Optimize translation
42
+ params = [to_params(pose), to_params(betas), to_params(cam)]
43
+ optim_params = [params[2]]
44
+
45
+ optimizer = torch.optim.LBFGS(
46
+ optim_params,
47
+ lr=lr,
48
+ max_iter=self.num_iters,
49
+ line_search_fn='strong_wolfe')
50
+
51
+ loss_fn = SMPLifyLoss(init_pose=pose, device=self.device, **kwargs)
52
+
53
+ closure = loss_fn.create_closure(optimizer,
54
+ self.smpl,
55
+ params,
56
+ bbox,
57
+ keypoints)
58
+
59
+ for j in (j_bar := tqdm(range(self.num_steps), leave=False)):
60
+ optimizer.zero_grad()
61
+ loss = optimizer.step(closure)
62
+ msg = f'Loss: {loss.item():.1f}'
63
+ j_bar.set_postfix_str(msg)
64
+
65
+
66
+ # Stage 2. Optimize all params
67
+ optimizer = torch.optim.LBFGS(
68
+ params,
69
+ lr=lr * BN,
70
+ max_iter=self.num_iters,
71
+ line_search_fn='strong_wolfe')
72
+
73
+ for j in (j_bar := tqdm(range(self.num_steps), leave=False)):
74
+ optimizer.zero_grad()
75
+ loss = optimizer.step(closure)
76
+ msg = f'Loss: {loss.item():.1f}'
77
+ j_bar.set_postfix_str(msg)
78
+
79
+ init_pred['pose'] = params[0].detach()
80
+ init_pred['betas'] = params[1].detach()
81
+ init_pred['cam'] = params[2].detach()
82
+
83
+ return init_pred
lib/models/wham.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ from torch import nn
7
+ import numpy as np
8
+
9
+ from configs import constants as _C
10
+ from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator,
11
+ rollout_global_motion, reset_root_velocity, compute_camera_motion)
12
+ from lib.utils.transforms import axis_angle_to_matrix
13
+
14
+
15
+ class Network(nn.Module):
16
+ def __init__(self,
17
+ smpl,
18
+ pose_dr=0.1,
19
+ d_embed=512,
20
+ n_layers=3,
21
+ d_feat=2048,
22
+ rnn_type='LSTM',
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+
27
+ n_joints = _C.KEYPOINTS.NUM_JOINTS
28
+ self.smpl = smpl
29
+ in_dim = n_joints * 2 + 3
30
+ d_context = d_embed + n_joints * 3
31
+
32
+ self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2))
33
+
34
+ # Module 1. Motion Encoder
35
+ self.motion_encoder = MotionEncoder(in_dim=in_dim,
36
+ d_embed=d_embed,
37
+ pose_dr=pose_dr,
38
+ rnn_type=rnn_type,
39
+ n_layers=n_layers,
40
+ n_joints=n_joints)
41
+
42
+ self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context,
43
+ rnn_type=rnn_type,
44
+ n_layers=n_layers)
45
+
46
+ # Module 3. Feature Integrator
47
+ self.integrator = Integrator(in_channel=d_feat + d_context,
48
+ out_channel=d_context)
49
+
50
+ # Module 4. Motion Decoder
51
+ self.motion_decoder = MotionDecoder(d_embed=d_context,
52
+ rnn_type=rnn_type,
53
+ n_layers=n_layers)
54
+
55
+ # Module 5. Trajectory Refiner
56
+ self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context,
57
+ d_hidden=d_embed,
58
+ rnn_type=rnn_type,
59
+ n_layers=2)
60
+
61
+ def compute_global_feet(self, root_world, trans):
62
+ # # Compute world-coordinate motion
63
+ cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam)
64
+ feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3)
65
+ feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT
66
+
67
+ return feet_world, cam_R
68
+
69
+ def forward_smpl(self, **kwargs):
70
+ self.output = self.smpl(self.pred_pose,
71
+ self.pred_shape,
72
+ cam=self.pred_cam,
73
+ return_full_pose=not self.training,
74
+ **kwargs,
75
+ )
76
+
77
+ from loguru import logger
78
+ logger.info(f"Output Joints: {self.output.joints}")
79
+ logger.info(f"Output Vertices: {self.output.vertices}")
80
+
81
+ # Save joints and vertices as .npy arrays
82
+
83
+ np.save('joints.npy', self.output.joints.cpu().numpy())
84
+ np.save('vertices.npy', self.output.vertices.cpu().numpy())
85
+
86
+ # Feet location in global coordinate
87
+ root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel)
88
+ feet_world, cam_R = self.compute_global_feet(root_world, trans)
89
+
90
+ # Return output
91
+ output = {'feet': feet_world,
92
+ 'contact': self.pred_contact,
93
+ 'pose': self.pred_pose,
94
+ 'betas': self.pred_shape,
95
+ 'cam': self.pred_cam,
96
+ 'poses_root_cam': self.output.global_orient,
97
+ 'poses_root_r6d': self.pred_root,
98
+ 'vel_root': self.pred_vel,
99
+ 'pose_root': self.pred_root,
100
+ 'verts_cam': self.output.vertices}
101
+
102
+ if self.training:
103
+ output.update({
104
+ 'kp3d': self.output.joints,
105
+ 'kp3d_nn': self.pred_kp3d,
106
+ 'full_kp2d': self.output.full_joints2d,
107
+ 'weak_kp2d': self.output.weak_joints2d,
108
+ 'R': cam_R,
109
+ })
110
+ else:
111
+ output.update({
112
+ 'poses_root_r6d': self.pred_root,
113
+ 'trans_cam': self.output.full_cam,
114
+ 'poses_body': self.output.body_pose})
115
+
116
+ return output
117
+
118
+
119
+ def preprocess(self, x, mask):
120
+ self.b, self.f = x.shape[:2]
121
+
122
+ # Treat masked keypoints
123
+ mask_embedding = mask.unsqueeze(-1) * self.mask_embedding
124
+ _mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1)
125
+ _mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1)
126
+ _mask_embedding = mask_embedding.reshape(self.b, self.f, -1)
127
+ _mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1)
128
+ x[_mask] = 0.0
129
+ x = x + _mask_embedding
130
+ return x
131
+
132
+
133
+ def rollout(self, output, pred_root, pred_vel, return_y_up):
134
+ root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
135
+
136
+ if return_y_up:
137
+ yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
138
+ root_world = yup2ydown.mT @ root_world
139
+ trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
140
+
141
+ output.update({
142
+ 'poses_root_world': root_world,
143
+ 'trans_world': trans_world,
144
+ })
145
+
146
+ return output
147
+
148
+
149
+ def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs):
150
+
151
+ # --------- Refine trajectory --------- #
152
+ update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5)
153
+ output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up)
154
+ # --------- #
155
+
156
+ # Do rollout
157
+ output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up)
158
+
159
+ # --------- Compute refined feet --------- #
160
+ if self.training:
161
+ feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world'])
162
+ output.update({'feet_refined': feet_world})
163
+
164
+ return output
165
+
166
+
167
+ def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None,
168
+ cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs):
169
+
170
+ x = self.preprocess(x, mask)
171
+ init_kp, init_smpl = inits
172
+
173
+ # --------- Inference --------- #
174
+ # Stage 1. Encode motion
175
+ pred_kp3d, motion_context = self.motion_encoder(x, init_kp)
176
+ self.old_motion_context = motion_context.detach().clone()
177
+
178
+ # Stage 2. Decode global trajectory
179
+ pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel)
180
+
181
+ # Stage 3. Integrate features
182
+ if img_features is not None and self.integrator is not None:
183
+ motion_context = self.integrator(motion_context, img_features)
184
+
185
+ # Stage 4. Decode SMPL motion
186
+ pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl)
187
+ # --------- #
188
+
189
+ # --------- Register predictions --------- #
190
+ self.pred_kp3d = pred_kp3d
191
+ self.pred_root = pred_root
192
+ self.pred_vel = pred_vel
193
+ self.pred_pose = pred_pose
194
+ self.pred_shape = pred_shape
195
+ self.pred_cam = pred_cam
196
+ self.pred_contact = pred_contact
197
+ # --------- #
198
+
199
+ # --------- Build SMPL --------- #
200
+ output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res)
201
+ # --------- #
202
+
203
+ # --------- Refine trajectory --------- #
204
+ if refine_traj:
205
+ output = self.refine_trajectory(output, cam_angvel, return_y_up)
206
+ else:
207
+ output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up)
208
+ # --------- #
209
+
210
+ return output
lib/utils/__pycache__/data_utils.cpython-39.pyc ADDED
Binary file (3.57 kB). View file
 
lib/utils/__pycache__/imutils.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
lib/utils/__pycache__/kp_utils.cpython-39.pyc ADDED
Binary file (9.99 kB). View file
 
lib/utils/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (23.2 kB). View file
 
lib/utils/data_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from lib.utils import transforms
9
+
10
+
11
+ def make_collate_fn():
12
+ def collate_fn(items):
13
+ items = list(filter(lambda x: x is not None , items))
14
+ batch = dict()
15
+ try: batch['vid'] = [item['vid'] for item in items]
16
+ except: pass
17
+ try: batch['gender'] = [item['gender'] for item in items]
18
+ except: pass
19
+ for key in items[0].keys():
20
+ try: batch[key] = torch.stack([item[key] for item in items])
21
+ except: pass
22
+ return batch
23
+
24
+ return collate_fn
25
+
26
+
27
+ def prepare_keypoints_data(target):
28
+ """Prepare keypoints data"""
29
+
30
+ # Prepare 2D keypoints
31
+ target['init_kp2d'] = target['kp2d'][:1]
32
+ target['kp2d'] = target['kp2d'][1:]
33
+ if 'kp3d' in target:
34
+ target['kp3d'] = target['kp3d'][1:]
35
+
36
+ return target
37
+
38
+
39
+ def prepare_smpl_data(target):
40
+ if 'pose' in target.keys():
41
+ # Use only the main joints
42
+ pose = target['pose'][:]
43
+ # 6-D Rotation representation
44
+ pose6d = transforms.matrix_to_rotation_6d(pose)
45
+ target['pose'] = pose6d[1:]
46
+
47
+ if 'betas' in target.keys():
48
+ target['betas'] = target['betas'][1:]
49
+
50
+ # Translation and shape parameters
51
+ if 'transl' in target.keys():
52
+ target['cam'] = target['transl'][1:]
53
+
54
+ # Initial pose and translation
55
+ target['init_pose'] = transforms.matrix_to_rotation_6d(target['init_pose'])
56
+
57
+ return target
58
+
59
+
60
+ def append_target(target, label, key_list, idx1, idx2=None, pad=True):
61
+ for key in key_list:
62
+ if idx2 is None: data = label[key][idx1]
63
+ else: data = label[key][idx1:idx2+1]
64
+ if not pad: data = data[2:]
65
+ target[key] = data
66
+
67
+ return target
68
+
69
+
70
+ def map_dmpl_to_smpl(pose):
71
+ """ Map AMASS DMPL pose representation to SMPL pose representation
72
+
73
+ Args:
74
+ pose - tensor / array with shape of (n_frames, 156)
75
+
76
+ Return:
77
+ pose - tensor / array with shape of (n_frames, 24, 3)
78
+ """
79
+
80
+ pose = pose.reshape(pose.shape[0], -1, 3)
81
+ pose[:, 23] = pose[:, 37] # right hand
82
+ if isinstance(pose, np.ndarray): pose = pose[:, :24].copy()
83
+ else: pose = pose[:, :24].clone()
84
+ return pose
85
+
86
+
87
+ def transform_global_coordinate(pose, T, transl=None):
88
+ """ Transform global coordinate of dataset with respect to the given matrix.
89
+ Various datasets have different global coordinate system,
90
+ thus we united all datasets to the cronical coordinate system.
91
+
92
+ Args:
93
+ pose - SMPL pose; tensor / array
94
+ T - Transformation matrix
95
+ transl - SMPL translation
96
+ """
97
+
98
+ return_to_numpy = False
99
+ if isinstance(pose, np.ndarray):
100
+ return_to_numpy = True
101
+ pose = torch.from_numpy(pose).float()
102
+ if transl is not None: transl = torch.from_numpy(transl).float()
103
+
104
+ pose = transforms.axis_angle_to_matrix(pose)
105
+ pose[:, 0] = T @ pose[:, 0]
106
+ pose = transforms.matrix_to_axis_angle(pose)
107
+ if transl is not None:
108
+ transl = (T @ transl.T).squeeze().T
109
+
110
+ if return_to_numpy:
111
+ pose = pose.detach().numpy()
112
+ if transl is not None: transl = transl.detach().numpy()
113
+ return pose, transl
lib/utils/imutils.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from . import transforms
6
+
7
+ def do_augmentation(scale_factor=0.2, trans_factor=0.1):
8
+ scale = random.uniform(1.2 - scale_factor, 1.2 + scale_factor)
9
+ trans_x = random.uniform(-trans_factor, trans_factor)
10
+ trans_y = random.uniform(-trans_factor, trans_factor)
11
+
12
+ return scale, trans_x, trans_y
13
+
14
+ def get_transform(center, scale, res, rot=0):
15
+ """Generate transformation matrix."""
16
+ # res: (height, width), (rows, cols)
17
+ crop_aspect_ratio = res[0] / float(res[1])
18
+ h = 200 * scale
19
+ w = h / crop_aspect_ratio
20
+ t = np.zeros((3, 3))
21
+ t[0, 0] = float(res[1]) / w
22
+ t[1, 1] = float(res[0]) / h
23
+ t[0, 2] = res[1] * (-float(center[0]) / w + .5)
24
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
25
+ t[2, 2] = 1
26
+ if not rot == 0:
27
+ rot = -rot # To match direction of rotation from cropping
28
+ rot_mat = np.zeros((3, 3))
29
+ rot_rad = rot * np.pi / 180
30
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
31
+ rot_mat[0, :2] = [cs, -sn]
32
+ rot_mat[1, :2] = [sn, cs]
33
+ rot_mat[2, 2] = 1
34
+ # Need to rotate around center
35
+ t_mat = np.eye(3)
36
+ t_mat[0, 2] = -res[1] / 2
37
+ t_mat[1, 2] = -res[0] / 2
38
+ t_inv = t_mat.copy()
39
+ t_inv[:2, 2] *= -1
40
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
41
+ return t
42
+
43
+
44
+ def transform(pt, center, scale, res, invert=0, rot=0):
45
+ """Transform pixel location to different reference."""
46
+ t = get_transform(center, scale, res, rot=rot)
47
+ if invert:
48
+ t = np.linalg.inv(t)
49
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
50
+ new_pt = np.dot(t, new_pt)
51
+ return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
52
+
53
+
54
+ def crop_cliff(img, center, scale, res):
55
+ """
56
+ Crop image according to the supplied bounding box.
57
+ res: [rows, cols]
58
+ """
59
+ # Upper left point
60
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
61
+ # Bottom right point
62
+ br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
63
+
64
+ # Padding so that when rotated proper amount of context is included
65
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
66
+
67
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
68
+ if len(img.shape) > 2:
69
+ new_shape += [img.shape[2]]
70
+ new_img = np.zeros(new_shape, dtype=np.float32)
71
+
72
+ # Range to fill new array
73
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
74
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
75
+ # Range to sample from original image
76
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
77
+ old_y = max(0, ul[1]), min(len(img), br[1])
78
+
79
+ try:
80
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
81
+ except Exception as e:
82
+ print(e)
83
+
84
+ new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
85
+
86
+ return new_img, ul, br
87
+
88
+
89
+ def obtain_bbox(center, scale, res, org_res):
90
+ # Upper left point
91
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
92
+ # Bottom right point
93
+ br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
94
+
95
+ # Padding so that when rotated proper amount of context is included
96
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
97
+
98
+ # Range to sample from original image
99
+ old_x = max(0, ul[0]), min(org_res[0], br[0])
100
+ old_y = max(0, ul[1]), min(org_res[1], br[1])
101
+
102
+ return old_x, old_y
103
+
104
+
105
+ def cam_crop2full(crop_cam, bbox, full_img_shape, focal_length=None):
106
+ """
107
+ convert the camera parameters from the crop camera to the full camera
108
+ :param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
109
+ :param center: shape=(N, 2) bbox coordinates (c_x, c_y)
110
+ :param scale: shape=(N, 1) square bbox resolution (b / 200)
111
+ :param full_img_shape: shape=(N, 2) original image height and width
112
+ :param focal_length: shape=(N,)
113
+ :return:
114
+ """
115
+
116
+ cx = bbox[..., 0].clone(); cy = bbox[..., 1].clone(); b = bbox[..., 2].clone() * 200
117
+ img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
118
+ w_2, h_2 = img_w / 2., img_h / 2.
119
+ bs = b * crop_cam[:, :, 0] + 1e-9
120
+
121
+ if focal_length is None:
122
+ focal_length = (img_w * img_w + img_h * img_h) ** 0.5
123
+
124
+ tz = 2 * focal_length.unsqueeze(-1) / bs
125
+ tx = (2 * (cx - w_2.unsqueeze(-1)) / bs) + crop_cam[:, :, 1]
126
+ ty = (2 * (cy - h_2.unsqueeze(-1)) / bs) + crop_cam[:, :, 2]
127
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
128
+ return full_cam
129
+
130
+
131
+ def cam_pred2full(crop_cam, center, scale, full_img_shape, focal_length=2000.,):
132
+ """
133
+ Reference CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation
134
+
135
+ convert the camera parameters from the crop camera to the full camera
136
+ :param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
137
+ :param center: shape=(N, 2) bbox coordinates (c_x, c_y)
138
+ :param scale: shape=(N, ) square bbox resolution (b / 200)
139
+ :param full_img_shape: shape=(N, 2) original image height and width
140
+ :param focal_length: shape=(N,)
141
+ :return:
142
+ """
143
+
144
+ # img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
145
+ img_w, img_h = full_img_shape[:, 0], full_img_shape[:, 1]
146
+ cx, cy, b = center[:, 0], center[:, 1], scale * 200
147
+ w_2, h_2 = img_w / 2., img_h / 2.
148
+ bs = b * crop_cam[:, 0] + 1e-9
149
+ tz = 2 * focal_length / bs
150
+ tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1]
151
+ ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2]
152
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
153
+ return full_cam
154
+
155
+
156
+ def cam_full2pred(full_cam, center, scale, full_img_shape, focal_length=2000.):
157
+ # img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
158
+ img_w, img_h = full_img_shape[:, 0], full_img_shape[:, 1]
159
+ cx, cy, b = center[:, 0], center[:, 1], scale * 200
160
+ w_2, h_2 = img_w / 2., img_h / 2.
161
+
162
+ bs = (2 * focal_length / full_cam[:, 2])
163
+ _s = bs / b
164
+ _tx = full_cam[:, 0] - (2 * (cx - w_2) / bs)
165
+ _ty = full_cam[:, 1] - (2 * (cy - h_2) / bs)
166
+ crop_cam = torch.stack([_s, _tx, _ty], dim=-1)
167
+ return crop_cam
168
+
169
+
170
+ def obtain_camera_intrinsics(image_shape, focal_length):
171
+ res_w = image_shape[..., 0].clone()
172
+ res_h = image_shape[..., 1].clone()
173
+ K = torch.eye(3).unsqueeze(0).expand(focal_length.shape[0], -1, -1).to(focal_length.device)
174
+ K[..., 0, 0] = focal_length.clone()
175
+ K[..., 1, 1] = focal_length.clone()
176
+ K[..., 0, 2] = res_w / 2
177
+ K[..., 1, 2] = res_h / 2
178
+
179
+ return K.unsqueeze(1)
180
+
181
+
182
+ def trans_point2d(pt_2d, trans):
183
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
184
+ dst_pt = np.dot(trans, src_pt)
185
+ return dst_pt[0:2]
186
+
187
+ def rotate_2d(pt_2d, rot_rad):
188
+ x = pt_2d[0]
189
+ y = pt_2d[1]
190
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
191
+ xx = x * cs - y * sn
192
+ yy = x * sn + y * cs
193
+ return np.array([xx, yy], dtype=np.float32)
194
+
195
+ def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
196
+ # augment size with scale
197
+ src_w = src_width * scale
198
+ src_h = src_height * scale
199
+ src_center = np.zeros(2)
200
+ src_center[0] = c_x
201
+ src_center[1] = c_y # np.array([c_x, c_y], dtype=np.float32)
202
+ # augment rotation
203
+ rot_rad = np.pi * rot / 180
204
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
205
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
206
+
207
+ dst_w = dst_width
208
+ dst_h = dst_height
209
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
210
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
211
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
212
+
213
+ src = np.zeros((3, 2), dtype=np.float32)
214
+ src[0, :] = src_center
215
+ src[1, :] = src_center + src_downdir
216
+ src[2, :] = src_center + src_rightdir
217
+
218
+ dst = np.zeros((3, 2), dtype=np.float32)
219
+ dst[0, :] = dst_center
220
+ dst[1, :] = dst_center + dst_downdir
221
+ dst[2, :] = dst_center + dst_rightdir
222
+
223
+ if inv:
224
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
225
+ else:
226
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
227
+
228
+ return trans
229
+
230
+ def transform_keypoints(kp_2d, bbox, patch_width, patch_height):
231
+
232
+ center_x, center_y, scale = bbox[:3]
233
+ width = height = scale * 200
234
+ # scale, rot = 1.2, 0
235
+ scale, rot = 1.0, 0
236
+
237
+ # generate transformation
238
+ trans = gen_trans_from_patch_cv(
239
+ center_x,
240
+ center_y,
241
+ width,
242
+ height,
243
+ patch_width,
244
+ patch_height,
245
+ scale,
246
+ rot,
247
+ inv=False,
248
+ )
249
+
250
+ for n_jt in range(kp_2d.shape[0]):
251
+ kp_2d[n_jt] = trans_point2d(kp_2d[n_jt], trans)
252
+
253
+ return kp_2d, trans
254
+
255
+
256
+ def transform(pt, center, scale, res, invert=0, rot=0):
257
+ """Transform pixel location to different reference."""
258
+ t = get_transform(center, scale, res, rot=rot)
259
+ if invert:
260
+ t = np.linalg.inv(t)
261
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
262
+ new_pt = np.dot(t, new_pt)
263
+ return new_pt[:2].astype(int) + 1
264
+
265
+
266
+ def compute_cam_intrinsics(res):
267
+ img_w, img_h = res
268
+ focal_length = (img_w * img_w + img_h * img_h) ** 0.5
269
+ cam_intrinsics = torch.eye(3).repeat(1, 1, 1).float()
270
+ cam_intrinsics[:, 0, 0] = focal_length
271
+ cam_intrinsics[:, 1, 1] = focal_length
272
+ cam_intrinsics[:, 0, 2] = img_w/2.
273
+ cam_intrinsics[:, 1, 2] = img_h/2.
274
+ return cam_intrinsics
275
+
276
+
277
+ def flip_kp(kp, img_w=None):
278
+ """Flip keypoints."""
279
+
280
+ flipped_parts = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
281
+ kp = kp[..., flipped_parts, :]
282
+
283
+ if img_w is not None:
284
+ # Assume 2D keypoints
285
+ kp[...,0] = img_w - kp[...,0]
286
+ return kp
287
+
288
+
289
+ def flip_bbox(bbox, img_w, img_h):
290
+ center = bbox[..., :2]
291
+ scale = bbox[..., -1:]
292
+
293
+ WH = np.ones_like(center)
294
+ WH[..., 0] *= img_w
295
+ WH[..., 1] *= img_h
296
+
297
+ center = center - WH/2
298
+ center[...,0] = - center[...,0]
299
+ center = center + WH/2
300
+
301
+ flipped_bbox = np.concatenate((center, scale), axis=-1)
302
+ return flipped_bbox
303
+
304
+
305
+ def flip_pose(rotation, representation='rotation_6d'):
306
+ """Flip pose.
307
+ The flipping is based on SMPL parameters.
308
+ """
309
+
310
+ BN = rotation.shape[0]
311
+
312
+ if representation == 'axis_angle':
313
+ pose = rotation.reshape(BN, -1).transpose(0, 1)
314
+ elif representation == 'matrix':
315
+ pose = transforms.matrix_to_axis_angle(rotation).reshape(BN, -1).transpose(0, 1)
316
+ elif representation == 'rotation_6d':
317
+ pose = transforms.matrix_to_axis_angle(
318
+ transforms.rotation_6d_to_matrix(rotation)
319
+ ).reshape(BN, -1).transpose(0, 1)
320
+ else:
321
+ raise ValueError(f"Unknown representation: {representation}")
322
+
323
+ SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
324
+ SMPL_POSE_FLIP_PERM = []
325
+ for i in SMPL_JOINTS_FLIP_PERM:
326
+ SMPL_POSE_FLIP_PERM.append(3*i)
327
+ SMPL_POSE_FLIP_PERM.append(3*i+1)
328
+ SMPL_POSE_FLIP_PERM.append(3*i+2)
329
+
330
+ pose = pose[SMPL_POSE_FLIP_PERM]
331
+
332
+ # we also negate the second and the third dimension of the axis-angle
333
+ pose[1::3] = -pose[1::3]
334
+ pose[2::3] = -pose[2::3]
335
+ pose = pose.transpose(0, 1).reshape(BN, -1, 3)
336
+
337
+ if representation == 'aa':
338
+ return pose
339
+ elif representation == 'rotmat':
340
+ return transforms.axis_angle_to_matrix(pose)
341
+ else:
342
+ return transforms.matrix_to_rotation_6d(
343
+ transforms.axis_angle_to_matrix(pose)
344
+ )
345
+
346
+ def avg_preds(rotation, shape, flipped_rotation, flipped_shape, representation='rotation_6d'):
347
+ # Rotation
348
+ flipped_rotation = flip_pose(flipped_rotation, representation=representation)
349
+
350
+ if representation != 'matrix':
351
+ flipped_rotation = eval(f'transforms.{representation}_to_matrix')(flipped_rotation)
352
+ rotation = eval(f'transforms.{representation}_to_matrix')(rotation)
353
+
354
+ avg_rotation = torch.stack([rotation, flipped_rotation])
355
+ avg_rotation = transforms.avg_rot(avg_rotation)
356
+
357
+ if representation != 'matrix':
358
+ avg_rotation = eval(f'transforms.matrix_to_{representation}')(avg_rotation)
359
+
360
+ # Shape
361
+ avg_shape = (shape + flipped_shape) / 2.0
362
+
363
+ return avg_rotation, avg_shape
lib/utils/kp_utils.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from configs import constants as _C
9
+
10
+ def root_centering(X, joint_type='coco'):
11
+ """Center the root joint to the pelvis."""
12
+ if joint_type != 'common' and X.shape[-2] == 14: return X
13
+
14
+ conf = None
15
+ if X.shape[-1] == 4:
16
+ conf = X[..., -1:]
17
+ X = X[..., :-1]
18
+
19
+ if X.shape[-2] == 31:
20
+ X[..., :17, :] = X[..., :17, :] - X[..., [12, 11], :].mean(-2, keepdims=True)
21
+ X[..., 17:, :] = X[..., 17:, :] - X[..., [19, 20], :].mean(-2, keepdims=True)
22
+
23
+ elif joint_type == 'coco':
24
+ X = X - X[..., [12, 11], :].mean(-2, keepdims=True)
25
+
26
+ elif joint_type == 'common':
27
+ X = X - X[..., [2, 3], :].mean(-2, keepdims=True)
28
+
29
+ if conf is not None:
30
+ X = torch.cat((X, conf), dim=-1)
31
+
32
+ return X
33
+
34
+
35
+ def convert_kps(joints2d, src, dst):
36
+ src_names = eval(f'get_{src}_joint_names')()
37
+ dst_names = eval(f'get_{dst}_joint_names')()
38
+
39
+ if isinstance(joints2d, np.ndarray):
40
+ out_joints2d = np.zeros((*joints2d.shape[:-2], len(dst_names), joints2d.shape[-1]))
41
+ else:
42
+ out_joints2d = torch.zeros((*joints2d.shape[:-2], len(dst_names), joints2d.shape[-1]), device=joints2d.device)
43
+
44
+ for idx, jn in enumerate(dst_names):
45
+ if jn in src_names:
46
+ out_joints2d[..., idx, :] = joints2d[..., src_names.index(jn), :]
47
+
48
+ return out_joints2d
49
+
50
+ def get_perm_idxs(src, dst):
51
+ src_names = eval(f'get_{src}_joint_names')()
52
+ dst_names = eval(f'get_{dst}_joint_names')()
53
+ idxs = [src_names.index(h) for h in dst_names if h in src_names]
54
+ return idxs
55
+
56
+ def get_mpii3d_test_joint_names():
57
+ return [
58
+ 'headtop', # 'head_top',
59
+ 'neck',
60
+ 'rshoulder',# 'right_shoulder',
61
+ 'relbow',# 'right_elbow',
62
+ 'rwrist',# 'right_wrist',
63
+ 'lshoulder',# 'left_shoulder',
64
+ 'lelbow', # 'left_elbow',
65
+ 'lwrist', # 'left_wrist',
66
+ 'rhip', # 'right_hip',
67
+ 'rknee', # 'right_knee',
68
+ 'rankle',# 'right_ankle',
69
+ 'lhip',# 'left_hip',
70
+ 'lknee',# 'left_knee',
71
+ 'lankle',# 'left_ankle'
72
+ 'hip',# 'pelvis',
73
+ 'Spine (H36M)',# 'spine',
74
+ 'Head (H36M)',# 'head'
75
+ ]
76
+
77
+ def get_mpii3d_joint_names():
78
+ return [
79
+ 'spine3', # 0,
80
+ 'spine4', # 1,
81
+ 'spine2', # 2,
82
+ 'Spine (H36M)', #'spine', # 3,
83
+ 'hip', # 'pelvis', # 4,
84
+ 'neck', # 5,
85
+ 'Head (H36M)', # 'head', # 6,
86
+ "headtop", # 'head_top', # 7,
87
+ 'left_clavicle', # 8,
88
+ "lshoulder", # 'left_shoulder', # 9,
89
+ "lelbow", # 'left_elbow',# 10,
90
+ "lwrist", # 'left_wrist',# 11,
91
+ 'left_hand',# 12,
92
+ 'right_clavicle',# 13,
93
+ 'rshoulder',# 'right_shoulder',# 14,
94
+ 'relbow',# 'right_elbow',# 15,
95
+ 'rwrist',# 'right_wrist',# 16,
96
+ 'right_hand',# 17,
97
+ 'lhip', # left_hip',# 18,
98
+ 'lknee', # 'left_knee',# 19,
99
+ 'lankle', #left ankle # 20
100
+ 'left_foot', # 21
101
+ 'left_toe', # 22
102
+ "rhip", # 'right_hip',# 23
103
+ "rknee", # 'right_knee',# 24
104
+ "rankle", #'right_ankle', # 25
105
+ 'right_foot',# 26
106
+ 'right_toe' # 27
107
+ ]
108
+
109
+ def get_insta_joint_names():
110
+ return [
111
+ 'OP RHeel',
112
+ 'OP RKnee',
113
+ 'OP RHip',
114
+ 'OP LHip',
115
+ 'OP LKnee',
116
+ 'OP LHeel',
117
+ 'OP RWrist',
118
+ 'OP RElbow',
119
+ 'OP RShoulder',
120
+ 'OP LShoulder',
121
+ 'OP LElbow',
122
+ 'OP LWrist',
123
+ 'OP Neck',
124
+ 'headtop',
125
+ 'OP Nose',
126
+ 'OP LEye',
127
+ 'OP REye',
128
+ 'OP LEar',
129
+ 'OP REar',
130
+ 'OP LBigToe',
131
+ 'OP RBigToe',
132
+ 'OP LSmallToe',
133
+ 'OP RSmallToe',
134
+ 'OP LAnkle',
135
+ 'OP RAnkle',
136
+ ]
137
+
138
+ def get_insta_skeleton():
139
+ return np.array(
140
+ [
141
+ [0 , 1],
142
+ [1 , 2],
143
+ [2 , 3],
144
+ [3 , 4],
145
+ [4 , 5],
146
+ [6 , 7],
147
+ [7 , 8],
148
+ [8 , 9],
149
+ [9 ,10],
150
+ [2 , 8],
151
+ [3 , 9],
152
+ [10,11],
153
+ [8 ,12],
154
+ [9 ,12],
155
+ [12,13],
156
+ [12,14],
157
+ [14,15],
158
+ [14,16],
159
+ [15,17],
160
+ [16,18],
161
+ [0 ,20],
162
+ [20,22],
163
+ [5 ,19],
164
+ [19,21],
165
+ [5 ,23],
166
+ [0 ,24],
167
+ ])
168
+
169
+ def get_staf_skeleton():
170
+ return np.array(
171
+ [
172
+ [0, 1],
173
+ [1, 2],
174
+ [2, 3],
175
+ [3, 4],
176
+ [1, 5],
177
+ [5, 6],
178
+ [6, 7],
179
+ [1, 8],
180
+ [8, 9],
181
+ [9, 10],
182
+ [10, 11],
183
+ [8, 12],
184
+ [12, 13],
185
+ [13, 14],
186
+ [0, 15],
187
+ [0, 16],
188
+ [15, 17],
189
+ [16, 18],
190
+ [2, 9],
191
+ [5, 12],
192
+ [1, 19],
193
+ [20, 19],
194
+ ]
195
+ )
196
+
197
+ def get_staf_joint_names():
198
+ return [
199
+ 'OP Nose', # 0,
200
+ 'OP Neck', # 1,
201
+ 'OP RShoulder', # 2,
202
+ 'OP RElbow', # 3,
203
+ 'OP RWrist', # 4,
204
+ 'OP LShoulder', # 5,
205
+ 'OP LElbow', # 6,
206
+ 'OP LWrist', # 7,
207
+ 'OP MidHip', # 8,
208
+ 'OP RHip', # 9,
209
+ 'OP RKnee', # 10,
210
+ 'OP RAnkle', # 11,
211
+ 'OP LHip', # 12,
212
+ 'OP LKnee', # 13,
213
+ 'OP LAnkle', # 14,
214
+ 'OP REye', # 15,
215
+ 'OP LEye', # 16,
216
+ 'OP REar', # 17,
217
+ 'OP LEar', # 18,
218
+ 'Neck (LSP)', # 19,
219
+ 'Top of Head (LSP)', # 20,
220
+ ]
221
+
222
+ def get_spin_joint_names():
223
+ return [
224
+ 'OP Nose', # 0
225
+ 'OP Neck', # 1
226
+ 'OP RShoulder', # 2
227
+ 'OP RElbow', # 3
228
+ 'OP RWrist', # 4
229
+ 'OP LShoulder', # 5
230
+ 'OP LElbow', # 6
231
+ 'OP LWrist', # 7
232
+ 'OP MidHip', # 8
233
+ 'OP RHip', # 9
234
+ 'OP RKnee', # 10
235
+ 'OP RAnkle', # 11
236
+ 'OP LHip', # 12
237
+ 'OP LKnee', # 13
238
+ 'OP LAnkle', # 14
239
+ 'OP REye', # 15
240
+ 'OP LEye', # 16
241
+ 'OP REar', # 17
242
+ 'OP LEar', # 18
243
+ 'OP LBigToe', # 19
244
+ 'OP LSmallToe', # 20
245
+ 'OP LHeel', # 21
246
+ 'OP RBigToe', # 22
247
+ 'OP RSmallToe', # 23
248
+ 'OP RHeel', # 24
249
+ 'rankle', # 25
250
+ 'rknee', # 26
251
+ 'rhip', # 27
252
+ 'lhip', # 28
253
+ 'lknee', # 29
254
+ 'lankle', # 30
255
+ 'rwrist', # 31
256
+ 'relbow', # 32
257
+ 'rshoulder', # 33
258
+ 'lshoulder', # 34
259
+ 'lelbow', # 35
260
+ 'lwrist', # 36
261
+ 'neck', # 37
262
+ 'headtop', # 38
263
+ 'hip', # 39 'Pelvis (MPII)', # 39
264
+ 'thorax', # 40 'Thorax (MPII)', # 40
265
+ 'Spine (H36M)', # 41
266
+ 'Jaw (H36M)', # 42
267
+ 'Head (H36M)', # 43
268
+ 'nose', # 44
269
+ 'leye', # 45 'Left Eye', # 45
270
+ 'reye', # 46 'Right Eye', # 46
271
+ 'lear', # 47 'Left Ear', # 47
272
+ 'rear', # 48 'Right Ear', # 48
273
+ ]
274
+
275
+ def get_h36m_joint_names():
276
+ return [
277
+ 'hip', # 0
278
+ 'lhip', # 1
279
+ 'lknee', # 2
280
+ 'lankle', # 3
281
+ 'rhip', # 4
282
+ 'rknee', # 5
283
+ 'rankle', # 6
284
+ 'Spine (H36M)', # 7
285
+ 'neck', # 8
286
+ 'Head (H36M)', # 9
287
+ 'headtop', # 10
288
+ 'lshoulder', # 11
289
+ 'lelbow', # 12
290
+ 'lwrist', # 13
291
+ 'rshoulder', # 14
292
+ 'relbow', # 15
293
+ 'rwrist', # 16
294
+ ]
295
+
296
+ 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist'
297
+
298
+ def get_spin_skeleton():
299
+ return np.array(
300
+ [
301
+ [0 , 1],
302
+ [1 , 2],
303
+ [2 , 3],
304
+ [3 , 4],
305
+ [1 , 5],
306
+ [5 , 6],
307
+ [6 , 7],
308
+ [1 , 8],
309
+ [8 , 9],
310
+ [9 ,10],
311
+ [10,11],
312
+ [8 ,12],
313
+ [12,13],
314
+ [13,14],
315
+ [0 ,15],
316
+ [0 ,16],
317
+ [15,17],
318
+ [16,18],
319
+ [21,19],
320
+ [19,20],
321
+ [14,21],
322
+ [11,24],
323
+ [24,22],
324
+ [22,23],
325
+ [0 ,38],
326
+ ]
327
+ )
328
+
329
+ def get_posetrack_joint_names():
330
+ return [
331
+ "nose",
332
+ "neck",
333
+ "headtop",
334
+ "lear",
335
+ "rear",
336
+ "lshoulder",
337
+ "rshoulder",
338
+ "lelbow",
339
+ "relbow",
340
+ "lwrist",
341
+ "rwrist",
342
+ "lhip",
343
+ "rhip",
344
+ "lknee",
345
+ "rknee",
346
+ "lankle",
347
+ "rankle"
348
+ ]
349
+
350
+ def get_posetrack_original_kp_names():
351
+ return [
352
+ 'nose',
353
+ 'head_bottom',
354
+ 'head_top',
355
+ 'left_ear',
356
+ 'right_ear',
357
+ 'left_shoulder',
358
+ 'right_shoulder',
359
+ 'left_elbow',
360
+ 'right_elbow',
361
+ 'left_wrist',
362
+ 'right_wrist',
363
+ 'left_hip',
364
+ 'right_hip',
365
+ 'left_knee',
366
+ 'right_knee',
367
+ 'left_ankle',
368
+ 'right_ankle'
369
+ ]
370
+
371
+ def get_pennaction_joint_names():
372
+ return [
373
+ "headtop", # 0
374
+ "lshoulder", # 1
375
+ "rshoulder", # 2
376
+ "lelbow", # 3
377
+ "relbow", # 4
378
+ "lwrist", # 5
379
+ "rwrist", # 6
380
+ "lhip" , # 7
381
+ "rhip" , # 8
382
+ "lknee", # 9
383
+ "rknee" , # 10
384
+ "lankle", # 11
385
+ "rankle" # 12
386
+ ]
387
+
388
+ def get_common_joint_names():
389
+ return [
390
+ "rankle", # 0 "lankle", # 0
391
+ "rknee", # 1 "lknee", # 1
392
+ "rhip", # 2 "lhip", # 2
393
+ "lhip", # 3 "rhip", # 3
394
+ "lknee", # 4 "rknee", # 4
395
+ "lankle", # 5 "rankle", # 5
396
+ "rwrist", # 6 "lwrist", # 6
397
+ "relbow", # 7 "lelbow", # 7
398
+ "rshoulder", # 8 "lshoulder", # 8
399
+ "lshoulder", # 9 "rshoulder", # 9
400
+ "lelbow", # 10 "relbow", # 10
401
+ "lwrist", # 11 "rwrist", # 11
402
+ "neck", # 12 "neck", # 12
403
+ "headtop", # 13 "headtop", # 13
404
+ ]
405
+
406
+ def get_coco_common_joint_names():
407
+ return [
408
+ "nose", # 0
409
+ "leye", # 1
410
+ "reye", # 2
411
+ "lear", # 3
412
+ "rear", # 4
413
+ "lshoulder", # 5
414
+ "rshoulder", # 6
415
+ "lelbow", # 7
416
+ "relbow", # 8
417
+ "lwrist", # 9
418
+ "rwrist", # 10
419
+ "lhip", # 11
420
+ "rhip", # 12
421
+ "lknee", # 13
422
+ "rknee", # 14
423
+ "lankle", # 15
424
+ "rankle", # 16
425
+ "neck", # 17 "neck", # 12
426
+ "headtop", # 18 "headtop", # 13
427
+ ]
428
+
429
+ def get_common_skeleton():
430
+ return np.array(
431
+ [
432
+ [ 0, 1 ],
433
+ [ 1, 2 ],
434
+ [ 3, 4 ],
435
+ [ 4, 5 ],
436
+ [ 6, 7 ],
437
+ [ 7, 8 ],
438
+ [ 8, 2 ],
439
+ [ 8, 9 ],
440
+ [ 9, 3 ],
441
+ [ 2, 3 ],
442
+ [ 8, 12],
443
+ [ 9, 10],
444
+ [12, 9 ],
445
+ [10, 11],
446
+ [12, 13],
447
+ ]
448
+ )
449
+
450
+ def get_coco_joint_names():
451
+ return [
452
+ "nose", # 0
453
+ "leye", # 1
454
+ "reye", # 2
455
+ "lear", # 3
456
+ "rear", # 4
457
+ "lshoulder", # 5
458
+ "rshoulder", # 6
459
+ "lelbow", # 7
460
+ "relbow", # 8
461
+ "lwrist", # 9
462
+ "rwrist", # 10
463
+ "lhip", # 11
464
+ "rhip", # 12
465
+ "lknee", # 13
466
+ "rknee", # 14
467
+ "lankle", # 15
468
+ "rankle", # 16
469
+ ]
470
+
471
+ def get_coco_skeleton():
472
+ # 0 - nose,
473
+ # 1 - leye,
474
+ # 2 - reye,
475
+ # 3 - lear,
476
+ # 4 - rear,
477
+ # 5 - lshoulder,
478
+ # 6 - rshoulder,
479
+ # 7 - lelbow,
480
+ # 8 - relbow,
481
+ # 9 - lwrist,
482
+ # 10 - rwrist,
483
+ # 11 - lhip,
484
+ # 12 - rhip,
485
+ # 13 - lknee,
486
+ # 14 - rknee,
487
+ # 15 - lankle,
488
+ # 16 - rankle,
489
+ return np.array(
490
+ [
491
+ [15, 13],
492
+ [13, 11],
493
+ [16, 14],
494
+ [14, 12],
495
+ [11, 12],
496
+ [ 5, 11],
497
+ [ 6, 12],
498
+ [ 5, 6 ],
499
+ [ 5, 7 ],
500
+ [ 6, 8 ],
501
+ [ 7, 9 ],
502
+ [ 8, 10],
503
+ [ 1, 2 ],
504
+ [ 0, 1 ],
505
+ [ 0, 2 ],
506
+ [ 1, 3 ],
507
+ [ 2, 4 ],
508
+ [ 3, 5 ],
509
+ [ 4, 6 ]
510
+ ]
511
+ )
512
+
513
+ def get_mpii_joint_names():
514
+ return [
515
+ "rankle", # 0
516
+ "rknee", # 1
517
+ "rhip", # 2
518
+ "lhip", # 3
519
+ "lknee", # 4
520
+ "lankle", # 5
521
+ "hip", # 6
522
+ "thorax", # 7
523
+ "neck", # 8
524
+ "headtop", # 9
525
+ "rwrist", # 10
526
+ "relbow", # 11
527
+ "rshoulder", # 12
528
+ "lshoulder", # 13
529
+ "lelbow", # 14
530
+ "lwrist", # 15
531
+ ]
532
+
533
+ def get_mpii_skeleton():
534
+ # 0 - rankle,
535
+ # 1 - rknee,
536
+ # 2 - rhip,
537
+ # 3 - lhip,
538
+ # 4 - lknee,
539
+ # 5 - lankle,
540
+ # 6 - hip,
541
+ # 7 - thorax,
542
+ # 8 - neck,
543
+ # 9 - headtop,
544
+ # 10 - rwrist,
545
+ # 11 - relbow,
546
+ # 12 - rshoulder,
547
+ # 13 - lshoulder,
548
+ # 14 - lelbow,
549
+ # 15 - lwrist,
550
+ return np.array(
551
+ [
552
+ [ 0, 1 ],
553
+ [ 1, 2 ],
554
+ [ 2, 6 ],
555
+ [ 6, 3 ],
556
+ [ 3, 4 ],
557
+ [ 4, 5 ],
558
+ [ 6, 7 ],
559
+ [ 7, 8 ],
560
+ [ 8, 9 ],
561
+ [ 7, 12],
562
+ [12, 11],
563
+ [11, 10],
564
+ [ 7, 13],
565
+ [13, 14],
566
+ [14, 15]
567
+ ]
568
+ )
569
+
570
+ def get_aich_joint_names():
571
+ return [
572
+ "rshoulder", # 0
573
+ "relbow", # 1
574
+ "rwrist", # 2
575
+ "lshoulder", # 3
576
+ "lelbow", # 4
577
+ "lwrist", # 5
578
+ "rhip", # 6
579
+ "rknee", # 7
580
+ "rankle", # 8
581
+ "lhip", # 9
582
+ "lknee", # 10
583
+ "lankle", # 11
584
+ "headtop", # 12
585
+ "neck", # 13
586
+ ]
587
+
588
+ def get_aich_skeleton():
589
+ # 0 - rshoulder,
590
+ # 1 - relbow,
591
+ # 2 - rwrist,
592
+ # 3 - lshoulder,
593
+ # 4 - lelbow,
594
+ # 5 - lwrist,
595
+ # 6 - rhip,
596
+ # 7 - rknee,
597
+ # 8 - rankle,
598
+ # 9 - lhip,
599
+ # 10 - lknee,
600
+ # 11 - lankle,
601
+ # 12 - headtop,
602
+ # 13 - neck,
603
+ return np.array(
604
+ [
605
+ [ 0, 1 ],
606
+ [ 1, 2 ],
607
+ [ 3, 4 ],
608
+ [ 4, 5 ],
609
+ [ 6, 7 ],
610
+ [ 7, 8 ],
611
+ [ 9, 10],
612
+ [10, 11],
613
+ [12, 13],
614
+ [13, 0 ],
615
+ [13, 3 ],
616
+ [ 0, 6 ],
617
+ [ 3, 9 ]
618
+ ]
619
+ )
620
+
621
+ def get_3dpw_joint_names():
622
+ return [
623
+ "nose", # 0
624
+ "thorax", # 1
625
+ "rshoulder", # 2
626
+ "relbow", # 3
627
+ "rwrist", # 4
628
+ "lshoulder", # 5
629
+ "lelbow", # 6
630
+ "lwrist", # 7
631
+ "rhip", # 8
632
+ "rknee", # 9
633
+ "rankle", # 10
634
+ "lhip", # 11
635
+ "lknee", # 12
636
+ "lankle", # 13
637
+ ]
638
+
639
+ def get_3dpw_skeleton():
640
+ return np.array(
641
+ [
642
+ [ 0, 1 ],
643
+ [ 1, 2 ],
644
+ [ 2, 3 ],
645
+ [ 3, 4 ],
646
+ [ 1, 5 ],
647
+ [ 5, 6 ],
648
+ [ 6, 7 ],
649
+ [ 2, 8 ],
650
+ [ 5, 11],
651
+ [ 8, 11],
652
+ [ 8, 9 ],
653
+ [ 9, 10],
654
+ [11, 12],
655
+ [12, 13]
656
+ ]
657
+ )
658
+
659
+ def get_smplcoco_joint_names():
660
+ return [
661
+ "rankle", # 0
662
+ "rknee", # 1
663
+ "rhip", # 2
664
+ "lhip", # 3
665
+ "lknee", # 4
666
+ "lankle", # 5
667
+ "rwrist", # 6
668
+ "relbow", # 7
669
+ "rshoulder", # 8
670
+ "lshoulder", # 9
671
+ "lelbow", # 10
672
+ "lwrist", # 11
673
+ "neck", # 12
674
+ "headtop", # 13
675
+ "nose", # 14
676
+ "leye", # 15
677
+ "reye", # 16
678
+ "lear", # 17
679
+ "rear", # 18
680
+ ]
681
+
682
+ def get_smplcoco_skeleton():
683
+ return np.array(
684
+ [
685
+ [ 0, 1 ],
686
+ [ 1, 2 ],
687
+ [ 3, 4 ],
688
+ [ 4, 5 ],
689
+ [ 6, 7 ],
690
+ [ 7, 8 ],
691
+ [ 8, 12],
692
+ [12, 9 ],
693
+ [ 9, 10],
694
+ [10, 11],
695
+ [12, 13],
696
+ [14, 15],
697
+ [15, 17],
698
+ [16, 18],
699
+ [14, 16],
700
+ [ 8, 2 ],
701
+ [ 9, 3 ],
702
+ [ 2, 3 ],
703
+ ]
704
+ )
705
+
706
+ def get_smpl_joint_names():
707
+ return [
708
+ 'hips', # 0
709
+ 'leftUpLeg', # 1
710
+ 'rightUpLeg', # 2
711
+ 'spine', # 3
712
+ 'leftLeg', # 4
713
+ 'rightLeg', # 5
714
+ 'spine1', # 6
715
+ 'leftFoot', # 7
716
+ 'rightFoot', # 8
717
+ 'spine2', # 9
718
+ 'leftToeBase', # 10
719
+ 'rightToeBase', # 11
720
+ 'neck', # 12
721
+ 'leftShoulder', # 13
722
+ 'rightShoulder', # 14
723
+ 'head', # 15
724
+ 'leftArm', # 16
725
+ 'rightArm', # 17
726
+ 'leftForeArm', # 18
727
+ 'rightForeArm', # 19
728
+ 'leftHand', # 20
729
+ 'rightHand', # 21
730
+ 'leftHandIndex1', # 22
731
+ 'rightHandIndex1', # 23
732
+ ]
733
+
734
+ def get_smpl_skeleton():
735
+ return np.array(
736
+ [
737
+ [ 0, 1 ],
738
+ [ 0, 2 ],
739
+ [ 0, 3 ],
740
+ [ 1, 4 ],
741
+ [ 2, 5 ],
742
+ [ 3, 6 ],
743
+ [ 4, 7 ],
744
+ [ 5, 8 ],
745
+ [ 6, 9 ],
746
+ [ 7, 10],
747
+ [ 8, 11],
748
+ [ 9, 12],
749
+ [ 9, 13],
750
+ [ 9, 14],
751
+ [12, 15],
752
+ [13, 16],
753
+ [14, 17],
754
+ [16, 18],
755
+ [17, 19],
756
+ [18, 20],
757
+ [19, 21],
758
+ [20, 22],
759
+ [21, 23],
760
+ ]
761
+ )
lib/utils/transforms.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This transforms function is mainly borrowed from PyTorch3D"""
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ Device = Union[str, torch.device]
15
+
16
+ """
17
+ The transformation matrices returned from the functions in this file assume
18
+ the points on which the transformation will be applied are column vectors.
19
+ i.e. the R matrix is structured as
20
+
21
+ R = [
22
+ [Rxx, Rxy, Rxz],
23
+ [Ryx, Ryy, Ryz],
24
+ [Rzx, Rzy, Rzz],
25
+ ] # (3, 3)
26
+
27
+ This matrix can be applied to column vectors by post multiplication
28
+ by the points e.g.
29
+
30
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
31
+ transformed_points = R * points
32
+
33
+ To apply the same matrix to points which are row vectors, the R matrix
34
+ can be transposed and pre multiplied by the points:
35
+
36
+ e.g.
37
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
38
+ transformed_points = points * R.transpose(1, 0)
39
+ """
40
+
41
+
42
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Convert rotations given as quaternions to rotation matrices.
45
+
46
+ Args:
47
+ quaternions: quaternions with real part first,
48
+ as tensor of shape (..., 4).
49
+
50
+ Returns:
51
+ Rotation matrices as tensor of shape (..., 3, 3).
52
+ """
53
+ r, i, j, k = torch.unbind(quaternions, -1)
54
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
55
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
56
+
57
+ o = torch.stack(
58
+ (
59
+ 1 - two_s * (j * j + k * k),
60
+ two_s * (i * j - k * r),
61
+ two_s * (i * k + j * r),
62
+ two_s * (i * j + k * r),
63
+ 1 - two_s * (i * i + k * k),
64
+ two_s * (j * k - i * r),
65
+ two_s * (i * k - j * r),
66
+ two_s * (j * k + i * r),
67
+ 1 - two_s * (i * i + j * j),
68
+ ),
69
+ -1,
70
+ )
71
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
72
+
73
+
74
+
75
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ Return a tensor where each element has the absolute value taken from the,
78
+ corresponding element of a, with sign taken from the corresponding
79
+ element of b. This is like the standard copysign floating-point operation,
80
+ but is not careful about negative 0 and NaN.
81
+
82
+ Args:
83
+ a: source tensor.
84
+ b: tensor whose signs will be used, of the same shape as a.
85
+
86
+ Returns:
87
+ Tensor of the same shape as a with the signs of b.
88
+ """
89
+ signs_differ = (a < 0) != (b < 0)
90
+ return torch.where(signs_differ, -a, a)
91
+
92
+
93
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ Returns torch.sqrt(torch.max(0, x))
96
+ but with a zero subgradient where x is 0.
97
+ """
98
+ ret = torch.zeros_like(x)
99
+ positive_mask = x > 0
100
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
101
+ return ret
102
+
103
+
104
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
105
+ """
106
+ Convert rotations given as rotation matrices to quaternions.
107
+
108
+ Args:
109
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
110
+
111
+ Returns:
112
+ quaternions with real part first, as tensor of shape (..., 4).
113
+ """
114
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
115
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
116
+
117
+ batch_dim = matrix.shape[:-2]
118
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
119
+ matrix.reshape(batch_dim + (9,)), dim=-1
120
+ )
121
+
122
+ q_abs = _sqrt_positive_part(
123
+ torch.stack(
124
+ [
125
+ 1.0 + m00 + m11 + m22,
126
+ 1.0 + m00 - m11 - m22,
127
+ 1.0 - m00 + m11 - m22,
128
+ 1.0 - m00 - m11 + m22,
129
+ ],
130
+ dim=-1,
131
+ )
132
+ )
133
+
134
+ # we produce the desired quaternion multiplied by each of r, i, j, k
135
+ quat_by_rijk = torch.stack(
136
+ [
137
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
138
+ # `int`.
139
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
140
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
141
+ # `int`.
142
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
143
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
144
+ # `int`.
145
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
146
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
147
+ # `int`.
148
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
149
+ ],
150
+ dim=-2,
151
+ )
152
+
153
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
154
+ # the candidate won't be picked.
155
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
156
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
157
+
158
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
159
+ # forall i; we pick the best-conditioned one (with the largest denominator)
160
+
161
+ return quat_candidates[
162
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
163
+ ].reshape(batch_dim + (4,))
164
+
165
+
166
+
167
+ def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Return the rotation matrices for one of the rotations about an axis
170
+ of which Euler angles describe, for each value of the angle given.
171
+
172
+ Args:
173
+ axis: Axis label "X" or "Y or "Z".
174
+ angle: any shape tensor of Euler angles in radians
175
+
176
+ Returns:
177
+ Rotation matrices as tensor of shape (..., 3, 3).
178
+ """
179
+
180
+ cos = torch.cos(angle)
181
+ sin = torch.sin(angle)
182
+ one = torch.ones_like(angle)
183
+ zero = torch.zeros_like(angle)
184
+
185
+ if axis == "X":
186
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
187
+ elif axis == "Y":
188
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
189
+ elif axis == "Z":
190
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
191
+ else:
192
+ raise ValueError("letter must be either X, Y or Z.")
193
+
194
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
195
+
196
+
197
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
198
+ """
199
+ Convert rotations given as Euler angles in radians to rotation matrices.
200
+
201
+ Args:
202
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
203
+ convention: Convention string of three uppercase letters from
204
+ {"X", "Y", and "Z"}.
205
+
206
+ Returns:
207
+ Rotation matrices as tensor of shape (..., 3, 3).
208
+ """
209
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
210
+ raise ValueError("Invalid input euler angles.")
211
+ if len(convention) != 3:
212
+ raise ValueError("Convention must have 3 letters.")
213
+ if convention[1] in (convention[0], convention[2]):
214
+ raise ValueError(f"Invalid convention {convention}.")
215
+ for letter in convention:
216
+ if letter not in ("X", "Y", "Z"):
217
+ raise ValueError(f"Invalid letter {letter} in convention string.")
218
+ matrices = [
219
+ _axis_angle_rotation(c, e)
220
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
221
+ ]
222
+ # return functools.reduce(torch.matmul, matrices)
223
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
224
+
225
+
226
+
227
+ def _angle_from_tan(
228
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
229
+ ) -> torch.Tensor:
230
+ """
231
+ Extract the first or third Euler angle from the two members of
232
+ the matrix which are positive constant times its sine and cosine.
233
+
234
+ Args:
235
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
236
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
237
+ convention.
238
+ data: Rotation matrices as tensor of shape (..., 3, 3).
239
+ horizontal: Whether we are looking for the angle for the third axis,
240
+ which means the relevant entries are in the same row of the
241
+ rotation matrix. If not, they are in the same column.
242
+ tait_bryan: Whether the first and third axes in the convention differ.
243
+
244
+ Returns:
245
+ Euler Angles in radians for each matrix in data as a tensor
246
+ of shape (...).
247
+ """
248
+
249
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
250
+ if horizontal:
251
+ i2, i1 = i1, i2
252
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
253
+ if horizontal == even:
254
+ return torch.atan2(data[..., i1], data[..., i2])
255
+ if tait_bryan:
256
+ return torch.atan2(-data[..., i2], data[..., i1])
257
+ return torch.atan2(data[..., i2], -data[..., i1])
258
+
259
+
260
+ def _index_from_letter(letter: str) -> int:
261
+ if letter == "X":
262
+ return 0
263
+ if letter == "Y":
264
+ return 1
265
+ if letter == "Z":
266
+ return 2
267
+ raise ValueError("letter must be either X, Y or Z.")
268
+
269
+
270
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
271
+ """
272
+ Convert rotations given as rotation matrices to Euler angles in radians.
273
+
274
+ Args:
275
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
276
+ convention: Convention string of three uppercase letters.
277
+
278
+ Returns:
279
+ Euler angles in radians as tensor of shape (..., 3).
280
+ """
281
+ if len(convention) != 3:
282
+ raise ValueError("Convention must have 3 letters.")
283
+ if convention[1] in (convention[0], convention[2]):
284
+ raise ValueError(f"Invalid convention {convention}.")
285
+ for letter in convention:
286
+ if letter not in ("X", "Y", "Z"):
287
+ raise ValueError(f"Invalid letter {letter} in convention string.")
288
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
289
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
290
+ i0 = _index_from_letter(convention[0])
291
+ i2 = _index_from_letter(convention[2])
292
+ tait_bryan = i0 != i2
293
+ if tait_bryan:
294
+ central_angle = torch.asin(
295
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
296
+ )
297
+ else:
298
+ central_angle = torch.acos(matrix[..., i0, i0])
299
+
300
+ o = (
301
+ _angle_from_tan(
302
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
303
+ ),
304
+ central_angle,
305
+ _angle_from_tan(
306
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
307
+ ),
308
+ )
309
+ return torch.stack(o, -1)
310
+
311
+
312
+
313
+ def random_quaternions(
314
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
315
+ ) -> torch.Tensor:
316
+ """
317
+ Generate random quaternions representing rotations,
318
+ i.e. versors with nonnegative real part.
319
+
320
+ Args:
321
+ n: Number of quaternions in a batch to return.
322
+ dtype: Type to return.
323
+ device: Desired device of returned tensor. Default:
324
+ uses the current device for the default tensor type.
325
+
326
+ Returns:
327
+ Quaternions as tensor of shape (N, 4).
328
+ """
329
+ if isinstance(device, str):
330
+ device = torch.device(device)
331
+ o = torch.randn((n, 4), dtype=dtype, device=device)
332
+ s = (o * o).sum(1)
333
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
334
+ return o
335
+
336
+
337
+
338
+ def random_rotations(
339
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
340
+ ) -> torch.Tensor:
341
+ """
342
+ Generate random rotations as 3x3 rotation matrices.
343
+
344
+ Args:
345
+ n: Number of rotation matrices in a batch to return.
346
+ dtype: Type to return.
347
+ device: Device of returned tensor. Default: if None,
348
+ uses the current device for the default tensor type.
349
+
350
+ Returns:
351
+ Rotation matrices as tensor of shape (n, 3, 3).
352
+ """
353
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
354
+ return quaternion_to_matrix(quaternions)
355
+
356
+
357
+
358
+ def random_rotation(
359
+ dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
360
+ ) -> torch.Tensor:
361
+ """
362
+ Generate a single random 3x3 rotation matrix.
363
+
364
+ Args:
365
+ dtype: Type to return
366
+ device: Device of returned tensor. Default: if None,
367
+ uses the current device for the default tensor type
368
+
369
+ Returns:
370
+ Rotation matrix as tensor of shape (3, 3).
371
+ """
372
+ return random_rotations(1, dtype, device)[0]
373
+
374
+
375
+
376
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
377
+ """
378
+ Convert a unit quaternion to a standard form: one in which the real
379
+ part is non negative.
380
+
381
+ Args:
382
+ quaternions: Quaternions with real part first,
383
+ as tensor of shape (..., 4).
384
+
385
+ Returns:
386
+ Standardized quaternions as tensor of shape (..., 4).
387
+ """
388
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
389
+
390
+
391
+
392
+ def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
393
+ """
394
+ Multiply two quaternions.
395
+ Usual torch rules for broadcasting apply.
396
+
397
+ Args:
398
+ a: Quaternions as tensor of shape (..., 4), real part first.
399
+ b: Quaternions as tensor of shape (..., 4), real part first.
400
+
401
+ Returns:
402
+ The product of a and b, a tensor of quaternions shape (..., 4).
403
+ """
404
+ aw, ax, ay, az = torch.unbind(a, -1)
405
+ bw, bx, by, bz = torch.unbind(b, -1)
406
+ ow = aw * bw - ax * bx - ay * by - az * bz
407
+ ox = aw * bx + ax * bw + ay * bz - az * by
408
+ oy = aw * by - ax * bz + ay * bw + az * bx
409
+ oz = aw * bz + ax * by - ay * bx + az * bw
410
+ return torch.stack((ow, ox, oy, oz), -1)
411
+
412
+
413
+
414
+ def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
415
+ """
416
+ Multiply two quaternions representing rotations, returning the quaternion
417
+ representing their composition, i.e. the versor with nonnegative real part.
418
+ Usual torch rules for broadcasting apply.
419
+
420
+ Args:
421
+ a: Quaternions as tensor of shape (..., 4), real part first.
422
+ b: Quaternions as tensor of shape (..., 4), real part first.
423
+
424
+ Returns:
425
+ The product of a and b, a tensor of quaternions of shape (..., 4).
426
+ """
427
+ ab = quaternion_raw_multiply(a, b)
428
+ return standardize_quaternion(ab)
429
+
430
+
431
+
432
+ def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
433
+ """
434
+ Given a quaternion representing rotation, get the quaternion representing
435
+ its inverse.
436
+
437
+ Args:
438
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
439
+ first, which must be versors (unit quaternions).
440
+
441
+ Returns:
442
+ The inverse, a tensor of quaternions of shape (..., 4).
443
+ """
444
+
445
+ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
446
+ return quaternion * scaling
447
+
448
+
449
+
450
+ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
451
+ """
452
+ Apply the rotation given by a quaternion to a 3D point.
453
+ Usual torch rules for broadcasting apply.
454
+
455
+ Args:
456
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
457
+ point: Tensor of 3D points of shape (..., 3).
458
+
459
+ Returns:
460
+ Tensor of rotated points of shape (..., 3).
461
+ """
462
+ if point.size(-1) != 3:
463
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
464
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
465
+ point_as_quaternion = torch.cat((real_parts, point), -1)
466
+ out = quaternion_raw_multiply(
467
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
468
+ quaternion_invert(quaternion),
469
+ )
470
+ return out[..., 1:]
471
+
472
+
473
+
474
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
475
+ """
476
+ Convert rotations given as axis/angle to rotation matrices.
477
+
478
+ Args:
479
+ axis_angle: Rotations given as a vector in axis angle form,
480
+ as a tensor of shape (..., 3), where the magnitude is
481
+ the angle turned anticlockwise in radians around the
482
+ vector's direction.
483
+
484
+ Returns:
485
+ Rotation matrices as tensor of shape (..., 3, 3).
486
+ """
487
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
488
+
489
+
490
+
491
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ Convert rotations given as rotation matrices to axis/angle.
494
+
495
+ Args:
496
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
497
+
498
+ Returns:
499
+ Rotations given as a vector in axis angle form, as a tensor
500
+ of shape (..., 3), where the magnitude is the angle
501
+ turned anticlockwise in radians around the vector's
502
+ direction.
503
+ """
504
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
505
+
506
+
507
+
508
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
509
+ """
510
+ Convert rotations given as axis/angle to quaternions.
511
+
512
+ Args:
513
+ axis_angle: Rotations given as a vector in axis angle form,
514
+ as a tensor of shape (..., 3), where the magnitude is
515
+ the angle turned anticlockwise in radians around the
516
+ vector's direction.
517
+
518
+ Returns:
519
+ quaternions with real part first, as tensor of shape (..., 4).
520
+ """
521
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
522
+ half_angles = angles * 0.5
523
+ eps = 1e-6
524
+ small_angles = angles.abs() < eps
525
+ sin_half_angles_over_angles = torch.empty_like(angles)
526
+ sin_half_angles_over_angles[~small_angles] = (
527
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
528
+ )
529
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
530
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
531
+ sin_half_angles_over_angles[small_angles] = (
532
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
533
+ )
534
+ quaternions = torch.cat(
535
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
536
+ )
537
+ return quaternions
538
+
539
+
540
+
541
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
542
+ """
543
+ Convert rotations given as quaternions to axis/angle.
544
+
545
+ Args:
546
+ quaternions: quaternions with real part first,
547
+ as tensor of shape (..., 4).
548
+
549
+ Returns:
550
+ Rotations given as a vector in axis angle form, as a tensor
551
+ of shape (..., 3), where the magnitude is the angle
552
+ turned anticlockwise in radians around the vector's
553
+ direction.
554
+ """
555
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
556
+ half_angles = torch.atan2(norms, quaternions[..., :1])
557
+ angles = 2 * half_angles
558
+ eps = 1e-6
559
+ small_angles = angles.abs() < eps
560
+ sin_half_angles_over_angles = torch.empty_like(angles)
561
+ sin_half_angles_over_angles[~small_angles] = (
562
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
563
+ )
564
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
565
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
566
+ sin_half_angles_over_angles[small_angles] = (
567
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
568
+ )
569
+ return quaternions[..., 1:] / sin_half_angles_over_angles
570
+
571
+
572
+
573
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
574
+ """
575
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
576
+ using Gram--Schmidt orthogonalization per Section B of [1].
577
+ Args:
578
+ d6: 6D rotation representation, of size (*, 6)
579
+
580
+ Returns:
581
+ batch of rotation matrices of size (*, 3, 3)
582
+
583
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
584
+ On the Continuity of Rotation Representations in Neural Networks.
585
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
586
+ Retrieved from http://arxiv.org/abs/1812.07035
587
+ """
588
+
589
+ a1, a2 = d6[..., :3], d6[..., 3:]
590
+ b1 = F.normalize(a1, dim=-1)
591
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
592
+ b2 = F.normalize(b2, dim=-1)
593
+ b3 = torch.cross(b1, b2, dim=-1)
594
+ return torch.stack((b1, b2, b3), dim=-2)
595
+
596
+
597
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
598
+ """
599
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
600
+ by dropping the last row. Note that 6D representation is not unique.
601
+ Args:
602
+ matrix: batch of rotation matrices of size (*, 3, 3)
603
+
604
+ Returns:
605
+ 6D rotation representation, of size (*, 6)
606
+
607
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
608
+ On the Continuity of Rotation Representations in Neural Networks.
609
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
610
+ Retrieved from http://arxiv.org/abs/1812.07035
611
+ """
612
+ batch_dim = matrix.size()[:-2]
613
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
614
+
615
+
616
+ def clean_rotation_6d(d6d: torch.Tensor) -> torch.Tensor:
617
+ """
618
+ Clean rotation 6d by converting it to matrix and then reconvert to d6
619
+ """
620
+ matrix = rotation_6d_to_matrix(d6d)
621
+ d6d = matrix_to_rotation_6d(matrix)
622
+ return d6d
623
+
624
+
625
+ def rot6d_to_rotmat(x):
626
+ """Convert 6D rotation representation to 3x3 rotation matrix.
627
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
628
+ Input:
629
+ (B,6) Batch of 6-D rotation representations
630
+ Output:
631
+ (B,3,3) Batch of corresponding rotation matrices
632
+ """
633
+ if x.shape[-1] == 6:
634
+ batch_dim = x.size()[:-1]
635
+ else:
636
+ x = x.reshape(*x.shape[:-1], -1, 6)
637
+ batch_dim = x.size()[:-1]
638
+
639
+ x = x.reshape(*batch_dim, 3, 2)
640
+ a1, a2 = x[..., 0], x[..., 1]
641
+
642
+ b1 = F.normalize(a1, dim=-1)
643
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
644
+ b2 = F.normalize(b2, dim=-1)
645
+ b3 = torch.cross(b1, b2, dim=-1)
646
+
647
+ return torch.stack((b1, b2, b3), dim=-1)
648
+
649
+
650
+ def rotmat_to_rot6d(x):
651
+ """Inverse computation of rot6d_to_rotmat."""
652
+ batch_dim = x.size()[:-2]
653
+ return x[..., :2].clone().reshape(batch_dim + (6,))
654
+
655
+
656
+ def convert_rotation_matrix_to_homogeneous(rotation_matrix):
657
+ "Add empty translation vector to Rotation matrix"""
658
+
659
+ transl = torch.zeros_like(rotation_matrix[...,:1])
660
+ rotation_matrix_hom = torch.cat((rotation_matrix, transl), dim=-1)
661
+
662
+ return rotation_matrix_hom
663
+
664
+
665
+ def rotation_matrix_to_angle_axis(rotation_matrix):
666
+ """Convert 3x4 rotation matrix to Rodrigues vector
667
+
668
+ Args:
669
+ rotation_matrix (Tensor): rotation matrix.
670
+
671
+ Returns:
672
+ Tensor: Rodrigues vector transformation.
673
+
674
+ Shape:
675
+ - Input: :math:`(N, 3, 4)`
676
+ - Output: :math:`(N, 3)`
677
+
678
+ Example:
679
+ >>> input = torch.rand(2, 3, 4) # Nx3x4
680
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
681
+ """
682
+
683
+ if rotation_matrix.size(-1) == 3:
684
+ rotation_matrix = convert_rotation_matrix_to_homogeneous(rotation_matrix)
685
+
686
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
687
+ return quaternion_to_angle_axis(quaternion)
688
+
689
+
690
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
691
+ """Convert 3x4 rotation matrix to 4d quaternion vector
692
+
693
+ This algorithm is based on algorithm described in
694
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
695
+
696
+ Args:
697
+ rotation_matrix (Tensor): the rotation matrix to convert.
698
+
699
+ Return:
700
+ Tensor: the rotation in quaternion
701
+
702
+ Shape:
703
+ - Input: :math:`(N, 3, 4)`
704
+ - Output: :math:`(N, 4)`
705
+
706
+ Example:
707
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
708
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
709
+ """
710
+ if not torch.is_tensor(rotation_matrix):
711
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
712
+ type(rotation_matrix)))
713
+
714
+ if len(rotation_matrix.shape) > 3:
715
+ raise ValueError(
716
+ "Input size must be a three dimensional tensor. Got {}".format(
717
+ rotation_matrix.shape))
718
+ if not rotation_matrix.shape[-2:] == (3, 4):
719
+ raise ValueError(
720
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
721
+ rotation_matrix.shape))
722
+
723
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
724
+
725
+ mask_d2 = rmat_t[:, 2, 2] < eps
726
+
727
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
728
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
729
+
730
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
731
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
732
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
733
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
734
+ t0_rep = t0.repeat(4, 1).t()
735
+
736
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
737
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
738
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
739
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
740
+ t1_rep = t1.repeat(4, 1).t()
741
+
742
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
743
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
744
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
745
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
746
+ t2_rep = t2.repeat(4, 1).t()
747
+
748
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
749
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
750
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
751
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
752
+ t3_rep = t3.repeat(4, 1).t()
753
+
754
+ mask_c0 = mask_d2 * mask_d0_d1
755
+ # mask_c1 = mask_d2 * (1 - mask_d0_d1)
756
+ mask_c1 = mask_d2 * ~mask_d0_d1
757
+ # mask_c2 = (1 - mask_d2) * mask_d0_nd1
758
+ mask_c2 = ~mask_d2 * mask_d0_nd1
759
+ # mask_c3 = (1 - mask_d2) * (1 - mask_d0_nd1)
760
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
761
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
762
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
763
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
764
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
765
+
766
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
767
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
768
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
769
+ q *= 0.5
770
+ return q
771
+
772
+
773
+ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
774
+ """Convert quaternion vector to angle axis of rotation.
775
+
776
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
777
+
778
+ Args:
779
+ quaternion (torch.Tensor): tensor with quaternions.
780
+
781
+ Return:
782
+ torch.Tensor: tensor with angle axis of rotation.
783
+
784
+ Shape:
785
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
786
+ - Output: :math:`(*, 3)`
787
+
788
+ Example:
789
+ >>> quaternion = torch.rand(2, 4) # Nx4
790
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
791
+ """
792
+ if not torch.is_tensor(quaternion):
793
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
794
+ type(quaternion)))
795
+
796
+ if not quaternion.shape[-1] == 4:
797
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
798
+ .format(quaternion.shape))
799
+ # unpack input and compute conversion
800
+ q1: torch.Tensor = quaternion[..., 1]
801
+ q2: torch.Tensor = quaternion[..., 2]
802
+ q3: torch.Tensor = quaternion[..., 3]
803
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
804
+
805
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
806
+ cos_theta: torch.Tensor = quaternion[..., 0]
807
+ two_theta: torch.Tensor = 2.0 * torch.where(
808
+ cos_theta < 0.0,
809
+ torch.atan2(-sin_theta, -cos_theta),
810
+ torch.atan2(sin_theta, cos_theta))
811
+
812
+ k_pos: torch.Tensor = two_theta / sin_theta
813
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
814
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
815
+
816
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
817
+ angle_axis[..., 0] += q1 * k
818
+ angle_axis[..., 1] += q2 * k
819
+ angle_axis[..., 2] += q3 * k
820
+ return angle_axis
821
+
822
+
823
+ def avg_rot(rot):
824
+ # input [B,...,3,3] --> output [...,3,3]
825
+ rot = rot.mean(dim=0)
826
+ U, _, V = torch.svd(rot)
827
+ rot = U @ V.transpose(-1, -2)
828
+ return rot
lib/utils/utils.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ import os
18
+ import yaml
19
+ import torch
20
+ import shutil
21
+ import logging
22
+ import operator
23
+ from tqdm import tqdm
24
+ from os import path as osp
25
+ from functools import reduce
26
+ from typing import List, Union
27
+ from collections import OrderedDict
28
+ from torch.optim.lr_scheduler import _LRScheduler
29
+
30
+ class CustomScheduler(_LRScheduler):
31
+ def __init__(self, optimizer, lr_lambda):
32
+ self.lr_lambda = lr_lambda
33
+ super(CustomScheduler, self).__init__(optimizer)
34
+
35
+ def get_lr(self):
36
+ return [base_lr * self.lr_lambda(self.last_epoch)
37
+ for base_lr in self.base_lrs]
38
+
39
+ def lr_decay_fn(epoch):
40
+ if epoch == 0: return 1.0
41
+ if epoch % big_epoch == 0:
42
+ return big_decay
43
+ else:
44
+ return small_decay
45
+
46
+ def save_obj(v, f, file_name='output.obj'):
47
+ obj_file = open(file_name, 'w')
48
+ for i in range(len(v)):
49
+ obj_file.write('v ' + str(v[i][0]) + ' ' + str(v[i][1]) + ' ' + str(v[i][2]) + '\n')
50
+ for i in range(len(f)):
51
+ obj_file.write('f ' + str(f[i][0]+1) + '/' + str(f[i][0]+1) + ' ' + str(f[i][1]+1) + '/' + str(f[i][1]+1) + ' ' + str(f[i][2]+1) + '/' + str(f[i][2]+1) + '\n')
52
+ obj_file.close()
53
+
54
+
55
+ def check_data_pararell(train_weight):
56
+ new_state_dict = OrderedDict()
57
+ for k, v in train_weight.items():
58
+ name = k[7:] if k.startswith('module') else k # remove `module.`
59
+ new_state_dict[name] = v
60
+ return new_state_dict
61
+
62
+
63
+ def get_from_dict(dict, keys):
64
+ return reduce(operator.getitem, keys, dict)
65
+
66
+
67
+ def tqdm_enumerate(iter):
68
+ i = 0
69
+ for y in tqdm(iter):
70
+ yield i, y
71
+ i += 1
72
+
73
+
74
+ def iterdict(d):
75
+ for k,v in d.items():
76
+ if isinstance(v, dict):
77
+ d[k] = dict(v)
78
+ iterdict(v)
79
+ return d
80
+
81
+
82
+ def accuracy(output, target):
83
+ _, pred = output.topk(1)
84
+ pred = pred.view(-1)
85
+
86
+ correct = pred.eq(target).sum()
87
+
88
+ return correct.item(), target.size(0) - correct.item()
89
+
90
+
91
+ def lr_decay(optimizer, step, lr, decay_step, gamma):
92
+ lr = lr * gamma ** (step/decay_step)
93
+ for param_group in optimizer.param_groups:
94
+ param_group['lr'] = lr
95
+ return lr
96
+
97
+
98
+ def step_decay(optimizer, step, lr, decay_step, gamma):
99
+ lr = lr * gamma ** (step / decay_step)
100
+ for param_group in optimizer.param_groups:
101
+ param_group['lr'] = lr
102
+ return lr
103
+
104
+
105
+ def read_yaml(filename):
106
+ return yaml.load(open(filename, 'r'))
107
+
108
+
109
+ def write_yaml(filename, object):
110
+ with open(filename, 'w') as f:
111
+ yaml.dump(object, f)
112
+
113
+
114
+ def save_dict_to_yaml(obj, filename, mode='w'):
115
+ with open(filename, mode) as f:
116
+ yaml.dump(obj, f, default_flow_style=False)
117
+
118
+
119
+ def save_to_file(obj, filename, mode='w'):
120
+ with open(filename, mode) as f:
121
+ f.write(obj)
122
+
123
+
124
+ def concatenate_dicts(dict_list, dim=0):
125
+ rdict = dict.fromkeys(dict_list[0].keys())
126
+ for k in rdict.keys():
127
+ rdict[k] = torch.cat([d[k] for d in dict_list], dim=dim)
128
+ return rdict
129
+
130
+
131
+ def bool_to_string(x: Union[List[bool],bool]) -> Union[List[str],str]:
132
+ """
133
+ boolean to string conversion
134
+ :param x: list or bool to be converted
135
+ :return: string converted thing
136
+ """
137
+ if isinstance(x, bool):
138
+ return [str(x)]
139
+ for i, j in enumerate(x):
140
+ x[i]=str(j)
141
+ return x
142
+
143
+
144
+ def checkpoint2model(checkpoint, key='gen_state_dict'):
145
+ state_dict = checkpoint[key]
146
+ print(f'Performance of loaded model on 3DPW is {checkpoint["performance"]:.2f}mm')
147
+ # del state_dict['regressor.mean_theta']
148
+ return state_dict
149
+
150
+
151
+ def get_optimizer(cfg, model, optim_type, momentum, stage):
152
+ if stage == 'stage2':
153
+ param_list = [{'params': model.integrator.parameters()}]
154
+ for name, param in model.named_parameters():
155
+ # if 'integrator' not in name and 'motion_encoder' not in name and 'trajectory_decoder' not in name:
156
+ if 'integrator' not in name:
157
+ param_list.append({'params': param, 'lr': cfg.TRAIN.LR_FINETUNE})
158
+ else:
159
+ param_list = [{'params': model.parameters()}]
160
+
161
+ if optim_type in ['sgd', 'SGD']:
162
+ opt = torch.optim.SGD(lr=cfg.TRAIN.LR, params=param_list, momentum=momentum)
163
+ elif optim_type in ['Adam', 'adam', 'ADAM']:
164
+ opt = torch.optim.Adam(lr=cfg.TRAIN.LR, params=param_list, weight_decay=cfg.TRAIN.WD, betas=(0.9, 0.999))
165
+ else:
166
+ raise ModuleNotFoundError
167
+
168
+ return opt
169
+
170
+
171
+ def create_logger(logdir, phase='train'):
172
+ os.makedirs(logdir, exist_ok=True)
173
+
174
+ log_file = osp.join(logdir, f'{phase}_log.txt')
175
+
176
+ head = '%(asctime)-15s %(message)s'
177
+ logging.basicConfig(filename=log_file,
178
+ format=head)
179
+ logger = logging.getLogger()
180
+ logger.setLevel(logging.INFO)
181
+ console = logging.StreamHandler()
182
+ logging.getLogger('').addHandler(console)
183
+
184
+ return logger
185
+
186
+
187
+ class AverageMeter(object):
188
+ def __init__(self):
189
+ self.val = 0
190
+ self.avg = 0
191
+ self.sum = 0
192
+ self.count = 0
193
+
194
+ def update(self, val, n=1):
195
+ self.val = val
196
+ self.sum += val * n
197
+ self.count += n
198
+ self.avg = self.sum / self.count
199
+
200
+
201
+ def prepare_output_dir(cfg, cfg_file):
202
+
203
+ # ==== create logdir
204
+ logdir = osp.join(cfg.OUTPUT_DIR, cfg.EXP_NAME)
205
+ os.makedirs(logdir, exist_ok=True)
206
+ shutil.copy(src=cfg_file, dst=osp.join(cfg.OUTPUT_DIR, 'config.yaml'))
207
+
208
+ cfg.LOGDIR = logdir
209
+
210
+ # save config
211
+ save_dict_to_yaml(cfg, osp.join(cfg.LOGDIR, 'config.yaml'))
212
+
213
+ return cfg
214
+
215
+
216
+ def prepare_groundtruth(batch, device):
217
+ groundtruths = dict()
218
+ gt_keys = ['pose', 'cam', 'betas', 'kp3d', 'bbox'] # Evaluation
219
+ gt_keys += ['pose_root', 'vel_root', 'weak_kp2d', 'verts', # Training
220
+ 'full_kp2d', 'contact', 'R', 'cam_angvel',
221
+ 'has_smpl', 'has_traj', 'has_full_screen', 'has_verts']
222
+ for gt_key in gt_keys:
223
+ if gt_key in batch.keys():
224
+ dtype = torch.float32 if batch[gt_key].dtype == torch.float64 else batch[gt_key].dtype
225
+ groundtruths[gt_key] = batch[gt_key].to(dtype=dtype, device=device)
226
+
227
+ return groundtruths
228
+
229
+ def prepare_auxiliary(batch, device):
230
+ aux = dict()
231
+ aux_keys = ['mask', 'bbox', 'res', 'cam_intrinsics', 'init_root', 'cam_angvel']
232
+ for key in aux_keys:
233
+ if key in batch.keys():
234
+ dtype = torch.float32 if batch[key].dtype == torch.float64 else batch[key].dtype
235
+ aux[key] = batch[key].to(dtype=dtype, device=device)
236
+
237
+ return aux
238
+
239
+ def prepare_input(batch, device, use_features):
240
+ # Input keypoints data
241
+ kp2d = batch['kp2d'].to(device).float()
242
+
243
+ # Input features
244
+ if use_features and 'features' in batch.keys():
245
+ features = batch['features'].to(device).float()
246
+ else:
247
+ features = None
248
+
249
+ # Initial SMPL parameters
250
+ init_smpl = batch['init_pose'].to(device).float()
251
+
252
+ # Initial keypoints
253
+ init_kp = torch.cat((
254
+ batch['init_kp3d'], batch['init_kp2d']
255
+ ), dim=-1).to(device).float()
256
+
257
+ return kp2d, (init_kp, init_smpl), features
258
+
259
+
260
+ def prepare_batch(batch, device, use_features=True):
261
+ x, inits, features = prepare_input(batch, device, use_features)
262
+ aux = prepare_auxiliary(batch, device)
263
+ groundtruths = prepare_groundtruth(batch, device)
264
+
265
+ return x, inits, features, aux, groundtruths
lib/vis/__pycache__/renderer.cpython-39.pyc ADDED
Binary file (9.31 kB). View file
 
lib/vis/__pycache__/run_vis.cpython-39.pyc ADDED
Binary file (3.08 kB). View file
 
lib/vis/__pycache__/tools.cpython-39.pyc ADDED
Binary file (15.9 kB). View file
 
lib/vis/renderer.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+
5
+ from pytorch3d.renderer import (
6
+ PerspectiveCameras,
7
+ TexturesVertex,
8
+ PointLights,
9
+ Materials,
10
+ RasterizationSettings,
11
+ MeshRenderer,
12
+ MeshRasterizer,
13
+ SoftPhongShader,
14
+ )
15
+ from pytorch3d.structures import Meshes
16
+ from pytorch3d.structures.meshes import join_meshes_as_scene
17
+ from pytorch3d.renderer.cameras import look_at_rotation
18
+
19
+ from .tools import get_colors, checkerboard_geometry
20
+
21
+
22
+ def overlay_image_onto_background(image, mask, bbox, background):
23
+ if isinstance(image, torch.Tensor):
24
+ image = image.detach().cpu().numpy()
25
+ if isinstance(mask, torch.Tensor):
26
+ mask = mask.detach().cpu().numpy()
27
+
28
+ out_image = background.copy()
29
+ bbox = bbox[0].int().cpu().numpy().copy()
30
+ roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
31
+
32
+ roi_image[mask] = image[mask]
33
+ out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image
34
+
35
+ return out_image
36
+
37
+
38
+ def update_intrinsics_from_bbox(K_org, bbox):
39
+ device, dtype = K_org.device, K_org.dtype
40
+
41
+ K = torch.zeros((K_org.shape[0], 4, 4)
42
+ ).to(device=device, dtype=dtype)
43
+ K[:, :3, :3] = K_org.clone()
44
+ K[:, 2, 2] = 0
45
+ K[:, 2, -1] = 1
46
+ K[:, -1, 2] = 1
47
+
48
+ image_sizes = []
49
+ for idx, bbox in enumerate(bbox):
50
+ left, upper, right, lower = bbox
51
+ cx, cy = K[idx, 0, 2], K[idx, 1, 2]
52
+
53
+ new_cx = cx - left
54
+ new_cy = cy - upper
55
+ new_height = max(lower - upper, 1)
56
+ new_width = max(right - left, 1)
57
+ new_cx = new_width - new_cx
58
+ new_cy = new_height - new_cy
59
+
60
+ K[idx, 0, 2] = new_cx
61
+ K[idx, 1, 2] = new_cy
62
+ image_sizes.append((int(new_height), int(new_width)))
63
+
64
+ return K, image_sizes
65
+
66
+
67
+ def perspective_projection(x3d, K, R=None, T=None):
68
+ if R != None:
69
+ x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
70
+ if T != None:
71
+ x3d = x3d + T.transpose(1, 2)
72
+
73
+ x2d = torch.div(x3d, x3d[..., 2:])
74
+ x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
75
+ return x2d
76
+
77
+
78
+ def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
79
+ left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
80
+ right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
81
+ top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
82
+ bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
83
+
84
+ cx = (left + right) / 2
85
+ cy = (top + bottom) / 2
86
+ width = (right - left)
87
+ height = (bottom - top)
88
+
89
+ new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1)
90
+ new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w)
91
+ new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1)
92
+ new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
93
+
94
+ bbox = torch.stack((new_left.detach(), new_top.detach(),
95
+ new_right.detach(), new_bottom.detach())).int().float().T
96
+
97
+ return bbox
98
+
99
+
100
+ class Renderer():
101
+ def __init__(self, width, height, focal_length, device, faces=None):
102
+
103
+ self.width = width
104
+ self.height = height
105
+ self.focal_length = focal_length
106
+
107
+ self.device = device
108
+ if faces is not None:
109
+ self.faces = torch.from_numpy(
110
+ (faces).astype('int')
111
+ ).unsqueeze(0).to(self.device)
112
+
113
+ self.initialize_camera_params()
114
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
115
+ self.create_renderer()
116
+
117
+ def create_renderer(self):
118
+ self.renderer = MeshRenderer(
119
+ rasterizer=MeshRasterizer(
120
+ raster_settings=RasterizationSettings(
121
+ image_size=self.image_sizes[0],
122
+ blur_radius=1e-5),
123
+ ),
124
+ shader=SoftPhongShader(
125
+ device=self.device,
126
+ lights=self.lights,
127
+ )
128
+ )
129
+
130
+ def create_camera(self, R=None, T=None):
131
+ if R is not None:
132
+ self.R = R.clone().view(1, 3, 3).to(self.device)
133
+ if T is not None:
134
+ self.T = T.clone().view(1, 3).to(self.device)
135
+
136
+ return PerspectiveCameras(
137
+ device=self.device,
138
+ R=self.R.mT,
139
+ T=self.T,
140
+ K=self.K_full,
141
+ image_size=self.image_sizes,
142
+ in_ndc=False)
143
+
144
+
145
+ def initialize_camera_params(self):
146
+ """Hard coding for camera parameters
147
+ TODO: Do some soft coding"""
148
+
149
+ # Extrinsics
150
+ self.R = torch.diag(
151
+ torch.tensor([1, 1, 1])
152
+ ).float().to(self.device).unsqueeze(0)
153
+
154
+ self.T = torch.tensor(
155
+ [0, 0, 0]
156
+ ).unsqueeze(0).float().to(self.device)
157
+
158
+ # Intrinsics
159
+ self.K = torch.tensor(
160
+ [[self.focal_length, 0, self.width/2],
161
+ [0, self.focal_length, self.height/2],
162
+ [0, 0, 1]]
163
+ ).unsqueeze(0).float().to(self.device)
164
+ self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
165
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
166
+ self.cameras = self.create_camera()
167
+
168
+
169
+ def set_ground(self, length, center_x, center_z):
170
+ device = self.device
171
+ v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y"))
172
+ v, f, vc = v.to(device), f.to(device), vc.to(device)
173
+ self.ground_geometry = [v, f, vc]
174
+
175
+
176
+ def update_bbox(self, x3d, scale=2.0, mask=None):
177
+ """ Update bbox of cameras from the given 3d points
178
+
179
+ x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
180
+ """
181
+
182
+ if x3d.size(-1) != 3:
183
+ x2d = x3d.unsqueeze(0)
184
+ else:
185
+ x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
186
+
187
+ if mask is not None:
188
+ x2d = x2d[:, ~mask]
189
+
190
+ bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
191
+ self.bboxes = bbox
192
+
193
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
194
+ self.cameras = self.create_camera()
195
+ self.create_renderer()
196
+
197
+ def reset_bbox(self,):
198
+ bbox = torch.zeros((1, 4)).float().to(self.device)
199
+ bbox[0, 2] = self.width
200
+ bbox[0, 3] = self.height
201
+ self.bboxes = bbox
202
+
203
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
204
+ self.cameras = self.create_camera()
205
+ self.create_renderer()
206
+
207
+ def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]):
208
+ self.update_bbox(vertices[::50], scale=1.2)
209
+ vertices = vertices.unsqueeze(0)
210
+
211
+ if colors[0] > 1: colors = [c / 255. for c in colors]
212
+ verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
213
+ verts_features = verts_features.repeat(1, vertices.shape[1], 1)
214
+ textures = TexturesVertex(verts_features=verts_features)
215
+
216
+ mesh = Meshes(verts=vertices,
217
+ faces=self.faces,
218
+ textures=textures,)
219
+
220
+ materials = Materials(
221
+ device=self.device,
222
+ specular_color=(colors, ),
223
+ shininess=0
224
+ )
225
+
226
+ results = torch.flip(
227
+ self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights),
228
+ [1, 2]
229
+ )
230
+ image = results[0, ..., :3] * 255
231
+ mask = results[0, ..., -1] > 1e-3
232
+
233
+ image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
234
+ self.reset_bbox()
235
+ return image
236
+
237
+
238
+ def render_with_ground(self, verts, faces, colors, cameras, lights):
239
+ """
240
+ :param verts (B, V, 3)
241
+ :param faces (F, 3)
242
+ :param colors (B, 3)
243
+ """
244
+
245
+ # (B, V, 3), (B, F, 3), (B, V, 3)
246
+ verts, faces, colors = prep_shared_geometry(verts, faces, colors)
247
+ # (V, 3), (F, 3), (V, 3)
248
+ gv, gf, gc = self.ground_geometry
249
+ verts = list(torch.unbind(verts, dim=0)) + [gv]
250
+ faces = list(torch.unbind(faces, dim=0)) + [gf]
251
+ colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
252
+ mesh = create_meshes(verts, faces, colors)
253
+
254
+ materials = Materials(
255
+ device=self.device,
256
+ shininess=0
257
+ )
258
+
259
+ results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
260
+ image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
261
+
262
+ return image
263
+
264
+
265
+ def prep_shared_geometry(verts, faces, colors):
266
+ """
267
+ :param verts (B, V, 3)
268
+ :param faces (F, 3)
269
+ :param colors (B, 4)
270
+ """
271
+ B, V, _ = verts.shape
272
+ F, _ = faces.shape
273
+ colors = colors.unsqueeze(1).expand(B, V, -1)[..., :3]
274
+ faces = faces.unsqueeze(0).expand(B, F, -1)
275
+ return verts, faces, colors
276
+
277
+
278
+ def create_meshes(verts, faces, colors):
279
+ """
280
+ :param verts (B, V, 3)
281
+ :param faces (B, F, 3)
282
+ :param colors (B, V, 3)
283
+ """
284
+ textures = TexturesVertex(verts_features=colors)
285
+ meshes = Meshes(verts=verts, faces=faces, textures=textures)
286
+ return join_meshes_as_scene(meshes)
287
+
288
+
289
+ def get_global_cameras(verts, device, distance=5, position=(-5.0, 5.0, 0.0)):
290
+ positions = torch.tensor([position]).repeat(len(verts), 1)
291
+ targets = verts.mean(1)
292
+
293
+ directions = targets - positions
294
+ directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
295
+ positions = targets - directions
296
+
297
+ rotation = look_at_rotation(positions, targets, ).mT
298
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
299
+
300
+ lights = PointLights(device=device, location=[position])
301
+ return rotation, translation, lights
302
+
303
+
304
+ def _get_global_cameras(verts, device, min_distance=3, chunk_size=100):
305
+
306
+ # split into smaller chunks to visualize
307
+ start_idxs = list(range(0, len(verts), chunk_size))
308
+ end_idxs = [min(start_idx + chunk_size, len(verts)) for start_idx in start_idxs]
309
+
310
+ Rs, Ts = [], []
311
+ for start_idx, end_idx in zip(start_idxs, end_idxs):
312
+ vert = verts[start_idx:end_idx].clone()
313
+ import pdb; pdb.set_trace()
lib/vis/run_vis.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import cv2
5
+ import torch
6
+ import imageio
7
+ import numpy as np
8
+ from progress.bar import Bar
9
+
10
+ from lib.vis.renderer import Renderer, get_global_cameras
11
+
12
+ def run_vis_on_demo(cfg, video, results, output_pth, smpl, vis_global=True):
13
+ # to torch tensor
14
+ tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE)
15
+
16
+ cap = cv2.VideoCapture(video)
17
+ fps = cap.get(cv2.CAP_PROP_FPS)
18
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
19
+ width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
20
+
21
+ # create renderer with cliff focal length estimation
22
+ focal_length = (width ** 2 + height ** 2) ** 0.5
23
+ renderer = Renderer(width, height, focal_length, cfg.DEVICE, smpl.faces)
24
+
25
+ if vis_global:
26
+ # setup global coordinate subject
27
+ # current implementation only visualize the subject appeared longest
28
+ n_frames = {k: len(results[k]['frame_ids']) for k in results.keys()}
29
+ sid = max(n_frames, key=n_frames.get)
30
+ global_output = smpl.get_output(
31
+ body_pose=tt(results[sid]['pose_world'][:, 3:]),
32
+ global_orient=tt(results[sid]['pose_world'][:, :3]),
33
+ betas=tt(results[sid]['betas']),
34
+ transl=tt(results[sid]['trans_world']))
35
+ verts_glob = global_output.vertices.cpu()
36
+ verts_glob[..., 1] = verts_glob[..., 1] - verts_glob[..., 1].min()
37
+ cx, cz = (verts_glob.mean(1).max(0)[0] + verts_glob.mean(1).min(0)[0])[[0, 2]] / 2.0
38
+ sx, sz = (verts_glob.mean(1).max(0)[0] - verts_glob.mean(1).min(0)[0])[[0, 2]]
39
+ scale = max(sx.item(), sz.item()) * 1.5
40
+
41
+ # set default ground
42
+ renderer.set_ground(scale, cx.item(), cz.item())
43
+
44
+ # build global camera
45
+ global_R, global_T, global_lights = get_global_cameras(verts_glob, cfg.DEVICE)
46
+
47
+ # build default camera
48
+ default_R, default_T = torch.eye(3), torch.zeros(3)
49
+
50
+ writer = imageio.get_writer(
51
+ osp.join(output_pth, 'output.mp4'),
52
+ fps=fps, mode='I', format='FFMPEG', macro_block_size=1
53
+ )
54
+ bar = Bar('Rendering results ...', fill='#', max=length)
55
+
56
+ frame_i = 0
57
+ _global_R, _global_T = None, None
58
+ # run rendering
59
+ while (cap.isOpened()):
60
+ flag, org_img = cap.read()
61
+ if not flag: break
62
+ img = org_img[..., ::-1].copy()
63
+
64
+ # render onto the input video
65
+ renderer.create_camera(default_R, default_T)
66
+ for _id, val in results.items():
67
+ # render onto the image
68
+ frame_i2 = np.where(val['frame_ids'] == frame_i)[0]
69
+ if len(frame_i2) == 0: continue
70
+ frame_i2 = frame_i2[0]
71
+ img = renderer.render_mesh(torch.from_numpy(val['verts'][frame_i2]).to(cfg.DEVICE), img)
72
+
73
+ if vis_global:
74
+ # render the global coordinate
75
+ if frame_i in results[sid]['frame_ids']:
76
+ frame_i3 = np.where(results[sid]['frame_ids'] == frame_i)[0]
77
+ verts = verts_glob[[frame_i3]].to(cfg.DEVICE)
78
+ faces = renderer.faces.clone().squeeze(0)
79
+ colors = torch.ones((1, 4)).float().to(cfg.DEVICE); colors[..., :3] *= 0.9
80
+
81
+ if _global_R is None:
82
+ _global_R = global_R[frame_i3].clone(); _global_T = global_T[frame_i3].clone()
83
+ cameras = renderer.create_camera(global_R[frame_i3], global_T[frame_i3])
84
+ img_glob = renderer.render_with_ground(verts, faces, colors, cameras, global_lights)
85
+
86
+ try: img = np.concatenate((img, img_glob), axis=1)
87
+ except: img = np.concatenate((img, np.ones_like(img) * 255), axis=1)
88
+
89
+ writer.append_data(img)
90
+ bar.next()
91
+ frame_i += 1
92
+ writer.close()
lib/vis/tools.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def read_image(path, scale=1):
9
+ im = Image.open(path)
10
+ if scale == 1:
11
+ return np.array(im)
12
+ W, H = im.size
13
+ w, h = int(scale * W), int(scale * H)
14
+ return np.array(im.resize((w, h), Image.ANTIALIAS))
15
+
16
+
17
+ def transform_torch3d(T_c2w):
18
+ """
19
+ :param T_c2w (*, 4, 4)
20
+ returns (*, 3, 3), (*, 3)
21
+ """
22
+ R1 = torch.tensor(
23
+ [[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0],], device=T_c2w.device,
24
+ )
25
+ R2 = torch.tensor(
26
+ [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0],], device=T_c2w.device,
27
+ )
28
+ cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3]
29
+ cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1)
30
+ cam_t = torch.einsum("ij,...j->...i", R2, cam_t)
31
+ return cam_R, cam_t
32
+
33
+
34
+ def transform_pyrender(T_c2w):
35
+ """
36
+ :param T_c2w (*, 4, 4)
37
+ """
38
+ T_vis = torch.tensor(
39
+ [
40
+ [1.0, 0.0, 0.0, 0.0],
41
+ [0.0, -1.0, 0.0, 0.0],
42
+ [0.0, 0.0, -1.0, 0.0],
43
+ [0.0, 0.0, 0.0, 1.0],
44
+ ],
45
+ device=T_c2w.device,
46
+ )
47
+ return torch.einsum(
48
+ "...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis
49
+ )
50
+
51
+
52
+ def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None):
53
+ """
54
+ :param verts (B, T, V, 3)
55
+ :param faces (F, 3)
56
+ :param vis_mask (optional) (B, T) visibility of each person
57
+ :param track_ids (optional) (B,)
58
+ returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3)
59
+ where B is different depending on the visibility of the people
60
+ """
61
+ B, T = verts.shape[:2]
62
+ device = verts.device
63
+
64
+ # (B, 3)
65
+ colors = (
66
+ track_to_colors(track_ids)
67
+ if track_ids is not None
68
+ else torch.ones(B, 3, device) * 0.5
69
+ )
70
+
71
+ # list T (B, V, 3), T (B, 3), T (F, 3)
72
+ return filter_visible_meshes(verts, colors, faces, vis_mask)
73
+
74
+
75
+ def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False):
76
+ """
77
+ :param verts (B, T, V, 3)
78
+ :param colors (B, 3)
79
+ :param faces (F, 3)
80
+ :param vis_mask (optional tensor, default None) (B, T) ternary mask
81
+ -1 if not in frame
82
+ 0 if temporarily occluded
83
+ 1 if visible
84
+ :param vis_opacity (optional bool, default False)
85
+ if True, make occluded people alpha=0.5, otherwise alpha=1
86
+ returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3)
87
+ """
88
+ # import ipdb; ipdb.set_trace()
89
+ B, T = verts.shape[:2]
90
+ faces = [faces for t in range(T)]
91
+ if vis_mask is None:
92
+ verts = [verts[:, t] for t in range(T)]
93
+ colors = [colors for t in range(T)]
94
+ return verts, colors, faces
95
+
96
+ # render occluded and visible, but not removed
97
+ vis_mask = vis_mask >= 0
98
+ if vis_opacity:
99
+ alpha = 0.5 * (vis_mask[..., None] + 1)
100
+ else:
101
+ alpha = (vis_mask[..., None] >= 0).float()
102
+ vert_list = [verts[vis_mask[:, t], t] for t in range(T)]
103
+ colors = [
104
+ torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1)
105
+ for t in range(T)
106
+ ]
107
+ bounds = get_bboxes(verts, vis_mask)
108
+ return vert_list, colors, faces, bounds
109
+
110
+
111
+ def get_bboxes(verts, vis_mask):
112
+ """
113
+ return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory
114
+ :param verts (B, T, V, 3)
115
+ :param vis_mask (B, T)
116
+ """
117
+ B, T, *_ = verts.shape
118
+ bb_min, bb_max, mean = [], [], []
119
+ for b in range(B):
120
+ v = verts[b, vis_mask[b, :T]] # (Tb, V, 3)
121
+ bb_min.append(v.amin(dim=(0, 1)))
122
+ bb_max.append(v.amax(dim=(0, 1)))
123
+ mean.append(v.mean(dim=(0, 1)))
124
+ bb_min = torch.stack(bb_min, dim=0)
125
+ bb_max = torch.stack(bb_max, dim=0)
126
+ mean = torch.stack(mean, dim=0)
127
+ # point to a track that's long and close to the camera
128
+ zs = mean[:, 2]
129
+ counts = vis_mask[:, :T].sum(dim=-1) # (B,)
130
+ mask = counts < 0.8 * T
131
+ zs[mask] = torch.inf
132
+ sel = torch.argmin(zs)
133
+ return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel]
134
+
135
+
136
+ def track_to_colors(track_ids):
137
+ """
138
+ :param track_ids (B)
139
+ """
140
+ color_map = torch.from_numpy(get_colors()).to(track_ids)
141
+ return color_map[track_ids] / 255 # (B, 3)
142
+
143
+
144
+ def get_colors():
145
+ # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt"))
146
+ color_file = os.path.abspath(os.path.join(__file__, "../colors.txt"))
147
+ RGB_tuples = np.vstack(
148
+ [
149
+ np.loadtxt(color_file, skiprows=0),
150
+ # np.loadtxt(color_file, skiprows=1),
151
+ np.random.uniform(0, 255, size=(10000, 3)),
152
+ [[0, 0, 0]],
153
+ ]
154
+ )
155
+ b = np.where(RGB_tuples == 0)
156
+ RGB_tuples[b] = 1
157
+ return RGB_tuples.astype(np.float32)
158
+
159
+
160
+ def checkerboard_geometry(
161
+ length=12.0,
162
+ color0=[0.8, 0.9, 0.9],
163
+ color1=[0.6, 0.7, 0.7],
164
+ tile_width=0.5,
165
+ alpha=1.0,
166
+ up="y",
167
+ c1=0.0,
168
+ c2=0.0,
169
+ ):
170
+ assert up == "y" or up == "z"
171
+ color0 = np.array(color0 + [alpha])
172
+ color1 = np.array(color1 + [alpha])
173
+ radius = length / 2.0
174
+ num_rows = num_cols = max(2, int(length / tile_width))
175
+ vertices = []
176
+ vert_colors = []
177
+ faces = []
178
+ face_colors = []
179
+ for i in range(num_rows):
180
+ for j in range(num_cols):
181
+ u0, v0 = j * tile_width - radius, i * tile_width - radius
182
+ us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
183
+ vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
184
+ zs = np.zeros(4)
185
+ if up == "y":
186
+ cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
187
+ cur_verts[:, 0] += c1
188
+ cur_verts[:, 2] += c2
189
+ else:
190
+ cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
191
+ cur_verts[:, 0] += c1
192
+ cur_verts[:, 1] += c2
193
+
194
+ cur_faces = np.array(
195
+ [[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64
196
+ )
197
+ cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
198
+ use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
199
+ cur_color = color0 if use_color0 else color1
200
+ cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
201
+
202
+ vertices.append(cur_verts)
203
+ faces.append(cur_faces)
204
+ vert_colors.append(cur_colors)
205
+ face_colors.append(cur_colors)
206
+
207
+ vertices = np.concatenate(vertices, axis=0).astype(np.float32)
208
+ vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
209
+ faces = np.concatenate(faces, axis=0).astype(np.float32)
210
+ face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
211
+
212
+ return vertices, faces, vert_colors, face_colors
213
+
214
+
215
+ def camera_marker_geometry(radius, height, up):
216
+ assert up == "y" or up == "z"
217
+ if up == "y":
218
+ vertices = np.array(
219
+ [
220
+ [-radius, -radius, 0],
221
+ [radius, -radius, 0],
222
+ [radius, radius, 0],
223
+ [-radius, radius, 0],
224
+ [0, 0, height],
225
+ ]
226
+ )
227
+ else:
228
+ vertices = np.array(
229
+ [
230
+ [-radius, 0, -radius],
231
+ [radius, 0, -radius],
232
+ [radius, 0, radius],
233
+ [-radius, 0, radius],
234
+ [0, -height, 0],
235
+ ]
236
+ )
237
+
238
+ faces = np.array(
239
+ [[0, 3, 1], [1, 3, 2], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4],]
240
+ )
241
+
242
+ face_colors = np.array(
243
+ [
244
+ [1.0, 1.0, 1.0, 1.0],
245
+ [1.0, 1.0, 1.0, 1.0],
246
+ [0.0, 1.0, 0.0, 1.0],
247
+ [1.0, 0.0, 0.0, 1.0],
248
+ [0.0, 1.0, 0.0, 1.0],
249
+ [1.0, 0.0, 0.0, 1.0],
250
+ ]
251
+ )
252
+ return vertices, faces, face_colors
253
+
254
+
255
+ def vis_keypoints(
256
+ keypts_list,
257
+ img_size,
258
+ radius=6,
259
+ thickness=3,
260
+ kpt_score_thr=0.3,
261
+ dataset="TopDownCocoDataset",
262
+ ):
263
+ """
264
+ Visualize keypoints
265
+ From ViTPose/mmpose/apis/inference.py
266
+ """
267
+ palette = np.array(
268
+ [
269
+ [255, 128, 0],
270
+ [255, 153, 51],
271
+ [255, 178, 102],
272
+ [230, 230, 0],
273
+ [255, 153, 255],
274
+ [153, 204, 255],
275
+ [255, 102, 255],
276
+ [255, 51, 255],
277
+ [102, 178, 255],
278
+ [51, 153, 255],
279
+ [255, 153, 153],
280
+ [255, 102, 102],
281
+ [255, 51, 51],
282
+ [153, 255, 153],
283
+ [102, 255, 102],
284
+ [51, 255, 51],
285
+ [0, 255, 0],
286
+ [0, 0, 255],
287
+ [255, 0, 0],
288
+ [255, 255, 255],
289
+ ]
290
+ )
291
+
292
+ if dataset in (
293
+ "TopDownCocoDataset",
294
+ "BottomUpCocoDataset",
295
+ "TopDownOCHumanDataset",
296
+ "AnimalMacaqueDataset",
297
+ ):
298
+ # show the results
299
+ skeleton = [
300
+ [15, 13],
301
+ [13, 11],
302
+ [16, 14],
303
+ [14, 12],
304
+ [11, 12],
305
+ [5, 11],
306
+ [6, 12],
307
+ [5, 6],
308
+ [5, 7],
309
+ [6, 8],
310
+ [7, 9],
311
+ [8, 10],
312
+ [1, 2],
313
+ [0, 1],
314
+ [0, 2],
315
+ [1, 3],
316
+ [2, 4],
317
+ [3, 5],
318
+ [4, 6],
319
+ ]
320
+
321
+ pose_link_color = palette[
322
+ [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
323
+ ]
324
+ pose_kpt_color = palette[
325
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
326
+ ]
327
+
328
+ elif dataset == "TopDownCocoWholeBodyDataset":
329
+ # show the results
330
+ skeleton = [
331
+ [15, 13],
332
+ [13, 11],
333
+ [16, 14],
334
+ [14, 12],
335
+ [11, 12],
336
+ [5, 11],
337
+ [6, 12],
338
+ [5, 6],
339
+ [5, 7],
340
+ [6, 8],
341
+ [7, 9],
342
+ [8, 10],
343
+ [1, 2],
344
+ [0, 1],
345
+ [0, 2],
346
+ [1, 3],
347
+ [2, 4],
348
+ [3, 5],
349
+ [4, 6],
350
+ [15, 17],
351
+ [15, 18],
352
+ [15, 19],
353
+ [16, 20],
354
+ [16, 21],
355
+ [16, 22],
356
+ [91, 92],
357
+ [92, 93],
358
+ [93, 94],
359
+ [94, 95],
360
+ [91, 96],
361
+ [96, 97],
362
+ [97, 98],
363
+ [98, 99],
364
+ [91, 100],
365
+ [100, 101],
366
+ [101, 102],
367
+ [102, 103],
368
+ [91, 104],
369
+ [104, 105],
370
+ [105, 106],
371
+ [106, 107],
372
+ [91, 108],
373
+ [108, 109],
374
+ [109, 110],
375
+ [110, 111],
376
+ [112, 113],
377
+ [113, 114],
378
+ [114, 115],
379
+ [115, 116],
380
+ [112, 117],
381
+ [117, 118],
382
+ [118, 119],
383
+ [119, 120],
384
+ [112, 121],
385
+ [121, 122],
386
+ [122, 123],
387
+ [123, 124],
388
+ [112, 125],
389
+ [125, 126],
390
+ [126, 127],
391
+ [127, 128],
392
+ [112, 129],
393
+ [129, 130],
394
+ [130, 131],
395
+ [131, 132],
396
+ ]
397
+
398
+ pose_link_color = palette[
399
+ [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
400
+ + [16, 16, 16, 16, 16, 16]
401
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
402
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
403
+ ]
404
+ pose_kpt_color = palette[
405
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
406
+ + [0, 0, 0, 0, 0, 0]
407
+ + [19] * (68 + 42)
408
+ ]
409
+
410
+ elif dataset == "TopDownAicDataset":
411
+ skeleton = [
412
+ [2, 1],
413
+ [1, 0],
414
+ [0, 13],
415
+ [13, 3],
416
+ [3, 4],
417
+ [4, 5],
418
+ [8, 7],
419
+ [7, 6],
420
+ [6, 9],
421
+ [9, 10],
422
+ [10, 11],
423
+ [12, 13],
424
+ [0, 6],
425
+ [3, 9],
426
+ ]
427
+
428
+ pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]]
429
+ pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]]
430
+
431
+ elif dataset == "TopDownMpiiDataset":
432
+ skeleton = [
433
+ [0, 1],
434
+ [1, 2],
435
+ [2, 6],
436
+ [6, 3],
437
+ [3, 4],
438
+ [4, 5],
439
+ [6, 7],
440
+ [7, 8],
441
+ [8, 9],
442
+ [8, 12],
443
+ [12, 11],
444
+ [11, 10],
445
+ [8, 13],
446
+ [13, 14],
447
+ [14, 15],
448
+ ]
449
+
450
+ pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]]
451
+ pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]]
452
+
453
+ elif dataset == "TopDownMpiiTrbDataset":
454
+ skeleton = [
455
+ [12, 13],
456
+ [13, 0],
457
+ [13, 1],
458
+ [0, 2],
459
+ [1, 3],
460
+ [2, 4],
461
+ [3, 5],
462
+ [0, 6],
463
+ [1, 7],
464
+ [6, 7],
465
+ [6, 8],
466
+ [7, 9],
467
+ [8, 10],
468
+ [9, 11],
469
+ [14, 15],
470
+ [16, 17],
471
+ [18, 19],
472
+ [20, 21],
473
+ [22, 23],
474
+ [24, 25],
475
+ [26, 27],
476
+ [28, 29],
477
+ [30, 31],
478
+ [32, 33],
479
+ [34, 35],
480
+ [36, 37],
481
+ [38, 39],
482
+ ]
483
+
484
+ pose_link_color = palette[[16] * 14 + [19] * 13]
485
+ pose_kpt_color = palette[[16] * 14 + [0] * 26]
486
+
487
+ elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
488
+ skeleton = [
489
+ [0, 1],
490
+ [1, 2],
491
+ [2, 3],
492
+ [3, 4],
493
+ [0, 5],
494
+ [5, 6],
495
+ [6, 7],
496
+ [7, 8],
497
+ [0, 9],
498
+ [9, 10],
499
+ [10, 11],
500
+ [11, 12],
501
+ [0, 13],
502
+ [13, 14],
503
+ [14, 15],
504
+ [15, 16],
505
+ [0, 17],
506
+ [17, 18],
507
+ [18, 19],
508
+ [19, 20],
509
+ ]
510
+
511
+ pose_link_color = palette[
512
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
513
+ ]
514
+ pose_kpt_color = palette[
515
+ [0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
516
+ ]
517
+
518
+ elif dataset == "InterHand2DDataset":
519
+ skeleton = [
520
+ [0, 1],
521
+ [1, 2],
522
+ [2, 3],
523
+ [4, 5],
524
+ [5, 6],
525
+ [6, 7],
526
+ [8, 9],
527
+ [9, 10],
528
+ [10, 11],
529
+ [12, 13],
530
+ [13, 14],
531
+ [14, 15],
532
+ [16, 17],
533
+ [17, 18],
534
+ [18, 19],
535
+ [3, 20],
536
+ [7, 20],
537
+ [11, 20],
538
+ [15, 20],
539
+ [19, 20],
540
+ ]
541
+
542
+ pose_link_color = palette[
543
+ [0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]
544
+ ]
545
+ pose_kpt_color = palette[
546
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]
547
+ ]
548
+
549
+ elif dataset == "Face300WDataset":
550
+ # show the results
551
+ skeleton = []
552
+
553
+ pose_link_color = palette[[]]
554
+ pose_kpt_color = palette[[19] * 68]
555
+ kpt_score_thr = 0
556
+
557
+ elif dataset == "FaceAFLWDataset":
558
+ # show the results
559
+ skeleton = []
560
+
561
+ pose_link_color = palette[[]]
562
+ pose_kpt_color = palette[[19] * 19]
563
+ kpt_score_thr = 0
564
+
565
+ elif dataset == "FaceCOFWDataset":
566
+ # show the results
567
+ skeleton = []
568
+
569
+ pose_link_color = palette[[]]
570
+ pose_kpt_color = palette[[19] * 29]
571
+ kpt_score_thr = 0
572
+
573
+ elif dataset == "FaceWFLWDataset":
574
+ # show the results
575
+ skeleton = []
576
+
577
+ pose_link_color = palette[[]]
578
+ pose_kpt_color = palette[[19] * 98]
579
+ kpt_score_thr = 0
580
+
581
+ elif dataset == "AnimalHorse10Dataset":
582
+ skeleton = [
583
+ [0, 1],
584
+ [1, 12],
585
+ [12, 16],
586
+ [16, 21],
587
+ [21, 17],
588
+ [17, 11],
589
+ [11, 10],
590
+ [10, 8],
591
+ [8, 9],
592
+ [9, 12],
593
+ [2, 3],
594
+ [3, 4],
595
+ [5, 6],
596
+ [6, 7],
597
+ [13, 14],
598
+ [14, 15],
599
+ [18, 19],
600
+ [19, 20],
601
+ ]
602
+
603
+ pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2]
604
+ pose_kpt_color = palette[
605
+ [4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]
606
+ ]
607
+
608
+ elif dataset == "AnimalFlyDataset":
609
+ skeleton = [
610
+ [1, 0],
611
+ [2, 0],
612
+ [3, 0],
613
+ [4, 3],
614
+ [5, 4],
615
+ [7, 6],
616
+ [8, 7],
617
+ [9, 8],
618
+ [11, 10],
619
+ [12, 11],
620
+ [13, 12],
621
+ [15, 14],
622
+ [16, 15],
623
+ [17, 16],
624
+ [19, 18],
625
+ [20, 19],
626
+ [21, 20],
627
+ [23, 22],
628
+ [24, 23],
629
+ [25, 24],
630
+ [27, 26],
631
+ [28, 27],
632
+ [29, 28],
633
+ [30, 3],
634
+ [31, 3],
635
+ ]
636
+
637
+ pose_link_color = palette[[0] * 25]
638
+ pose_kpt_color = palette[[0] * 32]
639
+
640
+ elif dataset == "AnimalLocustDataset":
641
+ skeleton = [
642
+ [1, 0],
643
+ [2, 1],
644
+ [3, 2],
645
+ [4, 3],
646
+ [6, 5],
647
+ [7, 6],
648
+ [9, 8],
649
+ [10, 9],
650
+ [11, 10],
651
+ [13, 12],
652
+ [14, 13],
653
+ [15, 14],
654
+ [17, 16],
655
+ [18, 17],
656
+ [19, 18],
657
+ [21, 20],
658
+ [22, 21],
659
+ [24, 23],
660
+ [25, 24],
661
+ [26, 25],
662
+ [28, 27],
663
+ [29, 28],
664
+ [30, 29],
665
+ [32, 31],
666
+ [33, 32],
667
+ [34, 33],
668
+ ]
669
+
670
+ pose_link_color = palette[[0] * 26]
671
+ pose_kpt_color = palette[[0] * 35]
672
+
673
+ elif dataset == "AnimalZebraDataset":
674
+ skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]]
675
+
676
+ pose_link_color = palette[[0] * 8]
677
+ pose_kpt_color = palette[[0] * 9]
678
+
679
+ elif dataset in "AnimalPoseDataset":
680
+ skeleton = [
681
+ [0, 1],
682
+ [0, 2],
683
+ [1, 3],
684
+ [0, 4],
685
+ [1, 4],
686
+ [4, 5],
687
+ [5, 7],
688
+ [6, 7],
689
+ [5, 8],
690
+ [8, 12],
691
+ [12, 16],
692
+ [5, 9],
693
+ [9, 13],
694
+ [13, 17],
695
+ [6, 10],
696
+ [10, 14],
697
+ [14, 18],
698
+ [6, 11],
699
+ [11, 15],
700
+ [15, 19],
701
+ ]
702
+
703
+ pose_link_color = palette[[0] * 20]
704
+ pose_kpt_color = palette[[0] * 20]
705
+ else:
706
+ NotImplementedError()
707
+
708
+ img_w, img_h = img_size
709
+ img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8)
710
+ img = imshow_keypoints(
711
+ img,
712
+ keypts_list,
713
+ skeleton,
714
+ kpt_score_thr,
715
+ pose_kpt_color,
716
+ pose_link_color,
717
+ radius,
718
+ thickness,
719
+ )
720
+ alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8)
721
+ return np.concatenate([img, alpha], axis=-1)
722
+
723
+
724
+ def imshow_keypoints(
725
+ img,
726
+ pose_result,
727
+ skeleton=None,
728
+ kpt_score_thr=0.3,
729
+ pose_kpt_color=None,
730
+ pose_link_color=None,
731
+ radius=4,
732
+ thickness=1,
733
+ show_keypoint_weight=False,
734
+ ):
735
+ """Draw keypoints and links on an image.
736
+ From ViTPose/mmpose/core/visualization/image.py
737
+
738
+ Args:
739
+ img (H, W, 3) array
740
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
741
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
742
+ keypoint is represented as x, y, score.
743
+ kpt_score_thr (float, optional): Minimum score of keypoints
744
+ to be shown. Default: 0.3.
745
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
746
+ the keypoint will not be drawn.
747
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
748
+ links will not be drawn.
749
+ thickness (int): Thickness of lines.
750
+ show_keypoint_weight (bool): If True, opacity indicates keypoint score
751
+ """
752
+ img_h, img_w, _ = img.shape
753
+ idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
754
+ for kpts in pose_result:
755
+ kpts = np.array(kpts, copy=False)[idcs]
756
+
757
+ # draw each point on image
758
+ if pose_kpt_color is not None:
759
+ assert len(pose_kpt_color) == len(kpts)
760
+ for kid, kpt in enumerate(kpts):
761
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
762
+ if kpt_score > kpt_score_thr:
763
+ color = tuple(int(c) for c in pose_kpt_color[kid])
764
+ if show_keypoint_weight:
765
+ img_copy = img.copy()
766
+ cv2.circle(
767
+ img_copy, (int(x_coord), int(y_coord)), radius, color, -1
768
+ )
769
+ transparency = max(0, min(1, kpt_score))
770
+ cv2.addWeighted(
771
+ img_copy, transparency, img, 1 - transparency, 0, dst=img
772
+ )
773
+ else:
774
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
775
+
776
+ # draw links
777
+ if skeleton is not None and pose_link_color is not None:
778
+ assert len(pose_link_color) == len(skeleton)
779
+ for sk_id, sk in enumerate(skeleton):
780
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
781
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
782
+ if (
783
+ pos1[0] > 0
784
+ and pos1[0] < img_w
785
+ and pos1[1] > 0
786
+ and pos1[1] < img_h
787
+ and pos2[0] > 0
788
+ and pos2[0] < img_w
789
+ and pos2[1] > 0
790
+ and pos2[1] < img_h
791
+ and kpts[sk[0], 2] > kpt_score_thr
792
+ and kpts[sk[1], 2] > kpt_score_thr
793
+ ):
794
+ color = tuple(int(c) for c in pose_link_color[sk_id])
795
+ if show_keypoint_weight:
796
+ img_copy = img.copy()
797
+ X = (pos1[0], pos2[0])
798
+ Y = (pos1[1], pos2[1])
799
+ mX = np.mean(X)
800
+ mY = np.mean(Y)
801
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
802
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
803
+ stickwidth = 2
804
+ polygon = cv2.ellipse2Poly(
805
+ (int(mX), int(mY)),
806
+ (int(length / 2), int(stickwidth)),
807
+ int(angle),
808
+ 0,
809
+ 360,
810
+ 1,
811
+ )
812
+ cv2.fillConvexPoly(img_copy, polygon, color)
813
+ transparency = max(
814
+ 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))
815
+ )
816
+ cv2.addWeighted(
817
+ img_copy, transparency, img, 1 - transparency, 0, dst=img
818
+ )
819
+ else:
820
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
821
+
822
+ return img
output/demo/test19/output.mp4 ADDED
Binary file (602 kB). View file
 
output/demo/test19/slam_results.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb6e0b47809fe94bdc26bc99318f9eb9beccb005ec81c407887a5bd7223b5b81
3
+ size 2353
output/demo/test19/tracking_results.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d1b3d6e23597e07daaa1b124cba63a1ea91d3909fe2c903c3c9b2b2819ce140
3
+ size 333898
output/demo/test19/wham_output.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1484cca6cf2774c3c0cbefa3d47ed7f2dd04db96b05ebdd14e5a11610a415b3e
3
+ size 3167067