Spaces:
Sleeping
Sleeping
ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe
Browse files- LICENSE +21 -0
- README.md +120 -11
- lib/models/layers/__pycache__/utils.cpython-39.pyc +0 -0
- lib/models/layers/modules.py +262 -0
- lib/models/layers/utils.py +52 -0
- lib/models/preproc/__pycache__/detector.cpython-39.pyc +0 -0
- lib/models/preproc/__pycache__/extractor.cpython-39.pyc +0 -0
- lib/models/preproc/__pycache__/slam.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/hmr2.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/pose_transformer.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/smpl_head.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/t_cond_mlp.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/utils.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/__pycache__/vit.cpython-39.pyc +0 -0
- lib/models/preproc/backbone/hmr2.py +77 -0
- lib/models/preproc/backbone/pose_transformer.py +357 -0
- lib/models/preproc/backbone/smpl_head.py +128 -0
- lib/models/preproc/backbone/t_cond_mlp.py +198 -0
- lib/models/preproc/backbone/utils.py +115 -0
- lib/models/preproc/backbone/vit.py +348 -0
- lib/models/preproc/detector.py +146 -0
- lib/models/preproc/extractor.py +112 -0
- lib/models/preproc/slam.py +70 -0
- lib/models/smpl.py +264 -0
- lib/models/smplify/__init__.py +1 -0
- lib/models/smplify/__pycache__/__init__.cpython-39.pyc +0 -0
- lib/models/smplify/__pycache__/losses.cpython-39.pyc +0 -0
- lib/models/smplify/__pycache__/smplify.cpython-39.pyc +0 -0
- lib/models/smplify/losses.py +87 -0
- lib/models/smplify/smplify.py +83 -0
- lib/models/wham.py +210 -0
- lib/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
- lib/utils/__pycache__/imutils.cpython-39.pyc +0 -0
- lib/utils/__pycache__/kp_utils.cpython-39.pyc +0 -0
- lib/utils/__pycache__/transforms.cpython-39.pyc +0 -0
- lib/utils/data_utils.py +113 -0
- lib/utils/imutils.py +363 -0
- lib/utils/kp_utils.py +761 -0
- lib/utils/transforms.py +828 -0
- lib/utils/utils.py +265 -0
- lib/vis/__pycache__/renderer.cpython-39.pyc +0 -0
- lib/vis/__pycache__/run_vis.cpython-39.pyc +0 -0
- lib/vis/__pycache__/tools.cpython-39.pyc +0 -0
- lib/vis/renderer.py +313 -0
- lib/vis/run_vis.py +92 -0
- lib/vis/tools.py +822 -0
- output/demo/test19/output.mp4 +0 -0
- output/demo/test19/slam_results.pth +3 -0
- output/demo/test19/tracking_results.pth +3 -0
- output/demo/test19/wham_output.pkl +3 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Soyong Shin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,120 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion
|
2 |
+
|
3 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> [](https://arxiv.org/abs/2312.07531) <a href="https://wham.is.tue.mpg.de/"><img alt="Project" src="https://img.shields.io/badge/-Project%20Page-lightgrey?logo=Google%20Chrome&color=informational&logoColor=white"></a> [](https://colab.research.google.com/drive/1ysUtGSwidTQIdBQRhq0hj63KbseFujkn?usp=sharing)
|
4 |
+
[](https://paperswithcode.com/sota/3d-human-pose-estimation-on-3dpw?p=wham-reconstructing-world-grounded-humans) [](https://paperswithcode.com/sota/3d-human-pose-estimation-on-emdb?p=wham-reconstructing-world-grounded-humans)
|
5 |
+
|
6 |
+
|
7 |
+
https://github.com/yohanshin/WHAM/assets/46889727/da4602b4-0597-4e64-8da4-ab06931b23ee
|
8 |
+
|
9 |
+
|
10 |
+
## Introduction
|
11 |
+
This repository is the official [Pytorch](https://pytorch.org/) implementation of [WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion](https://arxiv.org/abs/2312.07531). For more information, please visit our [project page](https://wham.is.tue.mpg.de/).
|
12 |
+
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
Please see [Installation](docs/INSTALL.md) for details.
|
16 |
+
|
17 |
+
|
18 |
+
## Quick Demo
|
19 |
+
|
20 |
+
### [<img src="https://i.imgur.com/QCojoJk.png" width="30"> Google Colab for WHAM demo is now available](https://colab.research.google.com/drive/1ysUtGSwidTQIdBQRhq0hj63KbseFujkn?usp=sharing)
|
21 |
+
|
22 |
+
### Registration
|
23 |
+
|
24 |
+
To download SMPL body models (Neutral, Female, and Male), you need to register for [SMPL](https://smpl.is.tue.mpg.de/) and [SMPLify](https://smplify.is.tue.mpg.de/). The username and password for both homepages will be used while fetching the demo data.
|
25 |
+
|
26 |
+
Next, run the following script to fetch demo data. This script will download all the required dependencies including trained models and demo videos.
|
27 |
+
|
28 |
+
```bash
|
29 |
+
bash fetch_demo_data.sh
|
30 |
+
```
|
31 |
+
|
32 |
+
You can try with one examplar video:
|
33 |
+
```
|
34 |
+
python demo.py --video examples/IMG_9732.mov --visualize
|
35 |
+
```
|
36 |
+
|
37 |
+
We assume camera focal length following [CLIFF](https://github.com/haofanwang/CLIFF). You can specify known camera intrinsics [fx fy cx cy] for SLAM as the demo example below:
|
38 |
+
```
|
39 |
+
python demo.py --video examples/drone_video.mp4 --calib examples/drone_calib.txt --visualize
|
40 |
+
```
|
41 |
+
|
42 |
+
You can skip SLAM if you only want to get camera-coordinate motion. You can run as:
|
43 |
+
```
|
44 |
+
python demo.py --video examples/IMG_9732.mov --visualize --estimate_local_only
|
45 |
+
```
|
46 |
+
|
47 |
+
You can further refine the results of WHAM using Temporal SMPLify as a post processing. This will allow better 2D alignment as well as 3D accuracy. All you need to do is add `--run_smplify` flag when running demo.
|
48 |
+
|
49 |
+
## Docker
|
50 |
+
|
51 |
+
Please refer to [Docker](docs/DOCKER.md) for details.
|
52 |
+
|
53 |
+
## Python API
|
54 |
+
|
55 |
+
Please refer to [API](docs/API.md) for details.
|
56 |
+
|
57 |
+
## Dataset
|
58 |
+
Please see [Dataset](docs/DATASET.md) for details.
|
59 |
+
|
60 |
+
## Evaluation
|
61 |
+
```bash
|
62 |
+
# Evaluate on 3DPW dataset
|
63 |
+
python -m lib.eval.evaluate_3dpw --cfg configs/yamls/demo.yaml TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar
|
64 |
+
|
65 |
+
# Evaluate on RICH dataset
|
66 |
+
python -m lib.eval.evaluate_rich --cfg configs/yamls/demo.yaml TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar
|
67 |
+
|
68 |
+
# Evaluate on EMDB dataset (also computes W-MPJPE and WA-MPJPE)
|
69 |
+
python -m lib.eval.evaluate_emdb --cfg configs/yamls/demo.yaml --eval-split 1 TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar # EMDB 1
|
70 |
+
|
71 |
+
python -m lib.eval.evaluate_emdb --cfg configs/yamls/demo.yaml --eval-split 2 TRAIN.CHECKPOINT checkpoints/wham_vit_w_3dpw.pth.tar # EMDB 2
|
72 |
+
```
|
73 |
+
|
74 |
+
## Training
|
75 |
+
WHAM training involves into two different stages; (1) 2D to SMPL lifting through AMASS dataset and (2) finetuning with feature integration using the video datasets. Please see [Dataset](docs/DATASET.md) for preprocessing the training datasets.
|
76 |
+
|
77 |
+
### Stage 1.
|
78 |
+
```bash
|
79 |
+
python train.py --cfg configs/yamls/stage1.yaml
|
80 |
+
```
|
81 |
+
|
82 |
+
### Stage 2.
|
83 |
+
Training stage 2 requires pretrained results from the stage 1. You can use your pretrained results, or download the weight from [Google Drive](https://drive.google.com/file/d/1Erjkho7O0bnZFawarntICRUCroaKabRE/view?usp=sharing) save as `checkpoints/wham_stage1.tar.pth`.
|
84 |
+
```bash
|
85 |
+
python train.py --cfg configs/yamls/stage2.yaml TRAIN.CHECKPOINT <PATH-TO-STAGE1-RESULTS>
|
86 |
+
```
|
87 |
+
|
88 |
+
### Train with BEDLAM
|
89 |
+
TBD
|
90 |
+
|
91 |
+
## Acknowledgement
|
92 |
+
We would like to sincerely appreciate Hongwei Yi and Silvia Zuffi for the discussion and proofreading. Part of this work was done when Soyong Shin was an intern at the Max Planck Institute for Intelligence System.
|
93 |
+
|
94 |
+
The base implementation is largely borrowed from [VIBE](https://github.com/mkocabas/VIBE) and [TCMR](https://github.com/hongsukchoi/TCMR_RELEASE). We use [ViTPose](https://github.com/ViTAE-Transformer/ViTPose) for 2D keypoints detection and [DPVO](https://github.com/princeton-vl/DPVO), [DROID-SLAM](https://github.com/princeton-vl/DROID-SLAM) for extracting camera motion. Please visit their official websites for more details.
|
95 |
+
|
96 |
+
## TODO
|
97 |
+
|
98 |
+
- [ ] Data preprocessing
|
99 |
+
|
100 |
+
- [x] Training implementation
|
101 |
+
|
102 |
+
- [x] Colab demo release
|
103 |
+
|
104 |
+
- [x] Demo for custom videos
|
105 |
+
|
106 |
+
## Citation
|
107 |
+
```
|
108 |
+
@InProceedings{shin2023wham,
|
109 |
+
title={WHAM: Reconstructing World-grounded Humans with Accurate 3D Motion},
|
110 |
+
author={Shin, Soyong and Kim, Juyong and Halilaj, Eni and Black, Michael J.},
|
111 |
+
booktitle={Computer Vision and Pattern Recognition (CVPR)},
|
112 |
+
year={2024}
|
113 |
+
}
|
114 |
+
```
|
115 |
+
|
116 |
+
## License
|
117 |
+
Please see [License](./LICENSE) for details.
|
118 |
+
|
119 |
+
## Contact
|
120 |
+
Please contact [email protected] for any questions related to this work.
|
lib/models/layers/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.08 kB). View file
|
|
lib/models/layers/modules.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from torch import nn
|
8 |
+
from configs import constants as _C
|
9 |
+
from .utils import rollout_global_motion
|
10 |
+
from lib.utils.transforms import axis_angle_to_matrix
|
11 |
+
|
12 |
+
|
13 |
+
class Regressor(nn.Module):
|
14 |
+
def __init__(self, in_dim, hid_dim, out_dims, init_dim, layer='LSTM', n_layers=2, n_iters=1):
|
15 |
+
super().__init__()
|
16 |
+
self.n_outs = len(out_dims)
|
17 |
+
|
18 |
+
self.rnn = getattr(nn, layer.upper())(
|
19 |
+
in_dim + init_dim, hid_dim, n_layers,
|
20 |
+
bidirectional=False, batch_first=True, dropout=0.3)
|
21 |
+
|
22 |
+
for i, out_dim in enumerate(out_dims):
|
23 |
+
setattr(self, 'declayer%d'%i, nn.Linear(hid_dim, out_dim))
|
24 |
+
nn.init.xavier_uniform_(getattr(self, 'declayer%d'%i).weight, gain=0.01)
|
25 |
+
|
26 |
+
def forward(self, x, inits, h0):
|
27 |
+
xc = torch.cat([x, *inits], dim=-1)
|
28 |
+
xc, h0 = self.rnn(xc, h0)
|
29 |
+
|
30 |
+
preds = []
|
31 |
+
for j in range(self.n_outs):
|
32 |
+
out = getattr(self, 'declayer%d'%j)(xc)
|
33 |
+
preds.append(out)
|
34 |
+
|
35 |
+
return preds, xc, h0
|
36 |
+
|
37 |
+
|
38 |
+
class NeuralInitialization(nn.Module):
|
39 |
+
def __init__(self, in_dim, hid_dim, layer, n_layers):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
out_dim = hid_dim
|
43 |
+
self.n_layers = n_layers
|
44 |
+
self.num_inits = int(layer.upper() == 'LSTM') + 1
|
45 |
+
out_dim *= self.num_inits * n_layers
|
46 |
+
|
47 |
+
self.linear1 = nn.Linear(in_dim, hid_dim)
|
48 |
+
self.linear2 = nn.Linear(hid_dim, hid_dim * self.n_layers)
|
49 |
+
self.linear3 = nn.Linear(hid_dim * self.n_layers, out_dim)
|
50 |
+
self.relu1 = nn.ReLU()
|
51 |
+
self.relu2 = nn.ReLU()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
b = x.shape[0]
|
55 |
+
|
56 |
+
out = self.linear3(self.relu2(self.linear2(self.relu1(self.linear1(x)))))
|
57 |
+
out = out.view(b, self.num_inits, self.n_layers, -1).permute(1, 2, 0, 3).contiguous()
|
58 |
+
|
59 |
+
if self.num_inits == 2:
|
60 |
+
return tuple([_ for _ in out])
|
61 |
+
return out[0]
|
62 |
+
|
63 |
+
|
64 |
+
class Integrator(nn.Module):
|
65 |
+
def __init__(self, in_channel, out_channel, hid_channel=1024):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.layer1 = nn.Linear(in_channel, hid_channel)
|
69 |
+
self.relu1 = nn.ReLU()
|
70 |
+
self.dr1 = nn.Dropout(0.1)
|
71 |
+
|
72 |
+
self.layer2 = nn.Linear(hid_channel, hid_channel)
|
73 |
+
self.relu2 = nn.ReLU()
|
74 |
+
self.dr2 = nn.Dropout(0.1)
|
75 |
+
|
76 |
+
self.layer3 = nn.Linear(hid_channel, out_channel)
|
77 |
+
|
78 |
+
|
79 |
+
def forward(self, x, feat):
|
80 |
+
res = x
|
81 |
+
mask = (feat != 0).all(dim=-1).all(dim=-1)
|
82 |
+
|
83 |
+
out = torch.cat((x, feat), dim=-1)
|
84 |
+
out = self.layer1(out)
|
85 |
+
out = self.relu1(out)
|
86 |
+
out = self.dr1(out)
|
87 |
+
|
88 |
+
out = self.layer2(out)
|
89 |
+
out = self.relu2(out)
|
90 |
+
out = self.dr2(out)
|
91 |
+
|
92 |
+
out = self.layer3(out)
|
93 |
+
out[mask] = out[mask] + res[mask]
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
class MotionEncoder(nn.Module):
|
99 |
+
def __init__(self,
|
100 |
+
in_dim,
|
101 |
+
d_embed,
|
102 |
+
pose_dr,
|
103 |
+
rnn_type,
|
104 |
+
n_layers,
|
105 |
+
n_joints):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.n_joints = n_joints
|
109 |
+
|
110 |
+
self.embed_layer = nn.Linear(in_dim, d_embed)
|
111 |
+
self.pos_drop = nn.Dropout(pose_dr)
|
112 |
+
|
113 |
+
# Keypoints initializer
|
114 |
+
self.neural_init = NeuralInitialization(n_joints * 3 + in_dim, d_embed, rnn_type, n_layers)
|
115 |
+
|
116 |
+
# 3d keypoints regressor
|
117 |
+
self.regressor = Regressor(
|
118 |
+
d_embed, d_embed, [n_joints * 3], n_joints * 3, rnn_type, n_layers)
|
119 |
+
|
120 |
+
def forward(self, x, init):
|
121 |
+
""" Forward pass of motion encoder.
|
122 |
+
"""
|
123 |
+
|
124 |
+
self.b, self.f = x.shape[:2]
|
125 |
+
x = self.embed_layer(x.reshape(self.b, self.f, -1))
|
126 |
+
x = self.pos_drop(x)
|
127 |
+
|
128 |
+
h0 = self.neural_init(init)
|
129 |
+
pred_list = [init[..., :self.n_joints * 3]]
|
130 |
+
motion_context_list = []
|
131 |
+
|
132 |
+
for i in range(self.f):
|
133 |
+
(pred_kp3d, ), motion_context, h0 = self.regressor(x[:, [i]], pred_list[-1:], h0)
|
134 |
+
motion_context_list.append(motion_context)
|
135 |
+
pred_list.append(pred_kp3d)
|
136 |
+
|
137 |
+
pred_kp3d = torch.cat(pred_list[1:], dim=1).view(self.b, self.f, -1, 3)
|
138 |
+
motion_context = torch.cat(motion_context_list, dim=1)
|
139 |
+
|
140 |
+
# Merge 3D keypoints with motion context
|
141 |
+
motion_context = torch.cat((motion_context, pred_kp3d.reshape(self.b, self.f, -1)), dim=-1)
|
142 |
+
return pred_kp3d, motion_context
|
143 |
+
|
144 |
+
|
145 |
+
class TrajectoryDecoder(nn.Module):
|
146 |
+
def __init__(self,
|
147 |
+
d_embed,
|
148 |
+
rnn_type,
|
149 |
+
n_layers):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
# Trajectory regressor
|
153 |
+
self.regressor = Regressor(
|
154 |
+
d_embed, d_embed, [3, 6], 12, rnn_type, n_layers, )
|
155 |
+
|
156 |
+
def forward(self, x, root, cam_a, h0=None):
|
157 |
+
""" Forward pass of trajectory decoder.
|
158 |
+
"""
|
159 |
+
|
160 |
+
b, f = x.shape[:2]
|
161 |
+
pred_root_list, pred_vel_list = [root[:, :1]], []
|
162 |
+
|
163 |
+
for i in range(f):
|
164 |
+
# Global coordinate estimation
|
165 |
+
(pred_rootv, pred_rootr), _, h0 = self.regressor(
|
166 |
+
x[:, [i]], [pred_root_list[-1], cam_a[:, [i]]], h0)
|
167 |
+
|
168 |
+
pred_root_list.append(pred_rootr)
|
169 |
+
pred_vel_list.append(pred_rootv)
|
170 |
+
|
171 |
+
pred_root = torch.cat(pred_root_list, dim=1).view(b, f + 1, -1)
|
172 |
+
pred_vel = torch.cat(pred_vel_list, dim=1).view(b, f, -1)
|
173 |
+
|
174 |
+
return pred_root, pred_vel
|
175 |
+
|
176 |
+
|
177 |
+
class MotionDecoder(nn.Module):
|
178 |
+
def __init__(self,
|
179 |
+
d_embed,
|
180 |
+
rnn_type,
|
181 |
+
n_layers):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.n_pose = 24
|
185 |
+
|
186 |
+
# SMPL pose initialization
|
187 |
+
self.neural_init = NeuralInitialization(len(_C.BMODEL.MAIN_JOINTS) * 6, d_embed, rnn_type, n_layers)
|
188 |
+
|
189 |
+
# 3d keypoints regressor
|
190 |
+
self.regressor = Regressor(
|
191 |
+
d_embed, d_embed, [self.n_pose * 6, 10, 3, 4], self.n_pose * 6, rnn_type, n_layers)
|
192 |
+
|
193 |
+
def forward(self, x, init):
|
194 |
+
""" Forward pass of motion decoder.
|
195 |
+
"""
|
196 |
+
b, f = x.shape[:2]
|
197 |
+
|
198 |
+
h0 = self.neural_init(init[:, :, _C.BMODEL.MAIN_JOINTS].reshape(b, 1, -1))
|
199 |
+
|
200 |
+
# Recursive prediction of SMPL parameters
|
201 |
+
pred_pose_list = [init.reshape(b, 1, -1)]
|
202 |
+
pred_shape_list, pred_cam_list, pred_contact_list = [], [], []
|
203 |
+
|
204 |
+
for i in range(f):
|
205 |
+
# Camera coordinate estimation
|
206 |
+
(pred_pose, pred_shape, pred_cam, pred_contact), _, h0 = self.regressor(x[:, [i]], pred_pose_list[-1:], h0)
|
207 |
+
pred_pose_list.append(pred_pose)
|
208 |
+
pred_shape_list.append(pred_shape)
|
209 |
+
pred_cam_list.append(pred_cam)
|
210 |
+
pred_contact_list.append(pred_contact)
|
211 |
+
|
212 |
+
pred_pose = torch.cat(pred_pose_list[1:], dim=1).view(b, f, -1)
|
213 |
+
pred_shape = torch.cat(pred_shape_list, dim=1).view(b, f, -1)
|
214 |
+
pred_cam = torch.cat(pred_cam_list, dim=1).view(b, f, -1)
|
215 |
+
pred_contact = torch.cat(pred_contact_list, dim=1).view(b, f, -1)
|
216 |
+
|
217 |
+
return pred_pose, pred_shape, pred_cam, pred_contact
|
218 |
+
|
219 |
+
|
220 |
+
class TrajectoryRefiner(nn.Module):
|
221 |
+
def __init__(self,
|
222 |
+
d_embed,
|
223 |
+
d_hidden,
|
224 |
+
rnn_type,
|
225 |
+
n_layers):
|
226 |
+
super().__init__()
|
227 |
+
|
228 |
+
d_input = d_embed + 12
|
229 |
+
self.refiner = Regressor(
|
230 |
+
d_input, d_hidden, [6, 3], 9, rnn_type, n_layers)
|
231 |
+
|
232 |
+
def forward(self, context, pred_vel, output, cam_angvel, return_y_up):
|
233 |
+
b, f = context.shape[:2]
|
234 |
+
|
235 |
+
# Register values
|
236 |
+
pred_root = output['poses_root_r6d'].clone().detach()
|
237 |
+
feet = output['feet'].clone().detach()
|
238 |
+
contact = output['contact'].clone().detach()
|
239 |
+
|
240 |
+
feet_vel = torch.cat((torch.zeros_like(feet[:, :1]), feet[:, 1:] - feet[:, :-1]), dim=1) * 30 # Normalize to 30 times
|
241 |
+
feet = (feet_vel * contact.unsqueeze(-1)).reshape(b, f, -1) # Velocity input
|
242 |
+
inpt_feat = torch.cat([context, feet], dim=-1)
|
243 |
+
|
244 |
+
(delta_root, delta_vel), _, _ = self.refiner(inpt_feat, [pred_root[:, 1:], pred_vel], h0=None)
|
245 |
+
pred_root[:, 1:] = pred_root[:, 1:] + delta_root
|
246 |
+
pred_vel = pred_vel + delta_vel
|
247 |
+
|
248 |
+
# root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
|
249 |
+
|
250 |
+
# if return_y_up:
|
251 |
+
# yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
|
252 |
+
# root_world = yup2ydown.mT @ root_world
|
253 |
+
# trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
|
254 |
+
|
255 |
+
output.update({
|
256 |
+
'poses_root_r6d_refined': pred_root,
|
257 |
+
'vel_root_refined': pred_vel,
|
258 |
+
# 'poses_root_world': root_world,
|
259 |
+
# 'trans_world': trans_world,
|
260 |
+
})
|
261 |
+
|
262 |
+
return output
|
lib/models/layers/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from lib.utils import transforms
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def rollout_global_motion(root_r, root_v, init_trans=None):
|
7 |
+
b, f = root_v.shape[:2]
|
8 |
+
root = transforms.rotation_6d_to_matrix(root_r[:])
|
9 |
+
vel_world = (root[:, :-1] @ root_v.unsqueeze(-1)).squeeze(-1)
|
10 |
+
trans = torch.cumsum(vel_world, dim=1)
|
11 |
+
|
12 |
+
if init_trans is not None: trans = trans + init_trans
|
13 |
+
return root[:, 1:], trans
|
14 |
+
|
15 |
+
def compute_camera_motion(output, root_c_d6d, root_w, trans, pred_cam):
|
16 |
+
root_c = transforms.rotation_6d_to_matrix(root_c_d6d) # Root orient in cam coord
|
17 |
+
cam_R = root_c @ root_w.mT
|
18 |
+
pelvis_cam = output.full_cam.view_as(pred_cam)
|
19 |
+
pelvis_world = (cam_R.mT @ pelvis_cam.unsqueeze(-1)).squeeze(-1)
|
20 |
+
cam_T_world = pelvis_world - trans
|
21 |
+
cam_T = (cam_R @ cam_T_world.unsqueeze(-1)).squeeze(-1)
|
22 |
+
|
23 |
+
return cam_R, cam_T
|
24 |
+
|
25 |
+
def compute_camera_pose(root_c_d6d, root_w):
|
26 |
+
root_c = transforms.rotation_6d_to_matrix(root_c_d6d) # Root orient in cam coord
|
27 |
+
cam_R = root_c @ root_w.mT
|
28 |
+
return cam_R
|
29 |
+
|
30 |
+
|
31 |
+
def reset_root_velocity(smpl, output, stationary, pred_ori, pred_vel, thr=0.7):
|
32 |
+
b, f = pred_vel.shape[:2]
|
33 |
+
|
34 |
+
stationary_mask = (stationary.clone().detach() > thr).unsqueeze(-1).float()
|
35 |
+
poses_root = transforms.rotation_6d_to_matrix(pred_ori.clone().detach())
|
36 |
+
vel_world = (poses_root[:, 1:] @ pred_vel.clone().detach().unsqueeze(-1)).squeeze(-1)
|
37 |
+
|
38 |
+
output = smpl.get_output(body_pose=output.body_pose.clone().detach(),
|
39 |
+
global_orient=poses_root[:, 1:].reshape(-1, 1, 3, 3),
|
40 |
+
betas=output.betas.clone().detach(),
|
41 |
+
pose2rot=False)
|
42 |
+
feet = output.feet.reshape(b, f, 4, 3)
|
43 |
+
feet_vel = feet[:, 1:] - feet[:, :-1] + vel_world[:, 1:].unsqueeze(-2)
|
44 |
+
feet_vel = torch.cat((torch.zeros_like(feet_vel[:, :1]), feet_vel), dim=1)
|
45 |
+
|
46 |
+
stationary_vel = feet_vel * stationary_mask
|
47 |
+
del_vel = stationary_vel.sum(dim=2) / ((stationary_vel != 0).sum(dim=2) + 1e-4)
|
48 |
+
vel_world_update = vel_world - del_vel
|
49 |
+
|
50 |
+
vel_root = (poses_root[:, 1:].mT @ vel_world_update.unsqueeze(-1)).squeeze(-1)
|
51 |
+
|
52 |
+
return vel_root
|
lib/models/preproc/__pycache__/detector.cpython-39.pyc
ADDED
Binary file (4.77 kB). View file
|
|
lib/models/preproc/__pycache__/extractor.cpython-39.pyc
ADDED
Binary file (3.48 kB). View file
|
|
lib/models/preproc/__pycache__/slam.cpython-39.pyc
ADDED
Binary file (2.6 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/hmr2.cpython-39.pyc
ADDED
Binary file (2.43 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/pose_transformer.cpython-39.pyc
ADDED
Binary file (10.8 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/smpl_head.cpython-39.pyc
ADDED
Binary file (4.46 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/t_cond_mlp.cpython-39.pyc
ADDED
Binary file (6.04 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.59 kB). View file
|
|
lib/models/preproc/backbone/__pycache__/vit.cpython-39.pyc
ADDED
Binary file (11.2 kB). View file
|
|
lib/models/preproc/backbone/hmr2.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torch.nn as nn
|
6 |
+
# import pytorch_lightning as pl
|
7 |
+
|
8 |
+
from yacs.config import CfgNode
|
9 |
+
from .vit import vit
|
10 |
+
from .smpl_head import SMPLTransformerDecoderHead
|
11 |
+
|
12 |
+
# class HMR2(pl.LightningModule):
|
13 |
+
class HMR2(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
"""
|
17 |
+
Setup HMR2 model
|
18 |
+
Args:
|
19 |
+
cfg (CfgNode): Config file as a yacs CfgNode
|
20 |
+
"""
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
# Create backbone feature extractor
|
24 |
+
self.backbone = vit()
|
25 |
+
|
26 |
+
# Create SMPL head
|
27 |
+
self.smpl_head = SMPLTransformerDecoderHead()
|
28 |
+
|
29 |
+
|
30 |
+
def decode(self, x):
|
31 |
+
|
32 |
+
batch_size = x.shape[0]
|
33 |
+
pred_smpl_params, pred_cam, _ = self.smpl_head(x)
|
34 |
+
|
35 |
+
# Compute model vertices, joints and the projected joints
|
36 |
+
pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
|
37 |
+
pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
|
38 |
+
pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
|
39 |
+
return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam
|
40 |
+
|
41 |
+
def forward(self, x, encode=False, **kwargs):
|
42 |
+
"""
|
43 |
+
Run a forward step of the network
|
44 |
+
Args:
|
45 |
+
batch (Dict): Dictionary containing batch data
|
46 |
+
train (bool): Flag indicating whether it is training or validation mode
|
47 |
+
Returns:
|
48 |
+
Dict: Dictionary containing the regression output
|
49 |
+
"""
|
50 |
+
|
51 |
+
# Use RGB image as input
|
52 |
+
batch_size = x.shape[0]
|
53 |
+
|
54 |
+
# Compute conditioning features using the backbone
|
55 |
+
# if using ViT backbone, we need to use a different aspect ratio
|
56 |
+
conditioning_feats = self.backbone(x[:,:,:,32:-32])
|
57 |
+
if encode:
|
58 |
+
conditioning_feats = einops.rearrange(conditioning_feats, 'b c h w -> b (h w) c')
|
59 |
+
token = torch.zeros(batch_size, 1, 1).to(x.device)
|
60 |
+
token_out = self.smpl_head.transformer(token, context=conditioning_feats)
|
61 |
+
return token_out.squeeze(1)
|
62 |
+
|
63 |
+
pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats)
|
64 |
+
|
65 |
+
# Compute model vertices, joints and the projected joints
|
66 |
+
pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
|
67 |
+
pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
|
68 |
+
pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
|
69 |
+
return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam
|
70 |
+
|
71 |
+
|
72 |
+
def hmr2(checkpoint_pth):
|
73 |
+
model = HMR2()
|
74 |
+
if os.path.exists(checkpoint_pth):
|
75 |
+
model.load_state_dict(torch.load(checkpoint_pth, map_location='cpu')['state_dict'], strict=False)
|
76 |
+
print(f'Load backbone weight: {checkpoint_pth}')
|
77 |
+
return model
|
lib/models/preproc/backbone/pose_transformer.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from .t_cond_mlp import (
|
10 |
+
AdaptiveLayerNorm1D,
|
11 |
+
FrequencyEmbedder,
|
12 |
+
normalization_layer,
|
13 |
+
)
|
14 |
+
# from .vit import Attention, FeedForward
|
15 |
+
|
16 |
+
|
17 |
+
def exists(val):
|
18 |
+
return val is not None
|
19 |
+
|
20 |
+
|
21 |
+
def default(val, d):
|
22 |
+
if exists(val):
|
23 |
+
return val
|
24 |
+
return d() if isfunction(d) else d
|
25 |
+
|
26 |
+
|
27 |
+
class PreNorm(nn.Module):
|
28 |
+
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
|
29 |
+
super().__init__()
|
30 |
+
self.norm = normalization_layer(norm, dim, norm_cond_dim)
|
31 |
+
self.fn = fn
|
32 |
+
|
33 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
34 |
+
if isinstance(self.norm, AdaptiveLayerNorm1D):
|
35 |
+
return self.fn(self.norm(x, *args), **kwargs)
|
36 |
+
else:
|
37 |
+
return self.fn(self.norm(x), **kwargs)
|
38 |
+
|
39 |
+
|
40 |
+
class FeedForward(nn.Module):
|
41 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
42 |
+
super().__init__()
|
43 |
+
self.net = nn.Sequential(
|
44 |
+
nn.Linear(dim, hidden_dim),
|
45 |
+
nn.GELU(),
|
46 |
+
nn.Dropout(dropout),
|
47 |
+
nn.Linear(hidden_dim, dim),
|
48 |
+
nn.Dropout(dropout),
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
return self.net(x)
|
53 |
+
|
54 |
+
|
55 |
+
class Attention(nn.Module):
|
56 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
57 |
+
super().__init__()
|
58 |
+
inner_dim = dim_head * heads
|
59 |
+
project_out = not (heads == 1 and dim_head == dim)
|
60 |
+
|
61 |
+
self.heads = heads
|
62 |
+
self.scale = dim_head**-0.5
|
63 |
+
|
64 |
+
self.attend = nn.Softmax(dim=-1)
|
65 |
+
self.dropout = nn.Dropout(dropout)
|
66 |
+
|
67 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
68 |
+
|
69 |
+
self.to_out = (
|
70 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
71 |
+
if project_out
|
72 |
+
else nn.Identity()
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
77 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
78 |
+
|
79 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
80 |
+
|
81 |
+
attn = self.attend(dots)
|
82 |
+
attn = self.dropout(attn)
|
83 |
+
|
84 |
+
out = torch.matmul(attn, v)
|
85 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
86 |
+
return self.to_out(out)
|
87 |
+
|
88 |
+
|
89 |
+
class CrossAttention(nn.Module):
|
90 |
+
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
91 |
+
super().__init__()
|
92 |
+
inner_dim = dim_head * heads
|
93 |
+
project_out = not (heads == 1 and dim_head == dim)
|
94 |
+
|
95 |
+
self.heads = heads
|
96 |
+
self.scale = dim_head**-0.5
|
97 |
+
|
98 |
+
self.attend = nn.Softmax(dim=-1)
|
99 |
+
self.dropout = nn.Dropout(dropout)
|
100 |
+
|
101 |
+
context_dim = default(context_dim, dim)
|
102 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
103 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
104 |
+
|
105 |
+
self.to_out = (
|
106 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
107 |
+
if project_out
|
108 |
+
else nn.Identity()
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x, context=None):
|
112 |
+
context = default(context, x)
|
113 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
114 |
+
q = self.to_q(x)
|
115 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
|
116 |
+
|
117 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
118 |
+
|
119 |
+
attn = self.attend(dots)
|
120 |
+
attn = self.dropout(attn)
|
121 |
+
|
122 |
+
out = torch.matmul(attn, v)
|
123 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
124 |
+
return self.to_out(out)
|
125 |
+
|
126 |
+
|
127 |
+
class Transformer(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
dim: int,
|
131 |
+
depth: int,
|
132 |
+
heads: int,
|
133 |
+
dim_head: int,
|
134 |
+
mlp_dim: int,
|
135 |
+
dropout: float = 0.0,
|
136 |
+
norm: str = "layer",
|
137 |
+
norm_cond_dim: int = -1,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
self.layers = nn.ModuleList([])
|
141 |
+
for _ in range(depth):
|
142 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
143 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
144 |
+
self.layers.append(
|
145 |
+
nn.ModuleList(
|
146 |
+
[
|
147 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
148 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
149 |
+
]
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(self, x: torch.Tensor, *args):
|
154 |
+
for attn, ff in self.layers:
|
155 |
+
x = attn(x, *args) + x
|
156 |
+
x = ff(x, *args) + x
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
class TransformerCrossAttn(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
dim: int,
|
164 |
+
depth: int,
|
165 |
+
heads: int,
|
166 |
+
dim_head: int,
|
167 |
+
mlp_dim: int,
|
168 |
+
dropout: float = 0.0,
|
169 |
+
norm: str = "layer",
|
170 |
+
norm_cond_dim: int = -1,
|
171 |
+
context_dim: Optional[int] = None,
|
172 |
+
):
|
173 |
+
super().__init__()
|
174 |
+
self.layers = nn.ModuleList([])
|
175 |
+
for _ in range(depth):
|
176 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
177 |
+
ca = CrossAttention(
|
178 |
+
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
|
179 |
+
)
|
180 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
181 |
+
self.layers.append(
|
182 |
+
nn.ModuleList(
|
183 |
+
[
|
184 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
185 |
+
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
|
186 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
187 |
+
]
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
|
192 |
+
if context_list is None:
|
193 |
+
context_list = [context] * len(self.layers)
|
194 |
+
if len(context_list) != len(self.layers):
|
195 |
+
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
|
196 |
+
|
197 |
+
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
|
198 |
+
x = self_attn(x, *args) + x
|
199 |
+
x = cross_attn(x, *args, context=context_list[i]) + x
|
200 |
+
x = ff(x, *args) + x
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class DropTokenDropout(nn.Module):
|
205 |
+
def __init__(self, p: float = 0.1):
|
206 |
+
super().__init__()
|
207 |
+
if p < 0 or p > 1:
|
208 |
+
raise ValueError(
|
209 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
210 |
+
)
|
211 |
+
self.p = p
|
212 |
+
|
213 |
+
def forward(self, x: torch.Tensor):
|
214 |
+
# x: (batch_size, seq_len, dim)
|
215 |
+
if self.training and self.p > 0:
|
216 |
+
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
|
217 |
+
# TODO: permutation idx for each batch using torch.argsort
|
218 |
+
if zero_mask.any():
|
219 |
+
x = x[:, ~zero_mask, :]
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
class ZeroTokenDropout(nn.Module):
|
224 |
+
def __init__(self, p: float = 0.1):
|
225 |
+
super().__init__()
|
226 |
+
if p < 0 or p > 1:
|
227 |
+
raise ValueError(
|
228 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
229 |
+
)
|
230 |
+
self.p = p
|
231 |
+
|
232 |
+
def forward(self, x: torch.Tensor):
|
233 |
+
# x: (batch_size, seq_len, dim)
|
234 |
+
if self.training and self.p > 0:
|
235 |
+
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
|
236 |
+
# Zero-out the masked tokens
|
237 |
+
x[zero_mask, :] = 0
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class TransformerEncoder(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
num_tokens: int,
|
245 |
+
token_dim: int,
|
246 |
+
dim: int,
|
247 |
+
depth: int,
|
248 |
+
heads: int,
|
249 |
+
mlp_dim: int,
|
250 |
+
dim_head: int = 64,
|
251 |
+
dropout: float = 0.0,
|
252 |
+
emb_dropout: float = 0.0,
|
253 |
+
emb_dropout_type: str = "drop",
|
254 |
+
emb_dropout_loc: str = "token",
|
255 |
+
norm: str = "layer",
|
256 |
+
norm_cond_dim: int = -1,
|
257 |
+
token_pe_numfreq: int = -1,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
if token_pe_numfreq > 0:
|
261 |
+
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
|
262 |
+
self.to_token_embedding = nn.Sequential(
|
263 |
+
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
|
264 |
+
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
|
265 |
+
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
|
266 |
+
nn.Linear(token_dim_new, dim),
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
270 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
271 |
+
if emb_dropout_type == "drop":
|
272 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
273 |
+
elif emb_dropout_type == "zero":
|
274 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
275 |
+
else:
|
276 |
+
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
|
277 |
+
self.emb_dropout_loc = emb_dropout_loc
|
278 |
+
|
279 |
+
self.transformer = Transformer(
|
280 |
+
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
|
281 |
+
)
|
282 |
+
|
283 |
+
def forward(self, inp: torch.Tensor, *args, **kwargs):
|
284 |
+
x = inp
|
285 |
+
|
286 |
+
if self.emb_dropout_loc == "input":
|
287 |
+
x = self.dropout(x)
|
288 |
+
x = self.to_token_embedding(x)
|
289 |
+
|
290 |
+
if self.emb_dropout_loc == "token":
|
291 |
+
x = self.dropout(x)
|
292 |
+
b, n, _ = x.shape
|
293 |
+
x += self.pos_embedding[:, :n]
|
294 |
+
|
295 |
+
if self.emb_dropout_loc == "token_afterpos":
|
296 |
+
x = self.dropout(x)
|
297 |
+
x = self.transformer(x, *args)
|
298 |
+
return x
|
299 |
+
|
300 |
+
|
301 |
+
class TransformerDecoder(nn.Module):
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
num_tokens: int,
|
305 |
+
token_dim: int,
|
306 |
+
dim: int,
|
307 |
+
depth: int,
|
308 |
+
heads: int,
|
309 |
+
mlp_dim: int,
|
310 |
+
dim_head: int = 64,
|
311 |
+
dropout: float = 0.0,
|
312 |
+
emb_dropout: float = 0.0,
|
313 |
+
emb_dropout_type: str = 'drop',
|
314 |
+
norm: str = "layer",
|
315 |
+
norm_cond_dim: int = -1,
|
316 |
+
context_dim: Optional[int] = None,
|
317 |
+
skip_token_embedding: bool = False,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
if not skip_token_embedding:
|
321 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
322 |
+
else:
|
323 |
+
self.to_token_embedding = nn.Identity()
|
324 |
+
if token_dim != dim:
|
325 |
+
raise ValueError(
|
326 |
+
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
|
327 |
+
)
|
328 |
+
|
329 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
330 |
+
if emb_dropout_type == "drop":
|
331 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
332 |
+
elif emb_dropout_type == "zero":
|
333 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
334 |
+
elif emb_dropout_type == "normal":
|
335 |
+
self.dropout = nn.Dropout(emb_dropout)
|
336 |
+
|
337 |
+
self.transformer = TransformerCrossAttn(
|
338 |
+
dim,
|
339 |
+
depth,
|
340 |
+
heads,
|
341 |
+
dim_head,
|
342 |
+
mlp_dim,
|
343 |
+
dropout,
|
344 |
+
norm=norm,
|
345 |
+
norm_cond_dim=norm_cond_dim,
|
346 |
+
context_dim=context_dim,
|
347 |
+
)
|
348 |
+
|
349 |
+
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
|
350 |
+
x = self.to_token_embedding(inp)
|
351 |
+
b, n, _ = x.shape
|
352 |
+
|
353 |
+
x = self.dropout(x)
|
354 |
+
x += self.pos_embedding[:, :n]
|
355 |
+
|
356 |
+
x = self.transformer(x, *args, context=context, context_list=context_list)
|
357 |
+
return x
|
lib/models/preproc/backbone/smpl_head.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import einops
|
6 |
+
|
7 |
+
from configs import constants as _C
|
8 |
+
from lib.utils.transforms import axis_angle_to_matrix
|
9 |
+
from .pose_transformer import TransformerDecoder
|
10 |
+
|
11 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
14 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
15 |
+
Args:
|
16 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
17 |
+
Returns:
|
18 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
19 |
+
"""
|
20 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
21 |
+
a1 = x[:, :, 0]
|
22 |
+
a2 = x[:, :, 1]
|
23 |
+
b1 = F.normalize(a1)
|
24 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
25 |
+
b3 = torch.cross(b1, b2)
|
26 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
27 |
+
|
28 |
+
def build_smpl_head(cfg):
|
29 |
+
smpl_head_type = 'transformer_decoder'
|
30 |
+
if smpl_head_type == 'transformer_decoder':
|
31 |
+
return SMPLTransformerDecoderHead(cfg)
|
32 |
+
else:
|
33 |
+
raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type))
|
34 |
+
|
35 |
+
class SMPLTransformerDecoderHead(nn.Module):
|
36 |
+
""" Cross-attention based SMPL Transformer decoder
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
super().__init__()
|
41 |
+
self.joint_rep_type = '6d'
|
42 |
+
self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
|
43 |
+
npose = self.joint_rep_dim * 24
|
44 |
+
self.npose = npose
|
45 |
+
self.input_is_mean_shape = False
|
46 |
+
transformer_args = dict(
|
47 |
+
num_tokens=1,
|
48 |
+
token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
|
49 |
+
dim=1024,
|
50 |
+
)
|
51 |
+
transformer_args_from_cfg = dict(
|
52 |
+
depth=6, heads=8, mlp_dim=1024, dim_head=64, dropout=0.0, emb_dropout=0.0, norm='layer', context_dim=1280
|
53 |
+
)
|
54 |
+
transformer_args = (transformer_args | transformer_args_from_cfg)
|
55 |
+
self.transformer = TransformerDecoder(
|
56 |
+
**transformer_args
|
57 |
+
)
|
58 |
+
dim=transformer_args['dim']
|
59 |
+
self.decpose = nn.Linear(dim, npose)
|
60 |
+
self.decshape = nn.Linear(dim, 10)
|
61 |
+
self.deccam = nn.Linear(dim, 3)
|
62 |
+
|
63 |
+
mean_params = np.load(_C.BMODEL.MEAN_PARAMS)
|
64 |
+
init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
|
65 |
+
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
66 |
+
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
|
67 |
+
self.register_buffer('init_body_pose', init_body_pose)
|
68 |
+
self.register_buffer('init_betas', init_betas)
|
69 |
+
self.register_buffer('init_cam', init_cam)
|
70 |
+
|
71 |
+
def forward(self, x, **kwargs):
|
72 |
+
|
73 |
+
batch_size = x.shape[0]
|
74 |
+
# vit pretrained backbone is channel-first. Change to token-first
|
75 |
+
|
76 |
+
init_body_pose = self.init_body_pose.expand(batch_size, -1)
|
77 |
+
init_betas = self.init_betas.expand(batch_size, -1)
|
78 |
+
init_cam = self.init_cam.expand(batch_size, -1)
|
79 |
+
|
80 |
+
# TODO: Convert init_body_pose to aa rep if needed
|
81 |
+
if self.joint_rep_type == 'aa':
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
pred_body_pose = init_body_pose
|
85 |
+
pred_betas = init_betas
|
86 |
+
pred_cam = init_cam
|
87 |
+
pred_body_pose_list = []
|
88 |
+
pred_betas_list = []
|
89 |
+
pred_cam_list = []
|
90 |
+
|
91 |
+
# Input token to transformer is zero token
|
92 |
+
if len(x.shape) > 2:
|
93 |
+
x = einops.rearrange(x, 'b c h w -> b (h w) c')
|
94 |
+
if self.input_is_mean_shape:
|
95 |
+
token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:,None,:]
|
96 |
+
else:
|
97 |
+
token = torch.zeros(batch_size, 1, 1).to(x.device)
|
98 |
+
|
99 |
+
# Pass through transformer
|
100 |
+
token_out = self.transformer(token, context=x)
|
101 |
+
token_out = token_out.squeeze(1) # (B, C)
|
102 |
+
else:
|
103 |
+
token_out = x
|
104 |
+
|
105 |
+
# Readout from token_out
|
106 |
+
pred_body_pose = self.decpose(token_out) + pred_body_pose
|
107 |
+
pred_betas = self.decshape(token_out) + pred_betas
|
108 |
+
pred_cam = self.deccam(token_out) + pred_cam
|
109 |
+
pred_body_pose_list.append(pred_body_pose)
|
110 |
+
pred_betas_list.append(pred_betas)
|
111 |
+
pred_cam_list.append(pred_cam)
|
112 |
+
|
113 |
+
# Convert self.joint_rep_type -> rotmat
|
114 |
+
joint_conversion_fn = {
|
115 |
+
'6d': rot6d_to_rotmat,
|
116 |
+
'aa': lambda x: axis_angle_to_matrix(x.view(-1, 3).contiguous())
|
117 |
+
}[self.joint_rep_type]
|
118 |
+
|
119 |
+
pred_smpl_params_list = {}
|
120 |
+
pred_smpl_params_list['body_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0)
|
121 |
+
pred_smpl_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
|
122 |
+
pred_smpl_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
|
123 |
+
pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, 24, 3, 3)
|
124 |
+
|
125 |
+
pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
|
126 |
+
'body_pose': pred_body_pose[:, 1:],
|
127 |
+
'betas': pred_betas}
|
128 |
+
return pred_smpl_params, pred_cam, pred_smpl_params_list
|
lib/models/preproc/backbone/t_cond_mlp.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class AdaptiveLayerNorm1D(torch.nn.Module):
|
8 |
+
def __init__(self, data_dim: int, norm_cond_dim: int):
|
9 |
+
super().__init__()
|
10 |
+
if data_dim <= 0:
|
11 |
+
raise ValueError(f"data_dim must be positive, but got {data_dim}")
|
12 |
+
if norm_cond_dim <= 0:
|
13 |
+
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
|
14 |
+
self.norm = torch.nn.LayerNorm(
|
15 |
+
data_dim
|
16 |
+
) # TODO: Check if elementwise_affine=True is correct
|
17 |
+
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
|
18 |
+
torch.nn.init.zeros_(self.linear.weight)
|
19 |
+
torch.nn.init.zeros_(self.linear.bias)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
22 |
+
# x: (batch, ..., data_dim)
|
23 |
+
# t: (batch, norm_cond_dim)
|
24 |
+
# return: (batch, data_dim)
|
25 |
+
x = self.norm(x)
|
26 |
+
alpha, beta = self.linear(t).chunk(2, dim=-1)
|
27 |
+
|
28 |
+
# Add singleton dimensions to alpha and beta
|
29 |
+
if x.dim() > 2:
|
30 |
+
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
|
31 |
+
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
|
32 |
+
|
33 |
+
return x * (1 + alpha) + beta
|
34 |
+
|
35 |
+
|
36 |
+
class SequentialCond(torch.nn.Sequential):
|
37 |
+
def forward(self, input, *args, **kwargs):
|
38 |
+
for module in self:
|
39 |
+
if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
|
40 |
+
# print(f'Passing on args to {module}', [a.shape for a in args])
|
41 |
+
input = module(input, *args, **kwargs)
|
42 |
+
else:
|
43 |
+
# print(f'Skipping passing args to {module}', [a.shape for a in args])
|
44 |
+
input = module(input)
|
45 |
+
return input
|
46 |
+
|
47 |
+
|
48 |
+
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
|
49 |
+
if norm == "batch":
|
50 |
+
return torch.nn.BatchNorm1d(dim)
|
51 |
+
elif norm == "layer":
|
52 |
+
return torch.nn.LayerNorm(dim)
|
53 |
+
elif norm == "ada":
|
54 |
+
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
|
55 |
+
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
|
56 |
+
elif norm is None:
|
57 |
+
return torch.nn.Identity()
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown norm: {norm}")
|
60 |
+
|
61 |
+
|
62 |
+
def linear_norm_activ_dropout(
|
63 |
+
input_dim: int,
|
64 |
+
output_dim: int,
|
65 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
66 |
+
bias: bool = True,
|
67 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
68 |
+
dropout: float = 0.0,
|
69 |
+
norm_cond_dim: int = -1,
|
70 |
+
) -> SequentialCond:
|
71 |
+
layers = []
|
72 |
+
layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
|
73 |
+
if norm is not None:
|
74 |
+
layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
|
75 |
+
layers.append(copy.deepcopy(activation))
|
76 |
+
if dropout > 0.0:
|
77 |
+
layers.append(torch.nn.Dropout(dropout))
|
78 |
+
return SequentialCond(*layers)
|
79 |
+
|
80 |
+
|
81 |
+
def create_simple_mlp(
|
82 |
+
input_dim: int,
|
83 |
+
hidden_dims: List[int],
|
84 |
+
output_dim: int,
|
85 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
86 |
+
bias: bool = True,
|
87 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
88 |
+
dropout: float = 0.0,
|
89 |
+
norm_cond_dim: int = -1,
|
90 |
+
) -> SequentialCond:
|
91 |
+
layers = []
|
92 |
+
prev_dim = input_dim
|
93 |
+
for hidden_dim in hidden_dims:
|
94 |
+
layers.extend(
|
95 |
+
linear_norm_activ_dropout(
|
96 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
97 |
+
)
|
98 |
+
)
|
99 |
+
prev_dim = hidden_dim
|
100 |
+
layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
|
101 |
+
return SequentialCond(*layers)
|
102 |
+
|
103 |
+
|
104 |
+
class ResidualMLPBlock(torch.nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
input_dim: int,
|
108 |
+
hidden_dim: int,
|
109 |
+
num_hidden_layers: int,
|
110 |
+
output_dim: int,
|
111 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
112 |
+
bias: bool = True,
|
113 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
114 |
+
dropout: float = 0.0,
|
115 |
+
norm_cond_dim: int = -1,
|
116 |
+
):
|
117 |
+
super().__init__()
|
118 |
+
if not (input_dim == output_dim == hidden_dim):
|
119 |
+
raise NotImplementedError(
|
120 |
+
f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
|
121 |
+
)
|
122 |
+
|
123 |
+
layers = []
|
124 |
+
prev_dim = input_dim
|
125 |
+
for i in range(num_hidden_layers):
|
126 |
+
layers.append(
|
127 |
+
linear_norm_activ_dropout(
|
128 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
129 |
+
)
|
130 |
+
)
|
131 |
+
prev_dim = hidden_dim
|
132 |
+
self.model = SequentialCond(*layers)
|
133 |
+
self.skip = torch.nn.Identity()
|
134 |
+
|
135 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
136 |
+
return x + self.model(x, *args, **kwargs)
|
137 |
+
|
138 |
+
|
139 |
+
class ResidualMLP(torch.nn.Module):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
input_dim: int,
|
143 |
+
hidden_dim: int,
|
144 |
+
num_hidden_layers: int,
|
145 |
+
output_dim: int,
|
146 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
147 |
+
bias: bool = True,
|
148 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
149 |
+
dropout: float = 0.0,
|
150 |
+
num_blocks: int = 1,
|
151 |
+
norm_cond_dim: int = -1,
|
152 |
+
):
|
153 |
+
super().__init__()
|
154 |
+
self.input_dim = input_dim
|
155 |
+
self.model = SequentialCond(
|
156 |
+
linear_norm_activ_dropout(
|
157 |
+
input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
158 |
+
),
|
159 |
+
*[
|
160 |
+
ResidualMLPBlock(
|
161 |
+
hidden_dim,
|
162 |
+
hidden_dim,
|
163 |
+
num_hidden_layers,
|
164 |
+
hidden_dim,
|
165 |
+
activation,
|
166 |
+
bias,
|
167 |
+
norm,
|
168 |
+
dropout,
|
169 |
+
norm_cond_dim,
|
170 |
+
)
|
171 |
+
for _ in range(num_blocks)
|
172 |
+
],
|
173 |
+
torch.nn.Linear(hidden_dim, output_dim, bias=bias),
|
174 |
+
)
|
175 |
+
|
176 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
177 |
+
return self.model(x, *args, **kwargs)
|
178 |
+
|
179 |
+
|
180 |
+
class FrequencyEmbedder(torch.nn.Module):
|
181 |
+
def __init__(self, num_frequencies, max_freq_log2):
|
182 |
+
super().__init__()
|
183 |
+
frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
|
184 |
+
self.register_buffer("frequencies", frequencies)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
# x should be of size (N,) or (N, D)
|
188 |
+
N = x.size(0)
|
189 |
+
if x.dim() == 1: # (N,)
|
190 |
+
x = x.unsqueeze(1) # (N, D) where D=1
|
191 |
+
x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
|
192 |
+
scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
|
193 |
+
s = torch.sin(scaled)
|
194 |
+
c = torch.cos(scaled)
|
195 |
+
embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
|
196 |
+
N, -1
|
197 |
+
) # (N, D * 2 * num_frequencies + D)
|
198 |
+
return embedded
|
lib/models/preproc/backbone/utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
from skimage.filters import gaussian
|
12 |
+
|
13 |
+
|
14 |
+
def get_transform(center, scale, res, rot=0):
|
15 |
+
"""Generate transformation matrix."""
|
16 |
+
# res: (height, width), (rows, cols)
|
17 |
+
crop_aspect_ratio = res[0] / float(res[1])
|
18 |
+
h = 200 * scale
|
19 |
+
w = h / crop_aspect_ratio
|
20 |
+
t = np.zeros((3, 3))
|
21 |
+
t[0, 0] = float(res[1]) / w
|
22 |
+
t[1, 1] = float(res[0]) / h
|
23 |
+
t[0, 2] = res[1] * (-float(center[0]) / w + .5)
|
24 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
25 |
+
t[2, 2] = 1
|
26 |
+
if not rot == 0:
|
27 |
+
rot = -rot # To match direction of rotation from cropping
|
28 |
+
rot_mat = np.zeros((3, 3))
|
29 |
+
rot_rad = rot * np.pi / 180
|
30 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
31 |
+
rot_mat[0, :2] = [cs, -sn]
|
32 |
+
rot_mat[1, :2] = [sn, cs]
|
33 |
+
rot_mat[2, 2] = 1
|
34 |
+
# Need to rotate around center
|
35 |
+
t_mat = np.eye(3)
|
36 |
+
t_mat[0, 2] = -res[1] / 2
|
37 |
+
t_mat[1, 2] = -res[0] / 2
|
38 |
+
t_inv = t_mat.copy()
|
39 |
+
t_inv[:2, 2] *= -1
|
40 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
41 |
+
return t
|
42 |
+
|
43 |
+
|
44 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
45 |
+
"""Transform pixel location to different reference."""
|
46 |
+
t = get_transform(center, scale, res, rot=rot)
|
47 |
+
if invert:
|
48 |
+
t = np.linalg.inv(t)
|
49 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
50 |
+
new_pt = np.dot(t, new_pt)
|
51 |
+
return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
|
52 |
+
|
53 |
+
|
54 |
+
def crop(img, center, scale, res):
|
55 |
+
"""
|
56 |
+
Crop image according to the supplied bounding box.
|
57 |
+
res: [rows, cols]
|
58 |
+
"""
|
59 |
+
# Upper left point
|
60 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
61 |
+
# Bottom right point
|
62 |
+
br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
|
63 |
+
|
64 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
65 |
+
if len(img.shape) > 2:
|
66 |
+
new_shape += [img.shape[2]]
|
67 |
+
new_img = np.zeros(new_shape, dtype=np.float32)
|
68 |
+
|
69 |
+
# Range to fill new array
|
70 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
71 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
72 |
+
# Range to sample from original image
|
73 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
74 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
75 |
+
try:
|
76 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
77 |
+
except Exception as e:
|
78 |
+
print(e)
|
79 |
+
|
80 |
+
new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
|
81 |
+
|
82 |
+
return new_img, ul, br
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def process_image(orig_img_rgb, center, scale, crop_height=256, crop_width=192, blur=False, do_crop=True):
|
87 |
+
"""
|
88 |
+
Read image, do preprocessing and possibly crop it according to the bounding box.
|
89 |
+
If there are bounding box annotations, use them to crop the image.
|
90 |
+
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
|
91 |
+
"""
|
92 |
+
|
93 |
+
if blur:
|
94 |
+
# Blur image to avoid aliasing artifacts
|
95 |
+
downsampling_factor = ((scale * 200 * 1.0) / crop_height)
|
96 |
+
downsampling_factor = downsampling_factor / 2.0
|
97 |
+
if downsampling_factor > 1.1:
|
98 |
+
orig_img_rgb = gaussian(orig_img_rgb, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
|
99 |
+
|
100 |
+
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
|
101 |
+
IMG_NORM_STD = [0.229, 0.224, 0.225]
|
102 |
+
|
103 |
+
if do_crop:
|
104 |
+
img, ul, br = crop(orig_img_rgb, center, scale, (crop_height, crop_width))
|
105 |
+
else:
|
106 |
+
img = orig_img_rgb.copy()
|
107 |
+
crop_img = img.copy()
|
108 |
+
|
109 |
+
img = img / 255.
|
110 |
+
mean = np.array(IMG_NORM_MEAN, dtype=np.float32)
|
111 |
+
std = np.array(IMG_NORM_STD, dtype=np.float32)
|
112 |
+
norm_img = (img - mean) / std
|
113 |
+
norm_img = np.transpose(norm_img, (2, 0, 1))
|
114 |
+
|
115 |
+
return norm_img, crop_img
|
lib/models/preproc/backbone/vit.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from functools import partial
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
|
10 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
11 |
+
|
12 |
+
def vit():
|
13 |
+
return ViT(
|
14 |
+
img_size=(256, 192),
|
15 |
+
patch_size=16,
|
16 |
+
embed_dim=1280,
|
17 |
+
depth=32,
|
18 |
+
num_heads=16,
|
19 |
+
ratio=1,
|
20 |
+
use_checkpoint=False,
|
21 |
+
mlp_ratio=4,
|
22 |
+
qkv_bias=True,
|
23 |
+
drop_path_rate=0.55,
|
24 |
+
)
|
25 |
+
|
26 |
+
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
|
27 |
+
"""
|
28 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
29 |
+
dimension for the original embeddings.
|
30 |
+
Args:
|
31 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
32 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
33 |
+
hw (Tuple): size of input image tokens.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
37 |
+
"""
|
38 |
+
cls_token = None
|
39 |
+
B, L, C = abs_pos.shape
|
40 |
+
if has_cls_token:
|
41 |
+
cls_token = abs_pos[:, 0:1]
|
42 |
+
abs_pos = abs_pos[:, 1:]
|
43 |
+
|
44 |
+
if ori_h != h or ori_w != w:
|
45 |
+
new_abs_pos = F.interpolate(
|
46 |
+
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
|
47 |
+
size=(h, w),
|
48 |
+
mode="bicubic",
|
49 |
+
align_corners=False,
|
50 |
+
).permute(0, 2, 3, 1).reshape(B, -1, C)
|
51 |
+
|
52 |
+
else:
|
53 |
+
new_abs_pos = abs_pos
|
54 |
+
|
55 |
+
if cls_token is not None:
|
56 |
+
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
|
57 |
+
return new_abs_pos
|
58 |
+
|
59 |
+
class DropPath(nn.Module):
|
60 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
61 |
+
"""
|
62 |
+
def __init__(self, drop_prob=None):
|
63 |
+
super(DropPath, self).__init__()
|
64 |
+
self.drop_prob = drop_prob
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
return drop_path(x, self.drop_prob, self.training)
|
68 |
+
|
69 |
+
def extra_repr(self):
|
70 |
+
return 'p={}'.format(self.drop_prob)
|
71 |
+
|
72 |
+
class Mlp(nn.Module):
|
73 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
74 |
+
super().__init__()
|
75 |
+
out_features = out_features or in_features
|
76 |
+
hidden_features = hidden_features or in_features
|
77 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
78 |
+
self.act = act_layer()
|
79 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
80 |
+
self.drop = nn.Dropout(drop)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = self.fc1(x)
|
84 |
+
x = self.act(x)
|
85 |
+
x = self.fc2(x)
|
86 |
+
x = self.drop(x)
|
87 |
+
return x
|
88 |
+
|
89 |
+
class Attention(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
92 |
+
proj_drop=0., attn_head_dim=None,):
|
93 |
+
super().__init__()
|
94 |
+
self.num_heads = num_heads
|
95 |
+
head_dim = dim // num_heads
|
96 |
+
self.dim = dim
|
97 |
+
|
98 |
+
if attn_head_dim is not None:
|
99 |
+
head_dim = attn_head_dim
|
100 |
+
all_head_dim = head_dim * self.num_heads
|
101 |
+
|
102 |
+
self.scale = qk_scale or head_dim ** -0.5
|
103 |
+
|
104 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
105 |
+
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
B, N, C = x.shape
|
112 |
+
qkv = self.qkv(x)
|
113 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
114 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
115 |
+
|
116 |
+
q = q * self.scale
|
117 |
+
attn = (q @ k.transpose(-2, -1))
|
118 |
+
|
119 |
+
attn = attn.softmax(dim=-1)
|
120 |
+
attn = self.attn_drop(attn)
|
121 |
+
|
122 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
123 |
+
x = self.proj(x)
|
124 |
+
x = self.proj_drop(x)
|
125 |
+
|
126 |
+
return x
|
127 |
+
|
128 |
+
class Block(nn.Module):
|
129 |
+
|
130 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
131 |
+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
|
132 |
+
norm_layer=nn.LayerNorm, attn_head_dim=None
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.norm1 = norm_layer(dim)
|
137 |
+
self.attn = Attention(
|
138 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
139 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
|
140 |
+
)
|
141 |
+
|
142 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
143 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
144 |
+
self.norm2 = norm_layer(dim)
|
145 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
146 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
150 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class PatchEmbed(nn.Module):
|
155 |
+
""" Image to Patch Embedding
|
156 |
+
"""
|
157 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
158 |
+
super().__init__()
|
159 |
+
img_size = to_2tuple(img_size)
|
160 |
+
patch_size = to_2tuple(patch_size)
|
161 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
|
162 |
+
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
|
163 |
+
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
|
164 |
+
self.img_size = img_size
|
165 |
+
self.patch_size = patch_size
|
166 |
+
self.num_patches = num_patches
|
167 |
+
|
168 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
|
169 |
+
|
170 |
+
def forward(self, x, **kwargs):
|
171 |
+
B, C, H, W = x.shape
|
172 |
+
x = self.proj(x)
|
173 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
174 |
+
|
175 |
+
x = x.flatten(2).transpose(1, 2)
|
176 |
+
return x, (Hp, Wp)
|
177 |
+
|
178 |
+
|
179 |
+
class HybridEmbed(nn.Module):
|
180 |
+
""" CNN Feature Map Embedding
|
181 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
182 |
+
"""
|
183 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
184 |
+
super().__init__()
|
185 |
+
assert isinstance(backbone, nn.Module)
|
186 |
+
img_size = to_2tuple(img_size)
|
187 |
+
self.img_size = img_size
|
188 |
+
self.backbone = backbone
|
189 |
+
if feature_size is None:
|
190 |
+
with torch.no_grad():
|
191 |
+
training = backbone.training
|
192 |
+
if training:
|
193 |
+
backbone.eval()
|
194 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
195 |
+
feature_size = o.shape[-2:]
|
196 |
+
feature_dim = o.shape[1]
|
197 |
+
backbone.train(training)
|
198 |
+
else:
|
199 |
+
feature_size = to_2tuple(feature_size)
|
200 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
201 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
202 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
x = self.backbone(x)[-1]
|
206 |
+
x = x.flatten(2).transpose(1, 2)
|
207 |
+
x = self.proj(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
class ViT(nn.Module):
|
212 |
+
|
213 |
+
def __init__(self,
|
214 |
+
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
|
215 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
216 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
|
217 |
+
frozen_stages=-1, ratio=1, last_norm=True,
|
218 |
+
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
|
219 |
+
):
|
220 |
+
# Protect mutable default arguments
|
221 |
+
super(ViT, self).__init__()
|
222 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
223 |
+
self.num_classes = num_classes
|
224 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
225 |
+
self.frozen_stages = frozen_stages
|
226 |
+
self.use_checkpoint = use_checkpoint
|
227 |
+
self.patch_padding = patch_padding
|
228 |
+
self.freeze_attn = freeze_attn
|
229 |
+
self.freeze_ffn = freeze_ffn
|
230 |
+
self.depth = depth
|
231 |
+
|
232 |
+
if hybrid_backbone is not None:
|
233 |
+
self.patch_embed = HybridEmbed(
|
234 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
235 |
+
else:
|
236 |
+
self.patch_embed = PatchEmbed(
|
237 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
|
238 |
+
num_patches = self.patch_embed.num_patches
|
239 |
+
|
240 |
+
# since the pretraining model has class token
|
241 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
242 |
+
|
243 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
244 |
+
|
245 |
+
self.blocks = nn.ModuleList([
|
246 |
+
Block(
|
247 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
248 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
249 |
+
)
|
250 |
+
for i in range(depth)])
|
251 |
+
|
252 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
253 |
+
|
254 |
+
if self.pos_embed is not None:
|
255 |
+
trunc_normal_(self.pos_embed, std=.02)
|
256 |
+
|
257 |
+
self._freeze_stages()
|
258 |
+
|
259 |
+
def _freeze_stages(self):
|
260 |
+
"""Freeze parameters."""
|
261 |
+
if self.frozen_stages >= 0:
|
262 |
+
self.patch_embed.eval()
|
263 |
+
for param in self.patch_embed.parameters():
|
264 |
+
param.requires_grad = False
|
265 |
+
|
266 |
+
for i in range(1, self.frozen_stages + 1):
|
267 |
+
m = self.blocks[i]
|
268 |
+
m.eval()
|
269 |
+
for param in m.parameters():
|
270 |
+
param.requires_grad = False
|
271 |
+
|
272 |
+
if self.freeze_attn:
|
273 |
+
for i in range(0, self.depth):
|
274 |
+
m = self.blocks[i]
|
275 |
+
m.attn.eval()
|
276 |
+
m.norm1.eval()
|
277 |
+
for param in m.attn.parameters():
|
278 |
+
param.requires_grad = False
|
279 |
+
for param in m.norm1.parameters():
|
280 |
+
param.requires_grad = False
|
281 |
+
|
282 |
+
if self.freeze_ffn:
|
283 |
+
self.pos_embed.requires_grad = False
|
284 |
+
self.patch_embed.eval()
|
285 |
+
for param in self.patch_embed.parameters():
|
286 |
+
param.requires_grad = False
|
287 |
+
for i in range(0, self.depth):
|
288 |
+
m = self.blocks[i]
|
289 |
+
m.mlp.eval()
|
290 |
+
m.norm2.eval()
|
291 |
+
for param in m.mlp.parameters():
|
292 |
+
param.requires_grad = False
|
293 |
+
for param in m.norm2.parameters():
|
294 |
+
param.requires_grad = False
|
295 |
+
|
296 |
+
def init_weights(self):
|
297 |
+
"""Initialize the weights in backbone.
|
298 |
+
Args:
|
299 |
+
pretrained (str, optional): Path to pre-trained weights.
|
300 |
+
Defaults to None.
|
301 |
+
"""
|
302 |
+
def _init_weights(m):
|
303 |
+
if isinstance(m, nn.Linear):
|
304 |
+
trunc_normal_(m.weight, std=.02)
|
305 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
306 |
+
nn.init.constant_(m.bias, 0)
|
307 |
+
elif isinstance(m, nn.LayerNorm):
|
308 |
+
nn.init.constant_(m.bias, 0)
|
309 |
+
nn.init.constant_(m.weight, 1.0)
|
310 |
+
|
311 |
+
self.apply(_init_weights)
|
312 |
+
|
313 |
+
def get_num_layers(self):
|
314 |
+
return len(self.blocks)
|
315 |
+
|
316 |
+
@torch.jit.ignore
|
317 |
+
def no_weight_decay(self):
|
318 |
+
return {'pos_embed', 'cls_token'}
|
319 |
+
|
320 |
+
def forward_features(self, x):
|
321 |
+
B, C, H, W = x.shape
|
322 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
323 |
+
|
324 |
+
if self.pos_embed is not None:
|
325 |
+
# fit for multiple GPU training
|
326 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
327 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
328 |
+
|
329 |
+
for blk in self.blocks:
|
330 |
+
if self.use_checkpoint:
|
331 |
+
x = checkpoint.checkpoint(blk, x)
|
332 |
+
else:
|
333 |
+
x = blk(x)
|
334 |
+
|
335 |
+
x = self.last_norm(x)
|
336 |
+
|
337 |
+
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
|
338 |
+
|
339 |
+
return xp
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
x = self.forward_features(x)
|
343 |
+
return x
|
344 |
+
|
345 |
+
def train(self, mode=True):
|
346 |
+
"""Convert the model into training mode."""
|
347 |
+
super().train(mode)
|
348 |
+
self._freeze_stages()
|
lib/models/preproc/detector.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import scipy.signal as signal
|
11 |
+
from progress.bar import Bar
|
12 |
+
|
13 |
+
from ultralytics import YOLO
|
14 |
+
from mmpose.apis import (
|
15 |
+
inference_top_down_pose_model,
|
16 |
+
init_pose_model,
|
17 |
+
get_track_id,
|
18 |
+
vis_pose_result,
|
19 |
+
)
|
20 |
+
|
21 |
+
ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
|
22 |
+
VIT_DIR = osp.join(ROOT_DIR, "third-party/ViTPose")
|
23 |
+
|
24 |
+
VIS_THRESH = 0.3
|
25 |
+
BBOX_CONF = 0.5
|
26 |
+
TRACKING_THR = 0.1
|
27 |
+
MINIMUM_FRMAES = 30
|
28 |
+
MINIMUM_JOINTS = 6
|
29 |
+
|
30 |
+
class DetectionModel(object):
|
31 |
+
def __init__(self, device):
|
32 |
+
|
33 |
+
# ViTPose
|
34 |
+
pose_model_cfg = osp.join(VIT_DIR, 'configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_huge_coco_256x192.py')
|
35 |
+
pose_model_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'vitpose-h-multi-coco.pth')
|
36 |
+
self.pose_model = init_pose_model(pose_model_cfg, pose_model_ckpt, device=device.lower())
|
37 |
+
|
38 |
+
# YOLO
|
39 |
+
bbox_model_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'yolov8x.pt')
|
40 |
+
self.bbox_model = YOLO(bbox_model_ckpt)
|
41 |
+
|
42 |
+
self.device = device
|
43 |
+
self.initialize_tracking()
|
44 |
+
|
45 |
+
def initialize_tracking(self, ):
|
46 |
+
self.next_id = 0
|
47 |
+
self.frame_id = 0
|
48 |
+
self.pose_results_last = []
|
49 |
+
self.tracking_results = {
|
50 |
+
'id': [],
|
51 |
+
'frame_id': [],
|
52 |
+
'bbox': [],
|
53 |
+
'keypoints': []
|
54 |
+
}
|
55 |
+
|
56 |
+
def xyxy_to_cxcys(self, bbox, s_factor=1.05):
|
57 |
+
cx, cy = bbox[[0, 2]].mean(), bbox[[1, 3]].mean()
|
58 |
+
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200 * s_factor
|
59 |
+
return np.array([[cx, cy, scale]])
|
60 |
+
|
61 |
+
def compute_bboxes_from_keypoints(self, s_factor=1.2):
|
62 |
+
X = self.tracking_results['keypoints'].copy()
|
63 |
+
mask = X[..., -1] > VIS_THRESH
|
64 |
+
|
65 |
+
bbox = np.zeros((len(X), 3))
|
66 |
+
for i, (kp, m) in enumerate(zip(X, mask)):
|
67 |
+
bb = [kp[m, 0].min(), kp[m, 1].min(),
|
68 |
+
kp[m, 0].max(), kp[m, 1].max()]
|
69 |
+
cx, cy = [(bb[2]+bb[0])/2, (bb[3]+bb[1])/2]
|
70 |
+
bb_w = bb[2] - bb[0]
|
71 |
+
bb_h = bb[3] - bb[1]
|
72 |
+
s = np.stack((bb_w, bb_h)).max()
|
73 |
+
bb = np.array((cx, cy, s))
|
74 |
+
bbox[i] = bb
|
75 |
+
|
76 |
+
bbox[:, 2] = bbox[:, 2] * s_factor / 200.0
|
77 |
+
self.tracking_results['bbox'] = bbox
|
78 |
+
|
79 |
+
def track(self, img, fps, length):
|
80 |
+
|
81 |
+
# bbox detection
|
82 |
+
bboxes = self.bbox_model.predict(
|
83 |
+
img, device=self.device, classes=0, conf=BBOX_CONF, save=False, verbose=False
|
84 |
+
)[0].boxes.xyxy.detach().cpu().numpy()
|
85 |
+
bboxes = [{'bbox': bbox} for bbox in bboxes]
|
86 |
+
|
87 |
+
# keypoints detection
|
88 |
+
pose_results, returned_outputs = inference_top_down_pose_model(
|
89 |
+
self.pose_model,
|
90 |
+
img,
|
91 |
+
person_results=bboxes,
|
92 |
+
format='xyxy',
|
93 |
+
return_heatmap=False,
|
94 |
+
outputs=None)
|
95 |
+
|
96 |
+
# person identification
|
97 |
+
pose_results, self.next_id = get_track_id(
|
98 |
+
pose_results,
|
99 |
+
self.pose_results_last,
|
100 |
+
self.next_id,
|
101 |
+
use_oks=False,
|
102 |
+
tracking_thr=TRACKING_THR,
|
103 |
+
use_one_euro=True,
|
104 |
+
fps=fps)
|
105 |
+
|
106 |
+
for pose_result in pose_results:
|
107 |
+
n_valid = (pose_result['keypoints'][:, -1] > VIS_THRESH).sum()
|
108 |
+
if n_valid < MINIMUM_JOINTS: continue
|
109 |
+
|
110 |
+
_id = pose_result['track_id']
|
111 |
+
xyxy = pose_result['bbox']
|
112 |
+
bbox = self.xyxy_to_cxcys(xyxy)
|
113 |
+
|
114 |
+
self.tracking_results['id'].append(_id)
|
115 |
+
self.tracking_results['frame_id'].append(self.frame_id)
|
116 |
+
self.tracking_results['bbox'].append(bbox)
|
117 |
+
self.tracking_results['keypoints'].append(pose_result['keypoints'])
|
118 |
+
|
119 |
+
self.frame_id += 1
|
120 |
+
self.pose_results_last = pose_results
|
121 |
+
|
122 |
+
def process(self, fps):
|
123 |
+
for key in ['id', 'frame_id', 'keypoints']:
|
124 |
+
self.tracking_results[key] = np.array(self.tracking_results[key])
|
125 |
+
self.compute_bboxes_from_keypoints()
|
126 |
+
|
127 |
+
output = defaultdict(lambda: defaultdict(list))
|
128 |
+
ids = np.unique(self.tracking_results['id'])
|
129 |
+
for _id in ids:
|
130 |
+
idxs = np.where(self.tracking_results['id'] == _id)[0]
|
131 |
+
for key, val in self.tracking_results.items():
|
132 |
+
if key == 'id': continue
|
133 |
+
output[_id][key] = val[idxs]
|
134 |
+
|
135 |
+
# Smooth bounding box detection
|
136 |
+
ids = list(output.keys())
|
137 |
+
for _id in ids:
|
138 |
+
if len(output[_id]['bbox']) < MINIMUM_FRMAES:
|
139 |
+
del output[_id]
|
140 |
+
continue
|
141 |
+
|
142 |
+
kernel = int(int(fps/2) / 2) * 2 + 1
|
143 |
+
smoothed_bbox = np.array([signal.medfilt(param, kernel) for param in output[_id]['bbox'].T]).T
|
144 |
+
output[_id]['bbox'] = smoothed_bbox
|
145 |
+
|
146 |
+
return output
|
lib/models/preproc/extractor.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import scipy.signal as signal
|
11 |
+
from progress.bar import Bar
|
12 |
+
from scipy.ndimage.filters import gaussian_filter1d
|
13 |
+
|
14 |
+
from configs import constants as _C
|
15 |
+
from .backbone.hmr2 import hmr2
|
16 |
+
from .backbone.utils import process_image
|
17 |
+
from ...utils.imutils import flip_kp, flip_bbox
|
18 |
+
|
19 |
+
ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
|
20 |
+
|
21 |
+
class FeatureExtractor(object):
|
22 |
+
def __init__(self, device, flip_eval=False, max_batch_size=64):
|
23 |
+
|
24 |
+
self.device = device
|
25 |
+
self.flip_eval = flip_eval
|
26 |
+
self.max_batch_size = max_batch_size
|
27 |
+
|
28 |
+
ckpt = osp.join(ROOT_DIR, 'checkpoints', 'hmr2a.ckpt')
|
29 |
+
self.model = hmr2(ckpt).to(device).eval()
|
30 |
+
|
31 |
+
def run(self, video, tracking_results, patch_h=256, patch_w=256):
|
32 |
+
|
33 |
+
if osp.isfile(video):
|
34 |
+
cap = cv2.VideoCapture(video)
|
35 |
+
is_video = True
|
36 |
+
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
37 |
+
width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
38 |
+
else: # Image list
|
39 |
+
cap = video
|
40 |
+
is_video = False
|
41 |
+
length = len(video)
|
42 |
+
height, width = cv2.imread(video[0]).shape[:2]
|
43 |
+
|
44 |
+
frame_id = 0
|
45 |
+
bar = Bar('Feature extraction ...', fill='#', max=length)
|
46 |
+
while True:
|
47 |
+
if is_video:
|
48 |
+
flag, img = cap.read()
|
49 |
+
if not flag:
|
50 |
+
break
|
51 |
+
else:
|
52 |
+
if frame_id >= len(cap):
|
53 |
+
break
|
54 |
+
img = cv2.imread(cap[frame_id])
|
55 |
+
|
56 |
+
for _id, val in tracking_results.items():
|
57 |
+
if not frame_id in val['frame_id']: continue
|
58 |
+
|
59 |
+
frame_id2 = np.where(val['frame_id'] == frame_id)[0][0]
|
60 |
+
bbox = val['bbox'][frame_id2]
|
61 |
+
cx, cy, scale = bbox
|
62 |
+
|
63 |
+
norm_img, crop_img = process_image(img[..., ::-1], [cx, cy], scale, patch_h, patch_w)
|
64 |
+
norm_img = torch.from_numpy(norm_img).unsqueeze(0).to(self.device)
|
65 |
+
feature = self.model(norm_img, encode=True)
|
66 |
+
tracking_results[_id]['features'].append(feature.cpu())
|
67 |
+
|
68 |
+
if frame_id2 == 0: # First frame of this subject
|
69 |
+
tracking_results = self.predict_init(norm_img, tracking_results, _id, flip_eval=False)
|
70 |
+
|
71 |
+
if self.flip_eval:
|
72 |
+
flipped_bbox = flip_bbox(bbox, width, height)
|
73 |
+
tracking_results[_id]['flipped_bbox'].append(flipped_bbox)
|
74 |
+
|
75 |
+
keypoints = val['keypoints'][frame_id2]
|
76 |
+
flipped_keypoints = flip_kp(keypoints, width)
|
77 |
+
tracking_results[_id]['flipped_keypoints'].append(flipped_keypoints)
|
78 |
+
|
79 |
+
flipped_features = self.model(torch.flip(norm_img, (3, )), encode=True)
|
80 |
+
tracking_results[_id]['flipped_features'].append(flipped_features.cpu())
|
81 |
+
|
82 |
+
if frame_id2 == 0:
|
83 |
+
tracking_results = self.predict_init(torch.flip(norm_img, (3, )), tracking_results, _id, flip_eval=True)
|
84 |
+
|
85 |
+
bar.next()
|
86 |
+
frame_id += 1
|
87 |
+
|
88 |
+
return self.process(tracking_results)
|
89 |
+
|
90 |
+
def predict_init(self, norm_img, tracking_results, _id, flip_eval=False):
|
91 |
+
prefix = 'flipped_' if flip_eval else ''
|
92 |
+
|
93 |
+
pred_global_orient, pred_body_pose, pred_betas, _ = self.model(norm_img, encode=False)
|
94 |
+
tracking_results[_id][prefix + 'init_global_orient'] = pred_global_orient.cpu()
|
95 |
+
tracking_results[_id][prefix + 'init_body_pose'] = pred_body_pose.cpu()
|
96 |
+
tracking_results[_id][prefix + 'init_betas'] = pred_betas.cpu()
|
97 |
+
return tracking_results
|
98 |
+
|
99 |
+
def process(self, tracking_results):
|
100 |
+
output = defaultdict(dict)
|
101 |
+
|
102 |
+
for _id, results in tracking_results.items():
|
103 |
+
|
104 |
+
for key, val in results.items():
|
105 |
+
if isinstance(val, list):
|
106 |
+
if isinstance(val[0], torch.Tensor):
|
107 |
+
val = torch.cat(val)
|
108 |
+
elif isinstance(val[0], np.ndarray):
|
109 |
+
val = np.array(val)
|
110 |
+
output[_id][key] = val
|
111 |
+
|
112 |
+
return output
|
lib/models/preproc/slam.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import glob
|
4 |
+
import os.path as osp
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
from pathlib import Path
|
9 |
+
from multiprocessing import Process, Queue
|
10 |
+
|
11 |
+
from dpvo.utils import Timer
|
12 |
+
from dpvo.dpvo import DPVO
|
13 |
+
from dpvo.config import cfg
|
14 |
+
from dpvo.stream import image_stream, video_stream
|
15 |
+
|
16 |
+
ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
|
17 |
+
DPVO_DIR = osp.join(ROOT_DIR, "third-party/DPVO")
|
18 |
+
|
19 |
+
|
20 |
+
class SLAMModel(object):
|
21 |
+
def __init__(self, video, output_pth, width, height, calib=None, stride=1, skip=0, buffer=2048):
|
22 |
+
|
23 |
+
if calib == None or not osp.exists(calib):
|
24 |
+
calib = osp.join(output_pth, 'calib.txt')
|
25 |
+
if not osp.exists(calib):
|
26 |
+
self.estimate_intrinsics(width, height, calib)
|
27 |
+
|
28 |
+
self.dpvo_cfg = osp.join(DPVO_DIR, 'config/default.yaml')
|
29 |
+
self.dpvo_ckpt = osp.join(ROOT_DIR, 'checkpoints', 'dpvo.pth')
|
30 |
+
|
31 |
+
self.buffer = buffer
|
32 |
+
self.times = []
|
33 |
+
self.slam = None
|
34 |
+
self.queue = Queue(maxsize=8)
|
35 |
+
self.reader = Process(target=video_stream, args=(self.queue, video, calib, stride, skip))
|
36 |
+
self.reader.start()
|
37 |
+
|
38 |
+
def estimate_intrinsics(self, width, height, calib):
|
39 |
+
focal_length = (height ** 2 + width ** 2) ** 0.5
|
40 |
+
center_x = width / 2
|
41 |
+
center_y = height / 2
|
42 |
+
|
43 |
+
with open(calib, 'w') as fopen:
|
44 |
+
line = f'{focal_length} {focal_length} {center_x} {center_y}'
|
45 |
+
fopen.write(line)
|
46 |
+
|
47 |
+
def track(self, ):
|
48 |
+
(t, image, intrinsics) = self.queue.get()
|
49 |
+
|
50 |
+
if t < 0: return
|
51 |
+
|
52 |
+
image = torch.from_numpy(image).permute(2,0,1).cuda()
|
53 |
+
intrinsics = torch.from_numpy(intrinsics).cuda()
|
54 |
+
|
55 |
+
if self.slam is None:
|
56 |
+
cfg.merge_from_file(self.dpvo_cfg)
|
57 |
+
cfg.BUFFER_SIZE = self.buffer
|
58 |
+
self.slam = DPVO(cfg, self.dpvo_ckpt, ht=image.shape[1], wd=image.shape[2], viz=False)
|
59 |
+
|
60 |
+
with Timer("SLAM", enabled=False):
|
61 |
+
t = time.time()
|
62 |
+
self.slam(t, image, intrinsics)
|
63 |
+
self.times.append(time.time() - t)
|
64 |
+
|
65 |
+
def process(self, ):
|
66 |
+
for _ in range(12):
|
67 |
+
self.slam.update()
|
68 |
+
|
69 |
+
self.reader.join()
|
70 |
+
return self.slam.terminate()[0]
|
lib/models/smpl.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import os, sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from lib.utils import transforms
|
10 |
+
|
11 |
+
from smplx import SMPL as _SMPL
|
12 |
+
from smplx.utils import SMPLOutput as ModelOutput
|
13 |
+
from smplx.lbs import vertices2joints
|
14 |
+
|
15 |
+
from configs import constants as _C
|
16 |
+
|
17 |
+
class SMPL(_SMPL):
|
18 |
+
""" Extension of the official SMPL implementation to support more joints """
|
19 |
+
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
sys.stdout = open(os.devnull, 'w')
|
22 |
+
super(SMPL, self).__init__(*args, **kwargs)
|
23 |
+
sys.stdout = sys.__stdout__
|
24 |
+
|
25 |
+
J_regressor_wham = np.load(_C.BMODEL.JOINTS_REGRESSOR_WHAM)
|
26 |
+
J_regressor_eval = np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
|
27 |
+
self.register_buffer('J_regressor_wham', torch.tensor(
|
28 |
+
J_regressor_wham, dtype=torch.float32))
|
29 |
+
self.register_buffer('J_regressor_eval', torch.tensor(
|
30 |
+
J_regressor_eval, dtype=torch.float32))
|
31 |
+
self.register_buffer('J_regressor_feet', torch.from_numpy(
|
32 |
+
np.load(_C.BMODEL.JOINTS_REGRESSOR_FEET)
|
33 |
+
).float())
|
34 |
+
|
35 |
+
def get_local_pose_from_reduced_global_pose(self, reduced_pose):
|
36 |
+
full_pose = torch.eye(
|
37 |
+
3, device=reduced_pose.device
|
38 |
+
)[(None, ) * 2].repeat(reduced_pose.shape[0], 24, 1, 1)
|
39 |
+
full_pose[:, _C.BMODEL.MAIN_JOINTS] = reduced_pose
|
40 |
+
return full_pose
|
41 |
+
|
42 |
+
def forward(self,
|
43 |
+
pred_rot6d,
|
44 |
+
betas,
|
45 |
+
cam=None,
|
46 |
+
cam_intrinsics=None,
|
47 |
+
bbox=None,
|
48 |
+
res=None,
|
49 |
+
return_full_pose=False,
|
50 |
+
**kwargs):
|
51 |
+
|
52 |
+
rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
|
53 |
+
).reshape(-1, 24, 3, 3)
|
54 |
+
|
55 |
+
output = self.get_output(body_pose=rotmat[:, 1:],
|
56 |
+
global_orient=rotmat[:, :1],
|
57 |
+
betas=betas.view(-1, 10),
|
58 |
+
pose2rot=False,
|
59 |
+
return_full_pose=return_full_pose)
|
60 |
+
|
61 |
+
if cam is not None:
|
62 |
+
joints3d = output.joints.reshape(*cam.shape[:2], -1, 3)
|
63 |
+
|
64 |
+
# Weak perspective projection (for InstaVariety)
|
65 |
+
weak_cam = convert_weak_perspective_to_perspective(cam)
|
66 |
+
|
67 |
+
weak_joints2d = weak_perspective_projection(
|
68 |
+
joints3d,
|
69 |
+
rotation=torch.eye(3, device=cam.device).unsqueeze(0).unsqueeze(0).expand(*cam.shape[:2], -1, -1),
|
70 |
+
translation=weak_cam,
|
71 |
+
focal_length=5000.,
|
72 |
+
camera_center=torch.zeros(*cam.shape[:2], 2, device=cam.device)
|
73 |
+
)
|
74 |
+
output.weak_joints2d = weak_joints2d
|
75 |
+
|
76 |
+
# Full perspective projection
|
77 |
+
full_cam = convert_pare_to_full_img_cam(
|
78 |
+
cam,
|
79 |
+
bbox[:, :, 2] * 200.,
|
80 |
+
bbox[:, :, :2],
|
81 |
+
res[:, 0].unsqueeze(-1),
|
82 |
+
res[:, 1].unsqueeze(-1),
|
83 |
+
focal_length=cam_intrinsics[:, :, 0, 0]
|
84 |
+
)
|
85 |
+
|
86 |
+
full_joints2d = full_perspective_projection(
|
87 |
+
joints3d,
|
88 |
+
translation=full_cam,
|
89 |
+
cam_intrinsics=cam_intrinsics,
|
90 |
+
)
|
91 |
+
output.full_joints2d = full_joints2d
|
92 |
+
output.full_cam = full_cam.reshape(-1, 3)
|
93 |
+
|
94 |
+
return output
|
95 |
+
|
96 |
+
def forward_nd(self,
|
97 |
+
pred_rot6d,
|
98 |
+
root,
|
99 |
+
betas,
|
100 |
+
return_full_pose=False):
|
101 |
+
|
102 |
+
rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
|
103 |
+
).reshape(-1, 24, 3, 3)
|
104 |
+
|
105 |
+
output = self.get_output(body_pose=rotmat[:, 1:],
|
106 |
+
global_orient=root.reshape(-1, 1, 3, 3),
|
107 |
+
betas=betas.view(-1, 10),
|
108 |
+
pose2rot=False,
|
109 |
+
return_full_pose=return_full_pose)
|
110 |
+
|
111 |
+
return output
|
112 |
+
|
113 |
+
def get_output(self, *args, **kwargs):
|
114 |
+
kwargs['get_skin'] = True
|
115 |
+
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
116 |
+
joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
|
117 |
+
feet = vertices2joints(self.J_regressor_feet, smpl_output.vertices)
|
118 |
+
|
119 |
+
offset = joints[..., [11, 12], :].mean(-2)
|
120 |
+
if 'transl' in kwargs:
|
121 |
+
offset = offset - kwargs['transl']
|
122 |
+
vertices = smpl_output.vertices - offset.unsqueeze(-2)
|
123 |
+
joints = joints - offset.unsqueeze(-2)
|
124 |
+
feet = feet - offset.unsqueeze(-2)
|
125 |
+
|
126 |
+
output = ModelOutput(vertices=vertices,
|
127 |
+
global_orient=smpl_output.global_orient,
|
128 |
+
body_pose=smpl_output.body_pose,
|
129 |
+
joints=joints,
|
130 |
+
betas=smpl_output.betas,
|
131 |
+
full_pose=smpl_output.full_pose)
|
132 |
+
output.feet = feet
|
133 |
+
output.offset = offset
|
134 |
+
return output
|
135 |
+
|
136 |
+
def get_offset(self, *args, **kwargs):
|
137 |
+
kwargs['get_skin'] = True
|
138 |
+
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
139 |
+
joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
|
140 |
+
|
141 |
+
offset = joints[..., [11, 12], :].mean(-2)
|
142 |
+
return offset
|
143 |
+
|
144 |
+
def get_faces(self):
|
145 |
+
return np.array(self.faces)
|
146 |
+
|
147 |
+
|
148 |
+
def convert_weak_perspective_to_perspective(
|
149 |
+
weak_perspective_camera,
|
150 |
+
focal_length=5000.,
|
151 |
+
img_res=224,
|
152 |
+
):
|
153 |
+
|
154 |
+
perspective_camera = torch.stack(
|
155 |
+
[
|
156 |
+
weak_perspective_camera[..., 1],
|
157 |
+
weak_perspective_camera[..., 2],
|
158 |
+
2 * focal_length / (img_res * weak_perspective_camera[..., 0] + 1e-9)
|
159 |
+
],
|
160 |
+
dim=-1
|
161 |
+
)
|
162 |
+
return perspective_camera
|
163 |
+
|
164 |
+
|
165 |
+
def weak_perspective_projection(
|
166 |
+
points,
|
167 |
+
rotation,
|
168 |
+
translation,
|
169 |
+
focal_length,
|
170 |
+
camera_center,
|
171 |
+
img_res=224,
|
172 |
+
normalize_joints2d=True,
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
This function computes the perspective projection of a set of points.
|
176 |
+
Input:
|
177 |
+
points (b, f, N, 3): 3D points
|
178 |
+
rotation (b, f, 3, 3): Camera rotation
|
179 |
+
translation (b, f, 3): Camera translation
|
180 |
+
focal_length (b, f,) or scalar: Focal length
|
181 |
+
camera_center (b, f, 2): Camera center
|
182 |
+
"""
|
183 |
+
|
184 |
+
K = torch.zeros([*points.shape[:2], 3, 3], device=points.device)
|
185 |
+
K[:,:,0,0] = focal_length
|
186 |
+
K[:,:,1,1] = focal_length
|
187 |
+
K[:,:,2,2] = 1.
|
188 |
+
K[:,:,:-1, -1] = camera_center
|
189 |
+
|
190 |
+
# Transform points
|
191 |
+
points = torch.einsum('bfij,bfkj->bfki', rotation, points)
|
192 |
+
points = points + translation.unsqueeze(-2)
|
193 |
+
|
194 |
+
# Apply perspective distortion
|
195 |
+
projected_points = points / points[...,-1].unsqueeze(-1)
|
196 |
+
|
197 |
+
# Apply camera intrinsics
|
198 |
+
projected_points = torch.einsum('bfij,bfkj->bfki', K, projected_points)
|
199 |
+
|
200 |
+
if normalize_joints2d:
|
201 |
+
projected_points = projected_points / (img_res / 2.)
|
202 |
+
|
203 |
+
return projected_points[..., :-1]
|
204 |
+
|
205 |
+
|
206 |
+
def full_perspective_projection(
|
207 |
+
points,
|
208 |
+
cam_intrinsics,
|
209 |
+
rotation=None,
|
210 |
+
translation=None,
|
211 |
+
):
|
212 |
+
|
213 |
+
K = cam_intrinsics
|
214 |
+
|
215 |
+
if rotation is not None:
|
216 |
+
points = (rotation @ points.transpose(-1, -2)).transpose(-1, -2)
|
217 |
+
if translation is not None:
|
218 |
+
points = points + translation.unsqueeze(-2)
|
219 |
+
projected_points = points / points[..., -1].unsqueeze(-1)
|
220 |
+
projected_points = (K @ projected_points.transpose(-1, -2)).transpose(-1, -2)
|
221 |
+
return projected_points[..., :-1]
|
222 |
+
|
223 |
+
|
224 |
+
def convert_pare_to_full_img_cam(
|
225 |
+
pare_cam,
|
226 |
+
bbox_height,
|
227 |
+
bbox_center,
|
228 |
+
img_w,
|
229 |
+
img_h,
|
230 |
+
focal_length,
|
231 |
+
crop_res=224
|
232 |
+
):
|
233 |
+
|
234 |
+
s, tx, ty = pare_cam[..., 0], pare_cam[..., 1], pare_cam[..., 2]
|
235 |
+
res = crop_res
|
236 |
+
r = bbox_height / res
|
237 |
+
tz = 2 * focal_length / (r * res * s)
|
238 |
+
|
239 |
+
cx = 2 * (bbox_center[..., 0] - (img_w / 2.)) / (s * bbox_height)
|
240 |
+
cy = 2 * (bbox_center[..., 1] - (img_h / 2.)) / (s * bbox_height)
|
241 |
+
|
242 |
+
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
243 |
+
return cam_t
|
244 |
+
|
245 |
+
|
246 |
+
def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length):
|
247 |
+
"""
|
248 |
+
convert the camera parameters from the crop camera to the full camera
|
249 |
+
:param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
|
250 |
+
:param center: shape=(N, 2) bbox coordinates (c_x, c_y)
|
251 |
+
:param scale: shape=(N) square bbox resolution (b / 200)
|
252 |
+
:param full_img_shape: shape=(N, 2) original image height and width
|
253 |
+
:param focal_length: shape=(N,)
|
254 |
+
:return:
|
255 |
+
"""
|
256 |
+
img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
|
257 |
+
cx, cy, b = center[:, 0], center[:, 1], scale * 200
|
258 |
+
w_2, h_2 = img_w / 2., img_h / 2.
|
259 |
+
bs = b * crop_cam[:, 0] + 1e-9
|
260 |
+
tz = 2 * focal_length / bs
|
261 |
+
tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1]
|
262 |
+
ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2]
|
263 |
+
full_cam = torch.stack([tx, ty, tz], dim=-1)
|
264 |
+
return full_cam
|
lib/models/smplify/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .smplify import TemporalSMPLify
|
lib/models/smplify/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (241 Bytes). View file
|
|
lib/models/smplify/__pycache__/losses.cpython-39.pyc
ADDED
Binary file (2.63 kB). View file
|
|
lib/models/smplify/__pycache__/smplify.cpython-39.pyc
ADDED
Binary file (2.09 kB). View file
|
|
lib/models/smplify/losses.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def gmof(x, sigma):
|
4 |
+
"""
|
5 |
+
Geman-McClure error function
|
6 |
+
"""
|
7 |
+
x_squared = x ** 2
|
8 |
+
sigma_squared = sigma ** 2
|
9 |
+
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
|
10 |
+
|
11 |
+
|
12 |
+
def compute_jitter(x):
|
13 |
+
"""
|
14 |
+
Compute jitter for the input tensor
|
15 |
+
"""
|
16 |
+
return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1)
|
17 |
+
|
18 |
+
|
19 |
+
class SMPLifyLoss(torch.nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
res,
|
22 |
+
cam_intrinsics,
|
23 |
+
init_pose,
|
24 |
+
device,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.res = res
|
31 |
+
self.cam_intrinsics = cam_intrinsics
|
32 |
+
self.init_pose = torch.from_numpy(init_pose).float().to(device)
|
33 |
+
|
34 |
+
def forward(self, output, params, input_keypoints, bbox,
|
35 |
+
reprojection_weight=100., regularize_weight=60.0,
|
36 |
+
consistency_weight=10.0, sprior_weight=0.04,
|
37 |
+
smooth_weight=20.0, sigma=100):
|
38 |
+
|
39 |
+
pose, shape, cam = params
|
40 |
+
scale = bbox[..., 2:].unsqueeze(-1) * 200.
|
41 |
+
|
42 |
+
# Loss 1. Data term
|
43 |
+
pred_keypoints = output.full_joints2d[..., :17, :]
|
44 |
+
joints_conf = input_keypoints[..., -1:]
|
45 |
+
reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma)
|
46 |
+
reprojection_error = ((reprojection_error * joints_conf) / scale).mean()
|
47 |
+
|
48 |
+
# Loss 2. Regularization term
|
49 |
+
regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean()
|
50 |
+
|
51 |
+
# Loss 3. Shape prior and consistency error
|
52 |
+
consistency_error = shape.std(dim=1).mean()
|
53 |
+
sprior_error = torch.linalg.norm(shape, dim=-1).mean()
|
54 |
+
shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error
|
55 |
+
|
56 |
+
# Loss 4. Smooth loss
|
57 |
+
pose_diff = compute_jitter(pose).mean()
|
58 |
+
cam_diff = compute_jitter(cam).mean()
|
59 |
+
smooth_error = pose_diff + cam_diff
|
60 |
+
|
61 |
+
# Sum up losses
|
62 |
+
loss = {
|
63 |
+
'reprojection': reprojection_weight * reprojection_error,
|
64 |
+
'regularize': regularize_weight * regularize_error,
|
65 |
+
'shape': shape_error,
|
66 |
+
'smooth': smooth_weight * smooth_error
|
67 |
+
}
|
68 |
+
|
69 |
+
return loss
|
70 |
+
|
71 |
+
def create_closure(self,
|
72 |
+
optimizer,
|
73 |
+
smpl,
|
74 |
+
params,
|
75 |
+
bbox,
|
76 |
+
input_keypoints):
|
77 |
+
|
78 |
+
def closure():
|
79 |
+
optimizer.zero_grad()
|
80 |
+
output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res)
|
81 |
+
|
82 |
+
loss_dict = self.forward(output, params, input_keypoints, bbox)
|
83 |
+
loss = sum(loss_dict.values())
|
84 |
+
loss.backward()
|
85 |
+
return loss
|
86 |
+
|
87 |
+
return closure
|
lib/models/smplify/smplify.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from lib.models import build_body_model
|
6 |
+
from .losses import SMPLifyLoss
|
7 |
+
|
8 |
+
class TemporalSMPLify():
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
smpl=None,
|
12 |
+
lr=1e-2,
|
13 |
+
num_iters=5,
|
14 |
+
num_steps=10,
|
15 |
+
img_w=None,
|
16 |
+
img_h=None,
|
17 |
+
device=None
|
18 |
+
):
|
19 |
+
|
20 |
+
self.smpl = smpl
|
21 |
+
self.lr = lr
|
22 |
+
self.num_iters = num_iters
|
23 |
+
self.num_steps = num_steps
|
24 |
+
self.img_w = img_w
|
25 |
+
self.img_h = img_h
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
def fit(self, init_pred, keypoints, bbox, **kwargs):
|
29 |
+
|
30 |
+
def to_params(param):
|
31 |
+
return torch.from_numpy(param).float().to(self.device).requires_grad_(True)
|
32 |
+
|
33 |
+
pose = init_pred['pose'].detach().cpu().numpy()
|
34 |
+
betas = init_pred['betas'].detach().cpu().numpy()
|
35 |
+
cam = init_pred['cam'].detach().cpu().numpy()
|
36 |
+
keypoints = torch.from_numpy(keypoints).float().unsqueeze(0).to(self.device)
|
37 |
+
|
38 |
+
BN = pose.shape[1]
|
39 |
+
lr = self.lr
|
40 |
+
|
41 |
+
# Stage 1. Optimize translation
|
42 |
+
params = [to_params(pose), to_params(betas), to_params(cam)]
|
43 |
+
optim_params = [params[2]]
|
44 |
+
|
45 |
+
optimizer = torch.optim.LBFGS(
|
46 |
+
optim_params,
|
47 |
+
lr=lr,
|
48 |
+
max_iter=self.num_iters,
|
49 |
+
line_search_fn='strong_wolfe')
|
50 |
+
|
51 |
+
loss_fn = SMPLifyLoss(init_pose=pose, device=self.device, **kwargs)
|
52 |
+
|
53 |
+
closure = loss_fn.create_closure(optimizer,
|
54 |
+
self.smpl,
|
55 |
+
params,
|
56 |
+
bbox,
|
57 |
+
keypoints)
|
58 |
+
|
59 |
+
for j in (j_bar := tqdm(range(self.num_steps), leave=False)):
|
60 |
+
optimizer.zero_grad()
|
61 |
+
loss = optimizer.step(closure)
|
62 |
+
msg = f'Loss: {loss.item():.1f}'
|
63 |
+
j_bar.set_postfix_str(msg)
|
64 |
+
|
65 |
+
|
66 |
+
# Stage 2. Optimize all params
|
67 |
+
optimizer = torch.optim.LBFGS(
|
68 |
+
params,
|
69 |
+
lr=lr * BN,
|
70 |
+
max_iter=self.num_iters,
|
71 |
+
line_search_fn='strong_wolfe')
|
72 |
+
|
73 |
+
for j in (j_bar := tqdm(range(self.num_steps), leave=False)):
|
74 |
+
optimizer.zero_grad()
|
75 |
+
loss = optimizer.step(closure)
|
76 |
+
msg = f'Loss: {loss.item():.1f}'
|
77 |
+
j_bar.set_postfix_str(msg)
|
78 |
+
|
79 |
+
init_pred['pose'] = params[0].detach()
|
80 |
+
init_pred['betas'] = params[1].detach()
|
81 |
+
init_pred['cam'] = params[2].detach()
|
82 |
+
|
83 |
+
return init_pred
|
lib/models/wham.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from configs import constants as _C
|
10 |
+
from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator,
|
11 |
+
rollout_global_motion, reset_root_velocity, compute_camera_motion)
|
12 |
+
from lib.utils.transforms import axis_angle_to_matrix
|
13 |
+
|
14 |
+
|
15 |
+
class Network(nn.Module):
|
16 |
+
def __init__(self,
|
17 |
+
smpl,
|
18 |
+
pose_dr=0.1,
|
19 |
+
d_embed=512,
|
20 |
+
n_layers=3,
|
21 |
+
d_feat=2048,
|
22 |
+
rnn_type='LSTM',
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
n_joints = _C.KEYPOINTS.NUM_JOINTS
|
28 |
+
self.smpl = smpl
|
29 |
+
in_dim = n_joints * 2 + 3
|
30 |
+
d_context = d_embed + n_joints * 3
|
31 |
+
|
32 |
+
self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2))
|
33 |
+
|
34 |
+
# Module 1. Motion Encoder
|
35 |
+
self.motion_encoder = MotionEncoder(in_dim=in_dim,
|
36 |
+
d_embed=d_embed,
|
37 |
+
pose_dr=pose_dr,
|
38 |
+
rnn_type=rnn_type,
|
39 |
+
n_layers=n_layers,
|
40 |
+
n_joints=n_joints)
|
41 |
+
|
42 |
+
self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context,
|
43 |
+
rnn_type=rnn_type,
|
44 |
+
n_layers=n_layers)
|
45 |
+
|
46 |
+
# Module 3. Feature Integrator
|
47 |
+
self.integrator = Integrator(in_channel=d_feat + d_context,
|
48 |
+
out_channel=d_context)
|
49 |
+
|
50 |
+
# Module 4. Motion Decoder
|
51 |
+
self.motion_decoder = MotionDecoder(d_embed=d_context,
|
52 |
+
rnn_type=rnn_type,
|
53 |
+
n_layers=n_layers)
|
54 |
+
|
55 |
+
# Module 5. Trajectory Refiner
|
56 |
+
self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context,
|
57 |
+
d_hidden=d_embed,
|
58 |
+
rnn_type=rnn_type,
|
59 |
+
n_layers=2)
|
60 |
+
|
61 |
+
def compute_global_feet(self, root_world, trans):
|
62 |
+
# # Compute world-coordinate motion
|
63 |
+
cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam)
|
64 |
+
feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3)
|
65 |
+
feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT
|
66 |
+
|
67 |
+
return feet_world, cam_R
|
68 |
+
|
69 |
+
def forward_smpl(self, **kwargs):
|
70 |
+
self.output = self.smpl(self.pred_pose,
|
71 |
+
self.pred_shape,
|
72 |
+
cam=self.pred_cam,
|
73 |
+
return_full_pose=not self.training,
|
74 |
+
**kwargs,
|
75 |
+
)
|
76 |
+
|
77 |
+
from loguru import logger
|
78 |
+
logger.info(f"Output Joints: {self.output.joints}")
|
79 |
+
logger.info(f"Output Vertices: {self.output.vertices}")
|
80 |
+
|
81 |
+
# Save joints and vertices as .npy arrays
|
82 |
+
|
83 |
+
np.save('joints.npy', self.output.joints.cpu().numpy())
|
84 |
+
np.save('vertices.npy', self.output.vertices.cpu().numpy())
|
85 |
+
|
86 |
+
# Feet location in global coordinate
|
87 |
+
root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel)
|
88 |
+
feet_world, cam_R = self.compute_global_feet(root_world, trans)
|
89 |
+
|
90 |
+
# Return output
|
91 |
+
output = {'feet': feet_world,
|
92 |
+
'contact': self.pred_contact,
|
93 |
+
'pose': self.pred_pose,
|
94 |
+
'betas': self.pred_shape,
|
95 |
+
'cam': self.pred_cam,
|
96 |
+
'poses_root_cam': self.output.global_orient,
|
97 |
+
'poses_root_r6d': self.pred_root,
|
98 |
+
'vel_root': self.pred_vel,
|
99 |
+
'pose_root': self.pred_root,
|
100 |
+
'verts_cam': self.output.vertices}
|
101 |
+
|
102 |
+
if self.training:
|
103 |
+
output.update({
|
104 |
+
'kp3d': self.output.joints,
|
105 |
+
'kp3d_nn': self.pred_kp3d,
|
106 |
+
'full_kp2d': self.output.full_joints2d,
|
107 |
+
'weak_kp2d': self.output.weak_joints2d,
|
108 |
+
'R': cam_R,
|
109 |
+
})
|
110 |
+
else:
|
111 |
+
output.update({
|
112 |
+
'poses_root_r6d': self.pred_root,
|
113 |
+
'trans_cam': self.output.full_cam,
|
114 |
+
'poses_body': self.output.body_pose})
|
115 |
+
|
116 |
+
return output
|
117 |
+
|
118 |
+
|
119 |
+
def preprocess(self, x, mask):
|
120 |
+
self.b, self.f = x.shape[:2]
|
121 |
+
|
122 |
+
# Treat masked keypoints
|
123 |
+
mask_embedding = mask.unsqueeze(-1) * self.mask_embedding
|
124 |
+
_mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1)
|
125 |
+
_mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1)
|
126 |
+
_mask_embedding = mask_embedding.reshape(self.b, self.f, -1)
|
127 |
+
_mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1)
|
128 |
+
x[_mask] = 0.0
|
129 |
+
x = x + _mask_embedding
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
def rollout(self, output, pred_root, pred_vel, return_y_up):
|
134 |
+
root_world, trans_world = rollout_global_motion(pred_root, pred_vel)
|
135 |
+
|
136 |
+
if return_y_up:
|
137 |
+
yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device)
|
138 |
+
root_world = yup2ydown.mT @ root_world
|
139 |
+
trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1)
|
140 |
+
|
141 |
+
output.update({
|
142 |
+
'poses_root_world': root_world,
|
143 |
+
'trans_world': trans_world,
|
144 |
+
})
|
145 |
+
|
146 |
+
return output
|
147 |
+
|
148 |
+
|
149 |
+
def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs):
|
150 |
+
|
151 |
+
# --------- Refine trajectory --------- #
|
152 |
+
update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5)
|
153 |
+
output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up)
|
154 |
+
# --------- #
|
155 |
+
|
156 |
+
# Do rollout
|
157 |
+
output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up)
|
158 |
+
|
159 |
+
# --------- Compute refined feet --------- #
|
160 |
+
if self.training:
|
161 |
+
feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world'])
|
162 |
+
output.update({'feet_refined': feet_world})
|
163 |
+
|
164 |
+
return output
|
165 |
+
|
166 |
+
|
167 |
+
def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None,
|
168 |
+
cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs):
|
169 |
+
|
170 |
+
x = self.preprocess(x, mask)
|
171 |
+
init_kp, init_smpl = inits
|
172 |
+
|
173 |
+
# --------- Inference --------- #
|
174 |
+
# Stage 1. Encode motion
|
175 |
+
pred_kp3d, motion_context = self.motion_encoder(x, init_kp)
|
176 |
+
self.old_motion_context = motion_context.detach().clone()
|
177 |
+
|
178 |
+
# Stage 2. Decode global trajectory
|
179 |
+
pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel)
|
180 |
+
|
181 |
+
# Stage 3. Integrate features
|
182 |
+
if img_features is not None and self.integrator is not None:
|
183 |
+
motion_context = self.integrator(motion_context, img_features)
|
184 |
+
|
185 |
+
# Stage 4. Decode SMPL motion
|
186 |
+
pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl)
|
187 |
+
# --------- #
|
188 |
+
|
189 |
+
# --------- Register predictions --------- #
|
190 |
+
self.pred_kp3d = pred_kp3d
|
191 |
+
self.pred_root = pred_root
|
192 |
+
self.pred_vel = pred_vel
|
193 |
+
self.pred_pose = pred_pose
|
194 |
+
self.pred_shape = pred_shape
|
195 |
+
self.pred_cam = pred_cam
|
196 |
+
self.pred_contact = pred_contact
|
197 |
+
# --------- #
|
198 |
+
|
199 |
+
# --------- Build SMPL --------- #
|
200 |
+
output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res)
|
201 |
+
# --------- #
|
202 |
+
|
203 |
+
# --------- Refine trajectory --------- #
|
204 |
+
if refine_traj:
|
205 |
+
output = self.refine_trajectory(output, cam_angvel, return_y_up)
|
206 |
+
else:
|
207 |
+
output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up)
|
208 |
+
# --------- #
|
209 |
+
|
210 |
+
return output
|
lib/utils/__pycache__/data_utils.cpython-39.pyc
ADDED
Binary file (3.57 kB). View file
|
|
lib/utils/__pycache__/imutils.cpython-39.pyc
ADDED
Binary file (10.6 kB). View file
|
|
lib/utils/__pycache__/kp_utils.cpython-39.pyc
ADDED
Binary file (9.99 kB). View file
|
|
lib/utils/__pycache__/transforms.cpython-39.pyc
ADDED
Binary file (23.2 kB). View file
|
|
lib/utils/data_utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from lib.utils import transforms
|
9 |
+
|
10 |
+
|
11 |
+
def make_collate_fn():
|
12 |
+
def collate_fn(items):
|
13 |
+
items = list(filter(lambda x: x is not None , items))
|
14 |
+
batch = dict()
|
15 |
+
try: batch['vid'] = [item['vid'] for item in items]
|
16 |
+
except: pass
|
17 |
+
try: batch['gender'] = [item['gender'] for item in items]
|
18 |
+
except: pass
|
19 |
+
for key in items[0].keys():
|
20 |
+
try: batch[key] = torch.stack([item[key] for item in items])
|
21 |
+
except: pass
|
22 |
+
return batch
|
23 |
+
|
24 |
+
return collate_fn
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_keypoints_data(target):
|
28 |
+
"""Prepare keypoints data"""
|
29 |
+
|
30 |
+
# Prepare 2D keypoints
|
31 |
+
target['init_kp2d'] = target['kp2d'][:1]
|
32 |
+
target['kp2d'] = target['kp2d'][1:]
|
33 |
+
if 'kp3d' in target:
|
34 |
+
target['kp3d'] = target['kp3d'][1:]
|
35 |
+
|
36 |
+
return target
|
37 |
+
|
38 |
+
|
39 |
+
def prepare_smpl_data(target):
|
40 |
+
if 'pose' in target.keys():
|
41 |
+
# Use only the main joints
|
42 |
+
pose = target['pose'][:]
|
43 |
+
# 6-D Rotation representation
|
44 |
+
pose6d = transforms.matrix_to_rotation_6d(pose)
|
45 |
+
target['pose'] = pose6d[1:]
|
46 |
+
|
47 |
+
if 'betas' in target.keys():
|
48 |
+
target['betas'] = target['betas'][1:]
|
49 |
+
|
50 |
+
# Translation and shape parameters
|
51 |
+
if 'transl' in target.keys():
|
52 |
+
target['cam'] = target['transl'][1:]
|
53 |
+
|
54 |
+
# Initial pose and translation
|
55 |
+
target['init_pose'] = transforms.matrix_to_rotation_6d(target['init_pose'])
|
56 |
+
|
57 |
+
return target
|
58 |
+
|
59 |
+
|
60 |
+
def append_target(target, label, key_list, idx1, idx2=None, pad=True):
|
61 |
+
for key in key_list:
|
62 |
+
if idx2 is None: data = label[key][idx1]
|
63 |
+
else: data = label[key][idx1:idx2+1]
|
64 |
+
if not pad: data = data[2:]
|
65 |
+
target[key] = data
|
66 |
+
|
67 |
+
return target
|
68 |
+
|
69 |
+
|
70 |
+
def map_dmpl_to_smpl(pose):
|
71 |
+
""" Map AMASS DMPL pose representation to SMPL pose representation
|
72 |
+
|
73 |
+
Args:
|
74 |
+
pose - tensor / array with shape of (n_frames, 156)
|
75 |
+
|
76 |
+
Return:
|
77 |
+
pose - tensor / array with shape of (n_frames, 24, 3)
|
78 |
+
"""
|
79 |
+
|
80 |
+
pose = pose.reshape(pose.shape[0], -1, 3)
|
81 |
+
pose[:, 23] = pose[:, 37] # right hand
|
82 |
+
if isinstance(pose, np.ndarray): pose = pose[:, :24].copy()
|
83 |
+
else: pose = pose[:, :24].clone()
|
84 |
+
return pose
|
85 |
+
|
86 |
+
|
87 |
+
def transform_global_coordinate(pose, T, transl=None):
|
88 |
+
""" Transform global coordinate of dataset with respect to the given matrix.
|
89 |
+
Various datasets have different global coordinate system,
|
90 |
+
thus we united all datasets to the cronical coordinate system.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
pose - SMPL pose; tensor / array
|
94 |
+
T - Transformation matrix
|
95 |
+
transl - SMPL translation
|
96 |
+
"""
|
97 |
+
|
98 |
+
return_to_numpy = False
|
99 |
+
if isinstance(pose, np.ndarray):
|
100 |
+
return_to_numpy = True
|
101 |
+
pose = torch.from_numpy(pose).float()
|
102 |
+
if transl is not None: transl = torch.from_numpy(transl).float()
|
103 |
+
|
104 |
+
pose = transforms.axis_angle_to_matrix(pose)
|
105 |
+
pose[:, 0] = T @ pose[:, 0]
|
106 |
+
pose = transforms.matrix_to_axis_angle(pose)
|
107 |
+
if transl is not None:
|
108 |
+
transl = (T @ transl.T).squeeze().T
|
109 |
+
|
110 |
+
if return_to_numpy:
|
111 |
+
pose = pose.detach().numpy()
|
112 |
+
if transl is not None: transl = transl.detach().numpy()
|
113 |
+
return pose, transl
|
lib/utils/imutils.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from . import transforms
|
6 |
+
|
7 |
+
def do_augmentation(scale_factor=0.2, trans_factor=0.1):
|
8 |
+
scale = random.uniform(1.2 - scale_factor, 1.2 + scale_factor)
|
9 |
+
trans_x = random.uniform(-trans_factor, trans_factor)
|
10 |
+
trans_y = random.uniform(-trans_factor, trans_factor)
|
11 |
+
|
12 |
+
return scale, trans_x, trans_y
|
13 |
+
|
14 |
+
def get_transform(center, scale, res, rot=0):
|
15 |
+
"""Generate transformation matrix."""
|
16 |
+
# res: (height, width), (rows, cols)
|
17 |
+
crop_aspect_ratio = res[0] / float(res[1])
|
18 |
+
h = 200 * scale
|
19 |
+
w = h / crop_aspect_ratio
|
20 |
+
t = np.zeros((3, 3))
|
21 |
+
t[0, 0] = float(res[1]) / w
|
22 |
+
t[1, 1] = float(res[0]) / h
|
23 |
+
t[0, 2] = res[1] * (-float(center[0]) / w + .5)
|
24 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
25 |
+
t[2, 2] = 1
|
26 |
+
if not rot == 0:
|
27 |
+
rot = -rot # To match direction of rotation from cropping
|
28 |
+
rot_mat = np.zeros((3, 3))
|
29 |
+
rot_rad = rot * np.pi / 180
|
30 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
31 |
+
rot_mat[0, :2] = [cs, -sn]
|
32 |
+
rot_mat[1, :2] = [sn, cs]
|
33 |
+
rot_mat[2, 2] = 1
|
34 |
+
# Need to rotate around center
|
35 |
+
t_mat = np.eye(3)
|
36 |
+
t_mat[0, 2] = -res[1] / 2
|
37 |
+
t_mat[1, 2] = -res[0] / 2
|
38 |
+
t_inv = t_mat.copy()
|
39 |
+
t_inv[:2, 2] *= -1
|
40 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
41 |
+
return t
|
42 |
+
|
43 |
+
|
44 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
45 |
+
"""Transform pixel location to different reference."""
|
46 |
+
t = get_transform(center, scale, res, rot=rot)
|
47 |
+
if invert:
|
48 |
+
t = np.linalg.inv(t)
|
49 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
50 |
+
new_pt = np.dot(t, new_pt)
|
51 |
+
return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
|
52 |
+
|
53 |
+
|
54 |
+
def crop_cliff(img, center, scale, res):
|
55 |
+
"""
|
56 |
+
Crop image according to the supplied bounding box.
|
57 |
+
res: [rows, cols]
|
58 |
+
"""
|
59 |
+
# Upper left point
|
60 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
61 |
+
# Bottom right point
|
62 |
+
br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
|
63 |
+
|
64 |
+
# Padding so that when rotated proper amount of context is included
|
65 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
66 |
+
|
67 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
68 |
+
if len(img.shape) > 2:
|
69 |
+
new_shape += [img.shape[2]]
|
70 |
+
new_img = np.zeros(new_shape, dtype=np.float32)
|
71 |
+
|
72 |
+
# Range to fill new array
|
73 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
74 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
75 |
+
# Range to sample from original image
|
76 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
77 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
78 |
+
|
79 |
+
try:
|
80 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
81 |
+
except Exception as e:
|
82 |
+
print(e)
|
83 |
+
|
84 |
+
new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
|
85 |
+
|
86 |
+
return new_img, ul, br
|
87 |
+
|
88 |
+
|
89 |
+
def obtain_bbox(center, scale, res, org_res):
|
90 |
+
# Upper left point
|
91 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
92 |
+
# Bottom right point
|
93 |
+
br = np.array(transform([res[1] + 1, res[0] + 1], center, scale, res, invert=1)) - 1
|
94 |
+
|
95 |
+
# Padding so that when rotated proper amount of context is included
|
96 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
97 |
+
|
98 |
+
# Range to sample from original image
|
99 |
+
old_x = max(0, ul[0]), min(org_res[0], br[0])
|
100 |
+
old_y = max(0, ul[1]), min(org_res[1], br[1])
|
101 |
+
|
102 |
+
return old_x, old_y
|
103 |
+
|
104 |
+
|
105 |
+
def cam_crop2full(crop_cam, bbox, full_img_shape, focal_length=None):
|
106 |
+
"""
|
107 |
+
convert the camera parameters from the crop camera to the full camera
|
108 |
+
:param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
|
109 |
+
:param center: shape=(N, 2) bbox coordinates (c_x, c_y)
|
110 |
+
:param scale: shape=(N, 1) square bbox resolution (b / 200)
|
111 |
+
:param full_img_shape: shape=(N, 2) original image height and width
|
112 |
+
:param focal_length: shape=(N,)
|
113 |
+
:return:
|
114 |
+
"""
|
115 |
+
|
116 |
+
cx = bbox[..., 0].clone(); cy = bbox[..., 1].clone(); b = bbox[..., 2].clone() * 200
|
117 |
+
img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
|
118 |
+
w_2, h_2 = img_w / 2., img_h / 2.
|
119 |
+
bs = b * crop_cam[:, :, 0] + 1e-9
|
120 |
+
|
121 |
+
if focal_length is None:
|
122 |
+
focal_length = (img_w * img_w + img_h * img_h) ** 0.5
|
123 |
+
|
124 |
+
tz = 2 * focal_length.unsqueeze(-1) / bs
|
125 |
+
tx = (2 * (cx - w_2.unsqueeze(-1)) / bs) + crop_cam[:, :, 1]
|
126 |
+
ty = (2 * (cy - h_2.unsqueeze(-1)) / bs) + crop_cam[:, :, 2]
|
127 |
+
full_cam = torch.stack([tx, ty, tz], dim=-1)
|
128 |
+
return full_cam
|
129 |
+
|
130 |
+
|
131 |
+
def cam_pred2full(crop_cam, center, scale, full_img_shape, focal_length=2000.,):
|
132 |
+
"""
|
133 |
+
Reference CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation
|
134 |
+
|
135 |
+
convert the camera parameters from the crop camera to the full camera
|
136 |
+
:param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)
|
137 |
+
:param center: shape=(N, 2) bbox coordinates (c_x, c_y)
|
138 |
+
:param scale: shape=(N, ) square bbox resolution (b / 200)
|
139 |
+
:param full_img_shape: shape=(N, 2) original image height and width
|
140 |
+
:param focal_length: shape=(N,)
|
141 |
+
:return:
|
142 |
+
"""
|
143 |
+
|
144 |
+
# img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
|
145 |
+
img_w, img_h = full_img_shape[:, 0], full_img_shape[:, 1]
|
146 |
+
cx, cy, b = center[:, 0], center[:, 1], scale * 200
|
147 |
+
w_2, h_2 = img_w / 2., img_h / 2.
|
148 |
+
bs = b * crop_cam[:, 0] + 1e-9
|
149 |
+
tz = 2 * focal_length / bs
|
150 |
+
tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1]
|
151 |
+
ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2]
|
152 |
+
full_cam = torch.stack([tx, ty, tz], dim=-1)
|
153 |
+
return full_cam
|
154 |
+
|
155 |
+
|
156 |
+
def cam_full2pred(full_cam, center, scale, full_img_shape, focal_length=2000.):
|
157 |
+
# img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
|
158 |
+
img_w, img_h = full_img_shape[:, 0], full_img_shape[:, 1]
|
159 |
+
cx, cy, b = center[:, 0], center[:, 1], scale * 200
|
160 |
+
w_2, h_2 = img_w / 2., img_h / 2.
|
161 |
+
|
162 |
+
bs = (2 * focal_length / full_cam[:, 2])
|
163 |
+
_s = bs / b
|
164 |
+
_tx = full_cam[:, 0] - (2 * (cx - w_2) / bs)
|
165 |
+
_ty = full_cam[:, 1] - (2 * (cy - h_2) / bs)
|
166 |
+
crop_cam = torch.stack([_s, _tx, _ty], dim=-1)
|
167 |
+
return crop_cam
|
168 |
+
|
169 |
+
|
170 |
+
def obtain_camera_intrinsics(image_shape, focal_length):
|
171 |
+
res_w = image_shape[..., 0].clone()
|
172 |
+
res_h = image_shape[..., 1].clone()
|
173 |
+
K = torch.eye(3).unsqueeze(0).expand(focal_length.shape[0], -1, -1).to(focal_length.device)
|
174 |
+
K[..., 0, 0] = focal_length.clone()
|
175 |
+
K[..., 1, 1] = focal_length.clone()
|
176 |
+
K[..., 0, 2] = res_w / 2
|
177 |
+
K[..., 1, 2] = res_h / 2
|
178 |
+
|
179 |
+
return K.unsqueeze(1)
|
180 |
+
|
181 |
+
|
182 |
+
def trans_point2d(pt_2d, trans):
|
183 |
+
src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
|
184 |
+
dst_pt = np.dot(trans, src_pt)
|
185 |
+
return dst_pt[0:2]
|
186 |
+
|
187 |
+
def rotate_2d(pt_2d, rot_rad):
|
188 |
+
x = pt_2d[0]
|
189 |
+
y = pt_2d[1]
|
190 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
191 |
+
xx = x * cs - y * sn
|
192 |
+
yy = x * sn + y * cs
|
193 |
+
return np.array([xx, yy], dtype=np.float32)
|
194 |
+
|
195 |
+
def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
|
196 |
+
# augment size with scale
|
197 |
+
src_w = src_width * scale
|
198 |
+
src_h = src_height * scale
|
199 |
+
src_center = np.zeros(2)
|
200 |
+
src_center[0] = c_x
|
201 |
+
src_center[1] = c_y # np.array([c_x, c_y], dtype=np.float32)
|
202 |
+
# augment rotation
|
203 |
+
rot_rad = np.pi * rot / 180
|
204 |
+
src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
|
205 |
+
src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
|
206 |
+
|
207 |
+
dst_w = dst_width
|
208 |
+
dst_h = dst_height
|
209 |
+
dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
|
210 |
+
dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
|
211 |
+
dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
|
212 |
+
|
213 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
214 |
+
src[0, :] = src_center
|
215 |
+
src[1, :] = src_center + src_downdir
|
216 |
+
src[2, :] = src_center + src_rightdir
|
217 |
+
|
218 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
219 |
+
dst[0, :] = dst_center
|
220 |
+
dst[1, :] = dst_center + dst_downdir
|
221 |
+
dst[2, :] = dst_center + dst_rightdir
|
222 |
+
|
223 |
+
if inv:
|
224 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
225 |
+
else:
|
226 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
227 |
+
|
228 |
+
return trans
|
229 |
+
|
230 |
+
def transform_keypoints(kp_2d, bbox, patch_width, patch_height):
|
231 |
+
|
232 |
+
center_x, center_y, scale = bbox[:3]
|
233 |
+
width = height = scale * 200
|
234 |
+
# scale, rot = 1.2, 0
|
235 |
+
scale, rot = 1.0, 0
|
236 |
+
|
237 |
+
# generate transformation
|
238 |
+
trans = gen_trans_from_patch_cv(
|
239 |
+
center_x,
|
240 |
+
center_y,
|
241 |
+
width,
|
242 |
+
height,
|
243 |
+
patch_width,
|
244 |
+
patch_height,
|
245 |
+
scale,
|
246 |
+
rot,
|
247 |
+
inv=False,
|
248 |
+
)
|
249 |
+
|
250 |
+
for n_jt in range(kp_2d.shape[0]):
|
251 |
+
kp_2d[n_jt] = trans_point2d(kp_2d[n_jt], trans)
|
252 |
+
|
253 |
+
return kp_2d, trans
|
254 |
+
|
255 |
+
|
256 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
257 |
+
"""Transform pixel location to different reference."""
|
258 |
+
t = get_transform(center, scale, res, rot=rot)
|
259 |
+
if invert:
|
260 |
+
t = np.linalg.inv(t)
|
261 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
262 |
+
new_pt = np.dot(t, new_pt)
|
263 |
+
return new_pt[:2].astype(int) + 1
|
264 |
+
|
265 |
+
|
266 |
+
def compute_cam_intrinsics(res):
|
267 |
+
img_w, img_h = res
|
268 |
+
focal_length = (img_w * img_w + img_h * img_h) ** 0.5
|
269 |
+
cam_intrinsics = torch.eye(3).repeat(1, 1, 1).float()
|
270 |
+
cam_intrinsics[:, 0, 0] = focal_length
|
271 |
+
cam_intrinsics[:, 1, 1] = focal_length
|
272 |
+
cam_intrinsics[:, 0, 2] = img_w/2.
|
273 |
+
cam_intrinsics[:, 1, 2] = img_h/2.
|
274 |
+
return cam_intrinsics
|
275 |
+
|
276 |
+
|
277 |
+
def flip_kp(kp, img_w=None):
|
278 |
+
"""Flip keypoints."""
|
279 |
+
|
280 |
+
flipped_parts = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
281 |
+
kp = kp[..., flipped_parts, :]
|
282 |
+
|
283 |
+
if img_w is not None:
|
284 |
+
# Assume 2D keypoints
|
285 |
+
kp[...,0] = img_w - kp[...,0]
|
286 |
+
return kp
|
287 |
+
|
288 |
+
|
289 |
+
def flip_bbox(bbox, img_w, img_h):
|
290 |
+
center = bbox[..., :2]
|
291 |
+
scale = bbox[..., -1:]
|
292 |
+
|
293 |
+
WH = np.ones_like(center)
|
294 |
+
WH[..., 0] *= img_w
|
295 |
+
WH[..., 1] *= img_h
|
296 |
+
|
297 |
+
center = center - WH/2
|
298 |
+
center[...,0] = - center[...,0]
|
299 |
+
center = center + WH/2
|
300 |
+
|
301 |
+
flipped_bbox = np.concatenate((center, scale), axis=-1)
|
302 |
+
return flipped_bbox
|
303 |
+
|
304 |
+
|
305 |
+
def flip_pose(rotation, representation='rotation_6d'):
|
306 |
+
"""Flip pose.
|
307 |
+
The flipping is based on SMPL parameters.
|
308 |
+
"""
|
309 |
+
|
310 |
+
BN = rotation.shape[0]
|
311 |
+
|
312 |
+
if representation == 'axis_angle':
|
313 |
+
pose = rotation.reshape(BN, -1).transpose(0, 1)
|
314 |
+
elif representation == 'matrix':
|
315 |
+
pose = transforms.matrix_to_axis_angle(rotation).reshape(BN, -1).transpose(0, 1)
|
316 |
+
elif representation == 'rotation_6d':
|
317 |
+
pose = transforms.matrix_to_axis_angle(
|
318 |
+
transforms.rotation_6d_to_matrix(rotation)
|
319 |
+
).reshape(BN, -1).transpose(0, 1)
|
320 |
+
else:
|
321 |
+
raise ValueError(f"Unknown representation: {representation}")
|
322 |
+
|
323 |
+
SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
|
324 |
+
SMPL_POSE_FLIP_PERM = []
|
325 |
+
for i in SMPL_JOINTS_FLIP_PERM:
|
326 |
+
SMPL_POSE_FLIP_PERM.append(3*i)
|
327 |
+
SMPL_POSE_FLIP_PERM.append(3*i+1)
|
328 |
+
SMPL_POSE_FLIP_PERM.append(3*i+2)
|
329 |
+
|
330 |
+
pose = pose[SMPL_POSE_FLIP_PERM]
|
331 |
+
|
332 |
+
# we also negate the second and the third dimension of the axis-angle
|
333 |
+
pose[1::3] = -pose[1::3]
|
334 |
+
pose[2::3] = -pose[2::3]
|
335 |
+
pose = pose.transpose(0, 1).reshape(BN, -1, 3)
|
336 |
+
|
337 |
+
if representation == 'aa':
|
338 |
+
return pose
|
339 |
+
elif representation == 'rotmat':
|
340 |
+
return transforms.axis_angle_to_matrix(pose)
|
341 |
+
else:
|
342 |
+
return transforms.matrix_to_rotation_6d(
|
343 |
+
transforms.axis_angle_to_matrix(pose)
|
344 |
+
)
|
345 |
+
|
346 |
+
def avg_preds(rotation, shape, flipped_rotation, flipped_shape, representation='rotation_6d'):
|
347 |
+
# Rotation
|
348 |
+
flipped_rotation = flip_pose(flipped_rotation, representation=representation)
|
349 |
+
|
350 |
+
if representation != 'matrix':
|
351 |
+
flipped_rotation = eval(f'transforms.{representation}_to_matrix')(flipped_rotation)
|
352 |
+
rotation = eval(f'transforms.{representation}_to_matrix')(rotation)
|
353 |
+
|
354 |
+
avg_rotation = torch.stack([rotation, flipped_rotation])
|
355 |
+
avg_rotation = transforms.avg_rot(avg_rotation)
|
356 |
+
|
357 |
+
if representation != 'matrix':
|
358 |
+
avg_rotation = eval(f'transforms.matrix_to_{representation}')(avg_rotation)
|
359 |
+
|
360 |
+
# Shape
|
361 |
+
avg_shape = (shape + flipped_shape) / 2.0
|
362 |
+
|
363 |
+
return avg_rotation, avg_shape
|
lib/utils/kp_utils.py
ADDED
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from configs import constants as _C
|
9 |
+
|
10 |
+
def root_centering(X, joint_type='coco'):
|
11 |
+
"""Center the root joint to the pelvis."""
|
12 |
+
if joint_type != 'common' and X.shape[-2] == 14: return X
|
13 |
+
|
14 |
+
conf = None
|
15 |
+
if X.shape[-1] == 4:
|
16 |
+
conf = X[..., -1:]
|
17 |
+
X = X[..., :-1]
|
18 |
+
|
19 |
+
if X.shape[-2] == 31:
|
20 |
+
X[..., :17, :] = X[..., :17, :] - X[..., [12, 11], :].mean(-2, keepdims=True)
|
21 |
+
X[..., 17:, :] = X[..., 17:, :] - X[..., [19, 20], :].mean(-2, keepdims=True)
|
22 |
+
|
23 |
+
elif joint_type == 'coco':
|
24 |
+
X = X - X[..., [12, 11], :].mean(-2, keepdims=True)
|
25 |
+
|
26 |
+
elif joint_type == 'common':
|
27 |
+
X = X - X[..., [2, 3], :].mean(-2, keepdims=True)
|
28 |
+
|
29 |
+
if conf is not None:
|
30 |
+
X = torch.cat((X, conf), dim=-1)
|
31 |
+
|
32 |
+
return X
|
33 |
+
|
34 |
+
|
35 |
+
def convert_kps(joints2d, src, dst):
|
36 |
+
src_names = eval(f'get_{src}_joint_names')()
|
37 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
38 |
+
|
39 |
+
if isinstance(joints2d, np.ndarray):
|
40 |
+
out_joints2d = np.zeros((*joints2d.shape[:-2], len(dst_names), joints2d.shape[-1]))
|
41 |
+
else:
|
42 |
+
out_joints2d = torch.zeros((*joints2d.shape[:-2], len(dst_names), joints2d.shape[-1]), device=joints2d.device)
|
43 |
+
|
44 |
+
for idx, jn in enumerate(dst_names):
|
45 |
+
if jn in src_names:
|
46 |
+
out_joints2d[..., idx, :] = joints2d[..., src_names.index(jn), :]
|
47 |
+
|
48 |
+
return out_joints2d
|
49 |
+
|
50 |
+
def get_perm_idxs(src, dst):
|
51 |
+
src_names = eval(f'get_{src}_joint_names')()
|
52 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
53 |
+
idxs = [src_names.index(h) for h in dst_names if h in src_names]
|
54 |
+
return idxs
|
55 |
+
|
56 |
+
def get_mpii3d_test_joint_names():
|
57 |
+
return [
|
58 |
+
'headtop', # 'head_top',
|
59 |
+
'neck',
|
60 |
+
'rshoulder',# 'right_shoulder',
|
61 |
+
'relbow',# 'right_elbow',
|
62 |
+
'rwrist',# 'right_wrist',
|
63 |
+
'lshoulder',# 'left_shoulder',
|
64 |
+
'lelbow', # 'left_elbow',
|
65 |
+
'lwrist', # 'left_wrist',
|
66 |
+
'rhip', # 'right_hip',
|
67 |
+
'rknee', # 'right_knee',
|
68 |
+
'rankle',# 'right_ankle',
|
69 |
+
'lhip',# 'left_hip',
|
70 |
+
'lknee',# 'left_knee',
|
71 |
+
'lankle',# 'left_ankle'
|
72 |
+
'hip',# 'pelvis',
|
73 |
+
'Spine (H36M)',# 'spine',
|
74 |
+
'Head (H36M)',# 'head'
|
75 |
+
]
|
76 |
+
|
77 |
+
def get_mpii3d_joint_names():
|
78 |
+
return [
|
79 |
+
'spine3', # 0,
|
80 |
+
'spine4', # 1,
|
81 |
+
'spine2', # 2,
|
82 |
+
'Spine (H36M)', #'spine', # 3,
|
83 |
+
'hip', # 'pelvis', # 4,
|
84 |
+
'neck', # 5,
|
85 |
+
'Head (H36M)', # 'head', # 6,
|
86 |
+
"headtop", # 'head_top', # 7,
|
87 |
+
'left_clavicle', # 8,
|
88 |
+
"lshoulder", # 'left_shoulder', # 9,
|
89 |
+
"lelbow", # 'left_elbow',# 10,
|
90 |
+
"lwrist", # 'left_wrist',# 11,
|
91 |
+
'left_hand',# 12,
|
92 |
+
'right_clavicle',# 13,
|
93 |
+
'rshoulder',# 'right_shoulder',# 14,
|
94 |
+
'relbow',# 'right_elbow',# 15,
|
95 |
+
'rwrist',# 'right_wrist',# 16,
|
96 |
+
'right_hand',# 17,
|
97 |
+
'lhip', # left_hip',# 18,
|
98 |
+
'lknee', # 'left_knee',# 19,
|
99 |
+
'lankle', #left ankle # 20
|
100 |
+
'left_foot', # 21
|
101 |
+
'left_toe', # 22
|
102 |
+
"rhip", # 'right_hip',# 23
|
103 |
+
"rknee", # 'right_knee',# 24
|
104 |
+
"rankle", #'right_ankle', # 25
|
105 |
+
'right_foot',# 26
|
106 |
+
'right_toe' # 27
|
107 |
+
]
|
108 |
+
|
109 |
+
def get_insta_joint_names():
|
110 |
+
return [
|
111 |
+
'OP RHeel',
|
112 |
+
'OP RKnee',
|
113 |
+
'OP RHip',
|
114 |
+
'OP LHip',
|
115 |
+
'OP LKnee',
|
116 |
+
'OP LHeel',
|
117 |
+
'OP RWrist',
|
118 |
+
'OP RElbow',
|
119 |
+
'OP RShoulder',
|
120 |
+
'OP LShoulder',
|
121 |
+
'OP LElbow',
|
122 |
+
'OP LWrist',
|
123 |
+
'OP Neck',
|
124 |
+
'headtop',
|
125 |
+
'OP Nose',
|
126 |
+
'OP LEye',
|
127 |
+
'OP REye',
|
128 |
+
'OP LEar',
|
129 |
+
'OP REar',
|
130 |
+
'OP LBigToe',
|
131 |
+
'OP RBigToe',
|
132 |
+
'OP LSmallToe',
|
133 |
+
'OP RSmallToe',
|
134 |
+
'OP LAnkle',
|
135 |
+
'OP RAnkle',
|
136 |
+
]
|
137 |
+
|
138 |
+
def get_insta_skeleton():
|
139 |
+
return np.array(
|
140 |
+
[
|
141 |
+
[0 , 1],
|
142 |
+
[1 , 2],
|
143 |
+
[2 , 3],
|
144 |
+
[3 , 4],
|
145 |
+
[4 , 5],
|
146 |
+
[6 , 7],
|
147 |
+
[7 , 8],
|
148 |
+
[8 , 9],
|
149 |
+
[9 ,10],
|
150 |
+
[2 , 8],
|
151 |
+
[3 , 9],
|
152 |
+
[10,11],
|
153 |
+
[8 ,12],
|
154 |
+
[9 ,12],
|
155 |
+
[12,13],
|
156 |
+
[12,14],
|
157 |
+
[14,15],
|
158 |
+
[14,16],
|
159 |
+
[15,17],
|
160 |
+
[16,18],
|
161 |
+
[0 ,20],
|
162 |
+
[20,22],
|
163 |
+
[5 ,19],
|
164 |
+
[19,21],
|
165 |
+
[5 ,23],
|
166 |
+
[0 ,24],
|
167 |
+
])
|
168 |
+
|
169 |
+
def get_staf_skeleton():
|
170 |
+
return np.array(
|
171 |
+
[
|
172 |
+
[0, 1],
|
173 |
+
[1, 2],
|
174 |
+
[2, 3],
|
175 |
+
[3, 4],
|
176 |
+
[1, 5],
|
177 |
+
[5, 6],
|
178 |
+
[6, 7],
|
179 |
+
[1, 8],
|
180 |
+
[8, 9],
|
181 |
+
[9, 10],
|
182 |
+
[10, 11],
|
183 |
+
[8, 12],
|
184 |
+
[12, 13],
|
185 |
+
[13, 14],
|
186 |
+
[0, 15],
|
187 |
+
[0, 16],
|
188 |
+
[15, 17],
|
189 |
+
[16, 18],
|
190 |
+
[2, 9],
|
191 |
+
[5, 12],
|
192 |
+
[1, 19],
|
193 |
+
[20, 19],
|
194 |
+
]
|
195 |
+
)
|
196 |
+
|
197 |
+
def get_staf_joint_names():
|
198 |
+
return [
|
199 |
+
'OP Nose', # 0,
|
200 |
+
'OP Neck', # 1,
|
201 |
+
'OP RShoulder', # 2,
|
202 |
+
'OP RElbow', # 3,
|
203 |
+
'OP RWrist', # 4,
|
204 |
+
'OP LShoulder', # 5,
|
205 |
+
'OP LElbow', # 6,
|
206 |
+
'OP LWrist', # 7,
|
207 |
+
'OP MidHip', # 8,
|
208 |
+
'OP RHip', # 9,
|
209 |
+
'OP RKnee', # 10,
|
210 |
+
'OP RAnkle', # 11,
|
211 |
+
'OP LHip', # 12,
|
212 |
+
'OP LKnee', # 13,
|
213 |
+
'OP LAnkle', # 14,
|
214 |
+
'OP REye', # 15,
|
215 |
+
'OP LEye', # 16,
|
216 |
+
'OP REar', # 17,
|
217 |
+
'OP LEar', # 18,
|
218 |
+
'Neck (LSP)', # 19,
|
219 |
+
'Top of Head (LSP)', # 20,
|
220 |
+
]
|
221 |
+
|
222 |
+
def get_spin_joint_names():
|
223 |
+
return [
|
224 |
+
'OP Nose', # 0
|
225 |
+
'OP Neck', # 1
|
226 |
+
'OP RShoulder', # 2
|
227 |
+
'OP RElbow', # 3
|
228 |
+
'OP RWrist', # 4
|
229 |
+
'OP LShoulder', # 5
|
230 |
+
'OP LElbow', # 6
|
231 |
+
'OP LWrist', # 7
|
232 |
+
'OP MidHip', # 8
|
233 |
+
'OP RHip', # 9
|
234 |
+
'OP RKnee', # 10
|
235 |
+
'OP RAnkle', # 11
|
236 |
+
'OP LHip', # 12
|
237 |
+
'OP LKnee', # 13
|
238 |
+
'OP LAnkle', # 14
|
239 |
+
'OP REye', # 15
|
240 |
+
'OP LEye', # 16
|
241 |
+
'OP REar', # 17
|
242 |
+
'OP LEar', # 18
|
243 |
+
'OP LBigToe', # 19
|
244 |
+
'OP LSmallToe', # 20
|
245 |
+
'OP LHeel', # 21
|
246 |
+
'OP RBigToe', # 22
|
247 |
+
'OP RSmallToe', # 23
|
248 |
+
'OP RHeel', # 24
|
249 |
+
'rankle', # 25
|
250 |
+
'rknee', # 26
|
251 |
+
'rhip', # 27
|
252 |
+
'lhip', # 28
|
253 |
+
'lknee', # 29
|
254 |
+
'lankle', # 30
|
255 |
+
'rwrist', # 31
|
256 |
+
'relbow', # 32
|
257 |
+
'rshoulder', # 33
|
258 |
+
'lshoulder', # 34
|
259 |
+
'lelbow', # 35
|
260 |
+
'lwrist', # 36
|
261 |
+
'neck', # 37
|
262 |
+
'headtop', # 38
|
263 |
+
'hip', # 39 'Pelvis (MPII)', # 39
|
264 |
+
'thorax', # 40 'Thorax (MPII)', # 40
|
265 |
+
'Spine (H36M)', # 41
|
266 |
+
'Jaw (H36M)', # 42
|
267 |
+
'Head (H36M)', # 43
|
268 |
+
'nose', # 44
|
269 |
+
'leye', # 45 'Left Eye', # 45
|
270 |
+
'reye', # 46 'Right Eye', # 46
|
271 |
+
'lear', # 47 'Left Ear', # 47
|
272 |
+
'rear', # 48 'Right Ear', # 48
|
273 |
+
]
|
274 |
+
|
275 |
+
def get_h36m_joint_names():
|
276 |
+
return [
|
277 |
+
'hip', # 0
|
278 |
+
'lhip', # 1
|
279 |
+
'lknee', # 2
|
280 |
+
'lankle', # 3
|
281 |
+
'rhip', # 4
|
282 |
+
'rknee', # 5
|
283 |
+
'rankle', # 6
|
284 |
+
'Spine (H36M)', # 7
|
285 |
+
'neck', # 8
|
286 |
+
'Head (H36M)', # 9
|
287 |
+
'headtop', # 10
|
288 |
+
'lshoulder', # 11
|
289 |
+
'lelbow', # 12
|
290 |
+
'lwrist', # 13
|
291 |
+
'rshoulder', # 14
|
292 |
+
'relbow', # 15
|
293 |
+
'rwrist', # 16
|
294 |
+
]
|
295 |
+
|
296 |
+
'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist'
|
297 |
+
|
298 |
+
def get_spin_skeleton():
|
299 |
+
return np.array(
|
300 |
+
[
|
301 |
+
[0 , 1],
|
302 |
+
[1 , 2],
|
303 |
+
[2 , 3],
|
304 |
+
[3 , 4],
|
305 |
+
[1 , 5],
|
306 |
+
[5 , 6],
|
307 |
+
[6 , 7],
|
308 |
+
[1 , 8],
|
309 |
+
[8 , 9],
|
310 |
+
[9 ,10],
|
311 |
+
[10,11],
|
312 |
+
[8 ,12],
|
313 |
+
[12,13],
|
314 |
+
[13,14],
|
315 |
+
[0 ,15],
|
316 |
+
[0 ,16],
|
317 |
+
[15,17],
|
318 |
+
[16,18],
|
319 |
+
[21,19],
|
320 |
+
[19,20],
|
321 |
+
[14,21],
|
322 |
+
[11,24],
|
323 |
+
[24,22],
|
324 |
+
[22,23],
|
325 |
+
[0 ,38],
|
326 |
+
]
|
327 |
+
)
|
328 |
+
|
329 |
+
def get_posetrack_joint_names():
|
330 |
+
return [
|
331 |
+
"nose",
|
332 |
+
"neck",
|
333 |
+
"headtop",
|
334 |
+
"lear",
|
335 |
+
"rear",
|
336 |
+
"lshoulder",
|
337 |
+
"rshoulder",
|
338 |
+
"lelbow",
|
339 |
+
"relbow",
|
340 |
+
"lwrist",
|
341 |
+
"rwrist",
|
342 |
+
"lhip",
|
343 |
+
"rhip",
|
344 |
+
"lknee",
|
345 |
+
"rknee",
|
346 |
+
"lankle",
|
347 |
+
"rankle"
|
348 |
+
]
|
349 |
+
|
350 |
+
def get_posetrack_original_kp_names():
|
351 |
+
return [
|
352 |
+
'nose',
|
353 |
+
'head_bottom',
|
354 |
+
'head_top',
|
355 |
+
'left_ear',
|
356 |
+
'right_ear',
|
357 |
+
'left_shoulder',
|
358 |
+
'right_shoulder',
|
359 |
+
'left_elbow',
|
360 |
+
'right_elbow',
|
361 |
+
'left_wrist',
|
362 |
+
'right_wrist',
|
363 |
+
'left_hip',
|
364 |
+
'right_hip',
|
365 |
+
'left_knee',
|
366 |
+
'right_knee',
|
367 |
+
'left_ankle',
|
368 |
+
'right_ankle'
|
369 |
+
]
|
370 |
+
|
371 |
+
def get_pennaction_joint_names():
|
372 |
+
return [
|
373 |
+
"headtop", # 0
|
374 |
+
"lshoulder", # 1
|
375 |
+
"rshoulder", # 2
|
376 |
+
"lelbow", # 3
|
377 |
+
"relbow", # 4
|
378 |
+
"lwrist", # 5
|
379 |
+
"rwrist", # 6
|
380 |
+
"lhip" , # 7
|
381 |
+
"rhip" , # 8
|
382 |
+
"lknee", # 9
|
383 |
+
"rknee" , # 10
|
384 |
+
"lankle", # 11
|
385 |
+
"rankle" # 12
|
386 |
+
]
|
387 |
+
|
388 |
+
def get_common_joint_names():
|
389 |
+
return [
|
390 |
+
"rankle", # 0 "lankle", # 0
|
391 |
+
"rknee", # 1 "lknee", # 1
|
392 |
+
"rhip", # 2 "lhip", # 2
|
393 |
+
"lhip", # 3 "rhip", # 3
|
394 |
+
"lknee", # 4 "rknee", # 4
|
395 |
+
"lankle", # 5 "rankle", # 5
|
396 |
+
"rwrist", # 6 "lwrist", # 6
|
397 |
+
"relbow", # 7 "lelbow", # 7
|
398 |
+
"rshoulder", # 8 "lshoulder", # 8
|
399 |
+
"lshoulder", # 9 "rshoulder", # 9
|
400 |
+
"lelbow", # 10 "relbow", # 10
|
401 |
+
"lwrist", # 11 "rwrist", # 11
|
402 |
+
"neck", # 12 "neck", # 12
|
403 |
+
"headtop", # 13 "headtop", # 13
|
404 |
+
]
|
405 |
+
|
406 |
+
def get_coco_common_joint_names():
|
407 |
+
return [
|
408 |
+
"nose", # 0
|
409 |
+
"leye", # 1
|
410 |
+
"reye", # 2
|
411 |
+
"lear", # 3
|
412 |
+
"rear", # 4
|
413 |
+
"lshoulder", # 5
|
414 |
+
"rshoulder", # 6
|
415 |
+
"lelbow", # 7
|
416 |
+
"relbow", # 8
|
417 |
+
"lwrist", # 9
|
418 |
+
"rwrist", # 10
|
419 |
+
"lhip", # 11
|
420 |
+
"rhip", # 12
|
421 |
+
"lknee", # 13
|
422 |
+
"rknee", # 14
|
423 |
+
"lankle", # 15
|
424 |
+
"rankle", # 16
|
425 |
+
"neck", # 17 "neck", # 12
|
426 |
+
"headtop", # 18 "headtop", # 13
|
427 |
+
]
|
428 |
+
|
429 |
+
def get_common_skeleton():
|
430 |
+
return np.array(
|
431 |
+
[
|
432 |
+
[ 0, 1 ],
|
433 |
+
[ 1, 2 ],
|
434 |
+
[ 3, 4 ],
|
435 |
+
[ 4, 5 ],
|
436 |
+
[ 6, 7 ],
|
437 |
+
[ 7, 8 ],
|
438 |
+
[ 8, 2 ],
|
439 |
+
[ 8, 9 ],
|
440 |
+
[ 9, 3 ],
|
441 |
+
[ 2, 3 ],
|
442 |
+
[ 8, 12],
|
443 |
+
[ 9, 10],
|
444 |
+
[12, 9 ],
|
445 |
+
[10, 11],
|
446 |
+
[12, 13],
|
447 |
+
]
|
448 |
+
)
|
449 |
+
|
450 |
+
def get_coco_joint_names():
|
451 |
+
return [
|
452 |
+
"nose", # 0
|
453 |
+
"leye", # 1
|
454 |
+
"reye", # 2
|
455 |
+
"lear", # 3
|
456 |
+
"rear", # 4
|
457 |
+
"lshoulder", # 5
|
458 |
+
"rshoulder", # 6
|
459 |
+
"lelbow", # 7
|
460 |
+
"relbow", # 8
|
461 |
+
"lwrist", # 9
|
462 |
+
"rwrist", # 10
|
463 |
+
"lhip", # 11
|
464 |
+
"rhip", # 12
|
465 |
+
"lknee", # 13
|
466 |
+
"rknee", # 14
|
467 |
+
"lankle", # 15
|
468 |
+
"rankle", # 16
|
469 |
+
]
|
470 |
+
|
471 |
+
def get_coco_skeleton():
|
472 |
+
# 0 - nose,
|
473 |
+
# 1 - leye,
|
474 |
+
# 2 - reye,
|
475 |
+
# 3 - lear,
|
476 |
+
# 4 - rear,
|
477 |
+
# 5 - lshoulder,
|
478 |
+
# 6 - rshoulder,
|
479 |
+
# 7 - lelbow,
|
480 |
+
# 8 - relbow,
|
481 |
+
# 9 - lwrist,
|
482 |
+
# 10 - rwrist,
|
483 |
+
# 11 - lhip,
|
484 |
+
# 12 - rhip,
|
485 |
+
# 13 - lknee,
|
486 |
+
# 14 - rknee,
|
487 |
+
# 15 - lankle,
|
488 |
+
# 16 - rankle,
|
489 |
+
return np.array(
|
490 |
+
[
|
491 |
+
[15, 13],
|
492 |
+
[13, 11],
|
493 |
+
[16, 14],
|
494 |
+
[14, 12],
|
495 |
+
[11, 12],
|
496 |
+
[ 5, 11],
|
497 |
+
[ 6, 12],
|
498 |
+
[ 5, 6 ],
|
499 |
+
[ 5, 7 ],
|
500 |
+
[ 6, 8 ],
|
501 |
+
[ 7, 9 ],
|
502 |
+
[ 8, 10],
|
503 |
+
[ 1, 2 ],
|
504 |
+
[ 0, 1 ],
|
505 |
+
[ 0, 2 ],
|
506 |
+
[ 1, 3 ],
|
507 |
+
[ 2, 4 ],
|
508 |
+
[ 3, 5 ],
|
509 |
+
[ 4, 6 ]
|
510 |
+
]
|
511 |
+
)
|
512 |
+
|
513 |
+
def get_mpii_joint_names():
|
514 |
+
return [
|
515 |
+
"rankle", # 0
|
516 |
+
"rknee", # 1
|
517 |
+
"rhip", # 2
|
518 |
+
"lhip", # 3
|
519 |
+
"lknee", # 4
|
520 |
+
"lankle", # 5
|
521 |
+
"hip", # 6
|
522 |
+
"thorax", # 7
|
523 |
+
"neck", # 8
|
524 |
+
"headtop", # 9
|
525 |
+
"rwrist", # 10
|
526 |
+
"relbow", # 11
|
527 |
+
"rshoulder", # 12
|
528 |
+
"lshoulder", # 13
|
529 |
+
"lelbow", # 14
|
530 |
+
"lwrist", # 15
|
531 |
+
]
|
532 |
+
|
533 |
+
def get_mpii_skeleton():
|
534 |
+
# 0 - rankle,
|
535 |
+
# 1 - rknee,
|
536 |
+
# 2 - rhip,
|
537 |
+
# 3 - lhip,
|
538 |
+
# 4 - lknee,
|
539 |
+
# 5 - lankle,
|
540 |
+
# 6 - hip,
|
541 |
+
# 7 - thorax,
|
542 |
+
# 8 - neck,
|
543 |
+
# 9 - headtop,
|
544 |
+
# 10 - rwrist,
|
545 |
+
# 11 - relbow,
|
546 |
+
# 12 - rshoulder,
|
547 |
+
# 13 - lshoulder,
|
548 |
+
# 14 - lelbow,
|
549 |
+
# 15 - lwrist,
|
550 |
+
return np.array(
|
551 |
+
[
|
552 |
+
[ 0, 1 ],
|
553 |
+
[ 1, 2 ],
|
554 |
+
[ 2, 6 ],
|
555 |
+
[ 6, 3 ],
|
556 |
+
[ 3, 4 ],
|
557 |
+
[ 4, 5 ],
|
558 |
+
[ 6, 7 ],
|
559 |
+
[ 7, 8 ],
|
560 |
+
[ 8, 9 ],
|
561 |
+
[ 7, 12],
|
562 |
+
[12, 11],
|
563 |
+
[11, 10],
|
564 |
+
[ 7, 13],
|
565 |
+
[13, 14],
|
566 |
+
[14, 15]
|
567 |
+
]
|
568 |
+
)
|
569 |
+
|
570 |
+
def get_aich_joint_names():
|
571 |
+
return [
|
572 |
+
"rshoulder", # 0
|
573 |
+
"relbow", # 1
|
574 |
+
"rwrist", # 2
|
575 |
+
"lshoulder", # 3
|
576 |
+
"lelbow", # 4
|
577 |
+
"lwrist", # 5
|
578 |
+
"rhip", # 6
|
579 |
+
"rknee", # 7
|
580 |
+
"rankle", # 8
|
581 |
+
"lhip", # 9
|
582 |
+
"lknee", # 10
|
583 |
+
"lankle", # 11
|
584 |
+
"headtop", # 12
|
585 |
+
"neck", # 13
|
586 |
+
]
|
587 |
+
|
588 |
+
def get_aich_skeleton():
|
589 |
+
# 0 - rshoulder,
|
590 |
+
# 1 - relbow,
|
591 |
+
# 2 - rwrist,
|
592 |
+
# 3 - lshoulder,
|
593 |
+
# 4 - lelbow,
|
594 |
+
# 5 - lwrist,
|
595 |
+
# 6 - rhip,
|
596 |
+
# 7 - rknee,
|
597 |
+
# 8 - rankle,
|
598 |
+
# 9 - lhip,
|
599 |
+
# 10 - lknee,
|
600 |
+
# 11 - lankle,
|
601 |
+
# 12 - headtop,
|
602 |
+
# 13 - neck,
|
603 |
+
return np.array(
|
604 |
+
[
|
605 |
+
[ 0, 1 ],
|
606 |
+
[ 1, 2 ],
|
607 |
+
[ 3, 4 ],
|
608 |
+
[ 4, 5 ],
|
609 |
+
[ 6, 7 ],
|
610 |
+
[ 7, 8 ],
|
611 |
+
[ 9, 10],
|
612 |
+
[10, 11],
|
613 |
+
[12, 13],
|
614 |
+
[13, 0 ],
|
615 |
+
[13, 3 ],
|
616 |
+
[ 0, 6 ],
|
617 |
+
[ 3, 9 ]
|
618 |
+
]
|
619 |
+
)
|
620 |
+
|
621 |
+
def get_3dpw_joint_names():
|
622 |
+
return [
|
623 |
+
"nose", # 0
|
624 |
+
"thorax", # 1
|
625 |
+
"rshoulder", # 2
|
626 |
+
"relbow", # 3
|
627 |
+
"rwrist", # 4
|
628 |
+
"lshoulder", # 5
|
629 |
+
"lelbow", # 6
|
630 |
+
"lwrist", # 7
|
631 |
+
"rhip", # 8
|
632 |
+
"rknee", # 9
|
633 |
+
"rankle", # 10
|
634 |
+
"lhip", # 11
|
635 |
+
"lknee", # 12
|
636 |
+
"lankle", # 13
|
637 |
+
]
|
638 |
+
|
639 |
+
def get_3dpw_skeleton():
|
640 |
+
return np.array(
|
641 |
+
[
|
642 |
+
[ 0, 1 ],
|
643 |
+
[ 1, 2 ],
|
644 |
+
[ 2, 3 ],
|
645 |
+
[ 3, 4 ],
|
646 |
+
[ 1, 5 ],
|
647 |
+
[ 5, 6 ],
|
648 |
+
[ 6, 7 ],
|
649 |
+
[ 2, 8 ],
|
650 |
+
[ 5, 11],
|
651 |
+
[ 8, 11],
|
652 |
+
[ 8, 9 ],
|
653 |
+
[ 9, 10],
|
654 |
+
[11, 12],
|
655 |
+
[12, 13]
|
656 |
+
]
|
657 |
+
)
|
658 |
+
|
659 |
+
def get_smplcoco_joint_names():
|
660 |
+
return [
|
661 |
+
"rankle", # 0
|
662 |
+
"rknee", # 1
|
663 |
+
"rhip", # 2
|
664 |
+
"lhip", # 3
|
665 |
+
"lknee", # 4
|
666 |
+
"lankle", # 5
|
667 |
+
"rwrist", # 6
|
668 |
+
"relbow", # 7
|
669 |
+
"rshoulder", # 8
|
670 |
+
"lshoulder", # 9
|
671 |
+
"lelbow", # 10
|
672 |
+
"lwrist", # 11
|
673 |
+
"neck", # 12
|
674 |
+
"headtop", # 13
|
675 |
+
"nose", # 14
|
676 |
+
"leye", # 15
|
677 |
+
"reye", # 16
|
678 |
+
"lear", # 17
|
679 |
+
"rear", # 18
|
680 |
+
]
|
681 |
+
|
682 |
+
def get_smplcoco_skeleton():
|
683 |
+
return np.array(
|
684 |
+
[
|
685 |
+
[ 0, 1 ],
|
686 |
+
[ 1, 2 ],
|
687 |
+
[ 3, 4 ],
|
688 |
+
[ 4, 5 ],
|
689 |
+
[ 6, 7 ],
|
690 |
+
[ 7, 8 ],
|
691 |
+
[ 8, 12],
|
692 |
+
[12, 9 ],
|
693 |
+
[ 9, 10],
|
694 |
+
[10, 11],
|
695 |
+
[12, 13],
|
696 |
+
[14, 15],
|
697 |
+
[15, 17],
|
698 |
+
[16, 18],
|
699 |
+
[14, 16],
|
700 |
+
[ 8, 2 ],
|
701 |
+
[ 9, 3 ],
|
702 |
+
[ 2, 3 ],
|
703 |
+
]
|
704 |
+
)
|
705 |
+
|
706 |
+
def get_smpl_joint_names():
|
707 |
+
return [
|
708 |
+
'hips', # 0
|
709 |
+
'leftUpLeg', # 1
|
710 |
+
'rightUpLeg', # 2
|
711 |
+
'spine', # 3
|
712 |
+
'leftLeg', # 4
|
713 |
+
'rightLeg', # 5
|
714 |
+
'spine1', # 6
|
715 |
+
'leftFoot', # 7
|
716 |
+
'rightFoot', # 8
|
717 |
+
'spine2', # 9
|
718 |
+
'leftToeBase', # 10
|
719 |
+
'rightToeBase', # 11
|
720 |
+
'neck', # 12
|
721 |
+
'leftShoulder', # 13
|
722 |
+
'rightShoulder', # 14
|
723 |
+
'head', # 15
|
724 |
+
'leftArm', # 16
|
725 |
+
'rightArm', # 17
|
726 |
+
'leftForeArm', # 18
|
727 |
+
'rightForeArm', # 19
|
728 |
+
'leftHand', # 20
|
729 |
+
'rightHand', # 21
|
730 |
+
'leftHandIndex1', # 22
|
731 |
+
'rightHandIndex1', # 23
|
732 |
+
]
|
733 |
+
|
734 |
+
def get_smpl_skeleton():
|
735 |
+
return np.array(
|
736 |
+
[
|
737 |
+
[ 0, 1 ],
|
738 |
+
[ 0, 2 ],
|
739 |
+
[ 0, 3 ],
|
740 |
+
[ 1, 4 ],
|
741 |
+
[ 2, 5 ],
|
742 |
+
[ 3, 6 ],
|
743 |
+
[ 4, 7 ],
|
744 |
+
[ 5, 8 ],
|
745 |
+
[ 6, 9 ],
|
746 |
+
[ 7, 10],
|
747 |
+
[ 8, 11],
|
748 |
+
[ 9, 12],
|
749 |
+
[ 9, 13],
|
750 |
+
[ 9, 14],
|
751 |
+
[12, 15],
|
752 |
+
[13, 16],
|
753 |
+
[14, 17],
|
754 |
+
[16, 18],
|
755 |
+
[17, 19],
|
756 |
+
[18, 20],
|
757 |
+
[19, 21],
|
758 |
+
[20, 22],
|
759 |
+
[21, 23],
|
760 |
+
]
|
761 |
+
)
|
lib/utils/transforms.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This transforms function is mainly borrowed from PyTorch3D"""
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the BSD-style license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
Device = Union[str, torch.device]
|
15 |
+
|
16 |
+
"""
|
17 |
+
The transformation matrices returned from the functions in this file assume
|
18 |
+
the points on which the transformation will be applied are column vectors.
|
19 |
+
i.e. the R matrix is structured as
|
20 |
+
|
21 |
+
R = [
|
22 |
+
[Rxx, Rxy, Rxz],
|
23 |
+
[Ryx, Ryy, Ryz],
|
24 |
+
[Rzx, Rzy, Rzz],
|
25 |
+
] # (3, 3)
|
26 |
+
|
27 |
+
This matrix can be applied to column vectors by post multiplication
|
28 |
+
by the points e.g.
|
29 |
+
|
30 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
31 |
+
transformed_points = R * points
|
32 |
+
|
33 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
34 |
+
can be transposed and pre multiplied by the points:
|
35 |
+
|
36 |
+
e.g.
|
37 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
38 |
+
transformed_points = points * R.transpose(1, 0)
|
39 |
+
"""
|
40 |
+
|
41 |
+
|
42 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
43 |
+
"""
|
44 |
+
Convert rotations given as quaternions to rotation matrices.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
quaternions: quaternions with real part first,
|
48 |
+
as tensor of shape (..., 4).
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
52 |
+
"""
|
53 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
54 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
55 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
56 |
+
|
57 |
+
o = torch.stack(
|
58 |
+
(
|
59 |
+
1 - two_s * (j * j + k * k),
|
60 |
+
two_s * (i * j - k * r),
|
61 |
+
two_s * (i * k + j * r),
|
62 |
+
two_s * (i * j + k * r),
|
63 |
+
1 - two_s * (i * i + k * k),
|
64 |
+
two_s * (j * k - i * r),
|
65 |
+
two_s * (i * k - j * r),
|
66 |
+
two_s * (j * k + i * r),
|
67 |
+
1 - two_s * (i * i + j * j),
|
68 |
+
),
|
69 |
+
-1,
|
70 |
+
)
|
71 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
76 |
+
"""
|
77 |
+
Return a tensor where each element has the absolute value taken from the,
|
78 |
+
corresponding element of a, with sign taken from the corresponding
|
79 |
+
element of b. This is like the standard copysign floating-point operation,
|
80 |
+
but is not careful about negative 0 and NaN.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
a: source tensor.
|
84 |
+
b: tensor whose signs will be used, of the same shape as a.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Tensor of the same shape as a with the signs of b.
|
88 |
+
"""
|
89 |
+
signs_differ = (a < 0) != (b < 0)
|
90 |
+
return torch.where(signs_differ, -a, a)
|
91 |
+
|
92 |
+
|
93 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
Returns torch.sqrt(torch.max(0, x))
|
96 |
+
but with a zero subgradient where x is 0.
|
97 |
+
"""
|
98 |
+
ret = torch.zeros_like(x)
|
99 |
+
positive_mask = x > 0
|
100 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
101 |
+
return ret
|
102 |
+
|
103 |
+
|
104 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
105 |
+
"""
|
106 |
+
Convert rotations given as rotation matrices to quaternions.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
113 |
+
"""
|
114 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
115 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
116 |
+
|
117 |
+
batch_dim = matrix.shape[:-2]
|
118 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
119 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
120 |
+
)
|
121 |
+
|
122 |
+
q_abs = _sqrt_positive_part(
|
123 |
+
torch.stack(
|
124 |
+
[
|
125 |
+
1.0 + m00 + m11 + m22,
|
126 |
+
1.0 + m00 - m11 - m22,
|
127 |
+
1.0 - m00 + m11 - m22,
|
128 |
+
1.0 - m00 - m11 + m22,
|
129 |
+
],
|
130 |
+
dim=-1,
|
131 |
+
)
|
132 |
+
)
|
133 |
+
|
134 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
135 |
+
quat_by_rijk = torch.stack(
|
136 |
+
[
|
137 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
138 |
+
# `int`.
|
139 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
140 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
141 |
+
# `int`.
|
142 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
143 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
144 |
+
# `int`.
|
145 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
146 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
147 |
+
# `int`.
|
148 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
149 |
+
],
|
150 |
+
dim=-2,
|
151 |
+
)
|
152 |
+
|
153 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
154 |
+
# the candidate won't be picked.
|
155 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
156 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
157 |
+
|
158 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
159 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
160 |
+
|
161 |
+
return quat_candidates[
|
162 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
163 |
+
].reshape(batch_dim + (4,))
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
168 |
+
"""
|
169 |
+
Return the rotation matrices for one of the rotations about an axis
|
170 |
+
of which Euler angles describe, for each value of the angle given.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
axis: Axis label "X" or "Y or "Z".
|
174 |
+
angle: any shape tensor of Euler angles in radians
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
178 |
+
"""
|
179 |
+
|
180 |
+
cos = torch.cos(angle)
|
181 |
+
sin = torch.sin(angle)
|
182 |
+
one = torch.ones_like(angle)
|
183 |
+
zero = torch.zeros_like(angle)
|
184 |
+
|
185 |
+
if axis == "X":
|
186 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
187 |
+
elif axis == "Y":
|
188 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
189 |
+
elif axis == "Z":
|
190 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
191 |
+
else:
|
192 |
+
raise ValueError("letter must be either X, Y or Z.")
|
193 |
+
|
194 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
195 |
+
|
196 |
+
|
197 |
+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
198 |
+
"""
|
199 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
203 |
+
convention: Convention string of three uppercase letters from
|
204 |
+
{"X", "Y", and "Z"}.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
208 |
+
"""
|
209 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
210 |
+
raise ValueError("Invalid input euler angles.")
|
211 |
+
if len(convention) != 3:
|
212 |
+
raise ValueError("Convention must have 3 letters.")
|
213 |
+
if convention[1] in (convention[0], convention[2]):
|
214 |
+
raise ValueError(f"Invalid convention {convention}.")
|
215 |
+
for letter in convention:
|
216 |
+
if letter not in ("X", "Y", "Z"):
|
217 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
218 |
+
matrices = [
|
219 |
+
_axis_angle_rotation(c, e)
|
220 |
+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
221 |
+
]
|
222 |
+
# return functools.reduce(torch.matmul, matrices)
|
223 |
+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def _angle_from_tan(
|
228 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
229 |
+
) -> torch.Tensor:
|
230 |
+
"""
|
231 |
+
Extract the first or third Euler angle from the two members of
|
232 |
+
the matrix which are positive constant times its sine and cosine.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
236 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
237 |
+
convention.
|
238 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
239 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
240 |
+
which means the relevant entries are in the same row of the
|
241 |
+
rotation matrix. If not, they are in the same column.
|
242 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
Euler Angles in radians for each matrix in data as a tensor
|
246 |
+
of shape (...).
|
247 |
+
"""
|
248 |
+
|
249 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
250 |
+
if horizontal:
|
251 |
+
i2, i1 = i1, i2
|
252 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
253 |
+
if horizontal == even:
|
254 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
255 |
+
if tait_bryan:
|
256 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
257 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
258 |
+
|
259 |
+
|
260 |
+
def _index_from_letter(letter: str) -> int:
|
261 |
+
if letter == "X":
|
262 |
+
return 0
|
263 |
+
if letter == "Y":
|
264 |
+
return 1
|
265 |
+
if letter == "Z":
|
266 |
+
return 2
|
267 |
+
raise ValueError("letter must be either X, Y or Z.")
|
268 |
+
|
269 |
+
|
270 |
+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
271 |
+
"""
|
272 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
276 |
+
convention: Convention string of three uppercase letters.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Euler angles in radians as tensor of shape (..., 3).
|
280 |
+
"""
|
281 |
+
if len(convention) != 3:
|
282 |
+
raise ValueError("Convention must have 3 letters.")
|
283 |
+
if convention[1] in (convention[0], convention[2]):
|
284 |
+
raise ValueError(f"Invalid convention {convention}.")
|
285 |
+
for letter in convention:
|
286 |
+
if letter not in ("X", "Y", "Z"):
|
287 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
288 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
289 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
290 |
+
i0 = _index_from_letter(convention[0])
|
291 |
+
i2 = _index_from_letter(convention[2])
|
292 |
+
tait_bryan = i0 != i2
|
293 |
+
if tait_bryan:
|
294 |
+
central_angle = torch.asin(
|
295 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
299 |
+
|
300 |
+
o = (
|
301 |
+
_angle_from_tan(
|
302 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
303 |
+
),
|
304 |
+
central_angle,
|
305 |
+
_angle_from_tan(
|
306 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
307 |
+
),
|
308 |
+
)
|
309 |
+
return torch.stack(o, -1)
|
310 |
+
|
311 |
+
|
312 |
+
|
313 |
+
def random_quaternions(
|
314 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
315 |
+
) -> torch.Tensor:
|
316 |
+
"""
|
317 |
+
Generate random quaternions representing rotations,
|
318 |
+
i.e. versors with nonnegative real part.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
n: Number of quaternions in a batch to return.
|
322 |
+
dtype: Type to return.
|
323 |
+
device: Desired device of returned tensor. Default:
|
324 |
+
uses the current device for the default tensor type.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
Quaternions as tensor of shape (N, 4).
|
328 |
+
"""
|
329 |
+
if isinstance(device, str):
|
330 |
+
device = torch.device(device)
|
331 |
+
o = torch.randn((n, 4), dtype=dtype, device=device)
|
332 |
+
s = (o * o).sum(1)
|
333 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
334 |
+
return o
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
def random_rotations(
|
339 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
340 |
+
) -> torch.Tensor:
|
341 |
+
"""
|
342 |
+
Generate random rotations as 3x3 rotation matrices.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
n: Number of rotation matrices in a batch to return.
|
346 |
+
dtype: Type to return.
|
347 |
+
device: Device of returned tensor. Default: if None,
|
348 |
+
uses the current device for the default tensor type.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
352 |
+
"""
|
353 |
+
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
354 |
+
return quaternion_to_matrix(quaternions)
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
def random_rotation(
|
359 |
+
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
360 |
+
) -> torch.Tensor:
|
361 |
+
"""
|
362 |
+
Generate a single random 3x3 rotation matrix.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
dtype: Type to return
|
366 |
+
device: Device of returned tensor. Default: if None,
|
367 |
+
uses the current device for the default tensor type
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Rotation matrix as tensor of shape (3, 3).
|
371 |
+
"""
|
372 |
+
return random_rotations(1, dtype, device)[0]
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
377 |
+
"""
|
378 |
+
Convert a unit quaternion to a standard form: one in which the real
|
379 |
+
part is non negative.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
quaternions: Quaternions with real part first,
|
383 |
+
as tensor of shape (..., 4).
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
Standardized quaternions as tensor of shape (..., 4).
|
387 |
+
"""
|
388 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
393 |
+
"""
|
394 |
+
Multiply two quaternions.
|
395 |
+
Usual torch rules for broadcasting apply.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
399 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
403 |
+
"""
|
404 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
405 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
406 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
407 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
408 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
409 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
410 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
411 |
+
|
412 |
+
|
413 |
+
|
414 |
+
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
415 |
+
"""
|
416 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
417 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
418 |
+
Usual torch rules for broadcasting apply.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
422 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
426 |
+
"""
|
427 |
+
ab = quaternion_raw_multiply(a, b)
|
428 |
+
return standardize_quaternion(ab)
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
|
433 |
+
"""
|
434 |
+
Given a quaternion representing rotation, get the quaternion representing
|
435 |
+
its inverse.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
439 |
+
first, which must be versors (unit quaternions).
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
443 |
+
"""
|
444 |
+
|
445 |
+
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
|
446 |
+
return quaternion * scaling
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
|
451 |
+
"""
|
452 |
+
Apply the rotation given by a quaternion to a 3D point.
|
453 |
+
Usual torch rules for broadcasting apply.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
457 |
+
point: Tensor of 3D points of shape (..., 3).
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
Tensor of rotated points of shape (..., 3).
|
461 |
+
"""
|
462 |
+
if point.size(-1) != 3:
|
463 |
+
raise ValueError(f"Points are not in 3D, {point.shape}.")
|
464 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
465 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
466 |
+
out = quaternion_raw_multiply(
|
467 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
468 |
+
quaternion_invert(quaternion),
|
469 |
+
)
|
470 |
+
return out[..., 1:]
|
471 |
+
|
472 |
+
|
473 |
+
|
474 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
475 |
+
"""
|
476 |
+
Convert rotations given as axis/angle to rotation matrices.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
480 |
+
as a tensor of shape (..., 3), where the magnitude is
|
481 |
+
the angle turned anticlockwise in radians around the
|
482 |
+
vector's direction.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
486 |
+
"""
|
487 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
488 |
+
|
489 |
+
|
490 |
+
|
491 |
+
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
492 |
+
"""
|
493 |
+
Convert rotations given as rotation matrices to axis/angle.
|
494 |
+
|
495 |
+
Args:
|
496 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
497 |
+
|
498 |
+
Returns:
|
499 |
+
Rotations given as a vector in axis angle form, as a tensor
|
500 |
+
of shape (..., 3), where the magnitude is the angle
|
501 |
+
turned anticlockwise in radians around the vector's
|
502 |
+
direction.
|
503 |
+
"""
|
504 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
505 |
+
|
506 |
+
|
507 |
+
|
508 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
509 |
+
"""
|
510 |
+
Convert rotations given as axis/angle to quaternions.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
514 |
+
as a tensor of shape (..., 3), where the magnitude is
|
515 |
+
the angle turned anticlockwise in radians around the
|
516 |
+
vector's direction.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
520 |
+
"""
|
521 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
522 |
+
half_angles = angles * 0.5
|
523 |
+
eps = 1e-6
|
524 |
+
small_angles = angles.abs() < eps
|
525 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
526 |
+
sin_half_angles_over_angles[~small_angles] = (
|
527 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
528 |
+
)
|
529 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
530 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
531 |
+
sin_half_angles_over_angles[small_angles] = (
|
532 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
533 |
+
)
|
534 |
+
quaternions = torch.cat(
|
535 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
536 |
+
)
|
537 |
+
return quaternions
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
542 |
+
"""
|
543 |
+
Convert rotations given as quaternions to axis/angle.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
quaternions: quaternions with real part first,
|
547 |
+
as tensor of shape (..., 4).
|
548 |
+
|
549 |
+
Returns:
|
550 |
+
Rotations given as a vector in axis angle form, as a tensor
|
551 |
+
of shape (..., 3), where the magnitude is the angle
|
552 |
+
turned anticlockwise in radians around the vector's
|
553 |
+
direction.
|
554 |
+
"""
|
555 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
556 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
557 |
+
angles = 2 * half_angles
|
558 |
+
eps = 1e-6
|
559 |
+
small_angles = angles.abs() < eps
|
560 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
561 |
+
sin_half_angles_over_angles[~small_angles] = (
|
562 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
563 |
+
)
|
564 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
565 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
566 |
+
sin_half_angles_over_angles[small_angles] = (
|
567 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
568 |
+
)
|
569 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
570 |
+
|
571 |
+
|
572 |
+
|
573 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
574 |
+
"""
|
575 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
576 |
+
using Gram--Schmidt orthogonalization per Section B of [1].
|
577 |
+
Args:
|
578 |
+
d6: 6D rotation representation, of size (*, 6)
|
579 |
+
|
580 |
+
Returns:
|
581 |
+
batch of rotation matrices of size (*, 3, 3)
|
582 |
+
|
583 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
584 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
585 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
586 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
587 |
+
"""
|
588 |
+
|
589 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
590 |
+
b1 = F.normalize(a1, dim=-1)
|
591 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
592 |
+
b2 = F.normalize(b2, dim=-1)
|
593 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
594 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
595 |
+
|
596 |
+
|
597 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
598 |
+
"""
|
599 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
600 |
+
by dropping the last row. Note that 6D representation is not unique.
|
601 |
+
Args:
|
602 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
6D rotation representation, of size (*, 6)
|
606 |
+
|
607 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
608 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
609 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
610 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
611 |
+
"""
|
612 |
+
batch_dim = matrix.size()[:-2]
|
613 |
+
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
614 |
+
|
615 |
+
|
616 |
+
def clean_rotation_6d(d6d: torch.Tensor) -> torch.Tensor:
|
617 |
+
"""
|
618 |
+
Clean rotation 6d by converting it to matrix and then reconvert to d6
|
619 |
+
"""
|
620 |
+
matrix = rotation_6d_to_matrix(d6d)
|
621 |
+
d6d = matrix_to_rotation_6d(matrix)
|
622 |
+
return d6d
|
623 |
+
|
624 |
+
|
625 |
+
def rot6d_to_rotmat(x):
|
626 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
627 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
628 |
+
Input:
|
629 |
+
(B,6) Batch of 6-D rotation representations
|
630 |
+
Output:
|
631 |
+
(B,3,3) Batch of corresponding rotation matrices
|
632 |
+
"""
|
633 |
+
if x.shape[-1] == 6:
|
634 |
+
batch_dim = x.size()[:-1]
|
635 |
+
else:
|
636 |
+
x = x.reshape(*x.shape[:-1], -1, 6)
|
637 |
+
batch_dim = x.size()[:-1]
|
638 |
+
|
639 |
+
x = x.reshape(*batch_dim, 3, 2)
|
640 |
+
a1, a2 = x[..., 0], x[..., 1]
|
641 |
+
|
642 |
+
b1 = F.normalize(a1, dim=-1)
|
643 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
644 |
+
b2 = F.normalize(b2, dim=-1)
|
645 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
646 |
+
|
647 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
648 |
+
|
649 |
+
|
650 |
+
def rotmat_to_rot6d(x):
|
651 |
+
"""Inverse computation of rot6d_to_rotmat."""
|
652 |
+
batch_dim = x.size()[:-2]
|
653 |
+
return x[..., :2].clone().reshape(batch_dim + (6,))
|
654 |
+
|
655 |
+
|
656 |
+
def convert_rotation_matrix_to_homogeneous(rotation_matrix):
|
657 |
+
"Add empty translation vector to Rotation matrix"""
|
658 |
+
|
659 |
+
transl = torch.zeros_like(rotation_matrix[...,:1])
|
660 |
+
rotation_matrix_hom = torch.cat((rotation_matrix, transl), dim=-1)
|
661 |
+
|
662 |
+
return rotation_matrix_hom
|
663 |
+
|
664 |
+
|
665 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
666 |
+
"""Convert 3x4 rotation matrix to Rodrigues vector
|
667 |
+
|
668 |
+
Args:
|
669 |
+
rotation_matrix (Tensor): rotation matrix.
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
Tensor: Rodrigues vector transformation.
|
673 |
+
|
674 |
+
Shape:
|
675 |
+
- Input: :math:`(N, 3, 4)`
|
676 |
+
- Output: :math:`(N, 3)`
|
677 |
+
|
678 |
+
Example:
|
679 |
+
>>> input = torch.rand(2, 3, 4) # Nx3x4
|
680 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
681 |
+
"""
|
682 |
+
|
683 |
+
if rotation_matrix.size(-1) == 3:
|
684 |
+
rotation_matrix = convert_rotation_matrix_to_homogeneous(rotation_matrix)
|
685 |
+
|
686 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
687 |
+
return quaternion_to_angle_axis(quaternion)
|
688 |
+
|
689 |
+
|
690 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
691 |
+
"""Convert 3x4 rotation matrix to 4d quaternion vector
|
692 |
+
|
693 |
+
This algorithm is based on algorithm described in
|
694 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
695 |
+
|
696 |
+
Args:
|
697 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
698 |
+
|
699 |
+
Return:
|
700 |
+
Tensor: the rotation in quaternion
|
701 |
+
|
702 |
+
Shape:
|
703 |
+
- Input: :math:`(N, 3, 4)`
|
704 |
+
- Output: :math:`(N, 4)`
|
705 |
+
|
706 |
+
Example:
|
707 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
708 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
709 |
+
"""
|
710 |
+
if not torch.is_tensor(rotation_matrix):
|
711 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
712 |
+
type(rotation_matrix)))
|
713 |
+
|
714 |
+
if len(rotation_matrix.shape) > 3:
|
715 |
+
raise ValueError(
|
716 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
717 |
+
rotation_matrix.shape))
|
718 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
719 |
+
raise ValueError(
|
720 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
721 |
+
rotation_matrix.shape))
|
722 |
+
|
723 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
724 |
+
|
725 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
726 |
+
|
727 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
728 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
729 |
+
|
730 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
731 |
+
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
732 |
+
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
733 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
|
734 |
+
t0_rep = t0.repeat(4, 1).t()
|
735 |
+
|
736 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
737 |
+
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
738 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
739 |
+
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
|
740 |
+
t1_rep = t1.repeat(4, 1).t()
|
741 |
+
|
742 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
743 |
+
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
744 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
745 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
|
746 |
+
t2_rep = t2.repeat(4, 1).t()
|
747 |
+
|
748 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
749 |
+
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
750 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
751 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
|
752 |
+
t3_rep = t3.repeat(4, 1).t()
|
753 |
+
|
754 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
755 |
+
# mask_c1 = mask_d2 * (1 - mask_d0_d1)
|
756 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
757 |
+
# mask_c2 = (1 - mask_d2) * mask_d0_nd1
|
758 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
759 |
+
# mask_c3 = (1 - mask_d2) * (1 - mask_d0_nd1)
|
760 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
761 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
762 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
763 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
764 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
765 |
+
|
766 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
767 |
+
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
|
768 |
+
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
|
769 |
+
q *= 0.5
|
770 |
+
return q
|
771 |
+
|
772 |
+
|
773 |
+
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
774 |
+
"""Convert quaternion vector to angle axis of rotation.
|
775 |
+
|
776 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
777 |
+
|
778 |
+
Args:
|
779 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
780 |
+
|
781 |
+
Return:
|
782 |
+
torch.Tensor: tensor with angle axis of rotation.
|
783 |
+
|
784 |
+
Shape:
|
785 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
786 |
+
- Output: :math:`(*, 3)`
|
787 |
+
|
788 |
+
Example:
|
789 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
790 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
791 |
+
"""
|
792 |
+
if not torch.is_tensor(quaternion):
|
793 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
794 |
+
type(quaternion)))
|
795 |
+
|
796 |
+
if not quaternion.shape[-1] == 4:
|
797 |
+
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
|
798 |
+
.format(quaternion.shape))
|
799 |
+
# unpack input and compute conversion
|
800 |
+
q1: torch.Tensor = quaternion[..., 1]
|
801 |
+
q2: torch.Tensor = quaternion[..., 2]
|
802 |
+
q3: torch.Tensor = quaternion[..., 3]
|
803 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
804 |
+
|
805 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
806 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
807 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
808 |
+
cos_theta < 0.0,
|
809 |
+
torch.atan2(-sin_theta, -cos_theta),
|
810 |
+
torch.atan2(sin_theta, cos_theta))
|
811 |
+
|
812 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
813 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
814 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
815 |
+
|
816 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
817 |
+
angle_axis[..., 0] += q1 * k
|
818 |
+
angle_axis[..., 1] += q2 * k
|
819 |
+
angle_axis[..., 2] += q3 * k
|
820 |
+
return angle_axis
|
821 |
+
|
822 |
+
|
823 |
+
def avg_rot(rot):
|
824 |
+
# input [B,...,3,3] --> output [...,3,3]
|
825 |
+
rot = rot.mean(dim=0)
|
826 |
+
U, _, V = torch.svd(rot)
|
827 |
+
rot = U @ V.transpose(-1, -2)
|
828 |
+
return rot
|
lib/utils/utils.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import os
|
18 |
+
import yaml
|
19 |
+
import torch
|
20 |
+
import shutil
|
21 |
+
import logging
|
22 |
+
import operator
|
23 |
+
from tqdm import tqdm
|
24 |
+
from os import path as osp
|
25 |
+
from functools import reduce
|
26 |
+
from typing import List, Union
|
27 |
+
from collections import OrderedDict
|
28 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
29 |
+
|
30 |
+
class CustomScheduler(_LRScheduler):
|
31 |
+
def __init__(self, optimizer, lr_lambda):
|
32 |
+
self.lr_lambda = lr_lambda
|
33 |
+
super(CustomScheduler, self).__init__(optimizer)
|
34 |
+
|
35 |
+
def get_lr(self):
|
36 |
+
return [base_lr * self.lr_lambda(self.last_epoch)
|
37 |
+
for base_lr in self.base_lrs]
|
38 |
+
|
39 |
+
def lr_decay_fn(epoch):
|
40 |
+
if epoch == 0: return 1.0
|
41 |
+
if epoch % big_epoch == 0:
|
42 |
+
return big_decay
|
43 |
+
else:
|
44 |
+
return small_decay
|
45 |
+
|
46 |
+
def save_obj(v, f, file_name='output.obj'):
|
47 |
+
obj_file = open(file_name, 'w')
|
48 |
+
for i in range(len(v)):
|
49 |
+
obj_file.write('v ' + str(v[i][0]) + ' ' + str(v[i][1]) + ' ' + str(v[i][2]) + '\n')
|
50 |
+
for i in range(len(f)):
|
51 |
+
obj_file.write('f ' + str(f[i][0]+1) + '/' + str(f[i][0]+1) + ' ' + str(f[i][1]+1) + '/' + str(f[i][1]+1) + ' ' + str(f[i][2]+1) + '/' + str(f[i][2]+1) + '\n')
|
52 |
+
obj_file.close()
|
53 |
+
|
54 |
+
|
55 |
+
def check_data_pararell(train_weight):
|
56 |
+
new_state_dict = OrderedDict()
|
57 |
+
for k, v in train_weight.items():
|
58 |
+
name = k[7:] if k.startswith('module') else k # remove `module.`
|
59 |
+
new_state_dict[name] = v
|
60 |
+
return new_state_dict
|
61 |
+
|
62 |
+
|
63 |
+
def get_from_dict(dict, keys):
|
64 |
+
return reduce(operator.getitem, keys, dict)
|
65 |
+
|
66 |
+
|
67 |
+
def tqdm_enumerate(iter):
|
68 |
+
i = 0
|
69 |
+
for y in tqdm(iter):
|
70 |
+
yield i, y
|
71 |
+
i += 1
|
72 |
+
|
73 |
+
|
74 |
+
def iterdict(d):
|
75 |
+
for k,v in d.items():
|
76 |
+
if isinstance(v, dict):
|
77 |
+
d[k] = dict(v)
|
78 |
+
iterdict(v)
|
79 |
+
return d
|
80 |
+
|
81 |
+
|
82 |
+
def accuracy(output, target):
|
83 |
+
_, pred = output.topk(1)
|
84 |
+
pred = pred.view(-1)
|
85 |
+
|
86 |
+
correct = pred.eq(target).sum()
|
87 |
+
|
88 |
+
return correct.item(), target.size(0) - correct.item()
|
89 |
+
|
90 |
+
|
91 |
+
def lr_decay(optimizer, step, lr, decay_step, gamma):
|
92 |
+
lr = lr * gamma ** (step/decay_step)
|
93 |
+
for param_group in optimizer.param_groups:
|
94 |
+
param_group['lr'] = lr
|
95 |
+
return lr
|
96 |
+
|
97 |
+
|
98 |
+
def step_decay(optimizer, step, lr, decay_step, gamma):
|
99 |
+
lr = lr * gamma ** (step / decay_step)
|
100 |
+
for param_group in optimizer.param_groups:
|
101 |
+
param_group['lr'] = lr
|
102 |
+
return lr
|
103 |
+
|
104 |
+
|
105 |
+
def read_yaml(filename):
|
106 |
+
return yaml.load(open(filename, 'r'))
|
107 |
+
|
108 |
+
|
109 |
+
def write_yaml(filename, object):
|
110 |
+
with open(filename, 'w') as f:
|
111 |
+
yaml.dump(object, f)
|
112 |
+
|
113 |
+
|
114 |
+
def save_dict_to_yaml(obj, filename, mode='w'):
|
115 |
+
with open(filename, mode) as f:
|
116 |
+
yaml.dump(obj, f, default_flow_style=False)
|
117 |
+
|
118 |
+
|
119 |
+
def save_to_file(obj, filename, mode='w'):
|
120 |
+
with open(filename, mode) as f:
|
121 |
+
f.write(obj)
|
122 |
+
|
123 |
+
|
124 |
+
def concatenate_dicts(dict_list, dim=0):
|
125 |
+
rdict = dict.fromkeys(dict_list[0].keys())
|
126 |
+
for k in rdict.keys():
|
127 |
+
rdict[k] = torch.cat([d[k] for d in dict_list], dim=dim)
|
128 |
+
return rdict
|
129 |
+
|
130 |
+
|
131 |
+
def bool_to_string(x: Union[List[bool],bool]) -> Union[List[str],str]:
|
132 |
+
"""
|
133 |
+
boolean to string conversion
|
134 |
+
:param x: list or bool to be converted
|
135 |
+
:return: string converted thing
|
136 |
+
"""
|
137 |
+
if isinstance(x, bool):
|
138 |
+
return [str(x)]
|
139 |
+
for i, j in enumerate(x):
|
140 |
+
x[i]=str(j)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
def checkpoint2model(checkpoint, key='gen_state_dict'):
|
145 |
+
state_dict = checkpoint[key]
|
146 |
+
print(f'Performance of loaded model on 3DPW is {checkpoint["performance"]:.2f}mm')
|
147 |
+
# del state_dict['regressor.mean_theta']
|
148 |
+
return state_dict
|
149 |
+
|
150 |
+
|
151 |
+
def get_optimizer(cfg, model, optim_type, momentum, stage):
|
152 |
+
if stage == 'stage2':
|
153 |
+
param_list = [{'params': model.integrator.parameters()}]
|
154 |
+
for name, param in model.named_parameters():
|
155 |
+
# if 'integrator' not in name and 'motion_encoder' not in name and 'trajectory_decoder' not in name:
|
156 |
+
if 'integrator' not in name:
|
157 |
+
param_list.append({'params': param, 'lr': cfg.TRAIN.LR_FINETUNE})
|
158 |
+
else:
|
159 |
+
param_list = [{'params': model.parameters()}]
|
160 |
+
|
161 |
+
if optim_type in ['sgd', 'SGD']:
|
162 |
+
opt = torch.optim.SGD(lr=cfg.TRAIN.LR, params=param_list, momentum=momentum)
|
163 |
+
elif optim_type in ['Adam', 'adam', 'ADAM']:
|
164 |
+
opt = torch.optim.Adam(lr=cfg.TRAIN.LR, params=param_list, weight_decay=cfg.TRAIN.WD, betas=(0.9, 0.999))
|
165 |
+
else:
|
166 |
+
raise ModuleNotFoundError
|
167 |
+
|
168 |
+
return opt
|
169 |
+
|
170 |
+
|
171 |
+
def create_logger(logdir, phase='train'):
|
172 |
+
os.makedirs(logdir, exist_ok=True)
|
173 |
+
|
174 |
+
log_file = osp.join(logdir, f'{phase}_log.txt')
|
175 |
+
|
176 |
+
head = '%(asctime)-15s %(message)s'
|
177 |
+
logging.basicConfig(filename=log_file,
|
178 |
+
format=head)
|
179 |
+
logger = logging.getLogger()
|
180 |
+
logger.setLevel(logging.INFO)
|
181 |
+
console = logging.StreamHandler()
|
182 |
+
logging.getLogger('').addHandler(console)
|
183 |
+
|
184 |
+
return logger
|
185 |
+
|
186 |
+
|
187 |
+
class AverageMeter(object):
|
188 |
+
def __init__(self):
|
189 |
+
self.val = 0
|
190 |
+
self.avg = 0
|
191 |
+
self.sum = 0
|
192 |
+
self.count = 0
|
193 |
+
|
194 |
+
def update(self, val, n=1):
|
195 |
+
self.val = val
|
196 |
+
self.sum += val * n
|
197 |
+
self.count += n
|
198 |
+
self.avg = self.sum / self.count
|
199 |
+
|
200 |
+
|
201 |
+
def prepare_output_dir(cfg, cfg_file):
|
202 |
+
|
203 |
+
# ==== create logdir
|
204 |
+
logdir = osp.join(cfg.OUTPUT_DIR, cfg.EXP_NAME)
|
205 |
+
os.makedirs(logdir, exist_ok=True)
|
206 |
+
shutil.copy(src=cfg_file, dst=osp.join(cfg.OUTPUT_DIR, 'config.yaml'))
|
207 |
+
|
208 |
+
cfg.LOGDIR = logdir
|
209 |
+
|
210 |
+
# save config
|
211 |
+
save_dict_to_yaml(cfg, osp.join(cfg.LOGDIR, 'config.yaml'))
|
212 |
+
|
213 |
+
return cfg
|
214 |
+
|
215 |
+
|
216 |
+
def prepare_groundtruth(batch, device):
|
217 |
+
groundtruths = dict()
|
218 |
+
gt_keys = ['pose', 'cam', 'betas', 'kp3d', 'bbox'] # Evaluation
|
219 |
+
gt_keys += ['pose_root', 'vel_root', 'weak_kp2d', 'verts', # Training
|
220 |
+
'full_kp2d', 'contact', 'R', 'cam_angvel',
|
221 |
+
'has_smpl', 'has_traj', 'has_full_screen', 'has_verts']
|
222 |
+
for gt_key in gt_keys:
|
223 |
+
if gt_key in batch.keys():
|
224 |
+
dtype = torch.float32 if batch[gt_key].dtype == torch.float64 else batch[gt_key].dtype
|
225 |
+
groundtruths[gt_key] = batch[gt_key].to(dtype=dtype, device=device)
|
226 |
+
|
227 |
+
return groundtruths
|
228 |
+
|
229 |
+
def prepare_auxiliary(batch, device):
|
230 |
+
aux = dict()
|
231 |
+
aux_keys = ['mask', 'bbox', 'res', 'cam_intrinsics', 'init_root', 'cam_angvel']
|
232 |
+
for key in aux_keys:
|
233 |
+
if key in batch.keys():
|
234 |
+
dtype = torch.float32 if batch[key].dtype == torch.float64 else batch[key].dtype
|
235 |
+
aux[key] = batch[key].to(dtype=dtype, device=device)
|
236 |
+
|
237 |
+
return aux
|
238 |
+
|
239 |
+
def prepare_input(batch, device, use_features):
|
240 |
+
# Input keypoints data
|
241 |
+
kp2d = batch['kp2d'].to(device).float()
|
242 |
+
|
243 |
+
# Input features
|
244 |
+
if use_features and 'features' in batch.keys():
|
245 |
+
features = batch['features'].to(device).float()
|
246 |
+
else:
|
247 |
+
features = None
|
248 |
+
|
249 |
+
# Initial SMPL parameters
|
250 |
+
init_smpl = batch['init_pose'].to(device).float()
|
251 |
+
|
252 |
+
# Initial keypoints
|
253 |
+
init_kp = torch.cat((
|
254 |
+
batch['init_kp3d'], batch['init_kp2d']
|
255 |
+
), dim=-1).to(device).float()
|
256 |
+
|
257 |
+
return kp2d, (init_kp, init_smpl), features
|
258 |
+
|
259 |
+
|
260 |
+
def prepare_batch(batch, device, use_features=True):
|
261 |
+
x, inits, features = prepare_input(batch, device, use_features)
|
262 |
+
aux = prepare_auxiliary(batch, device)
|
263 |
+
groundtruths = prepare_groundtruth(batch, device)
|
264 |
+
|
265 |
+
return x, inits, features, aux, groundtruths
|
lib/vis/__pycache__/renderer.cpython-39.pyc
ADDED
Binary file (9.31 kB). View file
|
|
lib/vis/__pycache__/run_vis.cpython-39.pyc
ADDED
Binary file (3.08 kB). View file
|
|
lib/vis/__pycache__/tools.cpython-39.pyc
ADDED
Binary file (15.9 kB). View file
|
|
lib/vis/renderer.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from pytorch3d.renderer import (
|
6 |
+
PerspectiveCameras,
|
7 |
+
TexturesVertex,
|
8 |
+
PointLights,
|
9 |
+
Materials,
|
10 |
+
RasterizationSettings,
|
11 |
+
MeshRenderer,
|
12 |
+
MeshRasterizer,
|
13 |
+
SoftPhongShader,
|
14 |
+
)
|
15 |
+
from pytorch3d.structures import Meshes
|
16 |
+
from pytorch3d.structures.meshes import join_meshes_as_scene
|
17 |
+
from pytorch3d.renderer.cameras import look_at_rotation
|
18 |
+
|
19 |
+
from .tools import get_colors, checkerboard_geometry
|
20 |
+
|
21 |
+
|
22 |
+
def overlay_image_onto_background(image, mask, bbox, background):
|
23 |
+
if isinstance(image, torch.Tensor):
|
24 |
+
image = image.detach().cpu().numpy()
|
25 |
+
if isinstance(mask, torch.Tensor):
|
26 |
+
mask = mask.detach().cpu().numpy()
|
27 |
+
|
28 |
+
out_image = background.copy()
|
29 |
+
bbox = bbox[0].int().cpu().numpy().copy()
|
30 |
+
roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
31 |
+
|
32 |
+
roi_image[mask] = image[mask]
|
33 |
+
out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image
|
34 |
+
|
35 |
+
return out_image
|
36 |
+
|
37 |
+
|
38 |
+
def update_intrinsics_from_bbox(K_org, bbox):
|
39 |
+
device, dtype = K_org.device, K_org.dtype
|
40 |
+
|
41 |
+
K = torch.zeros((K_org.shape[0], 4, 4)
|
42 |
+
).to(device=device, dtype=dtype)
|
43 |
+
K[:, :3, :3] = K_org.clone()
|
44 |
+
K[:, 2, 2] = 0
|
45 |
+
K[:, 2, -1] = 1
|
46 |
+
K[:, -1, 2] = 1
|
47 |
+
|
48 |
+
image_sizes = []
|
49 |
+
for idx, bbox in enumerate(bbox):
|
50 |
+
left, upper, right, lower = bbox
|
51 |
+
cx, cy = K[idx, 0, 2], K[idx, 1, 2]
|
52 |
+
|
53 |
+
new_cx = cx - left
|
54 |
+
new_cy = cy - upper
|
55 |
+
new_height = max(lower - upper, 1)
|
56 |
+
new_width = max(right - left, 1)
|
57 |
+
new_cx = new_width - new_cx
|
58 |
+
new_cy = new_height - new_cy
|
59 |
+
|
60 |
+
K[idx, 0, 2] = new_cx
|
61 |
+
K[idx, 1, 2] = new_cy
|
62 |
+
image_sizes.append((int(new_height), int(new_width)))
|
63 |
+
|
64 |
+
return K, image_sizes
|
65 |
+
|
66 |
+
|
67 |
+
def perspective_projection(x3d, K, R=None, T=None):
|
68 |
+
if R != None:
|
69 |
+
x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
|
70 |
+
if T != None:
|
71 |
+
x3d = x3d + T.transpose(1, 2)
|
72 |
+
|
73 |
+
x2d = torch.div(x3d, x3d[..., 2:])
|
74 |
+
x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
|
75 |
+
return x2d
|
76 |
+
|
77 |
+
|
78 |
+
def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
|
79 |
+
left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
|
80 |
+
right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
|
81 |
+
top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
|
82 |
+
bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
|
83 |
+
|
84 |
+
cx = (left + right) / 2
|
85 |
+
cy = (top + bottom) / 2
|
86 |
+
width = (right - left)
|
87 |
+
height = (bottom - top)
|
88 |
+
|
89 |
+
new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1)
|
90 |
+
new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w)
|
91 |
+
new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1)
|
92 |
+
new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
|
93 |
+
|
94 |
+
bbox = torch.stack((new_left.detach(), new_top.detach(),
|
95 |
+
new_right.detach(), new_bottom.detach())).int().float().T
|
96 |
+
|
97 |
+
return bbox
|
98 |
+
|
99 |
+
|
100 |
+
class Renderer():
|
101 |
+
def __init__(self, width, height, focal_length, device, faces=None):
|
102 |
+
|
103 |
+
self.width = width
|
104 |
+
self.height = height
|
105 |
+
self.focal_length = focal_length
|
106 |
+
|
107 |
+
self.device = device
|
108 |
+
if faces is not None:
|
109 |
+
self.faces = torch.from_numpy(
|
110 |
+
(faces).astype('int')
|
111 |
+
).unsqueeze(0).to(self.device)
|
112 |
+
|
113 |
+
self.initialize_camera_params()
|
114 |
+
self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
|
115 |
+
self.create_renderer()
|
116 |
+
|
117 |
+
def create_renderer(self):
|
118 |
+
self.renderer = MeshRenderer(
|
119 |
+
rasterizer=MeshRasterizer(
|
120 |
+
raster_settings=RasterizationSettings(
|
121 |
+
image_size=self.image_sizes[0],
|
122 |
+
blur_radius=1e-5),
|
123 |
+
),
|
124 |
+
shader=SoftPhongShader(
|
125 |
+
device=self.device,
|
126 |
+
lights=self.lights,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
def create_camera(self, R=None, T=None):
|
131 |
+
if R is not None:
|
132 |
+
self.R = R.clone().view(1, 3, 3).to(self.device)
|
133 |
+
if T is not None:
|
134 |
+
self.T = T.clone().view(1, 3).to(self.device)
|
135 |
+
|
136 |
+
return PerspectiveCameras(
|
137 |
+
device=self.device,
|
138 |
+
R=self.R.mT,
|
139 |
+
T=self.T,
|
140 |
+
K=self.K_full,
|
141 |
+
image_size=self.image_sizes,
|
142 |
+
in_ndc=False)
|
143 |
+
|
144 |
+
|
145 |
+
def initialize_camera_params(self):
|
146 |
+
"""Hard coding for camera parameters
|
147 |
+
TODO: Do some soft coding"""
|
148 |
+
|
149 |
+
# Extrinsics
|
150 |
+
self.R = torch.diag(
|
151 |
+
torch.tensor([1, 1, 1])
|
152 |
+
).float().to(self.device).unsqueeze(0)
|
153 |
+
|
154 |
+
self.T = torch.tensor(
|
155 |
+
[0, 0, 0]
|
156 |
+
).unsqueeze(0).float().to(self.device)
|
157 |
+
|
158 |
+
# Intrinsics
|
159 |
+
self.K = torch.tensor(
|
160 |
+
[[self.focal_length, 0, self.width/2],
|
161 |
+
[0, self.focal_length, self.height/2],
|
162 |
+
[0, 0, 1]]
|
163 |
+
).unsqueeze(0).float().to(self.device)
|
164 |
+
self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
|
165 |
+
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
|
166 |
+
self.cameras = self.create_camera()
|
167 |
+
|
168 |
+
|
169 |
+
def set_ground(self, length, center_x, center_z):
|
170 |
+
device = self.device
|
171 |
+
v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y"))
|
172 |
+
v, f, vc = v.to(device), f.to(device), vc.to(device)
|
173 |
+
self.ground_geometry = [v, f, vc]
|
174 |
+
|
175 |
+
|
176 |
+
def update_bbox(self, x3d, scale=2.0, mask=None):
|
177 |
+
""" Update bbox of cameras from the given 3d points
|
178 |
+
|
179 |
+
x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
|
180 |
+
"""
|
181 |
+
|
182 |
+
if x3d.size(-1) != 3:
|
183 |
+
x2d = x3d.unsqueeze(0)
|
184 |
+
else:
|
185 |
+
x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
|
186 |
+
|
187 |
+
if mask is not None:
|
188 |
+
x2d = x2d[:, ~mask]
|
189 |
+
|
190 |
+
bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
|
191 |
+
self.bboxes = bbox
|
192 |
+
|
193 |
+
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
|
194 |
+
self.cameras = self.create_camera()
|
195 |
+
self.create_renderer()
|
196 |
+
|
197 |
+
def reset_bbox(self,):
|
198 |
+
bbox = torch.zeros((1, 4)).float().to(self.device)
|
199 |
+
bbox[0, 2] = self.width
|
200 |
+
bbox[0, 3] = self.height
|
201 |
+
self.bboxes = bbox
|
202 |
+
|
203 |
+
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
|
204 |
+
self.cameras = self.create_camera()
|
205 |
+
self.create_renderer()
|
206 |
+
|
207 |
+
def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]):
|
208 |
+
self.update_bbox(vertices[::50], scale=1.2)
|
209 |
+
vertices = vertices.unsqueeze(0)
|
210 |
+
|
211 |
+
if colors[0] > 1: colors = [c / 255. for c in colors]
|
212 |
+
verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
|
213 |
+
verts_features = verts_features.repeat(1, vertices.shape[1], 1)
|
214 |
+
textures = TexturesVertex(verts_features=verts_features)
|
215 |
+
|
216 |
+
mesh = Meshes(verts=vertices,
|
217 |
+
faces=self.faces,
|
218 |
+
textures=textures,)
|
219 |
+
|
220 |
+
materials = Materials(
|
221 |
+
device=self.device,
|
222 |
+
specular_color=(colors, ),
|
223 |
+
shininess=0
|
224 |
+
)
|
225 |
+
|
226 |
+
results = torch.flip(
|
227 |
+
self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights),
|
228 |
+
[1, 2]
|
229 |
+
)
|
230 |
+
image = results[0, ..., :3] * 255
|
231 |
+
mask = results[0, ..., -1] > 1e-3
|
232 |
+
|
233 |
+
image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
|
234 |
+
self.reset_bbox()
|
235 |
+
return image
|
236 |
+
|
237 |
+
|
238 |
+
def render_with_ground(self, verts, faces, colors, cameras, lights):
|
239 |
+
"""
|
240 |
+
:param verts (B, V, 3)
|
241 |
+
:param faces (F, 3)
|
242 |
+
:param colors (B, 3)
|
243 |
+
"""
|
244 |
+
|
245 |
+
# (B, V, 3), (B, F, 3), (B, V, 3)
|
246 |
+
verts, faces, colors = prep_shared_geometry(verts, faces, colors)
|
247 |
+
# (V, 3), (F, 3), (V, 3)
|
248 |
+
gv, gf, gc = self.ground_geometry
|
249 |
+
verts = list(torch.unbind(verts, dim=0)) + [gv]
|
250 |
+
faces = list(torch.unbind(faces, dim=0)) + [gf]
|
251 |
+
colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
|
252 |
+
mesh = create_meshes(verts, faces, colors)
|
253 |
+
|
254 |
+
materials = Materials(
|
255 |
+
device=self.device,
|
256 |
+
shininess=0
|
257 |
+
)
|
258 |
+
|
259 |
+
results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
|
260 |
+
image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
261 |
+
|
262 |
+
return image
|
263 |
+
|
264 |
+
|
265 |
+
def prep_shared_geometry(verts, faces, colors):
|
266 |
+
"""
|
267 |
+
:param verts (B, V, 3)
|
268 |
+
:param faces (F, 3)
|
269 |
+
:param colors (B, 4)
|
270 |
+
"""
|
271 |
+
B, V, _ = verts.shape
|
272 |
+
F, _ = faces.shape
|
273 |
+
colors = colors.unsqueeze(1).expand(B, V, -1)[..., :3]
|
274 |
+
faces = faces.unsqueeze(0).expand(B, F, -1)
|
275 |
+
return verts, faces, colors
|
276 |
+
|
277 |
+
|
278 |
+
def create_meshes(verts, faces, colors):
|
279 |
+
"""
|
280 |
+
:param verts (B, V, 3)
|
281 |
+
:param faces (B, F, 3)
|
282 |
+
:param colors (B, V, 3)
|
283 |
+
"""
|
284 |
+
textures = TexturesVertex(verts_features=colors)
|
285 |
+
meshes = Meshes(verts=verts, faces=faces, textures=textures)
|
286 |
+
return join_meshes_as_scene(meshes)
|
287 |
+
|
288 |
+
|
289 |
+
def get_global_cameras(verts, device, distance=5, position=(-5.0, 5.0, 0.0)):
|
290 |
+
positions = torch.tensor([position]).repeat(len(verts), 1)
|
291 |
+
targets = verts.mean(1)
|
292 |
+
|
293 |
+
directions = targets - positions
|
294 |
+
directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
|
295 |
+
positions = targets - directions
|
296 |
+
|
297 |
+
rotation = look_at_rotation(positions, targets, ).mT
|
298 |
+
translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
|
299 |
+
|
300 |
+
lights = PointLights(device=device, location=[position])
|
301 |
+
return rotation, translation, lights
|
302 |
+
|
303 |
+
|
304 |
+
def _get_global_cameras(verts, device, min_distance=3, chunk_size=100):
|
305 |
+
|
306 |
+
# split into smaller chunks to visualize
|
307 |
+
start_idxs = list(range(0, len(verts), chunk_size))
|
308 |
+
end_idxs = [min(start_idx + chunk_size, len(verts)) for start_idx in start_idxs]
|
309 |
+
|
310 |
+
Rs, Ts = [], []
|
311 |
+
for start_idx, end_idx in zip(start_idxs, end_idxs):
|
312 |
+
vert = verts[start_idx:end_idx].clone()
|
313 |
+
import pdb; pdb.set_trace()
|
lib/vis/run_vis.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
from progress.bar import Bar
|
9 |
+
|
10 |
+
from lib.vis.renderer import Renderer, get_global_cameras
|
11 |
+
|
12 |
+
def run_vis_on_demo(cfg, video, results, output_pth, smpl, vis_global=True):
|
13 |
+
# to torch tensor
|
14 |
+
tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE)
|
15 |
+
|
16 |
+
cap = cv2.VideoCapture(video)
|
17 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
18 |
+
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
19 |
+
width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
20 |
+
|
21 |
+
# create renderer with cliff focal length estimation
|
22 |
+
focal_length = (width ** 2 + height ** 2) ** 0.5
|
23 |
+
renderer = Renderer(width, height, focal_length, cfg.DEVICE, smpl.faces)
|
24 |
+
|
25 |
+
if vis_global:
|
26 |
+
# setup global coordinate subject
|
27 |
+
# current implementation only visualize the subject appeared longest
|
28 |
+
n_frames = {k: len(results[k]['frame_ids']) for k in results.keys()}
|
29 |
+
sid = max(n_frames, key=n_frames.get)
|
30 |
+
global_output = smpl.get_output(
|
31 |
+
body_pose=tt(results[sid]['pose_world'][:, 3:]),
|
32 |
+
global_orient=tt(results[sid]['pose_world'][:, :3]),
|
33 |
+
betas=tt(results[sid]['betas']),
|
34 |
+
transl=tt(results[sid]['trans_world']))
|
35 |
+
verts_glob = global_output.vertices.cpu()
|
36 |
+
verts_glob[..., 1] = verts_glob[..., 1] - verts_glob[..., 1].min()
|
37 |
+
cx, cz = (verts_glob.mean(1).max(0)[0] + verts_glob.mean(1).min(0)[0])[[0, 2]] / 2.0
|
38 |
+
sx, sz = (verts_glob.mean(1).max(0)[0] - verts_glob.mean(1).min(0)[0])[[0, 2]]
|
39 |
+
scale = max(sx.item(), sz.item()) * 1.5
|
40 |
+
|
41 |
+
# set default ground
|
42 |
+
renderer.set_ground(scale, cx.item(), cz.item())
|
43 |
+
|
44 |
+
# build global camera
|
45 |
+
global_R, global_T, global_lights = get_global_cameras(verts_glob, cfg.DEVICE)
|
46 |
+
|
47 |
+
# build default camera
|
48 |
+
default_R, default_T = torch.eye(3), torch.zeros(3)
|
49 |
+
|
50 |
+
writer = imageio.get_writer(
|
51 |
+
osp.join(output_pth, 'output.mp4'),
|
52 |
+
fps=fps, mode='I', format='FFMPEG', macro_block_size=1
|
53 |
+
)
|
54 |
+
bar = Bar('Rendering results ...', fill='#', max=length)
|
55 |
+
|
56 |
+
frame_i = 0
|
57 |
+
_global_R, _global_T = None, None
|
58 |
+
# run rendering
|
59 |
+
while (cap.isOpened()):
|
60 |
+
flag, org_img = cap.read()
|
61 |
+
if not flag: break
|
62 |
+
img = org_img[..., ::-1].copy()
|
63 |
+
|
64 |
+
# render onto the input video
|
65 |
+
renderer.create_camera(default_R, default_T)
|
66 |
+
for _id, val in results.items():
|
67 |
+
# render onto the image
|
68 |
+
frame_i2 = np.where(val['frame_ids'] == frame_i)[0]
|
69 |
+
if len(frame_i2) == 0: continue
|
70 |
+
frame_i2 = frame_i2[0]
|
71 |
+
img = renderer.render_mesh(torch.from_numpy(val['verts'][frame_i2]).to(cfg.DEVICE), img)
|
72 |
+
|
73 |
+
if vis_global:
|
74 |
+
# render the global coordinate
|
75 |
+
if frame_i in results[sid]['frame_ids']:
|
76 |
+
frame_i3 = np.where(results[sid]['frame_ids'] == frame_i)[0]
|
77 |
+
verts = verts_glob[[frame_i3]].to(cfg.DEVICE)
|
78 |
+
faces = renderer.faces.clone().squeeze(0)
|
79 |
+
colors = torch.ones((1, 4)).float().to(cfg.DEVICE); colors[..., :3] *= 0.9
|
80 |
+
|
81 |
+
if _global_R is None:
|
82 |
+
_global_R = global_R[frame_i3].clone(); _global_T = global_T[frame_i3].clone()
|
83 |
+
cameras = renderer.create_camera(global_R[frame_i3], global_T[frame_i3])
|
84 |
+
img_glob = renderer.render_with_ground(verts, faces, colors, cameras, global_lights)
|
85 |
+
|
86 |
+
try: img = np.concatenate((img, img_glob), axis=1)
|
87 |
+
except: img = np.concatenate((img, np.ones_like(img) * 255), axis=1)
|
88 |
+
|
89 |
+
writer.append_data(img)
|
90 |
+
bar.next()
|
91 |
+
frame_i += 1
|
92 |
+
writer.close()
|
lib/vis/tools.py
ADDED
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def read_image(path, scale=1):
|
9 |
+
im = Image.open(path)
|
10 |
+
if scale == 1:
|
11 |
+
return np.array(im)
|
12 |
+
W, H = im.size
|
13 |
+
w, h = int(scale * W), int(scale * H)
|
14 |
+
return np.array(im.resize((w, h), Image.ANTIALIAS))
|
15 |
+
|
16 |
+
|
17 |
+
def transform_torch3d(T_c2w):
|
18 |
+
"""
|
19 |
+
:param T_c2w (*, 4, 4)
|
20 |
+
returns (*, 3, 3), (*, 3)
|
21 |
+
"""
|
22 |
+
R1 = torch.tensor(
|
23 |
+
[[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0],], device=T_c2w.device,
|
24 |
+
)
|
25 |
+
R2 = torch.tensor(
|
26 |
+
[[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0],], device=T_c2w.device,
|
27 |
+
)
|
28 |
+
cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3]
|
29 |
+
cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1)
|
30 |
+
cam_t = torch.einsum("ij,...j->...i", R2, cam_t)
|
31 |
+
return cam_R, cam_t
|
32 |
+
|
33 |
+
|
34 |
+
def transform_pyrender(T_c2w):
|
35 |
+
"""
|
36 |
+
:param T_c2w (*, 4, 4)
|
37 |
+
"""
|
38 |
+
T_vis = torch.tensor(
|
39 |
+
[
|
40 |
+
[1.0, 0.0, 0.0, 0.0],
|
41 |
+
[0.0, -1.0, 0.0, 0.0],
|
42 |
+
[0.0, 0.0, -1.0, 0.0],
|
43 |
+
[0.0, 0.0, 0.0, 1.0],
|
44 |
+
],
|
45 |
+
device=T_c2w.device,
|
46 |
+
)
|
47 |
+
return torch.einsum(
|
48 |
+
"...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None):
|
53 |
+
"""
|
54 |
+
:param verts (B, T, V, 3)
|
55 |
+
:param faces (F, 3)
|
56 |
+
:param vis_mask (optional) (B, T) visibility of each person
|
57 |
+
:param track_ids (optional) (B,)
|
58 |
+
returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3)
|
59 |
+
where B is different depending on the visibility of the people
|
60 |
+
"""
|
61 |
+
B, T = verts.shape[:2]
|
62 |
+
device = verts.device
|
63 |
+
|
64 |
+
# (B, 3)
|
65 |
+
colors = (
|
66 |
+
track_to_colors(track_ids)
|
67 |
+
if track_ids is not None
|
68 |
+
else torch.ones(B, 3, device) * 0.5
|
69 |
+
)
|
70 |
+
|
71 |
+
# list T (B, V, 3), T (B, 3), T (F, 3)
|
72 |
+
return filter_visible_meshes(verts, colors, faces, vis_mask)
|
73 |
+
|
74 |
+
|
75 |
+
def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False):
|
76 |
+
"""
|
77 |
+
:param verts (B, T, V, 3)
|
78 |
+
:param colors (B, 3)
|
79 |
+
:param faces (F, 3)
|
80 |
+
:param vis_mask (optional tensor, default None) (B, T) ternary mask
|
81 |
+
-1 if not in frame
|
82 |
+
0 if temporarily occluded
|
83 |
+
1 if visible
|
84 |
+
:param vis_opacity (optional bool, default False)
|
85 |
+
if True, make occluded people alpha=0.5, otherwise alpha=1
|
86 |
+
returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3)
|
87 |
+
"""
|
88 |
+
# import ipdb; ipdb.set_trace()
|
89 |
+
B, T = verts.shape[:2]
|
90 |
+
faces = [faces for t in range(T)]
|
91 |
+
if vis_mask is None:
|
92 |
+
verts = [verts[:, t] for t in range(T)]
|
93 |
+
colors = [colors for t in range(T)]
|
94 |
+
return verts, colors, faces
|
95 |
+
|
96 |
+
# render occluded and visible, but not removed
|
97 |
+
vis_mask = vis_mask >= 0
|
98 |
+
if vis_opacity:
|
99 |
+
alpha = 0.5 * (vis_mask[..., None] + 1)
|
100 |
+
else:
|
101 |
+
alpha = (vis_mask[..., None] >= 0).float()
|
102 |
+
vert_list = [verts[vis_mask[:, t], t] for t in range(T)]
|
103 |
+
colors = [
|
104 |
+
torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1)
|
105 |
+
for t in range(T)
|
106 |
+
]
|
107 |
+
bounds = get_bboxes(verts, vis_mask)
|
108 |
+
return vert_list, colors, faces, bounds
|
109 |
+
|
110 |
+
|
111 |
+
def get_bboxes(verts, vis_mask):
|
112 |
+
"""
|
113 |
+
return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory
|
114 |
+
:param verts (B, T, V, 3)
|
115 |
+
:param vis_mask (B, T)
|
116 |
+
"""
|
117 |
+
B, T, *_ = verts.shape
|
118 |
+
bb_min, bb_max, mean = [], [], []
|
119 |
+
for b in range(B):
|
120 |
+
v = verts[b, vis_mask[b, :T]] # (Tb, V, 3)
|
121 |
+
bb_min.append(v.amin(dim=(0, 1)))
|
122 |
+
bb_max.append(v.amax(dim=(0, 1)))
|
123 |
+
mean.append(v.mean(dim=(0, 1)))
|
124 |
+
bb_min = torch.stack(bb_min, dim=0)
|
125 |
+
bb_max = torch.stack(bb_max, dim=0)
|
126 |
+
mean = torch.stack(mean, dim=0)
|
127 |
+
# point to a track that's long and close to the camera
|
128 |
+
zs = mean[:, 2]
|
129 |
+
counts = vis_mask[:, :T].sum(dim=-1) # (B,)
|
130 |
+
mask = counts < 0.8 * T
|
131 |
+
zs[mask] = torch.inf
|
132 |
+
sel = torch.argmin(zs)
|
133 |
+
return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel]
|
134 |
+
|
135 |
+
|
136 |
+
def track_to_colors(track_ids):
|
137 |
+
"""
|
138 |
+
:param track_ids (B)
|
139 |
+
"""
|
140 |
+
color_map = torch.from_numpy(get_colors()).to(track_ids)
|
141 |
+
return color_map[track_ids] / 255 # (B, 3)
|
142 |
+
|
143 |
+
|
144 |
+
def get_colors():
|
145 |
+
# color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt"))
|
146 |
+
color_file = os.path.abspath(os.path.join(__file__, "../colors.txt"))
|
147 |
+
RGB_tuples = np.vstack(
|
148 |
+
[
|
149 |
+
np.loadtxt(color_file, skiprows=0),
|
150 |
+
# np.loadtxt(color_file, skiprows=1),
|
151 |
+
np.random.uniform(0, 255, size=(10000, 3)),
|
152 |
+
[[0, 0, 0]],
|
153 |
+
]
|
154 |
+
)
|
155 |
+
b = np.where(RGB_tuples == 0)
|
156 |
+
RGB_tuples[b] = 1
|
157 |
+
return RGB_tuples.astype(np.float32)
|
158 |
+
|
159 |
+
|
160 |
+
def checkerboard_geometry(
|
161 |
+
length=12.0,
|
162 |
+
color0=[0.8, 0.9, 0.9],
|
163 |
+
color1=[0.6, 0.7, 0.7],
|
164 |
+
tile_width=0.5,
|
165 |
+
alpha=1.0,
|
166 |
+
up="y",
|
167 |
+
c1=0.0,
|
168 |
+
c2=0.0,
|
169 |
+
):
|
170 |
+
assert up == "y" or up == "z"
|
171 |
+
color0 = np.array(color0 + [alpha])
|
172 |
+
color1 = np.array(color1 + [alpha])
|
173 |
+
radius = length / 2.0
|
174 |
+
num_rows = num_cols = max(2, int(length / tile_width))
|
175 |
+
vertices = []
|
176 |
+
vert_colors = []
|
177 |
+
faces = []
|
178 |
+
face_colors = []
|
179 |
+
for i in range(num_rows):
|
180 |
+
for j in range(num_cols):
|
181 |
+
u0, v0 = j * tile_width - radius, i * tile_width - radius
|
182 |
+
us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
|
183 |
+
vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
|
184 |
+
zs = np.zeros(4)
|
185 |
+
if up == "y":
|
186 |
+
cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
|
187 |
+
cur_verts[:, 0] += c1
|
188 |
+
cur_verts[:, 2] += c2
|
189 |
+
else:
|
190 |
+
cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
|
191 |
+
cur_verts[:, 0] += c1
|
192 |
+
cur_verts[:, 1] += c2
|
193 |
+
|
194 |
+
cur_faces = np.array(
|
195 |
+
[[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64
|
196 |
+
)
|
197 |
+
cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
|
198 |
+
use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
|
199 |
+
cur_color = color0 if use_color0 else color1
|
200 |
+
cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
|
201 |
+
|
202 |
+
vertices.append(cur_verts)
|
203 |
+
faces.append(cur_faces)
|
204 |
+
vert_colors.append(cur_colors)
|
205 |
+
face_colors.append(cur_colors)
|
206 |
+
|
207 |
+
vertices = np.concatenate(vertices, axis=0).astype(np.float32)
|
208 |
+
vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
|
209 |
+
faces = np.concatenate(faces, axis=0).astype(np.float32)
|
210 |
+
face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
|
211 |
+
|
212 |
+
return vertices, faces, vert_colors, face_colors
|
213 |
+
|
214 |
+
|
215 |
+
def camera_marker_geometry(radius, height, up):
|
216 |
+
assert up == "y" or up == "z"
|
217 |
+
if up == "y":
|
218 |
+
vertices = np.array(
|
219 |
+
[
|
220 |
+
[-radius, -radius, 0],
|
221 |
+
[radius, -radius, 0],
|
222 |
+
[radius, radius, 0],
|
223 |
+
[-radius, radius, 0],
|
224 |
+
[0, 0, height],
|
225 |
+
]
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
vertices = np.array(
|
229 |
+
[
|
230 |
+
[-radius, 0, -radius],
|
231 |
+
[radius, 0, -radius],
|
232 |
+
[radius, 0, radius],
|
233 |
+
[-radius, 0, radius],
|
234 |
+
[0, -height, 0],
|
235 |
+
]
|
236 |
+
)
|
237 |
+
|
238 |
+
faces = np.array(
|
239 |
+
[[0, 3, 1], [1, 3, 2], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4],]
|
240 |
+
)
|
241 |
+
|
242 |
+
face_colors = np.array(
|
243 |
+
[
|
244 |
+
[1.0, 1.0, 1.0, 1.0],
|
245 |
+
[1.0, 1.0, 1.0, 1.0],
|
246 |
+
[0.0, 1.0, 0.0, 1.0],
|
247 |
+
[1.0, 0.0, 0.0, 1.0],
|
248 |
+
[0.0, 1.0, 0.0, 1.0],
|
249 |
+
[1.0, 0.0, 0.0, 1.0],
|
250 |
+
]
|
251 |
+
)
|
252 |
+
return vertices, faces, face_colors
|
253 |
+
|
254 |
+
|
255 |
+
def vis_keypoints(
|
256 |
+
keypts_list,
|
257 |
+
img_size,
|
258 |
+
radius=6,
|
259 |
+
thickness=3,
|
260 |
+
kpt_score_thr=0.3,
|
261 |
+
dataset="TopDownCocoDataset",
|
262 |
+
):
|
263 |
+
"""
|
264 |
+
Visualize keypoints
|
265 |
+
From ViTPose/mmpose/apis/inference.py
|
266 |
+
"""
|
267 |
+
palette = np.array(
|
268 |
+
[
|
269 |
+
[255, 128, 0],
|
270 |
+
[255, 153, 51],
|
271 |
+
[255, 178, 102],
|
272 |
+
[230, 230, 0],
|
273 |
+
[255, 153, 255],
|
274 |
+
[153, 204, 255],
|
275 |
+
[255, 102, 255],
|
276 |
+
[255, 51, 255],
|
277 |
+
[102, 178, 255],
|
278 |
+
[51, 153, 255],
|
279 |
+
[255, 153, 153],
|
280 |
+
[255, 102, 102],
|
281 |
+
[255, 51, 51],
|
282 |
+
[153, 255, 153],
|
283 |
+
[102, 255, 102],
|
284 |
+
[51, 255, 51],
|
285 |
+
[0, 255, 0],
|
286 |
+
[0, 0, 255],
|
287 |
+
[255, 0, 0],
|
288 |
+
[255, 255, 255],
|
289 |
+
]
|
290 |
+
)
|
291 |
+
|
292 |
+
if dataset in (
|
293 |
+
"TopDownCocoDataset",
|
294 |
+
"BottomUpCocoDataset",
|
295 |
+
"TopDownOCHumanDataset",
|
296 |
+
"AnimalMacaqueDataset",
|
297 |
+
):
|
298 |
+
# show the results
|
299 |
+
skeleton = [
|
300 |
+
[15, 13],
|
301 |
+
[13, 11],
|
302 |
+
[16, 14],
|
303 |
+
[14, 12],
|
304 |
+
[11, 12],
|
305 |
+
[5, 11],
|
306 |
+
[6, 12],
|
307 |
+
[5, 6],
|
308 |
+
[5, 7],
|
309 |
+
[6, 8],
|
310 |
+
[7, 9],
|
311 |
+
[8, 10],
|
312 |
+
[1, 2],
|
313 |
+
[0, 1],
|
314 |
+
[0, 2],
|
315 |
+
[1, 3],
|
316 |
+
[2, 4],
|
317 |
+
[3, 5],
|
318 |
+
[4, 6],
|
319 |
+
]
|
320 |
+
|
321 |
+
pose_link_color = palette[
|
322 |
+
[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
|
323 |
+
]
|
324 |
+
pose_kpt_color = palette[
|
325 |
+
[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
|
326 |
+
]
|
327 |
+
|
328 |
+
elif dataset == "TopDownCocoWholeBodyDataset":
|
329 |
+
# show the results
|
330 |
+
skeleton = [
|
331 |
+
[15, 13],
|
332 |
+
[13, 11],
|
333 |
+
[16, 14],
|
334 |
+
[14, 12],
|
335 |
+
[11, 12],
|
336 |
+
[5, 11],
|
337 |
+
[6, 12],
|
338 |
+
[5, 6],
|
339 |
+
[5, 7],
|
340 |
+
[6, 8],
|
341 |
+
[7, 9],
|
342 |
+
[8, 10],
|
343 |
+
[1, 2],
|
344 |
+
[0, 1],
|
345 |
+
[0, 2],
|
346 |
+
[1, 3],
|
347 |
+
[2, 4],
|
348 |
+
[3, 5],
|
349 |
+
[4, 6],
|
350 |
+
[15, 17],
|
351 |
+
[15, 18],
|
352 |
+
[15, 19],
|
353 |
+
[16, 20],
|
354 |
+
[16, 21],
|
355 |
+
[16, 22],
|
356 |
+
[91, 92],
|
357 |
+
[92, 93],
|
358 |
+
[93, 94],
|
359 |
+
[94, 95],
|
360 |
+
[91, 96],
|
361 |
+
[96, 97],
|
362 |
+
[97, 98],
|
363 |
+
[98, 99],
|
364 |
+
[91, 100],
|
365 |
+
[100, 101],
|
366 |
+
[101, 102],
|
367 |
+
[102, 103],
|
368 |
+
[91, 104],
|
369 |
+
[104, 105],
|
370 |
+
[105, 106],
|
371 |
+
[106, 107],
|
372 |
+
[91, 108],
|
373 |
+
[108, 109],
|
374 |
+
[109, 110],
|
375 |
+
[110, 111],
|
376 |
+
[112, 113],
|
377 |
+
[113, 114],
|
378 |
+
[114, 115],
|
379 |
+
[115, 116],
|
380 |
+
[112, 117],
|
381 |
+
[117, 118],
|
382 |
+
[118, 119],
|
383 |
+
[119, 120],
|
384 |
+
[112, 121],
|
385 |
+
[121, 122],
|
386 |
+
[122, 123],
|
387 |
+
[123, 124],
|
388 |
+
[112, 125],
|
389 |
+
[125, 126],
|
390 |
+
[126, 127],
|
391 |
+
[127, 128],
|
392 |
+
[112, 129],
|
393 |
+
[129, 130],
|
394 |
+
[130, 131],
|
395 |
+
[131, 132],
|
396 |
+
]
|
397 |
+
|
398 |
+
pose_link_color = palette[
|
399 |
+
[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
|
400 |
+
+ [16, 16, 16, 16, 16, 16]
|
401 |
+
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
|
402 |
+
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
|
403 |
+
]
|
404 |
+
pose_kpt_color = palette[
|
405 |
+
[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
|
406 |
+
+ [0, 0, 0, 0, 0, 0]
|
407 |
+
+ [19] * (68 + 42)
|
408 |
+
]
|
409 |
+
|
410 |
+
elif dataset == "TopDownAicDataset":
|
411 |
+
skeleton = [
|
412 |
+
[2, 1],
|
413 |
+
[1, 0],
|
414 |
+
[0, 13],
|
415 |
+
[13, 3],
|
416 |
+
[3, 4],
|
417 |
+
[4, 5],
|
418 |
+
[8, 7],
|
419 |
+
[7, 6],
|
420 |
+
[6, 9],
|
421 |
+
[9, 10],
|
422 |
+
[10, 11],
|
423 |
+
[12, 13],
|
424 |
+
[0, 6],
|
425 |
+
[3, 9],
|
426 |
+
]
|
427 |
+
|
428 |
+
pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]]
|
429 |
+
pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]]
|
430 |
+
|
431 |
+
elif dataset == "TopDownMpiiDataset":
|
432 |
+
skeleton = [
|
433 |
+
[0, 1],
|
434 |
+
[1, 2],
|
435 |
+
[2, 6],
|
436 |
+
[6, 3],
|
437 |
+
[3, 4],
|
438 |
+
[4, 5],
|
439 |
+
[6, 7],
|
440 |
+
[7, 8],
|
441 |
+
[8, 9],
|
442 |
+
[8, 12],
|
443 |
+
[12, 11],
|
444 |
+
[11, 10],
|
445 |
+
[8, 13],
|
446 |
+
[13, 14],
|
447 |
+
[14, 15],
|
448 |
+
]
|
449 |
+
|
450 |
+
pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]]
|
451 |
+
pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]]
|
452 |
+
|
453 |
+
elif dataset == "TopDownMpiiTrbDataset":
|
454 |
+
skeleton = [
|
455 |
+
[12, 13],
|
456 |
+
[13, 0],
|
457 |
+
[13, 1],
|
458 |
+
[0, 2],
|
459 |
+
[1, 3],
|
460 |
+
[2, 4],
|
461 |
+
[3, 5],
|
462 |
+
[0, 6],
|
463 |
+
[1, 7],
|
464 |
+
[6, 7],
|
465 |
+
[6, 8],
|
466 |
+
[7, 9],
|
467 |
+
[8, 10],
|
468 |
+
[9, 11],
|
469 |
+
[14, 15],
|
470 |
+
[16, 17],
|
471 |
+
[18, 19],
|
472 |
+
[20, 21],
|
473 |
+
[22, 23],
|
474 |
+
[24, 25],
|
475 |
+
[26, 27],
|
476 |
+
[28, 29],
|
477 |
+
[30, 31],
|
478 |
+
[32, 33],
|
479 |
+
[34, 35],
|
480 |
+
[36, 37],
|
481 |
+
[38, 39],
|
482 |
+
]
|
483 |
+
|
484 |
+
pose_link_color = palette[[16] * 14 + [19] * 13]
|
485 |
+
pose_kpt_color = palette[[16] * 14 + [0] * 26]
|
486 |
+
|
487 |
+
elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
|
488 |
+
skeleton = [
|
489 |
+
[0, 1],
|
490 |
+
[1, 2],
|
491 |
+
[2, 3],
|
492 |
+
[3, 4],
|
493 |
+
[0, 5],
|
494 |
+
[5, 6],
|
495 |
+
[6, 7],
|
496 |
+
[7, 8],
|
497 |
+
[0, 9],
|
498 |
+
[9, 10],
|
499 |
+
[10, 11],
|
500 |
+
[11, 12],
|
501 |
+
[0, 13],
|
502 |
+
[13, 14],
|
503 |
+
[14, 15],
|
504 |
+
[15, 16],
|
505 |
+
[0, 17],
|
506 |
+
[17, 18],
|
507 |
+
[18, 19],
|
508 |
+
[19, 20],
|
509 |
+
]
|
510 |
+
|
511 |
+
pose_link_color = palette[
|
512 |
+
[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
|
513 |
+
]
|
514 |
+
pose_kpt_color = palette[
|
515 |
+
[0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
|
516 |
+
]
|
517 |
+
|
518 |
+
elif dataset == "InterHand2DDataset":
|
519 |
+
skeleton = [
|
520 |
+
[0, 1],
|
521 |
+
[1, 2],
|
522 |
+
[2, 3],
|
523 |
+
[4, 5],
|
524 |
+
[5, 6],
|
525 |
+
[6, 7],
|
526 |
+
[8, 9],
|
527 |
+
[9, 10],
|
528 |
+
[10, 11],
|
529 |
+
[12, 13],
|
530 |
+
[13, 14],
|
531 |
+
[14, 15],
|
532 |
+
[16, 17],
|
533 |
+
[17, 18],
|
534 |
+
[18, 19],
|
535 |
+
[3, 20],
|
536 |
+
[7, 20],
|
537 |
+
[11, 20],
|
538 |
+
[15, 20],
|
539 |
+
[19, 20],
|
540 |
+
]
|
541 |
+
|
542 |
+
pose_link_color = palette[
|
543 |
+
[0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]
|
544 |
+
]
|
545 |
+
pose_kpt_color = palette[
|
546 |
+
[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]
|
547 |
+
]
|
548 |
+
|
549 |
+
elif dataset == "Face300WDataset":
|
550 |
+
# show the results
|
551 |
+
skeleton = []
|
552 |
+
|
553 |
+
pose_link_color = palette[[]]
|
554 |
+
pose_kpt_color = palette[[19] * 68]
|
555 |
+
kpt_score_thr = 0
|
556 |
+
|
557 |
+
elif dataset == "FaceAFLWDataset":
|
558 |
+
# show the results
|
559 |
+
skeleton = []
|
560 |
+
|
561 |
+
pose_link_color = palette[[]]
|
562 |
+
pose_kpt_color = palette[[19] * 19]
|
563 |
+
kpt_score_thr = 0
|
564 |
+
|
565 |
+
elif dataset == "FaceCOFWDataset":
|
566 |
+
# show the results
|
567 |
+
skeleton = []
|
568 |
+
|
569 |
+
pose_link_color = palette[[]]
|
570 |
+
pose_kpt_color = palette[[19] * 29]
|
571 |
+
kpt_score_thr = 0
|
572 |
+
|
573 |
+
elif dataset == "FaceWFLWDataset":
|
574 |
+
# show the results
|
575 |
+
skeleton = []
|
576 |
+
|
577 |
+
pose_link_color = palette[[]]
|
578 |
+
pose_kpt_color = palette[[19] * 98]
|
579 |
+
kpt_score_thr = 0
|
580 |
+
|
581 |
+
elif dataset == "AnimalHorse10Dataset":
|
582 |
+
skeleton = [
|
583 |
+
[0, 1],
|
584 |
+
[1, 12],
|
585 |
+
[12, 16],
|
586 |
+
[16, 21],
|
587 |
+
[21, 17],
|
588 |
+
[17, 11],
|
589 |
+
[11, 10],
|
590 |
+
[10, 8],
|
591 |
+
[8, 9],
|
592 |
+
[9, 12],
|
593 |
+
[2, 3],
|
594 |
+
[3, 4],
|
595 |
+
[5, 6],
|
596 |
+
[6, 7],
|
597 |
+
[13, 14],
|
598 |
+
[14, 15],
|
599 |
+
[18, 19],
|
600 |
+
[19, 20],
|
601 |
+
]
|
602 |
+
|
603 |
+
pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2]
|
604 |
+
pose_kpt_color = palette[
|
605 |
+
[4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]
|
606 |
+
]
|
607 |
+
|
608 |
+
elif dataset == "AnimalFlyDataset":
|
609 |
+
skeleton = [
|
610 |
+
[1, 0],
|
611 |
+
[2, 0],
|
612 |
+
[3, 0],
|
613 |
+
[4, 3],
|
614 |
+
[5, 4],
|
615 |
+
[7, 6],
|
616 |
+
[8, 7],
|
617 |
+
[9, 8],
|
618 |
+
[11, 10],
|
619 |
+
[12, 11],
|
620 |
+
[13, 12],
|
621 |
+
[15, 14],
|
622 |
+
[16, 15],
|
623 |
+
[17, 16],
|
624 |
+
[19, 18],
|
625 |
+
[20, 19],
|
626 |
+
[21, 20],
|
627 |
+
[23, 22],
|
628 |
+
[24, 23],
|
629 |
+
[25, 24],
|
630 |
+
[27, 26],
|
631 |
+
[28, 27],
|
632 |
+
[29, 28],
|
633 |
+
[30, 3],
|
634 |
+
[31, 3],
|
635 |
+
]
|
636 |
+
|
637 |
+
pose_link_color = palette[[0] * 25]
|
638 |
+
pose_kpt_color = palette[[0] * 32]
|
639 |
+
|
640 |
+
elif dataset == "AnimalLocustDataset":
|
641 |
+
skeleton = [
|
642 |
+
[1, 0],
|
643 |
+
[2, 1],
|
644 |
+
[3, 2],
|
645 |
+
[4, 3],
|
646 |
+
[6, 5],
|
647 |
+
[7, 6],
|
648 |
+
[9, 8],
|
649 |
+
[10, 9],
|
650 |
+
[11, 10],
|
651 |
+
[13, 12],
|
652 |
+
[14, 13],
|
653 |
+
[15, 14],
|
654 |
+
[17, 16],
|
655 |
+
[18, 17],
|
656 |
+
[19, 18],
|
657 |
+
[21, 20],
|
658 |
+
[22, 21],
|
659 |
+
[24, 23],
|
660 |
+
[25, 24],
|
661 |
+
[26, 25],
|
662 |
+
[28, 27],
|
663 |
+
[29, 28],
|
664 |
+
[30, 29],
|
665 |
+
[32, 31],
|
666 |
+
[33, 32],
|
667 |
+
[34, 33],
|
668 |
+
]
|
669 |
+
|
670 |
+
pose_link_color = palette[[0] * 26]
|
671 |
+
pose_kpt_color = palette[[0] * 35]
|
672 |
+
|
673 |
+
elif dataset == "AnimalZebraDataset":
|
674 |
+
skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]]
|
675 |
+
|
676 |
+
pose_link_color = palette[[0] * 8]
|
677 |
+
pose_kpt_color = palette[[0] * 9]
|
678 |
+
|
679 |
+
elif dataset in "AnimalPoseDataset":
|
680 |
+
skeleton = [
|
681 |
+
[0, 1],
|
682 |
+
[0, 2],
|
683 |
+
[1, 3],
|
684 |
+
[0, 4],
|
685 |
+
[1, 4],
|
686 |
+
[4, 5],
|
687 |
+
[5, 7],
|
688 |
+
[6, 7],
|
689 |
+
[5, 8],
|
690 |
+
[8, 12],
|
691 |
+
[12, 16],
|
692 |
+
[5, 9],
|
693 |
+
[9, 13],
|
694 |
+
[13, 17],
|
695 |
+
[6, 10],
|
696 |
+
[10, 14],
|
697 |
+
[14, 18],
|
698 |
+
[6, 11],
|
699 |
+
[11, 15],
|
700 |
+
[15, 19],
|
701 |
+
]
|
702 |
+
|
703 |
+
pose_link_color = palette[[0] * 20]
|
704 |
+
pose_kpt_color = palette[[0] * 20]
|
705 |
+
else:
|
706 |
+
NotImplementedError()
|
707 |
+
|
708 |
+
img_w, img_h = img_size
|
709 |
+
img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8)
|
710 |
+
img = imshow_keypoints(
|
711 |
+
img,
|
712 |
+
keypts_list,
|
713 |
+
skeleton,
|
714 |
+
kpt_score_thr,
|
715 |
+
pose_kpt_color,
|
716 |
+
pose_link_color,
|
717 |
+
radius,
|
718 |
+
thickness,
|
719 |
+
)
|
720 |
+
alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8)
|
721 |
+
return np.concatenate([img, alpha], axis=-1)
|
722 |
+
|
723 |
+
|
724 |
+
def imshow_keypoints(
|
725 |
+
img,
|
726 |
+
pose_result,
|
727 |
+
skeleton=None,
|
728 |
+
kpt_score_thr=0.3,
|
729 |
+
pose_kpt_color=None,
|
730 |
+
pose_link_color=None,
|
731 |
+
radius=4,
|
732 |
+
thickness=1,
|
733 |
+
show_keypoint_weight=False,
|
734 |
+
):
|
735 |
+
"""Draw keypoints and links on an image.
|
736 |
+
From ViTPose/mmpose/core/visualization/image.py
|
737 |
+
|
738 |
+
Args:
|
739 |
+
img (H, W, 3) array
|
740 |
+
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
741 |
+
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
742 |
+
keypoint is represented as x, y, score.
|
743 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
744 |
+
to be shown. Default: 0.3.
|
745 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
746 |
+
the keypoint will not be drawn.
|
747 |
+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
748 |
+
links will not be drawn.
|
749 |
+
thickness (int): Thickness of lines.
|
750 |
+
show_keypoint_weight (bool): If True, opacity indicates keypoint score
|
751 |
+
"""
|
752 |
+
img_h, img_w, _ = img.shape
|
753 |
+
idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
|
754 |
+
for kpts in pose_result:
|
755 |
+
kpts = np.array(kpts, copy=False)[idcs]
|
756 |
+
|
757 |
+
# draw each point on image
|
758 |
+
if pose_kpt_color is not None:
|
759 |
+
assert len(pose_kpt_color) == len(kpts)
|
760 |
+
for kid, kpt in enumerate(kpts):
|
761 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
762 |
+
if kpt_score > kpt_score_thr:
|
763 |
+
color = tuple(int(c) for c in pose_kpt_color[kid])
|
764 |
+
if show_keypoint_weight:
|
765 |
+
img_copy = img.copy()
|
766 |
+
cv2.circle(
|
767 |
+
img_copy, (int(x_coord), int(y_coord)), radius, color, -1
|
768 |
+
)
|
769 |
+
transparency = max(0, min(1, kpt_score))
|
770 |
+
cv2.addWeighted(
|
771 |
+
img_copy, transparency, img, 1 - transparency, 0, dst=img
|
772 |
+
)
|
773 |
+
else:
|
774 |
+
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
|
775 |
+
|
776 |
+
# draw links
|
777 |
+
if skeleton is not None and pose_link_color is not None:
|
778 |
+
assert len(pose_link_color) == len(skeleton)
|
779 |
+
for sk_id, sk in enumerate(skeleton):
|
780 |
+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
781 |
+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
782 |
+
if (
|
783 |
+
pos1[0] > 0
|
784 |
+
and pos1[0] < img_w
|
785 |
+
and pos1[1] > 0
|
786 |
+
and pos1[1] < img_h
|
787 |
+
and pos2[0] > 0
|
788 |
+
and pos2[0] < img_w
|
789 |
+
and pos2[1] > 0
|
790 |
+
and pos2[1] < img_h
|
791 |
+
and kpts[sk[0], 2] > kpt_score_thr
|
792 |
+
and kpts[sk[1], 2] > kpt_score_thr
|
793 |
+
):
|
794 |
+
color = tuple(int(c) for c in pose_link_color[sk_id])
|
795 |
+
if show_keypoint_weight:
|
796 |
+
img_copy = img.copy()
|
797 |
+
X = (pos1[0], pos2[0])
|
798 |
+
Y = (pos1[1], pos2[1])
|
799 |
+
mX = np.mean(X)
|
800 |
+
mY = np.mean(Y)
|
801 |
+
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
|
802 |
+
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
803 |
+
stickwidth = 2
|
804 |
+
polygon = cv2.ellipse2Poly(
|
805 |
+
(int(mX), int(mY)),
|
806 |
+
(int(length / 2), int(stickwidth)),
|
807 |
+
int(angle),
|
808 |
+
0,
|
809 |
+
360,
|
810 |
+
1,
|
811 |
+
)
|
812 |
+
cv2.fillConvexPoly(img_copy, polygon, color)
|
813 |
+
transparency = max(
|
814 |
+
0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))
|
815 |
+
)
|
816 |
+
cv2.addWeighted(
|
817 |
+
img_copy, transparency, img, 1 - transparency, 0, dst=img
|
818 |
+
)
|
819 |
+
else:
|
820 |
+
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
821 |
+
|
822 |
+
return img
|
output/demo/test19/output.mp4
ADDED
Binary file (602 kB). View file
|
|
output/demo/test19/slam_results.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb6e0b47809fe94bdc26bc99318f9eb9beccb005ec81c407887a5bd7223b5b81
|
3 |
+
size 2353
|
output/demo/test19/tracking_results.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d1b3d6e23597e07daaa1b124cba63a1ea91d3909fe2c903c3c9b2b2819ce140
|
3 |
+
size 333898
|
output/demo/test19/wham_output.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1484cca6cf2774c3c0cbefa3d47ed7f2dd04db96b05ebdd14e5a11610a415b3e
|
3 |
+
size 3167067
|