Spaces:
Build error
Build error
Commit
·
523c2b9
0
Parent(s):
Duplicate from YoonaAI/yoonaAvatarSpace
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +18 -0
- assets/garment_teaser.png +3 -0
- assets/intermediate_results.png +3 -0
- assets/teaser.gif +3 -0
- assets/thumbnail.png +3 -0
- examples/22097467bffc92d4a5c4246f7d4edb75.png +3 -0
- examples/44c0f84c957b6b9bdf77662af5bb7078.png +3 -0
- examples/5a6a25963db2f667441d5076972c207c.png +3 -0
- examples/8da7ceb94669c2f65cbd28022e1f9876.png +3 -0
- examples/923d65f767c85a42212cae13fba3750b.png +3 -0
- examples/c9856a2bc31846d684cbb965457fad59.png +3 -0
- examples/e1e7622af7074a022f5d96dc16672517.png +3 -0
- examples/fb9d20fdb93750584390599478ecf86e.png +3 -0
- .gitattributes +37 -0
- README.md +14 -0
- app.py +144 -0
- apps/ICON.py +735 -0
- apps/Normal.py +220 -0
- apps/infer.py +492 -0
- configs / icon-filter.yaml +25 -0
- configs / icon-nofilter.yaml +25 -0
- configs /pamir.yaml +24 -0
- configs /pifu.yaml +24 -0
- lib / pymaf / configs / pymaf_config.yaml +47 -0
- lib / pymaf /core / __init__.py +0 -0
- lib / pymaf /core / train_options.py +135 -0
- lib / pymaf /core /base_trainer.py +107 -0
- lib / pymaf /core /cfgs.py +100 -0
- lib / pymaf /core /constants.py +153 -0
- lib / pymaf /core /fits_dict.py +133 -0
- lib / pymaf /core /path_config.py +24 -0
- lib / pymaf /models / __init__.py +3 -0
- lib / pymaf /models / pymaf_net.py +362 -0
- lib / pymaf /models / smpl.py +92 -0
- lib / pymaf /models /hmr.py +303 -0
- lib / pymaf /models /maf_extractor.py +135 -0
- lib / pymaf /models /res_module.py +385 -0
- lib / pymaf /utils / __init__.py +0 -0
- lib / pymaf /utils / geometry.py +435 -0
- lib / pymaf /utils / imutils.py +491 -0
- lib / pymaf /utils / streamer.py +142 -0
- lib / pymaf /utils /transforms.py +78 -0
- lib / renderer / __init__.py +0 -0
- lib / renderer / camera.py +226 -0
- lib / renderer / gl / __init__.py +0 -0
- lib / renderer / gl / data / color.fs +20 -0
- lib / renderer / gl / data /color.vs +29 -0
- lib / renderer / gl / data /normal.fs +12 -0
- lib / renderer / gl / data /normal.vs +15 -0
- lib / renderer / gl / data /prt.fs +157 -0
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/*/*
|
2 |
+
data/thuman*
|
3 |
+
!data/tbfo.ttf
|
4 |
+
__pycache__
|
5 |
+
debug/
|
6 |
+
log/
|
7 |
+
.vscode
|
8 |
+
!.gitignore
|
9 |
+
force_push.sh
|
10 |
+
.idea
|
11 |
+
human_det/
|
12 |
+
kaolin/
|
13 |
+
neural_voxelization_layer/
|
14 |
+
pytorch3d/
|
15 |
+
force_push.sh
|
16 |
+
results/
|
17 |
+
gradio_cached_examples/
|
18 |
+
gradio_queue.db
|
assets/garment_teaser.png
ADDED
![]() |
Git LFS Details
|
assets/intermediate_results.png
ADDED
![]() |
Git LFS Details
|
assets/teaser.gif
ADDED
![]() |
Git LFS Details
|
assets/thumbnail.png
ADDED
![]() |
Git LFS Details
|
examples/22097467bffc92d4a5c4246f7d4edb75.png
ADDED
![]() |
Git LFS Details
|
examples/44c0f84c957b6b9bdf77662af5bb7078.png
ADDED
![]() |
Git LFS Details
|
examples/5a6a25963db2f667441d5076972c207c.png
ADDED
![]() |
Git LFS Details
|
examples/8da7ceb94669c2f65cbd28022e1f9876.png
ADDED
![]() |
Git LFS Details
|
examples/923d65f767c85a42212cae13fba3750b.png
ADDED
![]() |
Git LFS Details
|
examples/c9856a2bc31846d684cbb965457fad59.png
ADDED
![]() |
Git LFS Details
|
examples/e1e7622af7074a022f5d96dc16672517.png
ADDED
![]() |
Git LFS Details
|
examples/fb9d20fdb93750584390599478ecf86e.png
ADDED
![]() |
Git LFS Details
|
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.obj filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: YoonaAvatar
|
3 |
+
sdk: gradio
|
4 |
+
emoji: 🔥
|
5 |
+
colorFrom: red
|
6 |
+
colorTo: purple
|
7 |
+
sdk_version: 3.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: 3.8.13
|
11 |
+
duplicated_from: YoonaAI/yoonaAvatarSpace
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# install
|
2 |
+
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import gradio as gr
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import subprocess
|
10 |
+
|
11 |
+
if os.getenv('SYSTEM') == 'spaces':
|
12 |
+
subprocess.run('pip install pyembree'.split())
|
13 |
+
subprocess.run(
|
14 |
+
'pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html'.split())
|
15 |
+
subprocess.run(
|
16 |
+
'pip install https://download.is.tue.mpg.de/icon/HF/kaolin-0.11.0-cp38-cp38-linux_x86_64.whl'.split())
|
17 |
+
subprocess.run(
|
18 |
+
'pip install https://download.is.tue.mpg.de/icon/HF/pytorch3d-0.7.0-cp38-cp38-linux_x86_64.whl'.split())
|
19 |
+
subprocess.run(
|
20 |
+
'pip install git+https://github.com/YuliangXiu/neural_voxelization_layer.git'.split())
|
21 |
+
|
22 |
+
from apps.infer import generate_model
|
23 |
+
|
24 |
+
# running
|
25 |
+
|
26 |
+
description = '''
|
27 |
+
# ICON Clothed Human Digitization
|
28 |
+
### ICON: Implicit Clothed humans Obtained from Normals (CVPR 2022)
|
29 |
+
<table>
|
30 |
+
<th>
|
31 |
+
<ul>
|
32 |
+
<li><strong>Homepage</strong> <a href="http://icon.is.tue.mpg.de">icon.is.tue.mpg.de</a></li>
|
33 |
+
<li><strong>Code</strong> <a href="https://github.com/YuliangXiu/ICON">YuliangXiu/ICON</a></li>
|
34 |
+
<li><strong>Paper</strong> <a href="https://arxiv.org/abs/2112.09127">arXiv</a>, <a href="https://readpaper.com/paper/4569785684533977089">ReadPaper</a></li>
|
35 |
+
<li><strong>Chatroom</strong> <a href="https://discord.gg/Vqa7KBGRyk">Discord</a></li>
|
36 |
+
<li><strong>Colab Notebook</strong> <a href="https://colab.research.google.com/drive/1-AWeWhPvCTBX0KfMtgtMk10uPU05ihoA?usp=sharing">Google Colab</a></li>
|
37 |
+
</ul>
|
38 |
+
<a href="https://twitter.com/yuliangxiu"><img alt="Twitter Follow" src="https://img.shields.io/twitter/follow/yuliangxiu?style=social"></a>
|
39 |
+
<iframe src="https://ghbtns.com/github-btn.html?user=yuliangxiu&repo=ICON&type=star&count=true&v=2&size=small" frameborder="0" scrolling="0" width="100" height="20"></iframe>
|
40 |
+
<a href="https://youtu.be/hZd6AYin2DE"><img alt="YouTube Video Views" src="https://img.shields.io/youtube/views/hZd6AYin2DE?style=social"></a>
|
41 |
+
</th>
|
42 |
+
<th>
|
43 |
+
<iframe width="560" height="315" src="https://www.youtube.com/embed/hZd6AYin2DE" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
|
44 |
+
</th>
|
45 |
+
</table>
|
46 |
+
<h4> The reconstruction + refinement + video take about 200 seconds for single image. <span style="color:red"> If ERROR, try "Submit Image" again.</span></h4>
|
47 |
+
<details>
|
48 |
+
<summary>More</summary>
|
49 |
+
#### Citation
|
50 |
+
```
|
51 |
+
@inproceedings{xiu2022icon,
|
52 |
+
title = {{ICON}: {I}mplicit {C}lothed humans {O}btained from {N}ormals},
|
53 |
+
author = {Xiu, Yuliang and Yang, Jinlong and Tzionas, Dimitrios and Black, Michael J.},
|
54 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
55 |
+
month = {June},
|
56 |
+
year = {2022},
|
57 |
+
pages = {13296-13306}
|
58 |
+
}
|
59 |
+
```
|
60 |
+
#### Acknowledgments:
|
61 |
+
- [StyleGAN-Human, ECCV 2022](https://stylegan-human.github.io/)
|
62 |
+
- [nagolinc/styleGanHuman_and_PIFu](https://huggingface.co/spaces/nagolinc/styleGanHuman_and_PIFu)
|
63 |
+
- [radames/PIFu-Clothed-Human-Digitization](https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization)
|
64 |
+
#### Image Credits
|
65 |
+
* [Pinterest](https://www.pinterest.com/search/pins/?q=parkour&rs=sitelinks_searchbox)
|
66 |
+
#### Related works
|
67 |
+
* [ICON @ MPI](https://icon.is.tue.mpg.de/)
|
68 |
+
* [MonoPort @ USC](https://xiuyuliang.cn/monoport)
|
69 |
+
* [Phorhum @ Google](https://phorhum.github.io/)
|
70 |
+
* [PIFuHD @ Meta](https://shunsukesaito.github.io/PIFuHD/)
|
71 |
+
* [PaMIR @ Tsinghua](http://www.liuyebin.com/pamir/pamir.html)
|
72 |
+
</details>
|
73 |
+
'''
|
74 |
+
|
75 |
+
|
76 |
+
def generate_image(seed, psi):
|
77 |
+
iface = gr.Interface.load("spaces/hysts/StyleGAN-Human")
|
78 |
+
img = iface(seed, psi)
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
model_types = ['ICON', 'PIFu', 'PaMIR']
|
83 |
+
examples_names = glob.glob('examples/*.png')
|
84 |
+
examples_types = np.random.choice(
|
85 |
+
model_types, len(examples_names), p=[0.6, 0.2, 0.2])
|
86 |
+
|
87 |
+
examples = [list(item) for item in zip(examples_names, examples_types)]
|
88 |
+
|
89 |
+
with gr.Blocks() as demo:
|
90 |
+
gr.Markdown(description)
|
91 |
+
|
92 |
+
out_lst = []
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
seed = gr.inputs.Slider(
|
98 |
+
0, 1000, step=1, default=0, label='Seed (For Image Generation)')
|
99 |
+
psi = gr.inputs.Slider(
|
100 |
+
0, 2, step=0.05, default=0.7, label='Truncation psi (For Image Generation)')
|
101 |
+
radio_choice = gr.Radio(
|
102 |
+
model_types, label='Method (For Reconstruction)', value='icon-filter')
|
103 |
+
inp = gr.Image(type="filepath", label="Input Image")
|
104 |
+
with gr.Row():
|
105 |
+
btn_sample = gr.Button("Generate Image")
|
106 |
+
btn_submit = gr.Button("Submit Image")
|
107 |
+
|
108 |
+
gr.Examples(examples=examples,
|
109 |
+
inputs=[inp, radio_choice],
|
110 |
+
cache_examples=False,
|
111 |
+
fn=generate_model,
|
112 |
+
outputs=out_lst)
|
113 |
+
|
114 |
+
out_vid = gr.Video(
|
115 |
+
label="Image + Normal + SMPL Body + Clothed Human")
|
116 |
+
out_vid_download = gr.File(
|
117 |
+
label="Download Video, welcome share on Twitter with #ICON")
|
118 |
+
|
119 |
+
with gr.Column():
|
120 |
+
overlap_inp = gr.Image(
|
121 |
+
type="filepath", label="Image Normal Overlap")
|
122 |
+
out_final = gr.Model3D(
|
123 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="Clothed human")
|
124 |
+
out_final_download = gr.File(
|
125 |
+
label="Download clothed human mesh")
|
126 |
+
out_smpl = gr.Model3D(
|
127 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="SMPL body")
|
128 |
+
out_smpl_download = gr.File(label="Download SMPL body mesh")
|
129 |
+
out_smpl_npy_download = gr.File(label="Download SMPL params")
|
130 |
+
|
131 |
+
out_lst = [out_smpl, out_smpl_download, out_smpl_npy_download,
|
132 |
+
out_final, out_final_download, out_vid, out_vid_download, overlap_inp]
|
133 |
+
|
134 |
+
btn_submit.click(fn=generate_model, inputs=[
|
135 |
+
inp, radio_choice], outputs=out_lst)
|
136 |
+
btn_sample.click(fn=generate_image, inputs=[seed, psi], outputs=inp)
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
|
140 |
+
# demo.launch(debug=False, enable_queue=False,
|
141 |
+
# auth=(os.environ['USER'], os.environ['PASSWORD']),
|
142 |
+
# auth_message="Register at icon.is.tue.mpg.de to get HuggingFace username and password.")
|
143 |
+
|
144 |
+
demo.launch(debug=True, enable_queue=True)
|
apps/ICON.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
from lib.common.seg3d_lossless import Seg3dLossless
|
18 |
+
from lib.dataset.Evaluator import Evaluator
|
19 |
+
from lib.net import HGPIFuNet
|
20 |
+
from lib.common.train_util import *
|
21 |
+
from lib.common.render import Render
|
22 |
+
from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
|
23 |
+
import torch
|
24 |
+
import lib.smplx as smplx
|
25 |
+
import numpy as np
|
26 |
+
from torch import nn
|
27 |
+
from skimage.transform import resize
|
28 |
+
import pytorch_lightning as pl
|
29 |
+
|
30 |
+
torch.backends.cudnn.benchmark = True
|
31 |
+
|
32 |
+
|
33 |
+
class ICON(pl.LightningModule):
|
34 |
+
|
35 |
+
def __init__(self, cfg):
|
36 |
+
super(ICON, self).__init__()
|
37 |
+
|
38 |
+
self.cfg = cfg
|
39 |
+
self.batch_size = self.cfg.batch_size
|
40 |
+
self.lr_G = self.cfg.lr_G
|
41 |
+
|
42 |
+
self.use_sdf = cfg.sdf
|
43 |
+
self.prior_type = cfg.net.prior_type
|
44 |
+
self.mcube_res = cfg.mcube_res
|
45 |
+
self.clean_mesh_flag = cfg.clean_mesh
|
46 |
+
|
47 |
+
self.netG = HGPIFuNet(
|
48 |
+
self.cfg,
|
49 |
+
self.cfg.projection_mode,
|
50 |
+
error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
|
51 |
+
)
|
52 |
+
|
53 |
+
self.evaluator = Evaluator(
|
54 |
+
device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
|
55 |
+
|
56 |
+
self.resolutions = (np.logspace(
|
57 |
+
start=5,
|
58 |
+
stop=np.log2(self.mcube_res),
|
59 |
+
base=2,
|
60 |
+
num=int(np.log2(self.mcube_res) - 4),
|
61 |
+
endpoint=True,
|
62 |
+
) + 1.0)
|
63 |
+
self.resolutions = self.resolutions.astype(np.int16).tolist()
|
64 |
+
|
65 |
+
self.base_keys = ["smpl_verts", "smpl_faces"]
|
66 |
+
self.feat_names = self.cfg.net.smpl_feats
|
67 |
+
|
68 |
+
self.icon_keys = self.base_keys + [
|
69 |
+
f"smpl_{feat_name}" for feat_name in self.feat_names
|
70 |
+
]
|
71 |
+
self.keypoint_keys = self.base_keys + [
|
72 |
+
f"smpl_{feat_name}" for feat_name in self.feat_names
|
73 |
+
]
|
74 |
+
self.pamir_keys = [
|
75 |
+
"voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num"
|
76 |
+
]
|
77 |
+
self.pifu_keys = []
|
78 |
+
|
79 |
+
self.reconEngine = Seg3dLossless(
|
80 |
+
query_func=query_func,
|
81 |
+
b_min=[[-1.0, 1.0, -1.0]],
|
82 |
+
b_max=[[1.0, -1.0, 1.0]],
|
83 |
+
resolutions=self.resolutions,
|
84 |
+
align_corners=True,
|
85 |
+
balance_value=0.50,
|
86 |
+
device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"),
|
87 |
+
visualize=False,
|
88 |
+
debug=False,
|
89 |
+
use_cuda_impl=False,
|
90 |
+
faster=True,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.render = Render(
|
94 |
+
size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"))
|
95 |
+
self.smpl_data = SMPLX()
|
96 |
+
|
97 |
+
self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
|
98 |
+
self.smpl_data.model_dir,
|
99 |
+
kid_template_path=osp.join(
|
100 |
+
osp.realpath(self.smpl_data.model_dir),
|
101 |
+
f"{smpl_type}/{smpl_type}_kid_template.npy",
|
102 |
+
),
|
103 |
+
model_type=smpl_type,
|
104 |
+
gender=gender,
|
105 |
+
age=age,
|
106 |
+
v_template=v_template,
|
107 |
+
use_face_contour=False,
|
108 |
+
ext="pkl",
|
109 |
+
)
|
110 |
+
|
111 |
+
self.in_geo = [item[0] for item in cfg.net.in_geo]
|
112 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
113 |
+
self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
|
114 |
+
self.in_total = self.in_geo + self.in_nml
|
115 |
+
self.smpl_dim = cfg.net.smpl_dim
|
116 |
+
|
117 |
+
self.export_dir = None
|
118 |
+
self.result_eval = {}
|
119 |
+
|
120 |
+
def get_progress_bar_dict(self):
|
121 |
+
tqdm_dict = super().get_progress_bar_dict()
|
122 |
+
if "v_num" in tqdm_dict:
|
123 |
+
del tqdm_dict["v_num"]
|
124 |
+
return tqdm_dict
|
125 |
+
|
126 |
+
# Training related
|
127 |
+
def configure_optimizers(self):
|
128 |
+
|
129 |
+
# set optimizer
|
130 |
+
weight_decay = self.cfg.weight_decay
|
131 |
+
momentum = self.cfg.momentum
|
132 |
+
|
133 |
+
optim_params_G = [{
|
134 |
+
"params": self.netG.if_regressor.parameters(),
|
135 |
+
"lr": self.lr_G
|
136 |
+
}]
|
137 |
+
|
138 |
+
if self.cfg.net.use_filter:
|
139 |
+
optim_params_G.append({
|
140 |
+
"params": self.netG.F_filter.parameters(),
|
141 |
+
"lr": self.lr_G
|
142 |
+
})
|
143 |
+
|
144 |
+
if self.cfg.net.prior_type == "pamir":
|
145 |
+
optim_params_G.append({
|
146 |
+
"params": self.netG.ve.parameters(),
|
147 |
+
"lr": self.lr_G
|
148 |
+
})
|
149 |
+
|
150 |
+
if self.cfg.optim == "Adadelta":
|
151 |
+
|
152 |
+
optimizer_G = torch.optim.Adadelta(optim_params_G,
|
153 |
+
lr=self.lr_G,
|
154 |
+
weight_decay=weight_decay)
|
155 |
+
|
156 |
+
elif self.cfg.optim == "Adam":
|
157 |
+
|
158 |
+
optimizer_G = torch.optim.Adam(optim_params_G,
|
159 |
+
lr=self.lr_G,
|
160 |
+
weight_decay=weight_decay)
|
161 |
+
|
162 |
+
elif self.cfg.optim == "RMSprop":
|
163 |
+
|
164 |
+
optimizer_G = torch.optim.RMSprop(
|
165 |
+
optim_params_G,
|
166 |
+
lr=self.lr_G,
|
167 |
+
weight_decay=weight_decay,
|
168 |
+
momentum=momentum,
|
169 |
+
)
|
170 |
+
|
171 |
+
else:
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
# set scheduler
|
175 |
+
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
|
176 |
+
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
|
177 |
+
|
178 |
+
return [optimizer_G], [scheduler_G]
|
179 |
+
|
180 |
+
def training_step(self, batch, batch_idx):
|
181 |
+
|
182 |
+
if not self.cfg.fast_dev:
|
183 |
+
export_cfg(self.logger, self.cfg)
|
184 |
+
|
185 |
+
self.netG.train()
|
186 |
+
|
187 |
+
in_tensor_dict = {
|
188 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
189 |
+
"calib": batch["calib"],
|
190 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
191 |
+
}
|
192 |
+
|
193 |
+
for name in self.in_total:
|
194 |
+
in_tensor_dict.update({name: batch[name]})
|
195 |
+
|
196 |
+
in_tensor_dict.update({
|
197 |
+
k: batch[k] if k in batch.keys() else None
|
198 |
+
for k in getattr(self, f"{self.prior_type}_keys")
|
199 |
+
})
|
200 |
+
|
201 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
202 |
+
|
203 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
204 |
+
preds_G.flatten(),
|
205 |
+
in_tensor_dict["label"].flatten(),
|
206 |
+
0.5,
|
207 |
+
use_sdf=self.cfg.sdf,
|
208 |
+
)
|
209 |
+
|
210 |
+
# metrics processing
|
211 |
+
metrics_log = {
|
212 |
+
"train_loss": error_G.item(),
|
213 |
+
"train_acc": acc.item(),
|
214 |
+
"train_iou": iou.item(),
|
215 |
+
"train_prec": prec.item(),
|
216 |
+
"train_recall": recall.item(),
|
217 |
+
}
|
218 |
+
|
219 |
+
tf_log = tf_log_convert(metrics_log)
|
220 |
+
bar_log = bar_log_convert(metrics_log)
|
221 |
+
|
222 |
+
if batch_idx % int(self.cfg.freq_show_train) == 0:
|
223 |
+
|
224 |
+
with torch.no_grad():
|
225 |
+
self.render_func(in_tensor_dict, dataset="train")
|
226 |
+
|
227 |
+
metrics_return = {
|
228 |
+
k.replace("train_", ""): torch.tensor(v)
|
229 |
+
for k, v in metrics_log.items()
|
230 |
+
}
|
231 |
+
|
232 |
+
metrics_return.update({
|
233 |
+
"loss": error_G,
|
234 |
+
"log": tf_log,
|
235 |
+
"progress_bar": bar_log
|
236 |
+
})
|
237 |
+
|
238 |
+
return metrics_return
|
239 |
+
|
240 |
+
def training_epoch_end(self, outputs):
|
241 |
+
|
242 |
+
if [] in outputs:
|
243 |
+
outputs = outputs[0]
|
244 |
+
|
245 |
+
# metrics processing
|
246 |
+
metrics_log = {
|
247 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
248 |
+
"train_avgiou": batch_mean(outputs, "iou"),
|
249 |
+
"train_avgprec": batch_mean(outputs, "prec"),
|
250 |
+
"train_avgrecall": batch_mean(outputs, "recall"),
|
251 |
+
"train_avgacc": batch_mean(outputs, "acc"),
|
252 |
+
}
|
253 |
+
|
254 |
+
tf_log = tf_log_convert(metrics_log)
|
255 |
+
|
256 |
+
return {"log": tf_log}
|
257 |
+
|
258 |
+
def validation_step(self, batch, batch_idx):
|
259 |
+
|
260 |
+
self.netG.eval()
|
261 |
+
self.netG.training = False
|
262 |
+
|
263 |
+
in_tensor_dict = {
|
264 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
265 |
+
"calib": batch["calib"],
|
266 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
267 |
+
}
|
268 |
+
|
269 |
+
for name in self.in_total:
|
270 |
+
in_tensor_dict.update({name: batch[name]})
|
271 |
+
|
272 |
+
in_tensor_dict.update({
|
273 |
+
k: batch[k] if k in batch.keys() else None
|
274 |
+
for k in getattr(self, f"{self.prior_type}_keys")
|
275 |
+
})
|
276 |
+
|
277 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
278 |
+
|
279 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
280 |
+
preds_G.flatten(),
|
281 |
+
in_tensor_dict["label"].flatten(),
|
282 |
+
0.5,
|
283 |
+
use_sdf=self.cfg.sdf,
|
284 |
+
)
|
285 |
+
|
286 |
+
if batch_idx % int(self.cfg.freq_show_val) == 0:
|
287 |
+
with torch.no_grad():
|
288 |
+
self.render_func(in_tensor_dict, dataset="val", idx=batch_idx)
|
289 |
+
|
290 |
+
metrics_return = {
|
291 |
+
"val_loss": error_G,
|
292 |
+
"val_acc": acc,
|
293 |
+
"val_iou": iou,
|
294 |
+
"val_prec": prec,
|
295 |
+
"val_recall": recall,
|
296 |
+
}
|
297 |
+
|
298 |
+
return metrics_return
|
299 |
+
|
300 |
+
def validation_epoch_end(self, outputs):
|
301 |
+
|
302 |
+
# metrics processing
|
303 |
+
metrics_log = {
|
304 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
305 |
+
"val_avgacc": batch_mean(outputs, "val_acc"),
|
306 |
+
"val_avgiou": batch_mean(outputs, "val_iou"),
|
307 |
+
"val_avgprec": batch_mean(outputs, "val_prec"),
|
308 |
+
"val_avgrecall": batch_mean(outputs, "val_recall"),
|
309 |
+
}
|
310 |
+
|
311 |
+
tf_log = tf_log_convert(metrics_log)
|
312 |
+
|
313 |
+
return {"log": tf_log}
|
314 |
+
|
315 |
+
def compute_vis_cmap(self, smpl_type, smpl_verts, smpl_faces):
|
316 |
+
|
317 |
+
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
|
318 |
+
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
|
319 |
+
smpl_cmap = self.smpl_data.cmap_smpl_vids(smpl_type)
|
320 |
+
|
321 |
+
return {
|
322 |
+
"smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
|
323 |
+
"smpl_cmap": smpl_cmap.unsqueeze(0).to(self.device),
|
324 |
+
"smpl_verts": smpl_verts.unsqueeze(0),
|
325 |
+
}
|
326 |
+
|
327 |
+
@torch.enable_grad()
|
328 |
+
def optim_body(self, in_tensor_dict, batch):
|
329 |
+
|
330 |
+
smpl_model = self.get_smpl_model(batch["type"][0], batch["gender"][0],
|
331 |
+
batch["age"][0], None).to(self.device)
|
332 |
+
in_tensor_dict["smpl_faces"] = (torch.tensor(
|
333 |
+
smpl_model.faces.astype(np.int)).long().unsqueeze(0).to(
|
334 |
+
self.device))
|
335 |
+
|
336 |
+
# The optimizer and variables
|
337 |
+
optimed_pose = torch.tensor(batch["body_pose"][0],
|
338 |
+
device=self.device,
|
339 |
+
requires_grad=True) # [1,23,3,3]
|
340 |
+
optimed_trans = torch.tensor(batch["transl"][0],
|
341 |
+
device=self.device,
|
342 |
+
requires_grad=True) # [3]
|
343 |
+
optimed_betas = torch.tensor(batch["betas"][0],
|
344 |
+
device=self.device,
|
345 |
+
requires_grad=True) # [1,10]
|
346 |
+
optimed_orient = torch.tensor(batch["global_orient"][0],
|
347 |
+
device=self.device,
|
348 |
+
requires_grad=True) # [1,1,3,3]
|
349 |
+
|
350 |
+
optimizer_smpl = torch.optim.SGD(
|
351 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
352 |
+
lr=1e-3,
|
353 |
+
momentum=0.9,
|
354 |
+
)
|
355 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
356 |
+
optimizer_smpl,
|
357 |
+
mode="min",
|
358 |
+
factor=0.5,
|
359 |
+
verbose=0,
|
360 |
+
min_lr=1e-5,
|
361 |
+
patience=5)
|
362 |
+
loop_smpl = range(50)
|
363 |
+
for i in loop_smpl:
|
364 |
+
|
365 |
+
optimizer_smpl.zero_grad()
|
366 |
+
|
367 |
+
# prior_loss, optimed_pose = dataset.vposer_prior(optimed_pose)
|
368 |
+
smpl_out = smpl_model(
|
369 |
+
betas=optimed_betas,
|
370 |
+
body_pose=optimed_pose,
|
371 |
+
global_orient=optimed_orient,
|
372 |
+
transl=optimed_trans,
|
373 |
+
return_verts=True,
|
374 |
+
)
|
375 |
+
|
376 |
+
smpl_verts = smpl_out.vertices[0] * 100.0
|
377 |
+
smpl_verts = projection(smpl_verts,
|
378 |
+
batch["calib"][0],
|
379 |
+
format="tensor")
|
380 |
+
smpl_verts[:, 1] *= -1
|
381 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
382 |
+
self.render.load_meshes(smpl_verts, in_tensor_dict["smpl_faces"])
|
383 |
+
(
|
384 |
+
in_tensor_dict["T_normal_F"],
|
385 |
+
in_tensor_dict["T_normal_B"],
|
386 |
+
) = self.render.get_rgb_image()
|
387 |
+
|
388 |
+
T_mask_F, T_mask_B = self.render.get_silhouette_image()
|
389 |
+
|
390 |
+
with torch.no_grad():
|
391 |
+
(
|
392 |
+
in_tensor_dict["normal_F"],
|
393 |
+
in_tensor_dict["normal_B"],
|
394 |
+
) = self.netG.normal_filter(in_tensor_dict)
|
395 |
+
|
396 |
+
# mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
|
397 |
+
diff_F_smpl = torch.abs(in_tensor_dict["T_normal_F"] -
|
398 |
+
in_tensor_dict["normal_F"])
|
399 |
+
diff_B_smpl = torch.abs(in_tensor_dict["T_normal_B"] -
|
400 |
+
in_tensor_dict["normal_B"])
|
401 |
+
loss = (diff_F_smpl + diff_B_smpl).mean()
|
402 |
+
|
403 |
+
# silhouette loss
|
404 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
405 |
+
gt_arr = torch.cat(
|
406 |
+
[in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]],
|
407 |
+
dim=2).permute(1, 2, 0)
|
408 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
|
409 |
+
bg_color = (torch.Tensor(
|
410 |
+
[0.5, 0.5, 0.5]).unsqueeze(0).unsqueeze(0).to(self.device))
|
411 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
412 |
+
loss += torch.abs(smpl_arr - gt_arr).mean()
|
413 |
+
|
414 |
+
# Image.fromarray(((in_tensor_dict['T_normal_F'][0].permute(1,2,0)+1.0)*0.5*255.0).detach().cpu().numpy().astype(np.uint8)).show()
|
415 |
+
|
416 |
+
# loop_smpl.set_description(f"smpl = {loss:.3f}")
|
417 |
+
|
418 |
+
loss.backward(retain_graph=True)
|
419 |
+
optimizer_smpl.step()
|
420 |
+
scheduler_smpl.step(loss)
|
421 |
+
in_tensor_dict["smpl_verts"] = smpl_verts.unsqueeze(0)
|
422 |
+
|
423 |
+
in_tensor_dict.update(
|
424 |
+
self.compute_vis_cmap(
|
425 |
+
batch["type"][0],
|
426 |
+
in_tensor_dict["smpl_verts"][0],
|
427 |
+
in_tensor_dict["smpl_faces"][0],
|
428 |
+
))
|
429 |
+
|
430 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
431 |
+
|
432 |
+
return features, inter, in_tensor_dict
|
433 |
+
|
434 |
+
@torch.enable_grad()
|
435 |
+
def optim_cloth(self, verts_pr, faces_pr, inter):
|
436 |
+
|
437 |
+
# convert from GT to SDF
|
438 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
439 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
440 |
+
|
441 |
+
losses = {
|
442 |
+
"cloth": {
|
443 |
+
"weight": 5.0,
|
444 |
+
"value": 0.0
|
445 |
+
},
|
446 |
+
"edge": {
|
447 |
+
"weight": 100.0,
|
448 |
+
"value": 0.0
|
449 |
+
},
|
450 |
+
"normal": {
|
451 |
+
"weight": 0.2,
|
452 |
+
"value": 0.0
|
453 |
+
},
|
454 |
+
"laplacian": {
|
455 |
+
"weight": 100.0,
|
456 |
+
"value": 0.0
|
457 |
+
},
|
458 |
+
"smpl": {
|
459 |
+
"weight": 1.0,
|
460 |
+
"value": 0.0
|
461 |
+
},
|
462 |
+
"deform": {
|
463 |
+
"weight": 20.0,
|
464 |
+
"value": 0.0
|
465 |
+
},
|
466 |
+
}
|
467 |
+
|
468 |
+
deform_verts = torch.full(verts_pr.shape,
|
469 |
+
0.0,
|
470 |
+
device=self.device,
|
471 |
+
requires_grad=True)
|
472 |
+
optimizer_cloth = torch.optim.SGD([deform_verts],
|
473 |
+
lr=1e-1,
|
474 |
+
momentum=0.9)
|
475 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
476 |
+
optimizer_cloth,
|
477 |
+
mode="min",
|
478 |
+
factor=0.1,
|
479 |
+
verbose=0,
|
480 |
+
min_lr=1e-3,
|
481 |
+
patience=5)
|
482 |
+
# cloth optimization
|
483 |
+
loop_cloth = range(100)
|
484 |
+
|
485 |
+
for i in loop_cloth:
|
486 |
+
|
487 |
+
optimizer_cloth.zero_grad()
|
488 |
+
|
489 |
+
self.render.load_meshes(
|
490 |
+
verts_pr.unsqueeze(0).to(self.device),
|
491 |
+
faces_pr.unsqueeze(0).to(self.device).long(),
|
492 |
+
deform_verts,
|
493 |
+
)
|
494 |
+
P_normal_F, P_normal_B = self.render.get_rgb_image()
|
495 |
+
|
496 |
+
update_mesh_shape_prior_losses(self.render.mesh, losses)
|
497 |
+
diff_F_cloth = torch.abs(P_normal_F[0] - inter[:3])
|
498 |
+
diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
|
499 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
500 |
+
losses["deform"]["value"] = torch.topk(
|
501 |
+
torch.abs(deform_verts.flatten()), 30)[0].mean()
|
502 |
+
|
503 |
+
# Weighted sum of the losses
|
504 |
+
cloth_loss = torch.tensor(0.0, device=self.device)
|
505 |
+
pbar_desc = ""
|
506 |
+
|
507 |
+
for k in losses.keys():
|
508 |
+
if k != "smpl":
|
509 |
+
cloth_loss_per_cls = losses[k]["value"] * \
|
510 |
+
losses[k]["weight"]
|
511 |
+
pbar_desc += f"{k}: {cloth_loss_per_cls:.3f} | "
|
512 |
+
cloth_loss += cloth_loss_per_cls
|
513 |
+
|
514 |
+
# loop_cloth.set_description(pbar_desc)
|
515 |
+
cloth_loss.backward(retain_graph=True)
|
516 |
+
optimizer_cloth.step()
|
517 |
+
scheduler_cloth.step(cloth_loss)
|
518 |
+
|
519 |
+
# convert from GT to SDF
|
520 |
+
deform_verts = deform_verts.flatten().detach()
|
521 |
+
deform_verts[torch.topk(torch.abs(deform_verts),
|
522 |
+
30)[1]] = deform_verts.mean()
|
523 |
+
deform_verts = deform_verts.view(-1, 3).cpu()
|
524 |
+
|
525 |
+
verts_pr += deform_verts
|
526 |
+
verts_pr *= (self.resolutions[-1] - 1) / 2.0
|
527 |
+
verts_pr += (self.resolutions[-1] - 1) / 2.0
|
528 |
+
|
529 |
+
return verts_pr
|
530 |
+
|
531 |
+
def test_step(self, batch, batch_idx):
|
532 |
+
|
533 |
+
self.netG.eval()
|
534 |
+
self.netG.training = False
|
535 |
+
in_tensor_dict = {}
|
536 |
+
|
537 |
+
# export paths
|
538 |
+
mesh_name = batch["subject"][0]
|
539 |
+
mesh_rot = batch["rotation"][0].item()
|
540 |
+
|
541 |
+
self.export_dir = osp.join(self.cfg.results_path, self.cfg.name,
|
542 |
+
"-".join(self.cfg.dataset.types), mesh_name)
|
543 |
+
|
544 |
+
os.makedirs(self.export_dir, exist_ok=True)
|
545 |
+
|
546 |
+
for name in self.in_total:
|
547 |
+
if name in batch.keys():
|
548 |
+
in_tensor_dict.update({name: batch[name]})
|
549 |
+
|
550 |
+
in_tensor_dict.update({
|
551 |
+
k: batch[k] if k in batch.keys() else None
|
552 |
+
for k in getattr(self, f"{self.prior_type}_keys")
|
553 |
+
})
|
554 |
+
|
555 |
+
if "T_normal_F" not in in_tensor_dict.keys(
|
556 |
+
) or "T_normal_B" not in in_tensor_dict.keys():
|
557 |
+
|
558 |
+
# update the new T_normal_F/B
|
559 |
+
self.render.load_meshes(
|
560 |
+
batch["smpl_verts"] *
|
561 |
+
torch.tensor([1.0, -1.0, 1.0]).to(self.device),
|
562 |
+
batch["smpl_faces"])
|
563 |
+
T_normal_F, T_noraml_B = self.render.get_rgb_image()
|
564 |
+
in_tensor_dict.update({
|
565 |
+
'T_normal_F': T_normal_F,
|
566 |
+
'T_normal_B': T_noraml_B
|
567 |
+
})
|
568 |
+
|
569 |
+
with torch.no_grad():
|
570 |
+
features, inter = self.netG.filter(in_tensor_dict,
|
571 |
+
return_inter=True)
|
572 |
+
sdf = self.reconEngine(opt=self.cfg,
|
573 |
+
netG=self.netG,
|
574 |
+
features=features,
|
575 |
+
proj_matrix=None)
|
576 |
+
|
577 |
+
def tensor2arr(x):
|
578 |
+
return (x[0].permute(1, 2, 0).detach().cpu().numpy() +
|
579 |
+
1.0) * 0.5 * 255.0
|
580 |
+
|
581 |
+
# save inter results
|
582 |
+
image = tensor2arr(in_tensor_dict["image"])
|
583 |
+
smpl_F = tensor2arr(in_tensor_dict["T_normal_F"])
|
584 |
+
smpl_B = tensor2arr(in_tensor_dict["T_normal_B"])
|
585 |
+
image_inter = np.concatenate(self.tensor2image(512, inter[0]) +
|
586 |
+
[smpl_F, smpl_B, image],
|
587 |
+
axis=1)
|
588 |
+
Image.fromarray((image_inter).astype(np.uint8)).save(
|
589 |
+
osp.join(self.export_dir, f"{mesh_rot}_inter.png"))
|
590 |
+
|
591 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
592 |
+
|
593 |
+
if self.clean_mesh_flag:
|
594 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
595 |
+
|
596 |
+
verts_gt = batch["verts"][0]
|
597 |
+
faces_gt = batch["faces"][0]
|
598 |
+
|
599 |
+
self.result_eval.update({
|
600 |
+
"verts_gt": verts_gt,
|
601 |
+
"faces_gt": faces_gt,
|
602 |
+
"verts_pr": verts_pr,
|
603 |
+
"faces_pr": faces_pr,
|
604 |
+
"recon_size": (self.resolutions[-1] - 1.0),
|
605 |
+
"calib": batch["calib"][0],
|
606 |
+
})
|
607 |
+
|
608 |
+
self.evaluator.set_mesh(self.result_eval)
|
609 |
+
chamfer, p2s = self.evaluator.calculate_chamfer_p2s(num_samples=1000)
|
610 |
+
normal_consist = self.evaluator.calculate_normal_consist(
|
611 |
+
osp.join(self.export_dir, f"{mesh_rot}_nc.png"))
|
612 |
+
|
613 |
+
test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
|
614 |
+
|
615 |
+
return test_log
|
616 |
+
|
617 |
+
def test_epoch_end(self, outputs):
|
618 |
+
|
619 |
+
# make_test_gif("/".join(self.export_dir.split("/")[:-2]))
|
620 |
+
|
621 |
+
accu_outputs = accumulate(
|
622 |
+
outputs,
|
623 |
+
rot_num=3,
|
624 |
+
split={
|
625 |
+
"cape-easy": (0, 50),
|
626 |
+
"cape-hard": (50, 100)
|
627 |
+
},
|
628 |
+
)
|
629 |
+
|
630 |
+
print(colored(self.cfg.name, "green"))
|
631 |
+
print(colored(self.cfg.dataset.noise_scale, "green"))
|
632 |
+
|
633 |
+
self.logger.experiment.add_hparams(
|
634 |
+
hparam_dict={
|
635 |
+
"lr_G": self.lr_G,
|
636 |
+
"bsize": self.batch_size
|
637 |
+
},
|
638 |
+
metric_dict=accu_outputs,
|
639 |
+
)
|
640 |
+
|
641 |
+
np.save(
|
642 |
+
osp.join(self.export_dir, "../test_results.npy"),
|
643 |
+
accu_outputs,
|
644 |
+
allow_pickle=True,
|
645 |
+
)
|
646 |
+
|
647 |
+
return accu_outputs
|
648 |
+
|
649 |
+
def tensor2image(self, height, inter):
|
650 |
+
|
651 |
+
all = []
|
652 |
+
for dim in self.in_geo_dim:
|
653 |
+
img = resize(
|
654 |
+
np.tile(
|
655 |
+
((inter[:dim].cpu().numpy() + 1.0) / 2.0 *
|
656 |
+
255.0).transpose(1, 2, 0),
|
657 |
+
(1, 1, int(3 / dim)),
|
658 |
+
),
|
659 |
+
(height, height),
|
660 |
+
anti_aliasing=True,
|
661 |
+
)
|
662 |
+
|
663 |
+
all.append(img)
|
664 |
+
inter = inter[dim:]
|
665 |
+
|
666 |
+
return all
|
667 |
+
|
668 |
+
def render_func(self, in_tensor_dict, dataset="title", idx=0):
|
669 |
+
|
670 |
+
for name in in_tensor_dict.keys():
|
671 |
+
if in_tensor_dict[name] is not None:
|
672 |
+
in_tensor_dict[name] = in_tensor_dict[name][0:1]
|
673 |
+
|
674 |
+
self.netG.eval()
|
675 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
676 |
+
sdf = self.reconEngine(opt=self.cfg,
|
677 |
+
netG=self.netG,
|
678 |
+
features=features,
|
679 |
+
proj_matrix=None)
|
680 |
+
|
681 |
+
if sdf is not None:
|
682 |
+
render = self.reconEngine.display(sdf)
|
683 |
+
|
684 |
+
image_pred = np.flip(render[:, :, ::-1], axis=0)
|
685 |
+
height = image_pred.shape[0]
|
686 |
+
|
687 |
+
image_gt = resize(
|
688 |
+
((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0 *
|
689 |
+
255.0).transpose(1, 2, 0),
|
690 |
+
(height, height),
|
691 |
+
anti_aliasing=True,
|
692 |
+
)
|
693 |
+
image_inter = self.tensor2image(height, inter[0])
|
694 |
+
image = np.concatenate([image_pred, image_gt] + image_inter,
|
695 |
+
axis=1)
|
696 |
+
|
697 |
+
step_id = self.global_step if dataset == "train" else self.global_step + idx
|
698 |
+
self.logger.experiment.add_image(
|
699 |
+
tag=f"Occupancy-{dataset}/{step_id}",
|
700 |
+
img_tensor=image.transpose(2, 0, 1),
|
701 |
+
global_step=step_id,
|
702 |
+
)
|
703 |
+
|
704 |
+
def test_single(self, batch):
|
705 |
+
|
706 |
+
self.netG.eval()
|
707 |
+
self.netG.training = False
|
708 |
+
in_tensor_dict = {}
|
709 |
+
|
710 |
+
for name in self.in_total:
|
711 |
+
if name in batch.keys():
|
712 |
+
in_tensor_dict.update({name: batch[name]})
|
713 |
+
|
714 |
+
in_tensor_dict.update({
|
715 |
+
k: batch[k] if k in batch.keys() else None
|
716 |
+
for k in getattr(self, f"{self.prior_type}_keys")
|
717 |
+
})
|
718 |
+
|
719 |
+
with torch.no_grad():
|
720 |
+
features, inter = self.netG.filter(in_tensor_dict,
|
721 |
+
return_inter=True)
|
722 |
+
sdf = self.reconEngine(opt=self.cfg,
|
723 |
+
netG=self.netG,
|
724 |
+
features=features,
|
725 |
+
proj_matrix=None)
|
726 |
+
|
727 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
728 |
+
|
729 |
+
if self.clean_mesh_flag:
|
730 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
731 |
+
|
732 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
733 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
734 |
+
|
735 |
+
return verts_pr, faces_pr, inter
|
apps/Normal.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.net import NormalNet
|
2 |
+
from lib.common.train_util import *
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torch import nn
|
7 |
+
from skimage.transform import resize
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
|
10 |
+
torch.backends.cudnn.benchmark = True
|
11 |
+
|
12 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
|
18 |
+
class Normal(pl.LightningModule):
|
19 |
+
|
20 |
+
def __init__(self, cfg):
|
21 |
+
super(Normal, self).__init__()
|
22 |
+
self.cfg = cfg
|
23 |
+
self.batch_size = self.cfg.batch_size
|
24 |
+
self.lr_N = self.cfg.lr_N
|
25 |
+
|
26 |
+
self.schedulers = []
|
27 |
+
|
28 |
+
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
|
29 |
+
|
30 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
31 |
+
|
32 |
+
def get_progress_bar_dict(self):
|
33 |
+
tqdm_dict = super().get_progress_bar_dict()
|
34 |
+
if "v_num" in tqdm_dict:
|
35 |
+
del tqdm_dict["v_num"]
|
36 |
+
return tqdm_dict
|
37 |
+
|
38 |
+
# Training related
|
39 |
+
def configure_optimizers(self):
|
40 |
+
|
41 |
+
# set optimizer
|
42 |
+
weight_decay = self.cfg.weight_decay
|
43 |
+
momentum = self.cfg.momentum
|
44 |
+
|
45 |
+
optim_params_N_F = [{
|
46 |
+
"params": self.netG.netF.parameters(),
|
47 |
+
"lr": self.lr_N
|
48 |
+
}]
|
49 |
+
optim_params_N_B = [{
|
50 |
+
"params": self.netG.netB.parameters(),
|
51 |
+
"lr": self.lr_N
|
52 |
+
}]
|
53 |
+
|
54 |
+
optimizer_N_F = torch.optim.Adam(optim_params_N_F,
|
55 |
+
lr=self.lr_N,
|
56 |
+
weight_decay=weight_decay)
|
57 |
+
|
58 |
+
optimizer_N_B = torch.optim.Adam(optim_params_N_B,
|
59 |
+
lr=self.lr_N,
|
60 |
+
weight_decay=weight_decay)
|
61 |
+
|
62 |
+
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
63 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
|
64 |
+
|
65 |
+
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
66 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
|
67 |
+
|
68 |
+
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
69 |
+
optims = [optimizer_N_F, optimizer_N_B]
|
70 |
+
|
71 |
+
return optims, self.schedulers
|
72 |
+
|
73 |
+
def render_func(self, render_tensor):
|
74 |
+
|
75 |
+
height = render_tensor["image"].shape[2]
|
76 |
+
result_list = []
|
77 |
+
|
78 |
+
for name in render_tensor.keys():
|
79 |
+
result_list.append(
|
80 |
+
resize(
|
81 |
+
((render_tensor[name].cpu().numpy()[0] + 1.0) /
|
82 |
+
2.0).transpose(1, 2, 0),
|
83 |
+
(height, height),
|
84 |
+
anti_aliasing=True,
|
85 |
+
))
|
86 |
+
result_array = np.concatenate(result_list, axis=1)
|
87 |
+
|
88 |
+
return result_array
|
89 |
+
|
90 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
91 |
+
|
92 |
+
export_cfg(self.logger, self.cfg)
|
93 |
+
|
94 |
+
# retrieve the data
|
95 |
+
in_tensor = {}
|
96 |
+
for name in self.in_nml:
|
97 |
+
in_tensor[name] = batch[name]
|
98 |
+
|
99 |
+
FB_tensor = {
|
100 |
+
"normal_F": batch["normal_F"],
|
101 |
+
"normal_B": batch["normal_B"]
|
102 |
+
}
|
103 |
+
|
104 |
+
self.netG.train()
|
105 |
+
|
106 |
+
preds_F, preds_B = self.netG(in_tensor)
|
107 |
+
error_NF, error_NB = self.netG.get_norm_error(preds_F, preds_B,
|
108 |
+
FB_tensor)
|
109 |
+
|
110 |
+
(opt_nf, opt_nb) = self.optimizers()
|
111 |
+
|
112 |
+
opt_nf.zero_grad()
|
113 |
+
opt_nb.zero_grad()
|
114 |
+
|
115 |
+
self.manual_backward(error_NF, opt_nf)
|
116 |
+
self.manual_backward(error_NB, opt_nb)
|
117 |
+
|
118 |
+
opt_nf.step()
|
119 |
+
opt_nb.step()
|
120 |
+
|
121 |
+
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
|
122 |
+
|
123 |
+
self.netG.eval()
|
124 |
+
with torch.no_grad():
|
125 |
+
nmlF, nmlB = self.netG(in_tensor)
|
126 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
127 |
+
result_array = self.render_func(in_tensor)
|
128 |
+
|
129 |
+
self.logger.experiment.add_image(
|
130 |
+
tag=f"Normal-train/{self.global_step}",
|
131 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
132 |
+
global_step=self.global_step,
|
133 |
+
)
|
134 |
+
|
135 |
+
# metrics processing
|
136 |
+
metrics_log = {
|
137 |
+
"train_loss-NF": error_NF.item(),
|
138 |
+
"train_loss-NB": error_NB.item(),
|
139 |
+
}
|
140 |
+
|
141 |
+
tf_log = tf_log_convert(metrics_log)
|
142 |
+
bar_log = bar_log_convert(metrics_log)
|
143 |
+
|
144 |
+
return {
|
145 |
+
"loss": error_NF + error_NB,
|
146 |
+
"loss-NF": error_NF,
|
147 |
+
"loss-NB": error_NB,
|
148 |
+
"log": tf_log,
|
149 |
+
"progress_bar": bar_log,
|
150 |
+
}
|
151 |
+
|
152 |
+
def training_epoch_end(self, outputs):
|
153 |
+
|
154 |
+
if [] in outputs:
|
155 |
+
outputs = outputs[0]
|
156 |
+
|
157 |
+
# metrics processing
|
158 |
+
metrics_log = {
|
159 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
160 |
+
"train_avgloss-NF": batch_mean(outputs, "loss-NF"),
|
161 |
+
"train_avgloss-NB": batch_mean(outputs, "loss-NB"),
|
162 |
+
}
|
163 |
+
|
164 |
+
tf_log = tf_log_convert(metrics_log)
|
165 |
+
|
166 |
+
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
|
167 |
+
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
|
168 |
+
|
169 |
+
return {"log": tf_log}
|
170 |
+
|
171 |
+
def validation_step(self, batch, batch_idx):
|
172 |
+
|
173 |
+
# retrieve the data
|
174 |
+
in_tensor = {}
|
175 |
+
for name in self.in_nml:
|
176 |
+
in_tensor[name] = batch[name]
|
177 |
+
|
178 |
+
FB_tensor = {
|
179 |
+
"normal_F": batch["normal_F"],
|
180 |
+
"normal_B": batch["normal_B"]
|
181 |
+
}
|
182 |
+
|
183 |
+
self.netG.train()
|
184 |
+
|
185 |
+
preds_F, preds_B = self.netG(in_tensor)
|
186 |
+
error_NF, error_NB = self.netG.get_norm_error(preds_F, preds_B,
|
187 |
+
FB_tensor)
|
188 |
+
|
189 |
+
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train)
|
190 |
+
== 0) or (batch_idx == 0):
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
nmlF, nmlB = self.netG(in_tensor)
|
194 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
195 |
+
result_array = self.render_func(in_tensor)
|
196 |
+
|
197 |
+
self.logger.experiment.add_image(
|
198 |
+
tag=f"Normal-val/{self.global_step}",
|
199 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
200 |
+
global_step=self.global_step,
|
201 |
+
)
|
202 |
+
|
203 |
+
return {
|
204 |
+
"val_loss": error_NF + error_NB,
|
205 |
+
"val_loss-NF": error_NF,
|
206 |
+
"val_loss-NB": error_NB,
|
207 |
+
}
|
208 |
+
|
209 |
+
def validation_epoch_end(self, outputs):
|
210 |
+
|
211 |
+
# metrics processing
|
212 |
+
metrics_log = {
|
213 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
214 |
+
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
|
215 |
+
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
|
216 |
+
}
|
217 |
+
|
218 |
+
tf_log = tf_log_convert(metrics_log)
|
219 |
+
|
220 |
+
return {"log": tf_log}
|
apps/infer.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import os
|
18 |
+
import gc
|
19 |
+
|
20 |
+
import logging
|
21 |
+
from lib.common.config import cfg
|
22 |
+
from lib.dataset.mesh_util import (
|
23 |
+
load_checkpoint,
|
24 |
+
update_mesh_shape_prior_losses,
|
25 |
+
blend_rgb_norm,
|
26 |
+
unwrap,
|
27 |
+
remesh,
|
28 |
+
tensor2variable,
|
29 |
+
rot6d_to_rotmat
|
30 |
+
)
|
31 |
+
|
32 |
+
from lib.dataset.TestDataset import TestDataset
|
33 |
+
from lib.common.render import query_color
|
34 |
+
from lib.net.local_affine import LocalAffine
|
35 |
+
from pytorch3d.structures import Meshes
|
36 |
+
from apps.ICON import ICON
|
37 |
+
|
38 |
+
from termcolor import colored
|
39 |
+
import numpy as np
|
40 |
+
from PIL import Image
|
41 |
+
import trimesh
|
42 |
+
import numpy as np
|
43 |
+
from tqdm import tqdm
|
44 |
+
|
45 |
+
import torch
|
46 |
+
torch.backends.cudnn.benchmark = True
|
47 |
+
|
48 |
+
logging.getLogger("trimesh").setLevel(logging.ERROR)
|
49 |
+
|
50 |
+
|
51 |
+
def generate_model(in_path, model_type):
|
52 |
+
|
53 |
+
torch.cuda.empty_cache()
|
54 |
+
|
55 |
+
if model_type == 'ICON':
|
56 |
+
model_type = 'icon-filter'
|
57 |
+
else:
|
58 |
+
model_type = model_type.lower()
|
59 |
+
|
60 |
+
config_dict = {'loop_smpl': 100,
|
61 |
+
'loop_cloth': 200,
|
62 |
+
'patience': 5,
|
63 |
+
'out_dir': './results',
|
64 |
+
'hps_type': 'pymaf',
|
65 |
+
'config': f"./configs/{model_type}.yaml"}
|
66 |
+
|
67 |
+
# cfg read and merge
|
68 |
+
cfg.merge_from_file(config_dict['config'])
|
69 |
+
cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")
|
70 |
+
|
71 |
+
os.makedirs(config_dict['out_dir'], exist_ok=True)
|
72 |
+
|
73 |
+
cfg_show_list = [
|
74 |
+
"test_gpus",
|
75 |
+
[0],
|
76 |
+
"mcube_res",
|
77 |
+
256,
|
78 |
+
"clean_mesh",
|
79 |
+
True,
|
80 |
+
]
|
81 |
+
|
82 |
+
cfg.merge_from_list(cfg_show_list)
|
83 |
+
cfg.freeze()
|
84 |
+
|
85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
86 |
+
device = torch.device(f"cuda:0")
|
87 |
+
|
88 |
+
# load model and dataloader
|
89 |
+
model = ICON(cfg)
|
90 |
+
model = load_checkpoint(model, cfg)
|
91 |
+
|
92 |
+
dataset_param = {
|
93 |
+
'image_path': in_path,
|
94 |
+
'seg_dir': None,
|
95 |
+
'has_det': True, # w/ or w/o detection
|
96 |
+
'hps_type': 'pymaf' # pymaf/pare/pixie
|
97 |
+
}
|
98 |
+
|
99 |
+
if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']:
|
100 |
+
print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
|
101 |
+
dataset_param["hps_type"] = "pymaf"
|
102 |
+
|
103 |
+
dataset = TestDataset(dataset_param, device)
|
104 |
+
|
105 |
+
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
106 |
+
|
107 |
+
pbar = tqdm(dataset)
|
108 |
+
|
109 |
+
for data in pbar:
|
110 |
+
|
111 |
+
pbar.set_description(f"{data['name']}")
|
112 |
+
|
113 |
+
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}
|
114 |
+
|
115 |
+
# The optimizer and variables
|
116 |
+
optimed_pose = torch.tensor(
|
117 |
+
data["body_pose"], device=device, requires_grad=True
|
118 |
+
) # [1,23,3,3]
|
119 |
+
optimed_trans = torch.tensor(
|
120 |
+
data["trans"], device=device, requires_grad=True
|
121 |
+
) # [3]
|
122 |
+
optimed_betas = torch.tensor(
|
123 |
+
data["betas"], device=device, requires_grad=True
|
124 |
+
) # [1,10]
|
125 |
+
optimed_orient = torch.tensor(
|
126 |
+
data["global_orient"], device=device, requires_grad=True
|
127 |
+
) # [1,1,3,3]
|
128 |
+
|
129 |
+
optimizer_smpl = torch.optim.Adam(
|
130 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
131 |
+
lr=1e-3,
|
132 |
+
amsgrad=True,
|
133 |
+
)
|
134 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
135 |
+
optimizer_smpl,
|
136 |
+
mode="min",
|
137 |
+
factor=0.5,
|
138 |
+
verbose=0,
|
139 |
+
min_lr=1e-5,
|
140 |
+
patience=config_dict['patience'],
|
141 |
+
)
|
142 |
+
|
143 |
+
losses = {
|
144 |
+
# Cloth: Normal_recon - Normal_pred
|
145 |
+
"cloth": {"weight": 1e1, "value": 0.0},
|
146 |
+
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
147 |
+
"stiffness": {"weight": 1e5, "value": 0.0},
|
148 |
+
# Cloth: det(R) = 1
|
149 |
+
"rigid": {"weight": 1e5, "value": 0.0},
|
150 |
+
# Cloth: edge length
|
151 |
+
"edge": {"weight": 0, "value": 0.0},
|
152 |
+
# Cloth: normal consistency
|
153 |
+
"nc": {"weight": 0, "value": 0.0},
|
154 |
+
# Cloth: laplacian smoonth
|
155 |
+
"laplacian": {"weight": 1e2, "value": 0.0},
|
156 |
+
# Body: Normal_pred - Normal_smpl
|
157 |
+
"normal": {"weight": 1e0, "value": 0.0},
|
158 |
+
# Body: Silhouette_pred - Silhouette_smpl
|
159 |
+
"silhouette": {"weight": 1e0, "value": 0.0},
|
160 |
+
}
|
161 |
+
|
162 |
+
# smpl optimization
|
163 |
+
|
164 |
+
loop_smpl = tqdm(range(config_dict['loop_smpl']))
|
165 |
+
|
166 |
+
for _ in loop_smpl:
|
167 |
+
|
168 |
+
optimizer_smpl.zero_grad()
|
169 |
+
|
170 |
+
# 6d_rot to rot_mat
|
171 |
+
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,6)).unsqueeze(0)
|
172 |
+
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,6)).unsqueeze(0)
|
173 |
+
|
174 |
+
if dataset_param["hps_type"] != "pixie":
|
175 |
+
smpl_out = dataset.smpl_model(
|
176 |
+
betas=optimed_betas,
|
177 |
+
body_pose=optimed_pose_mat,
|
178 |
+
global_orient=optimed_orient_mat,
|
179 |
+
pose2rot=False,
|
180 |
+
)
|
181 |
+
|
182 |
+
smpl_verts = ((smpl_out.vertices) +
|
183 |
+
optimed_trans) * data["scale"]
|
184 |
+
else:
|
185 |
+
smpl_verts, _, _ = dataset.smpl_model(
|
186 |
+
shape_params=optimed_betas,
|
187 |
+
expression_params=tensor2variable(data["exp"], device),
|
188 |
+
body_pose=optimed_pose_mat,
|
189 |
+
global_pose=optimed_orient_mat,
|
190 |
+
jaw_pose=tensor2variable(data["jaw_pose"], device),
|
191 |
+
left_hand_pose=tensor2variable(
|
192 |
+
data["left_hand_pose"], device),
|
193 |
+
right_hand_pose=tensor2variable(
|
194 |
+
data["right_hand_pose"], device),
|
195 |
+
)
|
196 |
+
|
197 |
+
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
|
198 |
+
|
199 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
200 |
+
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
|
201 |
+
smpl_verts *
|
202 |
+
torch.tensor([1.0, -1.0, -1.0]
|
203 |
+
).to(device), in_tensor["smpl_faces"]
|
204 |
+
)
|
205 |
+
T_mask_F, T_mask_B = dataset.render.get_silhouette_image()
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
|
209 |
+
in_tensor
|
210 |
+
)
|
211 |
+
|
212 |
+
diff_F_smpl = torch.abs(
|
213 |
+
in_tensor["T_normal_F"] - in_tensor["normal_F"])
|
214 |
+
diff_B_smpl = torch.abs(
|
215 |
+
in_tensor["T_normal_B"] - in_tensor["normal_B"])
|
216 |
+
|
217 |
+
losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean()
|
218 |
+
|
219 |
+
# silhouette loss
|
220 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
221 |
+
gt_arr = torch.cat(
|
222 |
+
[in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
|
223 |
+
).permute(1, 2, 0)
|
224 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
|
225 |
+
bg_color = (
|
226 |
+
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
|
227 |
+
0).unsqueeze(0).to(device)
|
228 |
+
)
|
229 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
230 |
+
diff_S = torch.abs(smpl_arr - gt_arr)
|
231 |
+
losses["silhouette"]["value"] = diff_S.mean()
|
232 |
+
|
233 |
+
# Weighted sum of the losses
|
234 |
+
smpl_loss = 0.0
|
235 |
+
pbar_desc = "Body Fitting --- "
|
236 |
+
for k in ["normal", "silhouette"]:
|
237 |
+
pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
|
238 |
+
smpl_loss += losses[k]["value"] * losses[k]["weight"]
|
239 |
+
pbar_desc += f"Total: {smpl_loss:.3f}"
|
240 |
+
loop_smpl.set_description(pbar_desc)
|
241 |
+
|
242 |
+
smpl_loss.backward()
|
243 |
+
optimizer_smpl.step()
|
244 |
+
scheduler_smpl.step(smpl_loss)
|
245 |
+
in_tensor["smpl_verts"] = smpl_verts * \
|
246 |
+
torch.tensor([1.0, 1.0, -1.0]).to(device)
|
247 |
+
|
248 |
+
# visualize the optimization process
|
249 |
+
# 1. SMPL Fitting
|
250 |
+
# 2. Clothes Refinement
|
251 |
+
|
252 |
+
os.makedirs(os.path.join(config_dict['out_dir'], cfg.name,
|
253 |
+
"refinement"), exist_ok=True)
|
254 |
+
|
255 |
+
# visualize the final results in self-rotation mode
|
256 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
257 |
+
cfg.name, "vid"), exist_ok=True)
|
258 |
+
|
259 |
+
# final results rendered as image
|
260 |
+
# 1. Render the final fitted SMPL (xxx_smpl.png)
|
261 |
+
# 2. Render the final reconstructed clothed human (xxx_cloth.png)
|
262 |
+
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
|
263 |
+
|
264 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
265 |
+
cfg.name, "png"), exist_ok=True)
|
266 |
+
|
267 |
+
# final reconstruction meshes
|
268 |
+
# 1. SMPL mesh (xxx_smpl.obj)
|
269 |
+
# 2. SMPL params (xxx_smpl.npy)
|
270 |
+
# 3. clohted mesh (xxx_recon.obj)
|
271 |
+
# 4. remeshed clothed mesh (xxx_remesh.obj)
|
272 |
+
# 5. refined clothed mesh (xxx_refine.obj)
|
273 |
+
|
274 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
275 |
+
cfg.name, "obj"), exist_ok=True)
|
276 |
+
|
277 |
+
norm_pred_F = (
|
278 |
+
((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
|
279 |
+
.detach()
|
280 |
+
.cpu()
|
281 |
+
.numpy()
|
282 |
+
.astype(np.uint8)
|
283 |
+
)
|
284 |
+
|
285 |
+
norm_pred_B = (
|
286 |
+
((in_tensor["normal_B"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
|
287 |
+
.detach()
|
288 |
+
.cpu()
|
289 |
+
.numpy()
|
290 |
+
.astype(np.uint8)
|
291 |
+
)
|
292 |
+
|
293 |
+
norm_orig_F = unwrap(norm_pred_F, data)
|
294 |
+
norm_orig_B = unwrap(norm_pred_B, data)
|
295 |
+
|
296 |
+
mask_orig = unwrap(
|
297 |
+
np.repeat(
|
298 |
+
data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
|
299 |
+
).astype(np.uint8),
|
300 |
+
data,
|
301 |
+
)
|
302 |
+
rgb_norm_F = blend_rgb_norm(data["ori_image"], norm_orig_F, mask_orig)
|
303 |
+
rgb_norm_B = blend_rgb_norm(data["ori_image"], norm_orig_B, mask_orig)
|
304 |
+
|
305 |
+
Image.fromarray(
|
306 |
+
np.concatenate(
|
307 |
+
[data["ori_image"].astype(np.uint8), rgb_norm_F, rgb_norm_B], axis=1)
|
308 |
+
).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png"))
|
309 |
+
|
310 |
+
smpl_obj = trimesh.Trimesh(
|
311 |
+
in_tensor["smpl_verts"].detach().cpu()[0] *
|
312 |
+
torch.tensor([1.0, -1.0, 1.0]),
|
313 |
+
in_tensor['smpl_faces'].detach().cpu()[0],
|
314 |
+
process=False,
|
315 |
+
maintains_order=True
|
316 |
+
)
|
317 |
+
smpl_obj.visual.vertex_colors = (smpl_obj.vertex_normals+1.0)*255.0*0.5
|
318 |
+
smpl_obj.export(
|
319 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj")
|
320 |
+
smpl_obj.export(
|
321 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
|
322 |
+
|
323 |
+
smpl_info = {'betas': optimed_betas,
|
324 |
+
'pose': optimed_pose_mat,
|
325 |
+
'orient': optimed_orient_mat,
|
326 |
+
'trans': optimed_trans}
|
327 |
+
|
328 |
+
np.save(
|
329 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)
|
330 |
+
|
331 |
+
# ------------------------------------------------------------------------------------------------------------------
|
332 |
+
|
333 |
+
# cloth optimization
|
334 |
+
|
335 |
+
# cloth recon
|
336 |
+
in_tensor.update(
|
337 |
+
dataset.compute_vis_cmap(
|
338 |
+
in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
|
339 |
+
)
|
340 |
+
)
|
341 |
+
|
342 |
+
if cfg.net.prior_type == "pamir":
|
343 |
+
in_tensor.update(
|
344 |
+
dataset.compute_voxel_verts(
|
345 |
+
optimed_pose,
|
346 |
+
optimed_orient,
|
347 |
+
optimed_betas,
|
348 |
+
optimed_trans,
|
349 |
+
data["scale"],
|
350 |
+
)
|
351 |
+
)
|
352 |
+
|
353 |
+
with torch.no_grad():
|
354 |
+
verts_pr, faces_pr, _ = model.test_single(in_tensor)
|
355 |
+
|
356 |
+
recon_obj = trimesh.Trimesh(
|
357 |
+
verts_pr, faces_pr, process=False, maintains_order=True
|
358 |
+
)
|
359 |
+
recon_obj.visual.vertex_colors = (
|
360 |
+
recon_obj.vertex_normals+1.0)*255.0*0.5
|
361 |
+
recon_obj.export(
|
362 |
+
os.path.join(config_dict['out_dir'], cfg.name,
|
363 |
+
f"obj/{data['name']}_recon.obj")
|
364 |
+
)
|
365 |
+
|
366 |
+
# Isotropic Explicit Remeshing for better geometry topology
|
367 |
+
verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name,
|
368 |
+
f"obj/{data['name']}_recon.obj"), 0.5, device)
|
369 |
+
|
370 |
+
# define local_affine deform verts
|
371 |
+
mesh_pr = Meshes(verts_refine, faces_refine).to(device)
|
372 |
+
local_affine_model = LocalAffine(
|
373 |
+
mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
|
374 |
+
optimizer_cloth = torch.optim.Adam(
|
375 |
+
[{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)
|
376 |
+
|
377 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
378 |
+
optimizer_cloth,
|
379 |
+
mode="min",
|
380 |
+
factor=0.1,
|
381 |
+
verbose=0,
|
382 |
+
min_lr=1e-5,
|
383 |
+
patience=config_dict['patience'],
|
384 |
+
)
|
385 |
+
|
386 |
+
final = None
|
387 |
+
|
388 |
+
if config_dict['loop_cloth'] > 0:
|
389 |
+
|
390 |
+
loop_cloth = tqdm(range(config_dict['loop_cloth']))
|
391 |
+
|
392 |
+
for _ in loop_cloth:
|
393 |
+
|
394 |
+
optimizer_cloth.zero_grad()
|
395 |
+
|
396 |
+
deformed_verts, stiffness, rigid = local_affine_model(
|
397 |
+
verts_refine.to(device), return_stiff=True)
|
398 |
+
mesh_pr = mesh_pr.update_padded(deformed_verts)
|
399 |
+
|
400 |
+
# losses for laplacian, edge, normal consistency
|
401 |
+
update_mesh_shape_prior_losses(mesh_pr, losses)
|
402 |
+
|
403 |
+
in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
|
404 |
+
mesh_pr.verts_padded(), mesh_pr.faces_padded())
|
405 |
+
|
406 |
+
diff_F_cloth = torch.abs(
|
407 |
+
in_tensor["P_normal_F"] - in_tensor["normal_F"])
|
408 |
+
diff_B_cloth = torch.abs(
|
409 |
+
in_tensor["P_normal_B"] - in_tensor["normal_B"])
|
410 |
+
|
411 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
412 |
+
losses["stiffness"]["value"] = torch.mean(stiffness)
|
413 |
+
losses["rigid"]["value"] = torch.mean(rigid)
|
414 |
+
|
415 |
+
# Weighted sum of the losses
|
416 |
+
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
417 |
+
pbar_desc = "Cloth Refinement --- "
|
418 |
+
|
419 |
+
for k in losses.keys():
|
420 |
+
if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
|
421 |
+
cloth_loss = cloth_loss + \
|
422 |
+
losses[k]["value"] * losses[k]["weight"]
|
423 |
+
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "
|
424 |
+
|
425 |
+
pbar_desc += f"Total: {cloth_loss:.5f}"
|
426 |
+
loop_cloth.set_description(pbar_desc)
|
427 |
+
|
428 |
+
# update params
|
429 |
+
cloth_loss.backward()
|
430 |
+
optimizer_cloth.step()
|
431 |
+
scheduler_cloth.step(cloth_loss)
|
432 |
+
|
433 |
+
final = trimesh.Trimesh(
|
434 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
435 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
436 |
+
process=False, maintains_order=True
|
437 |
+
)
|
438 |
+
|
439 |
+
# only with front texture
|
440 |
+
tex_colors = query_color(
|
441 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
442 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
443 |
+
in_tensor["image"],
|
444 |
+
device=device,
|
445 |
+
)
|
446 |
+
|
447 |
+
# full normal textures
|
448 |
+
norm_colors = (mesh_pr.verts_normals_padded().squeeze(
|
449 |
+
0).detach().cpu() + 1.0) * 0.5 * 255.0
|
450 |
+
|
451 |
+
final.visual.vertex_colors = tex_colors
|
452 |
+
final.export(
|
453 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj")
|
454 |
+
|
455 |
+
final.visual.vertex_colors = norm_colors
|
456 |
+
final.export(
|
457 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb")
|
458 |
+
|
459 |
+
# always export visualized video regardless of the cloth refinment
|
460 |
+
verts_lst = [smpl_obj.vertices, final.vertices]
|
461 |
+
faces_lst = [smpl_obj.faces, final.faces]
|
462 |
+
|
463 |
+
# self-rotated video
|
464 |
+
dataset.render.load_meshes(
|
465 |
+
verts_lst, faces_lst)
|
466 |
+
dataset.render.get_rendered_video(
|
467 |
+
[data["ori_image"], rgb_norm_F, rgb_norm_B],
|
468 |
+
os.path.join(config_dict['out_dir'], cfg.name,
|
469 |
+
f"vid/{data['name']}_cloth.mp4"),
|
470 |
+
)
|
471 |
+
|
472 |
+
smpl_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj"
|
473 |
+
smpl_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb"
|
474 |
+
smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy"
|
475 |
+
refine_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj"
|
476 |
+
refine_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb"
|
477 |
+
|
478 |
+
video_path = os.path.join(
|
479 |
+
config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4")
|
480 |
+
overlap_path = os.path.join(
|
481 |
+
config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")
|
482 |
+
|
483 |
+
# clean all the variables
|
484 |
+
for element in dir():
|
485 |
+
if 'path' not in element:
|
486 |
+
del locals()[element]
|
487 |
+
gc.collect()
|
488 |
+
torch.cuda.empty_cache()
|
489 |
+
|
490 |
+
return [smpl_glb_path, smpl_obj_path,smpl_npy_path,
|
491 |
+
refine_glb_path, refine_obj_path,
|
492 |
+
video_path, video_path, overlap_path]
|
configs / icon-filter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-filter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-filter.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs / icon-nofilter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-nofilter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-nofilter.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: False
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs /pamir.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pamir
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pamir.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pamir" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 6
|
20 |
+
voxel_dim: 7
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs /pifu.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pifu
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pifu.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pifu" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 12
|
20 |
+
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
lib / pymaf / configs / pymaf_config.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SOLVER:
|
2 |
+
MAX_ITER: 500000
|
3 |
+
TYPE: Adam
|
4 |
+
BASE_LR: 0.00005
|
5 |
+
GAMMA: 0.1
|
6 |
+
STEPS: [0]
|
7 |
+
EPOCHS: [0]
|
8 |
+
DEBUG: False
|
9 |
+
LOGDIR: ''
|
10 |
+
DEVICE: cuda
|
11 |
+
NUM_WORKERS: 8
|
12 |
+
SEED_VALUE: -1
|
13 |
+
LOSS:
|
14 |
+
KP_2D_W: 300.0
|
15 |
+
KP_3D_W: 300.0
|
16 |
+
SHAPE_W: 0.06
|
17 |
+
POSE_W: 60.0
|
18 |
+
VERT_W: 0.0
|
19 |
+
INDEX_WEIGHTS: 2.0
|
20 |
+
# Loss weights for surface parts. (24 Parts)
|
21 |
+
PART_WEIGHTS: 0.3
|
22 |
+
# Loss weights for UV regression.
|
23 |
+
POINT_REGRESSION_WEIGHTS: 0.5
|
24 |
+
TRAIN:
|
25 |
+
NUM_WORKERS: 8
|
26 |
+
BATCH_SIZE: 64
|
27 |
+
PIN_MEMORY: True
|
28 |
+
TEST:
|
29 |
+
BATCH_SIZE: 32
|
30 |
+
MODEL:
|
31 |
+
PyMAF:
|
32 |
+
BACKBONE: 'res50'
|
33 |
+
MLP_DIM: [256, 128, 64, 5]
|
34 |
+
N_ITER: 3
|
35 |
+
AUX_SUPV_ON: True
|
36 |
+
DP_HEATMAP_SIZE: 56
|
37 |
+
RES_MODEL:
|
38 |
+
DECONV_WITH_BIAS: False
|
39 |
+
NUM_DECONV_LAYERS: 3
|
40 |
+
NUM_DECONV_FILTERS:
|
41 |
+
- 256
|
42 |
+
- 256
|
43 |
+
- 256
|
44 |
+
NUM_DECONV_KERNELS:
|
45 |
+
- 4
|
46 |
+
- 4
|
47 |
+
- 4
|
lib / pymaf /core / __init__.py
ADDED
File without changes
|
lib / pymaf /core / train_options.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
class TrainOptions():
|
5 |
+
def __init__(self):
|
6 |
+
self.parser = argparse.ArgumentParser()
|
7 |
+
|
8 |
+
gen = self.parser.add_argument_group('General')
|
9 |
+
gen.add_argument(
|
10 |
+
'--resume',
|
11 |
+
dest='resume',
|
12 |
+
default=False,
|
13 |
+
action='store_true',
|
14 |
+
help='Resume from checkpoint (Use latest checkpoint by default')
|
15 |
+
|
16 |
+
io = self.parser.add_argument_group('io')
|
17 |
+
io.add_argument('--log_dir',
|
18 |
+
default='logs',
|
19 |
+
help='Directory to store logs')
|
20 |
+
io.add_argument(
|
21 |
+
'--pretrained_checkpoint',
|
22 |
+
default=None,
|
23 |
+
help='Load a pretrained checkpoint at the beginning training')
|
24 |
+
|
25 |
+
train = self.parser.add_argument_group('Training Options')
|
26 |
+
train.add_argument('--num_epochs',
|
27 |
+
type=int,
|
28 |
+
default=200,
|
29 |
+
help='Total number of training epochs')
|
30 |
+
train.add_argument('--regressor',
|
31 |
+
type=str,
|
32 |
+
choices=['hmr', 'pymaf_net'],
|
33 |
+
default='pymaf_net',
|
34 |
+
help='Name of the SMPL regressor.')
|
35 |
+
train.add_argument('--cfg_file',
|
36 |
+
type=str,
|
37 |
+
default='./configs/pymaf_config.yaml',
|
38 |
+
help='config file path for PyMAF.')
|
39 |
+
train.add_argument(
|
40 |
+
'--img_res',
|
41 |
+
type=int,
|
42 |
+
default=224,
|
43 |
+
help='Rescale bounding boxes to size [img_res, img_res] before feeding them in the network'
|
44 |
+
)
|
45 |
+
train.add_argument(
|
46 |
+
'--rot_factor',
|
47 |
+
type=float,
|
48 |
+
default=30,
|
49 |
+
help='Random rotation in the range [-rot_factor, rot_factor]')
|
50 |
+
train.add_argument(
|
51 |
+
'--noise_factor',
|
52 |
+
type=float,
|
53 |
+
default=0.4,
|
54 |
+
help='Randomly multiply pixel values with factor in the range [1-noise_factor, 1+noise_factor]'
|
55 |
+
)
|
56 |
+
train.add_argument(
|
57 |
+
'--scale_factor',
|
58 |
+
type=float,
|
59 |
+
default=0.25,
|
60 |
+
help='Rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]'
|
61 |
+
)
|
62 |
+
train.add_argument(
|
63 |
+
'--openpose_train_weight',
|
64 |
+
default=0.,
|
65 |
+
help='Weight for OpenPose keypoints during training')
|
66 |
+
train.add_argument('--gt_train_weight',
|
67 |
+
default=1.,
|
68 |
+
help='Weight for GT keypoints during training')
|
69 |
+
train.add_argument('--eval_dataset',
|
70 |
+
type=str,
|
71 |
+
default='h36m-p2-mosh',
|
72 |
+
help='Name of the evaluation dataset.')
|
73 |
+
train.add_argument('--single_dataset',
|
74 |
+
default=False,
|
75 |
+
action='store_true',
|
76 |
+
help='Use a single dataset')
|
77 |
+
train.add_argument('--single_dataname',
|
78 |
+
type=str,
|
79 |
+
default='h36m',
|
80 |
+
help='Name of the single dataset.')
|
81 |
+
train.add_argument('--eval_pve',
|
82 |
+
default=False,
|
83 |
+
action='store_true',
|
84 |
+
help='evaluate PVE')
|
85 |
+
train.add_argument('--overwrite',
|
86 |
+
default=False,
|
87 |
+
action='store_true',
|
88 |
+
help='overwrite the latest checkpoint')
|
89 |
+
|
90 |
+
train.add_argument('--distributed',
|
91 |
+
action='store_true',
|
92 |
+
help='Use distributed training')
|
93 |
+
train.add_argument('--dist_backend',
|
94 |
+
default='nccl',
|
95 |
+
type=str,
|
96 |
+
help='distributed backend')
|
97 |
+
train.add_argument('--dist_url',
|
98 |
+
default='tcp://127.0.0.1:10356',
|
99 |
+
type=str,
|
100 |
+
help='url used to set up distributed training')
|
101 |
+
train.add_argument('--world_size',
|
102 |
+
default=1,
|
103 |
+
type=int,
|
104 |
+
help='number of nodes for distributed training')
|
105 |
+
train.add_argument("--local_rank", default=0, type=int)
|
106 |
+
train.add_argument('--rank',
|
107 |
+
default=0,
|
108 |
+
type=int,
|
109 |
+
help='node rank for distributed training')
|
110 |
+
train.add_argument(
|
111 |
+
'--multiprocessing_distributed',
|
112 |
+
action='store_true',
|
113 |
+
help='Use multi-processing distributed training to launch '
|
114 |
+
'N processes per node, which has N GPUs. This is the '
|
115 |
+
'fastest way to use PyTorch for either single node or '
|
116 |
+
'multi node data parallel training')
|
117 |
+
|
118 |
+
misc = self.parser.add_argument_group('Misc Options')
|
119 |
+
misc.add_argument('--misc',
|
120 |
+
help="Modify config options using the command-line",
|
121 |
+
default=None,
|
122 |
+
nargs=argparse.REMAINDER)
|
123 |
+
return
|
124 |
+
|
125 |
+
def parse_args(self):
|
126 |
+
"""Parse input arguments."""
|
127 |
+
self.args = self.parser.parse_args()
|
128 |
+
self.save_dump()
|
129 |
+
return self.args
|
130 |
+
|
131 |
+
def save_dump(self):
|
132 |
+
"""Store all argument values to a json file.
|
133 |
+
The default location is logs/expname/args.json.
|
134 |
+
"""
|
135 |
+
pass
|
lib / pymaf /core /base_trainer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py
|
2 |
+
from __future__ import division
|
3 |
+
import logging
|
4 |
+
from utils import CheckpointSaver
|
5 |
+
from tensorboardX import SummaryWriter
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
tqdm.monitor_interval = 0
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class BaseTrainer(object):
|
17 |
+
"""Base class for Trainer objects.
|
18 |
+
Takes care of checkpointing/logging/resuming training.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, options):
|
22 |
+
self.options = options
|
23 |
+
if options.multiprocessing_distributed:
|
24 |
+
self.device = torch.device('cuda', options.gpu)
|
25 |
+
else:
|
26 |
+
self.device = torch.device(
|
27 |
+
'cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
+
# override this function to define your model, optimizers etc.
|
29 |
+
self.saver = CheckpointSaver(save_dir=options.checkpoint_dir,
|
30 |
+
overwrite=options.overwrite)
|
31 |
+
if options.rank == 0:
|
32 |
+
self.summary_writer = SummaryWriter(self.options.summary_dir)
|
33 |
+
self.init_fn()
|
34 |
+
|
35 |
+
self.checkpoint = None
|
36 |
+
if options.resume and self.saver.exists_checkpoint():
|
37 |
+
self.checkpoint = self.saver.load_checkpoint(
|
38 |
+
self.models_dict, self.optimizers_dict)
|
39 |
+
|
40 |
+
if self.checkpoint is None:
|
41 |
+
self.epoch_count = 0
|
42 |
+
self.step_count = 0
|
43 |
+
else:
|
44 |
+
self.epoch_count = self.checkpoint['epoch']
|
45 |
+
self.step_count = self.checkpoint['total_step_count']
|
46 |
+
|
47 |
+
if self.checkpoint is not None:
|
48 |
+
self.checkpoint_batch_idx = self.checkpoint['batch_idx']
|
49 |
+
else:
|
50 |
+
self.checkpoint_batch_idx = 0
|
51 |
+
|
52 |
+
self.best_performance = float('inf')
|
53 |
+
|
54 |
+
def load_pretrained(self, checkpoint_file=None):
|
55 |
+
"""Load a pretrained checkpoint.
|
56 |
+
This is different from resuming training using --resume.
|
57 |
+
"""
|
58 |
+
if checkpoint_file is not None:
|
59 |
+
checkpoint = torch.load(checkpoint_file)
|
60 |
+
for model in self.models_dict:
|
61 |
+
if model in checkpoint:
|
62 |
+
self.models_dict[model].load_state_dict(checkpoint[model],
|
63 |
+
strict=True)
|
64 |
+
print(f'Checkpoint {model} loaded')
|
65 |
+
|
66 |
+
def move_dict_to_device(self, dict, device, tensor2float=False):
|
67 |
+
for k, v in dict.items():
|
68 |
+
if isinstance(v, torch.Tensor):
|
69 |
+
if tensor2float:
|
70 |
+
dict[k] = v.float().to(device)
|
71 |
+
else:
|
72 |
+
dict[k] = v.to(device)
|
73 |
+
|
74 |
+
# The following methods (with the possible exception of test) have to be implemented in the derived classes
|
75 |
+
def train(self, epoch):
|
76 |
+
raise NotImplementedError('You need to provide an train method')
|
77 |
+
|
78 |
+
def init_fn(self):
|
79 |
+
raise NotImplementedError('You need to provide an _init_fn method')
|
80 |
+
|
81 |
+
def train_step(self, input_batch):
|
82 |
+
raise NotImplementedError('You need to provide a _train_step method')
|
83 |
+
|
84 |
+
def train_summaries(self, input_batch):
|
85 |
+
raise NotImplementedError(
|
86 |
+
'You need to provide a _train_summaries method')
|
87 |
+
|
88 |
+
def visualize(self, input_batch):
|
89 |
+
raise NotImplementedError('You need to provide a visualize method')
|
90 |
+
|
91 |
+
def validate(self):
|
92 |
+
pass
|
93 |
+
|
94 |
+
def test(self):
|
95 |
+
pass
|
96 |
+
|
97 |
+
def evaluate(self):
|
98 |
+
pass
|
99 |
+
|
100 |
+
def fit(self):
|
101 |
+
# Run training for num_epochs epochs
|
102 |
+
for epoch in tqdm(range(self.epoch_count, self.options.num_epochs),
|
103 |
+
total=self.options.num_epochs,
|
104 |
+
initial=self.epoch_count):
|
105 |
+
self.epoch_count = epoch
|
106 |
+
self.train(epoch)
|
107 |
+
return
|
lib / pymaf /core /cfgs.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import os
|
18 |
+
import json
|
19 |
+
from yacs.config import CfgNode as CN
|
20 |
+
|
21 |
+
# Configuration variables
|
22 |
+
cfg = CN(new_allowed=True)
|
23 |
+
|
24 |
+
cfg.OUTPUT_DIR = 'results'
|
25 |
+
cfg.DEVICE = 'cuda'
|
26 |
+
cfg.DEBUG = False
|
27 |
+
cfg.LOGDIR = ''
|
28 |
+
cfg.VAL_VIS_BATCH_FREQ = 200
|
29 |
+
cfg.TRAIN_VIS_ITER_FERQ = 1000
|
30 |
+
cfg.SEED_VALUE = -1
|
31 |
+
|
32 |
+
cfg.TRAIN = CN(new_allowed=True)
|
33 |
+
|
34 |
+
cfg.LOSS = CN(new_allowed=True)
|
35 |
+
cfg.LOSS.KP_2D_W = 300.0
|
36 |
+
cfg.LOSS.KP_3D_W = 300.0
|
37 |
+
cfg.LOSS.SHAPE_W = 0.06
|
38 |
+
cfg.LOSS.POSE_W = 60.0
|
39 |
+
cfg.LOSS.VERT_W = 0.0
|
40 |
+
|
41 |
+
# Loss weights for dense correspondences
|
42 |
+
cfg.LOSS.INDEX_WEIGHTS = 2.0
|
43 |
+
# Loss weights for surface parts. (24 Parts)
|
44 |
+
cfg.LOSS.PART_WEIGHTS = 0.3
|
45 |
+
# Loss weights for UV regression.
|
46 |
+
cfg.LOSS.POINT_REGRESSION_WEIGHTS = 0.5
|
47 |
+
|
48 |
+
cfg.MODEL = CN(new_allowed=True)
|
49 |
+
|
50 |
+
cfg.MODEL.PyMAF = CN(new_allowed=True)
|
51 |
+
|
52 |
+
# switch
|
53 |
+
cfg.TRAIN.VAL_LOOP = True
|
54 |
+
|
55 |
+
cfg.TEST = CN(new_allowed=True)
|
56 |
+
|
57 |
+
|
58 |
+
def get_cfg_defaults():
|
59 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
60 |
+
# Return a clone so that the defaults will not be altered
|
61 |
+
# This is for the "local variable" use pattern
|
62 |
+
# return cfg.clone()
|
63 |
+
return cfg
|
64 |
+
|
65 |
+
|
66 |
+
def update_cfg(cfg_file):
|
67 |
+
# cfg = get_cfg_defaults()
|
68 |
+
cfg.merge_from_file(cfg_file)
|
69 |
+
# return cfg.clone()
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
|
73 |
+
def parse_args(args):
|
74 |
+
cfg_file = args.cfg_file
|
75 |
+
if args.cfg_file is not None:
|
76 |
+
cfg = update_cfg(args.cfg_file)
|
77 |
+
else:
|
78 |
+
cfg = get_cfg_defaults()
|
79 |
+
|
80 |
+
# if args.misc is not None:
|
81 |
+
# cfg.merge_from_list(args.misc)
|
82 |
+
|
83 |
+
return cfg
|
84 |
+
|
85 |
+
|
86 |
+
def parse_args_extend(args):
|
87 |
+
if args.resume:
|
88 |
+
if not os.path.exists(args.log_dir):
|
89 |
+
raise ValueError(
|
90 |
+
'Experiment are set to resume mode, but log directory does not exist.'
|
91 |
+
)
|
92 |
+
|
93 |
+
# load log's cfg
|
94 |
+
cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
|
95 |
+
cfg = update_cfg(cfg_file)
|
96 |
+
|
97 |
+
if args.misc is not None:
|
98 |
+
cfg.merge_from_list(args.misc)
|
99 |
+
else:
|
100 |
+
parse_args(args)
|
lib / pymaf /core /constants.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/constants.py
|
2 |
+
FOCAL_LENGTH = 5000.
|
3 |
+
IMG_RES = 224
|
4 |
+
|
5 |
+
# Mean and standard deviation for normalizing input image
|
6 |
+
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
|
7 |
+
IMG_NORM_STD = [0.229, 0.224, 0.225]
|
8 |
+
"""
|
9 |
+
We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
|
10 |
+
We keep a superset of 24 joints such that we include all joints from every dataset.
|
11 |
+
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
|
12 |
+
The joints used here are the following:
|
13 |
+
"""
|
14 |
+
JOINT_NAMES = [
|
15 |
+
# 25 OpenPose joints (in the order provided by OpenPose)
|
16 |
+
'OP Nose',
|
17 |
+
'OP Neck',
|
18 |
+
'OP RShoulder',
|
19 |
+
'OP RElbow',
|
20 |
+
'OP RWrist',
|
21 |
+
'OP LShoulder',
|
22 |
+
'OP LElbow',
|
23 |
+
'OP LWrist',
|
24 |
+
'OP MidHip',
|
25 |
+
'OP RHip',
|
26 |
+
'OP RKnee',
|
27 |
+
'OP RAnkle',
|
28 |
+
'OP LHip',
|
29 |
+
'OP LKnee',
|
30 |
+
'OP LAnkle',
|
31 |
+
'OP REye',
|
32 |
+
'OP LEye',
|
33 |
+
'OP REar',
|
34 |
+
'OP LEar',
|
35 |
+
'OP LBigToe',
|
36 |
+
'OP LSmallToe',
|
37 |
+
'OP LHeel',
|
38 |
+
'OP RBigToe',
|
39 |
+
'OP RSmallToe',
|
40 |
+
'OP RHeel',
|
41 |
+
# 24 Ground Truth joints (superset of joints from different datasets)
|
42 |
+
'Right Ankle',
|
43 |
+
'Right Knee',
|
44 |
+
'Right Hip', # 2
|
45 |
+
'Left Hip',
|
46 |
+
'Left Knee', # 4
|
47 |
+
'Left Ankle',
|
48 |
+
'Right Wrist', # 6
|
49 |
+
'Right Elbow',
|
50 |
+
'Right Shoulder', # 8
|
51 |
+
'Left Shoulder',
|
52 |
+
'Left Elbow', # 10
|
53 |
+
'Left Wrist',
|
54 |
+
'Neck (LSP)', # 12
|
55 |
+
'Top of Head (LSP)',
|
56 |
+
'Pelvis (MPII)', # 14
|
57 |
+
'Thorax (MPII)',
|
58 |
+
'Spine (H36M)', # 16
|
59 |
+
'Jaw (H36M)',
|
60 |
+
'Head (H36M)', # 18
|
61 |
+
'Nose',
|
62 |
+
'Left Eye',
|
63 |
+
'Right Eye',
|
64 |
+
'Left Ear',
|
65 |
+
'Right Ear'
|
66 |
+
]
|
67 |
+
|
68 |
+
# Dict containing the joints in numerical order
|
69 |
+
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
|
70 |
+
|
71 |
+
# Map joints to SMPL joints
|
72 |
+
JOINT_MAP = {
|
73 |
+
'OP Nose': 24,
|
74 |
+
'OP Neck': 12,
|
75 |
+
'OP RShoulder': 17,
|
76 |
+
'OP RElbow': 19,
|
77 |
+
'OP RWrist': 21,
|
78 |
+
'OP LShoulder': 16,
|
79 |
+
'OP LElbow': 18,
|
80 |
+
'OP LWrist': 20,
|
81 |
+
'OP MidHip': 0,
|
82 |
+
'OP RHip': 2,
|
83 |
+
'OP RKnee': 5,
|
84 |
+
'OP RAnkle': 8,
|
85 |
+
'OP LHip': 1,
|
86 |
+
'OP LKnee': 4,
|
87 |
+
'OP LAnkle': 7,
|
88 |
+
'OP REye': 25,
|
89 |
+
'OP LEye': 26,
|
90 |
+
'OP REar': 27,
|
91 |
+
'OP LEar': 28,
|
92 |
+
'OP LBigToe': 29,
|
93 |
+
'OP LSmallToe': 30,
|
94 |
+
'OP LHeel': 31,
|
95 |
+
'OP RBigToe': 32,
|
96 |
+
'OP RSmallToe': 33,
|
97 |
+
'OP RHeel': 34,
|
98 |
+
'Right Ankle': 8,
|
99 |
+
'Right Knee': 5,
|
100 |
+
'Right Hip': 45,
|
101 |
+
'Left Hip': 46,
|
102 |
+
'Left Knee': 4,
|
103 |
+
'Left Ankle': 7,
|
104 |
+
'Right Wrist': 21,
|
105 |
+
'Right Elbow': 19,
|
106 |
+
'Right Shoulder': 17,
|
107 |
+
'Left Shoulder': 16,
|
108 |
+
'Left Elbow': 18,
|
109 |
+
'Left Wrist': 20,
|
110 |
+
'Neck (LSP)': 47,
|
111 |
+
'Top of Head (LSP)': 48,
|
112 |
+
'Pelvis (MPII)': 49,
|
113 |
+
'Thorax (MPII)': 50,
|
114 |
+
'Spine (H36M)': 51,
|
115 |
+
'Jaw (H36M)': 52,
|
116 |
+
'Head (H36M)': 53,
|
117 |
+
'Nose': 24,
|
118 |
+
'Left Eye': 26,
|
119 |
+
'Right Eye': 25,
|
120 |
+
'Left Ear': 28,
|
121 |
+
'Right Ear': 27
|
122 |
+
}
|
123 |
+
|
124 |
+
# Joint selectors
|
125 |
+
# Indices to get the 14 LSP joints from the 17 H36M joints
|
126 |
+
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
|
127 |
+
H36M_TO_J14 = H36M_TO_J17[:14]
|
128 |
+
# Indices to get the 14 LSP joints from the ground truth joints
|
129 |
+
J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
|
130 |
+
J24_TO_J14 = J24_TO_J17[:14]
|
131 |
+
J24_TO_J19 = J24_TO_J17[:14] + [19, 20, 21, 22, 23]
|
132 |
+
J24_TO_JCOCO = [19, 20, 21, 22, 23, 9, 8, 10, 7, 11, 6, 3, 2, 4, 1, 5, 0]
|
133 |
+
|
134 |
+
# Permutation of SMPL pose parameters when flipping the shape
|
135 |
+
SMPL_JOINTS_FLIP_PERM = [
|
136 |
+
0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21,
|
137 |
+
20, 23, 22
|
138 |
+
]
|
139 |
+
SMPL_POSE_FLIP_PERM = []
|
140 |
+
for i in SMPL_JOINTS_FLIP_PERM:
|
141 |
+
SMPL_POSE_FLIP_PERM.append(3 * i)
|
142 |
+
SMPL_POSE_FLIP_PERM.append(3 * i + 1)
|
143 |
+
SMPL_POSE_FLIP_PERM.append(3 * i + 2)
|
144 |
+
# Permutation indices for the 24 ground truth joints
|
145 |
+
J24_FLIP_PERM = [
|
146 |
+
5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21,
|
147 |
+
20, 23, 22
|
148 |
+
]
|
149 |
+
# Permutation indices for the full set of 49 joints
|
150 |
+
J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
|
151 |
+
+ [25+i for i in J24_FLIP_PERM]
|
152 |
+
SMPL_J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
|
153 |
+
+ [25+i for i in SMPL_JOINTS_FLIP_PERM]
|
lib / pymaf /core /fits_dict.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis
|
9 |
+
|
10 |
+
from core import path_config, constants
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class FitsDict():
|
18 |
+
""" Dictionary keeping track of the best fit per image in the training set """
|
19 |
+
|
20 |
+
def __init__(self, options, train_dataset):
|
21 |
+
self.options = options
|
22 |
+
self.train_dataset = train_dataset
|
23 |
+
self.fits_dict = {}
|
24 |
+
self.valid_fit_state = {}
|
25 |
+
# array used to flip SMPL pose parameters
|
26 |
+
self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM,
|
27 |
+
dtype=torch.int64)
|
28 |
+
# Load dictionary state
|
29 |
+
for ds_name, ds in train_dataset.dataset_dict.items():
|
30 |
+
if ds_name in ['h36m']:
|
31 |
+
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
|
32 |
+
ds_name + '.npy')
|
33 |
+
self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file))
|
34 |
+
self.valid_fit_state[ds_name] = torch.ones(len(
|
35 |
+
self.fits_dict[ds_name]),
|
36 |
+
dtype=torch.uint8)
|
37 |
+
else:
|
38 |
+
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
|
39 |
+
ds_name + '.npz')
|
40 |
+
fits_dict = np.load(dict_file)
|
41 |
+
opt_pose = torch.from_numpy(fits_dict['pose'])
|
42 |
+
opt_betas = torch.from_numpy(fits_dict['betas'])
|
43 |
+
opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to(
|
44 |
+
torch.uint8)
|
45 |
+
self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas],
|
46 |
+
dim=1)
|
47 |
+
self.valid_fit_state[ds_name] = opt_valid_fit
|
48 |
+
|
49 |
+
if not options.single_dataset:
|
50 |
+
for ds in train_dataset.datasets:
|
51 |
+
if ds.dataset not in ['h36m']:
|
52 |
+
ds.pose = self.fits_dict[ds.dataset][:, :72].numpy()
|
53 |
+
ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy()
|
54 |
+
ds.has_smpl = self.valid_fit_state[ds.dataset].numpy()
|
55 |
+
|
56 |
+
def save(self):
|
57 |
+
""" Save dictionary state to disk """
|
58 |
+
for ds_name in self.train_dataset.dataset_dict.keys():
|
59 |
+
dict_file = os.path.join(self.options.checkpoint_dir,
|
60 |
+
ds_name + '_fits.npy')
|
61 |
+
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
|
62 |
+
|
63 |
+
def __getitem__(self, x):
|
64 |
+
""" Retrieve dictionary entries """
|
65 |
+
dataset_name, ind, rot, is_flipped = x
|
66 |
+
batch_size = len(dataset_name)
|
67 |
+
pose = torch.zeros((batch_size, 72))
|
68 |
+
betas = torch.zeros((batch_size, 10))
|
69 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
70 |
+
params = self.fits_dict[ds][i]
|
71 |
+
pose[n, :] = params[:72]
|
72 |
+
betas[n, :] = params[72:]
|
73 |
+
pose = pose.clone()
|
74 |
+
# Apply flipping and rotation
|
75 |
+
pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped)
|
76 |
+
betas = betas.clone()
|
77 |
+
return pose, betas
|
78 |
+
|
79 |
+
def get_vaild_state(self, dataset_name, ind):
|
80 |
+
batch_size = len(dataset_name)
|
81 |
+
valid_fit = torch.zeros(batch_size, dtype=torch.uint8)
|
82 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
83 |
+
valid_fit[n] = self.valid_fit_state[ds][i]
|
84 |
+
valid_fit = valid_fit.clone()
|
85 |
+
return valid_fit
|
86 |
+
|
87 |
+
def __setitem__(self, x, val):
|
88 |
+
""" Update dictionary entries """
|
89 |
+
dataset_name, ind, rot, is_flipped, update = x
|
90 |
+
pose, betas = val
|
91 |
+
batch_size = len(dataset_name)
|
92 |
+
# Undo flipping and rotation
|
93 |
+
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot)
|
94 |
+
params = torch.cat((pose, betas), dim=-1).cpu()
|
95 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
96 |
+
if update[n]:
|
97 |
+
self.fits_dict[ds][i] = params[n]
|
98 |
+
|
99 |
+
def flip_pose(self, pose, is_flipped):
|
100 |
+
"""flip SMPL pose parameters"""
|
101 |
+
is_flipped = is_flipped.byte()
|
102 |
+
pose_f = pose.clone()
|
103 |
+
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
|
104 |
+
# we also negate the second and the third dimension of the axis-angle representation
|
105 |
+
pose_f[is_flipped, 1::3] *= -1
|
106 |
+
pose_f[is_flipped, 2::3] *= -1
|
107 |
+
return pose_f
|
108 |
+
|
109 |
+
def rotate_pose(self, pose, rot):
|
110 |
+
"""Rotate SMPL pose parameters by rot degrees"""
|
111 |
+
pose = pose.clone()
|
112 |
+
cos = torch.cos(-np.pi * rot / 180.)
|
113 |
+
sin = torch.sin(-np.pi * rot / 180.)
|
114 |
+
zeros = torch.zeros_like(cos)
|
115 |
+
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
|
116 |
+
r3[:, 0, -1] = 1
|
117 |
+
R = torch.cat([
|
118 |
+
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
|
119 |
+
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
|
120 |
+
],
|
121 |
+
dim=1)
|
122 |
+
global_pose = pose[:, :3]
|
123 |
+
global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose)
|
124 |
+
global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3]
|
125 |
+
global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3)
|
126 |
+
global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3
|
127 |
+
global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy()
|
128 |
+
global_pose_np = np.zeros((global_pose.shape[0], 3))
|
129 |
+
for i in range(global_pose.shape[0]):
|
130 |
+
aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
|
131 |
+
global_pose_np[i, :] = aa.squeeze()
|
132 |
+
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
|
133 |
+
return pose
|
lib / pymaf /core /path_config.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/path_config.py
|
3 |
+
path configuration
|
4 |
+
This file contains definitions of useful data stuctures and the paths
|
5 |
+
for the datasets and data files necessary to run the code.
|
6 |
+
Things you need to change: *_ROOT that indicate the path to each dataset
|
7 |
+
"""
|
8 |
+
import os
|
9 |
+
from huggingface_hub import hf_hub_url, cached_download
|
10 |
+
|
11 |
+
# pymaf
|
12 |
+
pymaf_data_dir = hf_hub_url('Yuliang/PyMAF', '')
|
13 |
+
smpl_data_dir = hf_hub_url('Yuliang/SMPL', '')
|
14 |
+
SMPL_MODEL_DIR = os.path.join(smpl_data_dir, 'models/smpl')
|
15 |
+
|
16 |
+
SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
|
17 |
+
MESH_DOWNSAMPLEING = cached_download(os.path.join(pymaf_data_dir, 'mesh_downsampling.npz'), use_auth_token=os.environ['ICON'])
|
18 |
+
CUBE_PARTS_FILE = cached_download(os.path.join(pymaf_data_dir, 'cube_parts.npy'), use_auth_token=os.environ['ICON'])
|
19 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_extra.npy'), use_auth_token=os.environ['ICON'])
|
20 |
+
JOINT_REGRESSOR_H36M = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_h36m.npy'), use_auth_token=os.environ['ICON'])
|
21 |
+
VERTEX_TEXTURE_FILE = cached_download(os.path.join(pymaf_data_dir, 'vertex_texture.npy'), use_auth_token=os.environ['ICON'])
|
22 |
+
SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
|
23 |
+
CHECKPOINT_FILE = cached_download(os.path.join(pymaf_data_dir, 'pretrained_model/PyMAF_model_checkpoint.pt'), use_auth_token=os.environ['ICON'])
|
24 |
+
|
lib / pymaf /models / __init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .hmr import hmr
|
2 |
+
from .pymaf_net import pymaf_net
|
3 |
+
from .smpl import SMPL
|
lib / pymaf /models / pymaf_net.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from lib.pymaf.utils.geometry import rot6d_to_rotmat, projection, rotation_matrix_to_angle_axis
|
6 |
+
from .maf_extractor import MAF_Extractor
|
7 |
+
from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, H36M_TO_J14
|
8 |
+
from .hmr import ResNet_Backbone
|
9 |
+
from .res_module import IUV_predict_layer
|
10 |
+
from lib.common.config import cfg
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
BN_MOMENTUM = 0.1
|
16 |
+
|
17 |
+
|
18 |
+
class Regressor(nn.Module):
|
19 |
+
def __init__(self, feat_dim, smpl_mean_params):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
npose = 24 * 6
|
23 |
+
|
24 |
+
self.fc1 = nn.Linear(feat_dim + npose + 13, 1024)
|
25 |
+
self.drop1 = nn.Dropout()
|
26 |
+
self.fc2 = nn.Linear(1024, 1024)
|
27 |
+
self.drop2 = nn.Dropout()
|
28 |
+
self.decpose = nn.Linear(1024, npose)
|
29 |
+
self.decshape = nn.Linear(1024, 10)
|
30 |
+
self.deccam = nn.Linear(1024, 3)
|
31 |
+
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
|
32 |
+
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
|
33 |
+
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
|
34 |
+
|
35 |
+
self.smpl = SMPL(SMPL_MODEL_DIR, batch_size=64, create_transl=False)
|
36 |
+
|
37 |
+
mean_params = np.load(smpl_mean_params)
|
38 |
+
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
|
39 |
+
init_shape = torch.from_numpy(
|
40 |
+
mean_params['shape'][:].astype('float32')).unsqueeze(0)
|
41 |
+
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
|
42 |
+
self.register_buffer('init_pose', init_pose)
|
43 |
+
self.register_buffer('init_shape', init_shape)
|
44 |
+
self.register_buffer('init_cam', init_cam)
|
45 |
+
|
46 |
+
def forward(self,
|
47 |
+
x,
|
48 |
+
init_pose=None,
|
49 |
+
init_shape=None,
|
50 |
+
init_cam=None,
|
51 |
+
n_iter=1,
|
52 |
+
J_regressor=None):
|
53 |
+
batch_size = x.shape[0]
|
54 |
+
|
55 |
+
if init_pose is None:
|
56 |
+
init_pose = self.init_pose.expand(batch_size, -1)
|
57 |
+
if init_shape is None:
|
58 |
+
init_shape = self.init_shape.expand(batch_size, -1)
|
59 |
+
if init_cam is None:
|
60 |
+
init_cam = self.init_cam.expand(batch_size, -1)
|
61 |
+
|
62 |
+
pred_pose = init_pose
|
63 |
+
pred_shape = init_shape
|
64 |
+
pred_cam = init_cam
|
65 |
+
for i in range(n_iter):
|
66 |
+
xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
|
67 |
+
xc = self.fc1(xc)
|
68 |
+
xc = self.drop1(xc)
|
69 |
+
xc = self.fc2(xc)
|
70 |
+
xc = self.drop2(xc)
|
71 |
+
pred_pose = self.decpose(xc) + pred_pose
|
72 |
+
pred_shape = self.decshape(xc) + pred_shape
|
73 |
+
pred_cam = self.deccam(xc) + pred_cam
|
74 |
+
|
75 |
+
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
|
76 |
+
|
77 |
+
pred_output = self.smpl(betas=pred_shape,
|
78 |
+
body_pose=pred_rotmat[:, 1:],
|
79 |
+
global_orient=pred_rotmat[:, 0].unsqueeze(1),
|
80 |
+
pose2rot=False)
|
81 |
+
|
82 |
+
pred_vertices = pred_output.vertices
|
83 |
+
pred_joints = pred_output.joints
|
84 |
+
pred_smpl_joints = pred_output.smpl_joints
|
85 |
+
pred_keypoints_2d = projection(pred_joints, pred_cam)
|
86 |
+
pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
|
87 |
+
3)).reshape(
|
88 |
+
-1, 72)
|
89 |
+
|
90 |
+
if J_regressor is not None:
|
91 |
+
pred_joints = torch.matmul(J_regressor, pred_vertices)
|
92 |
+
pred_pelvis = pred_joints[:, [0], :].clone()
|
93 |
+
pred_joints = pred_joints[:, H36M_TO_J14, :]
|
94 |
+
pred_joints = pred_joints - pred_pelvis
|
95 |
+
|
96 |
+
output = {
|
97 |
+
'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
|
98 |
+
'verts': pred_vertices,
|
99 |
+
'kp_2d': pred_keypoints_2d,
|
100 |
+
'kp_3d': pred_joints,
|
101 |
+
'smpl_kp_3d': pred_smpl_joints,
|
102 |
+
'rotmat': pred_rotmat,
|
103 |
+
'pred_cam': pred_cam,
|
104 |
+
'pred_shape': pred_shape,
|
105 |
+
'pred_pose': pred_pose,
|
106 |
+
}
|
107 |
+
return output
|
108 |
+
|
109 |
+
def forward_init(self,
|
110 |
+
x,
|
111 |
+
init_pose=None,
|
112 |
+
init_shape=None,
|
113 |
+
init_cam=None,
|
114 |
+
n_iter=1,
|
115 |
+
J_regressor=None):
|
116 |
+
batch_size = x.shape[0]
|
117 |
+
|
118 |
+
if init_pose is None:
|
119 |
+
init_pose = self.init_pose.expand(batch_size, -1)
|
120 |
+
if init_shape is None:
|
121 |
+
init_shape = self.init_shape.expand(batch_size, -1)
|
122 |
+
if init_cam is None:
|
123 |
+
init_cam = self.init_cam.expand(batch_size, -1)
|
124 |
+
|
125 |
+
pred_pose = init_pose
|
126 |
+
pred_shape = init_shape
|
127 |
+
pred_cam = init_cam
|
128 |
+
|
129 |
+
pred_rotmat = rot6d_to_rotmat(pred_pose.contiguous()).view(
|
130 |
+
batch_size, 24, 3, 3)
|
131 |
+
|
132 |
+
pred_output = self.smpl(betas=pred_shape,
|
133 |
+
body_pose=pred_rotmat[:, 1:],
|
134 |
+
global_orient=pred_rotmat[:, 0].unsqueeze(1),
|
135 |
+
pose2rot=False)
|
136 |
+
|
137 |
+
pred_vertices = pred_output.vertices
|
138 |
+
pred_joints = pred_output.joints
|
139 |
+
pred_smpl_joints = pred_output.smpl_joints
|
140 |
+
pred_keypoints_2d = projection(pred_joints, pred_cam)
|
141 |
+
pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
|
142 |
+
3)).reshape(
|
143 |
+
-1, 72)
|
144 |
+
|
145 |
+
if J_regressor is not None:
|
146 |
+
pred_joints = torch.matmul(J_regressor, pred_vertices)
|
147 |
+
pred_pelvis = pred_joints[:, [0], :].clone()
|
148 |
+
pred_joints = pred_joints[:, H36M_TO_J14, :]
|
149 |
+
pred_joints = pred_joints - pred_pelvis
|
150 |
+
|
151 |
+
output = {
|
152 |
+
'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
|
153 |
+
'verts': pred_vertices,
|
154 |
+
'kp_2d': pred_keypoints_2d,
|
155 |
+
'kp_3d': pred_joints,
|
156 |
+
'smpl_kp_3d': pred_smpl_joints,
|
157 |
+
'rotmat': pred_rotmat,
|
158 |
+
'pred_cam': pred_cam,
|
159 |
+
'pred_shape': pred_shape,
|
160 |
+
'pred_pose': pred_pose,
|
161 |
+
}
|
162 |
+
return output
|
163 |
+
|
164 |
+
|
165 |
+
class PyMAF(nn.Module):
|
166 |
+
""" PyMAF based Deep Regressor for Human Mesh Recovery
|
167 |
+
PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True):
|
171 |
+
super().__init__()
|
172 |
+
self.feature_extractor = ResNet_Backbone(
|
173 |
+
model=cfg.MODEL.PyMAF.BACKBONE, pretrained=pretrained)
|
174 |
+
|
175 |
+
# deconv layers
|
176 |
+
self.inplanes = self.feature_extractor.inplanes
|
177 |
+
self.deconv_with_bias = cfg.RES_MODEL.DECONV_WITH_BIAS
|
178 |
+
self.deconv_layers = self._make_deconv_layer(
|
179 |
+
cfg.RES_MODEL.NUM_DECONV_LAYERS,
|
180 |
+
cfg.RES_MODEL.NUM_DECONV_FILTERS,
|
181 |
+
cfg.RES_MODEL.NUM_DECONV_KERNELS,
|
182 |
+
)
|
183 |
+
|
184 |
+
self.maf_extractor = nn.ModuleList()
|
185 |
+
for _ in range(cfg.MODEL.PyMAF.N_ITER):
|
186 |
+
self.maf_extractor.append(MAF_Extractor())
|
187 |
+
ma_feat_len = self.maf_extractor[-1].Dmap.shape[
|
188 |
+
0] * cfg.MODEL.PyMAF.MLP_DIM[-1]
|
189 |
+
|
190 |
+
grid_size = 21
|
191 |
+
xv, yv = torch.meshgrid([
|
192 |
+
torch.linspace(-1, 1, grid_size),
|
193 |
+
torch.linspace(-1, 1, grid_size)
|
194 |
+
])
|
195 |
+
points_grid = torch.stack([xv.reshape(-1),
|
196 |
+
yv.reshape(-1)]).unsqueeze(0)
|
197 |
+
self.register_buffer('points_grid', points_grid)
|
198 |
+
grid_feat_len = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1]
|
199 |
+
|
200 |
+
self.regressor = nn.ModuleList()
|
201 |
+
for i in range(cfg.MODEL.PyMAF.N_ITER):
|
202 |
+
if i == 0:
|
203 |
+
ref_infeat_dim = grid_feat_len
|
204 |
+
else:
|
205 |
+
ref_infeat_dim = ma_feat_len
|
206 |
+
self.regressor.append(
|
207 |
+
Regressor(feat_dim=ref_infeat_dim,
|
208 |
+
smpl_mean_params=smpl_mean_params))
|
209 |
+
|
210 |
+
dp_feat_dim = 256
|
211 |
+
self.with_uv = cfg.LOSS.POINT_REGRESSION_WEIGHTS > 0
|
212 |
+
if cfg.MODEL.PyMAF.AUX_SUPV_ON:
|
213 |
+
self.dp_head = IUV_predict_layer(feat_dim=dp_feat_dim)
|
214 |
+
|
215 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
216 |
+
downsample = None
|
217 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
218 |
+
downsample = nn.Sequential(
|
219 |
+
nn.Conv2d(self.inplanes,
|
220 |
+
planes * block.expansion,
|
221 |
+
kernel_size=1,
|
222 |
+
stride=stride,
|
223 |
+
bias=False),
|
224 |
+
nn.BatchNorm2d(planes * block.expansion),
|
225 |
+
)
|
226 |
+
|
227 |
+
layers = []
|
228 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
229 |
+
self.inplanes = planes * block.expansion
|
230 |
+
for i in range(1, blocks):
|
231 |
+
layers.append(block(self.inplanes, planes))
|
232 |
+
|
233 |
+
return nn.Sequential(*layers)
|
234 |
+
|
235 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
236 |
+
"""
|
237 |
+
Deconv_layer used in Simple Baselines:
|
238 |
+
Xiao et al. Simple Baselines for Human Pose Estimation and Tracking
|
239 |
+
https://github.com/microsoft/human-pose-estimation.pytorch
|
240 |
+
"""
|
241 |
+
assert num_layers == len(num_filters), \
|
242 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
243 |
+
assert num_layers == len(num_kernels), \
|
244 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
245 |
+
|
246 |
+
def _get_deconv_cfg(deconv_kernel, index):
|
247 |
+
if deconv_kernel == 4:
|
248 |
+
padding = 1
|
249 |
+
output_padding = 0
|
250 |
+
elif deconv_kernel == 3:
|
251 |
+
padding = 1
|
252 |
+
output_padding = 1
|
253 |
+
elif deconv_kernel == 2:
|
254 |
+
padding = 0
|
255 |
+
output_padding = 0
|
256 |
+
|
257 |
+
return deconv_kernel, padding, output_padding
|
258 |
+
|
259 |
+
layers = []
|
260 |
+
for i in range(num_layers):
|
261 |
+
kernel, padding, output_padding = _get_deconv_cfg(
|
262 |
+
num_kernels[i], i)
|
263 |
+
|
264 |
+
planes = num_filters[i]
|
265 |
+
layers.append(
|
266 |
+
nn.ConvTranspose2d(in_channels=self.inplanes,
|
267 |
+
out_channels=planes,
|
268 |
+
kernel_size=kernel,
|
269 |
+
stride=2,
|
270 |
+
padding=padding,
|
271 |
+
output_padding=output_padding,
|
272 |
+
bias=self.deconv_with_bias))
|
273 |
+
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
274 |
+
layers.append(nn.ReLU(inplace=True))
|
275 |
+
self.inplanes = planes
|
276 |
+
|
277 |
+
return nn.Sequential(*layers)
|
278 |
+
|
279 |
+
def forward(self, x, J_regressor=None):
|
280 |
+
|
281 |
+
batch_size = x.shape[0]
|
282 |
+
|
283 |
+
# spatial features and global features
|
284 |
+
s_feat, g_feat = self.feature_extractor(x)
|
285 |
+
|
286 |
+
assert cfg.MODEL.PyMAF.N_ITER >= 0 and cfg.MODEL.PyMAF.N_ITER <= 3
|
287 |
+
if cfg.MODEL.PyMAF.N_ITER == 1:
|
288 |
+
deconv_blocks = [self.deconv_layers]
|
289 |
+
elif cfg.MODEL.PyMAF.N_ITER == 2:
|
290 |
+
deconv_blocks = [self.deconv_layers[0:6], self.deconv_layers[6:9]]
|
291 |
+
elif cfg.MODEL.PyMAF.N_ITER == 3:
|
292 |
+
deconv_blocks = [
|
293 |
+
self.deconv_layers[0:3], self.deconv_layers[3:6],
|
294 |
+
self.deconv_layers[6:9]
|
295 |
+
]
|
296 |
+
|
297 |
+
out_list = {}
|
298 |
+
|
299 |
+
# initial parameters
|
300 |
+
# TODO: remove the initial mesh generation during forward to reduce runtime
|
301 |
+
# by generating initial mesh the beforehand: smpl_output = self.init_smpl
|
302 |
+
smpl_output = self.regressor[0].forward_init(g_feat,
|
303 |
+
J_regressor=J_regressor)
|
304 |
+
|
305 |
+
out_list['smpl_out'] = [smpl_output]
|
306 |
+
out_list['dp_out'] = []
|
307 |
+
|
308 |
+
# for visulization
|
309 |
+
vis_feat_list = [s_feat.detach()]
|
310 |
+
|
311 |
+
# parameter predictions
|
312 |
+
for rf_i in range(cfg.MODEL.PyMAF.N_ITER):
|
313 |
+
pred_cam = smpl_output['pred_cam']
|
314 |
+
pred_shape = smpl_output['pred_shape']
|
315 |
+
pred_pose = smpl_output['pred_pose']
|
316 |
+
|
317 |
+
pred_cam = pred_cam.detach()
|
318 |
+
pred_shape = pred_shape.detach()
|
319 |
+
pred_pose = pred_pose.detach()
|
320 |
+
|
321 |
+
s_feat_i = deconv_blocks[rf_i](s_feat)
|
322 |
+
s_feat = s_feat_i
|
323 |
+
vis_feat_list.append(s_feat_i.detach())
|
324 |
+
|
325 |
+
self.maf_extractor[rf_i].im_feat = s_feat_i
|
326 |
+
self.maf_extractor[rf_i].cam = pred_cam
|
327 |
+
|
328 |
+
if rf_i == 0:
|
329 |
+
sample_points = torch.transpose(
|
330 |
+
self.points_grid.expand(batch_size, -1, -1), 1, 2)
|
331 |
+
ref_feature = self.maf_extractor[rf_i].sampling(sample_points)
|
332 |
+
else:
|
333 |
+
pred_smpl_verts = smpl_output['verts'].detach()
|
334 |
+
# TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration
|
335 |
+
pred_smpl_verts_ds = torch.matmul(
|
336 |
+
self.maf_extractor[rf_i].Dmap.unsqueeze(0),
|
337 |
+
pred_smpl_verts) # [B, 431, 3]
|
338 |
+
ref_feature = self.maf_extractor[rf_i](
|
339 |
+
pred_smpl_verts_ds) # [B, 431 * n_feat]
|
340 |
+
|
341 |
+
smpl_output = self.regressor[rf_i](ref_feature,
|
342 |
+
pred_pose,
|
343 |
+
pred_shape,
|
344 |
+
pred_cam,
|
345 |
+
n_iter=1,
|
346 |
+
J_regressor=J_regressor)
|
347 |
+
out_list['smpl_out'].append(smpl_output)
|
348 |
+
|
349 |
+
if self.training and cfg.MODEL.PyMAF.AUX_SUPV_ON:
|
350 |
+
iuv_out_dict = self.dp_head(s_feat)
|
351 |
+
out_list['dp_out'].append(iuv_out_dict)
|
352 |
+
|
353 |
+
return out_list
|
354 |
+
|
355 |
+
|
356 |
+
def pymaf_net(smpl_mean_params, pretrained=True):
|
357 |
+
""" Constructs an PyMAF model with ResNet50 backbone.
|
358 |
+
Args:
|
359 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
360 |
+
"""
|
361 |
+
model = PyMAF(smpl_mean_params, pretrained)
|
362 |
+
return model
|
lib / pymaf /models / smpl.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/smpl.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from lib.smplx import SMPL as _SMPL
|
6 |
+
from lib.smplx.body_models import ModelOutput
|
7 |
+
from lib.smplx.lbs import vertices2joints
|
8 |
+
from collections import namedtuple
|
9 |
+
|
10 |
+
from lib.pymaf.core import path_config, constants
|
11 |
+
|
12 |
+
SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
|
13 |
+
SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR
|
14 |
+
|
15 |
+
# Indices to get the 14 LSP joints from the 17 H36M joints
|
16 |
+
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
|
17 |
+
H36M_TO_J14 = H36M_TO_J17[:14]
|
18 |
+
|
19 |
+
|
20 |
+
class SMPL(_SMPL):
|
21 |
+
""" Extension of the official SMPL implementation to support more joints """
|
22 |
+
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
|
26 |
+
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
|
27 |
+
self.register_buffer(
|
28 |
+
'J_regressor_extra',
|
29 |
+
torch.tensor(J_regressor_extra, dtype=torch.float32))
|
30 |
+
self.joint_map = torch.tensor(joints, dtype=torch.long)
|
31 |
+
self.ModelOutput = namedtuple(
|
32 |
+
'ModelOutput_', ModelOutput._fields + (
|
33 |
+
'smpl_joints',
|
34 |
+
'joints_J19',
|
35 |
+
))
|
36 |
+
self.ModelOutput.__new__.__defaults__ = (None, ) * len(
|
37 |
+
self.ModelOutput._fields)
|
38 |
+
|
39 |
+
def forward(self, *args, **kwargs):
|
40 |
+
kwargs['get_skin'] = True
|
41 |
+
smpl_output = super().forward(*args, **kwargs)
|
42 |
+
extra_joints = vertices2joints(self.J_regressor_extra,
|
43 |
+
smpl_output.vertices)
|
44 |
+
# smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3]
|
45 |
+
vertices = smpl_output.vertices
|
46 |
+
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
|
47 |
+
smpl_joints = smpl_output.joints[:, :24]
|
48 |
+
joints = joints[:, self.joint_map, :] # [B, 49, 3]
|
49 |
+
joints_J24 = joints[:, -24:, :]
|
50 |
+
joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
|
51 |
+
output = self.ModelOutput(vertices=vertices,
|
52 |
+
global_orient=smpl_output.global_orient,
|
53 |
+
body_pose=smpl_output.body_pose,
|
54 |
+
joints=joints,
|
55 |
+
joints_J19=joints_J19,
|
56 |
+
smpl_joints=smpl_joints,
|
57 |
+
betas=smpl_output.betas,
|
58 |
+
full_pose=smpl_output.full_pose)
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
def get_smpl_faces():
|
63 |
+
smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
|
64 |
+
return smpl.faces
|
65 |
+
|
66 |
+
|
67 |
+
def get_part_joints(smpl_joints):
|
68 |
+
batch_size = smpl_joints.shape[0]
|
69 |
+
|
70 |
+
# part_joints = torch.zeros().to(smpl_joints.device)
|
71 |
+
|
72 |
+
one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14),
|
73 |
+
(12, 15), (13, 16), (14, 17)]
|
74 |
+
two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19),
|
75 |
+
(18, 20), (19, 21)]
|
76 |
+
|
77 |
+
one_seg_pairs.extend(two_seg_pairs)
|
78 |
+
|
79 |
+
single_joints = [(10), (11), (15), (22), (23)]
|
80 |
+
|
81 |
+
part_joints = []
|
82 |
+
|
83 |
+
for j_p in one_seg_pairs:
|
84 |
+
new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True)
|
85 |
+
part_joints.append(new_joint)
|
86 |
+
|
87 |
+
for j_p in single_joints:
|
88 |
+
part_joints.append(smpl_joints[:, j_p:j_p + 1])
|
89 |
+
|
90 |
+
part_joints = torch.cat(part_joints, dim=1)
|
91 |
+
|
92 |
+
return part_joints
|
lib / pymaf /models /hmr.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/hmr.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchvision.models.resnet as resnet
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
from lib.pymaf.utils.geometry import rot6d_to_rotmat
|
9 |
+
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
BN_MOMENTUM = 0.1
|
15 |
+
|
16 |
+
|
17 |
+
class Bottleneck(nn.Module):
|
18 |
+
""" Redefinition of Bottleneck residual block
|
19 |
+
Adapted from the official PyTorch implementation
|
20 |
+
"""
|
21 |
+
expansion = 4
|
22 |
+
|
23 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
24 |
+
super().__init__()
|
25 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
26 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
27 |
+
self.conv2 = nn.Conv2d(planes,
|
28 |
+
planes,
|
29 |
+
kernel_size=3,
|
30 |
+
stride=stride,
|
31 |
+
padding=1,
|
32 |
+
bias=False)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
34 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
35 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
36 |
+
self.relu = nn.ReLU(inplace=True)
|
37 |
+
self.downsample = downsample
|
38 |
+
self.stride = stride
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
residual = x
|
42 |
+
|
43 |
+
out = self.conv1(x)
|
44 |
+
out = self.bn1(out)
|
45 |
+
out = self.relu(out)
|
46 |
+
|
47 |
+
out = self.conv2(out)
|
48 |
+
out = self.bn2(out)
|
49 |
+
out = self.relu(out)
|
50 |
+
|
51 |
+
out = self.conv3(out)
|
52 |
+
out = self.bn3(out)
|
53 |
+
|
54 |
+
if self.downsample is not None:
|
55 |
+
residual = self.downsample(x)
|
56 |
+
|
57 |
+
out += residual
|
58 |
+
out = self.relu(out)
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class ResNet_Backbone(nn.Module):
|
64 |
+
""" Feature Extrator with ResNet backbone
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, model='res50', pretrained=True):
|
68 |
+
if model == 'res50':
|
69 |
+
block, layers = Bottleneck, [3, 4, 6, 3]
|
70 |
+
else:
|
71 |
+
pass # TODO
|
72 |
+
|
73 |
+
self.inplanes = 64
|
74 |
+
super().__init__()
|
75 |
+
npose = 24 * 6
|
76 |
+
self.conv1 = nn.Conv2d(3,
|
77 |
+
64,
|
78 |
+
kernel_size=7,
|
79 |
+
stride=2,
|
80 |
+
padding=3,
|
81 |
+
bias=False)
|
82 |
+
self.bn1 = nn.BatchNorm2d(64)
|
83 |
+
self.relu = nn.ReLU(inplace=True)
|
84 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
85 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
86 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
87 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
88 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
89 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
90 |
+
|
91 |
+
if pretrained:
|
92 |
+
resnet_imagenet = resnet.resnet50(pretrained=True)
|
93 |
+
self.load_state_dict(resnet_imagenet.state_dict(), strict=False)
|
94 |
+
logger.info('loaded resnet50 imagenet pretrained model')
|
95 |
+
|
96 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
97 |
+
downsample = None
|
98 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
99 |
+
downsample = nn.Sequential(
|
100 |
+
nn.Conv2d(self.inplanes,
|
101 |
+
planes * block.expansion,
|
102 |
+
kernel_size=1,
|
103 |
+
stride=stride,
|
104 |
+
bias=False),
|
105 |
+
nn.BatchNorm2d(planes * block.expansion),
|
106 |
+
)
|
107 |
+
|
108 |
+
layers = []
|
109 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
110 |
+
self.inplanes = planes * block.expansion
|
111 |
+
for i in range(1, blocks):
|
112 |
+
layers.append(block(self.inplanes, planes))
|
113 |
+
|
114 |
+
return nn.Sequential(*layers)
|
115 |
+
|
116 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
117 |
+
assert num_layers == len(num_filters), \
|
118 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
119 |
+
assert num_layers == len(num_kernels), \
|
120 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
121 |
+
|
122 |
+
def _get_deconv_cfg(deconv_kernel, index):
|
123 |
+
if deconv_kernel == 4:
|
124 |
+
padding = 1
|
125 |
+
output_padding = 0
|
126 |
+
elif deconv_kernel == 3:
|
127 |
+
padding = 1
|
128 |
+
output_padding = 1
|
129 |
+
elif deconv_kernel == 2:
|
130 |
+
padding = 0
|
131 |
+
output_padding = 0
|
132 |
+
|
133 |
+
return deconv_kernel, padding, output_padding
|
134 |
+
|
135 |
+
layers = []
|
136 |
+
for i in range(num_layers):
|
137 |
+
kernel, padding, output_padding = _get_deconv_cfg(
|
138 |
+
num_kernels[i], i)
|
139 |
+
|
140 |
+
planes = num_filters[i]
|
141 |
+
layers.append(
|
142 |
+
nn.ConvTranspose2d(in_channels=self.inplanes,
|
143 |
+
out_channels=planes,
|
144 |
+
kernel_size=kernel,
|
145 |
+
stride=2,
|
146 |
+
padding=padding,
|
147 |
+
output_padding=output_padding,
|
148 |
+
bias=self.deconv_with_bias))
|
149 |
+
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
150 |
+
layers.append(nn.ReLU(inplace=True))
|
151 |
+
self.inplanes = planes
|
152 |
+
|
153 |
+
return nn.Sequential(*layers)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
|
157 |
+
batch_size = x.shape[0]
|
158 |
+
|
159 |
+
x = self.conv1(x)
|
160 |
+
x = self.bn1(x)
|
161 |
+
x = self.relu(x)
|
162 |
+
x = self.maxpool(x)
|
163 |
+
|
164 |
+
x1 = self.layer1(x)
|
165 |
+
x2 = self.layer2(x1)
|
166 |
+
x3 = self.layer3(x2)
|
167 |
+
x4 = self.layer4(x3)
|
168 |
+
|
169 |
+
xf = self.avgpool(x4)
|
170 |
+
xf = xf.view(xf.size(0), -1)
|
171 |
+
|
172 |
+
x_featmap = x4
|
173 |
+
|
174 |
+
return x_featmap, xf
|
175 |
+
|
176 |
+
|
177 |
+
class HMR(nn.Module):
|
178 |
+
""" SMPL Iterative Regressor with ResNet50 backbone
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, block, layers, smpl_mean_params):
|
182 |
+
self.inplanes = 64
|
183 |
+
super().__init__()
|
184 |
+
npose = 24 * 6
|
185 |
+
self.conv1 = nn.Conv2d(3,
|
186 |
+
64,
|
187 |
+
kernel_size=7,
|
188 |
+
stride=2,
|
189 |
+
padding=3,
|
190 |
+
bias=False)
|
191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
192 |
+
self.relu = nn.ReLU(inplace=True)
|
193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
198 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
199 |
+
self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024)
|
200 |
+
self.drop1 = nn.Dropout()
|
201 |
+
self.fc2 = nn.Linear(1024, 1024)
|
202 |
+
self.drop2 = nn.Dropout()
|
203 |
+
self.decpose = nn.Linear(1024, npose)
|
204 |
+
self.decshape = nn.Linear(1024, 10)
|
205 |
+
self.deccam = nn.Linear(1024, 3)
|
206 |
+
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
|
207 |
+
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
|
208 |
+
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
|
209 |
+
|
210 |
+
for m in self.modules():
|
211 |
+
if isinstance(m, nn.Conv2d):
|
212 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
213 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
214 |
+
elif isinstance(m, nn.BatchNorm2d):
|
215 |
+
m.weight.data.fill_(1)
|
216 |
+
m.bias.data.zero_()
|
217 |
+
|
218 |
+
mean_params = np.load(smpl_mean_params)
|
219 |
+
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
|
220 |
+
init_shape = torch.from_numpy(
|
221 |
+
mean_params['shape'][:].astype('float32')).unsqueeze(0)
|
222 |
+
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
|
223 |
+
self.register_buffer('init_pose', init_pose)
|
224 |
+
self.register_buffer('init_shape', init_shape)
|
225 |
+
self.register_buffer('init_cam', init_cam)
|
226 |
+
|
227 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
228 |
+
downsample = None
|
229 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
230 |
+
downsample = nn.Sequential(
|
231 |
+
nn.Conv2d(self.inplanes,
|
232 |
+
planes * block.expansion,
|
233 |
+
kernel_size=1,
|
234 |
+
stride=stride,
|
235 |
+
bias=False),
|
236 |
+
nn.BatchNorm2d(planes * block.expansion),
|
237 |
+
)
|
238 |
+
|
239 |
+
layers = []
|
240 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
241 |
+
self.inplanes = planes * block.expansion
|
242 |
+
for i in range(1, blocks):
|
243 |
+
layers.append(block(self.inplanes, planes))
|
244 |
+
|
245 |
+
return nn.Sequential(*layers)
|
246 |
+
|
247 |
+
def forward(self,
|
248 |
+
x,
|
249 |
+
init_pose=None,
|
250 |
+
init_shape=None,
|
251 |
+
init_cam=None,
|
252 |
+
n_iter=3):
|
253 |
+
|
254 |
+
batch_size = x.shape[0]
|
255 |
+
|
256 |
+
if init_pose is None:
|
257 |
+
init_pose = self.init_pose.expand(batch_size, -1)
|
258 |
+
if init_shape is None:
|
259 |
+
init_shape = self.init_shape.expand(batch_size, -1)
|
260 |
+
if init_cam is None:
|
261 |
+
init_cam = self.init_cam.expand(batch_size, -1)
|
262 |
+
|
263 |
+
x = self.conv1(x)
|
264 |
+
x = self.bn1(x)
|
265 |
+
x = self.relu(x)
|
266 |
+
x = self.maxpool(x)
|
267 |
+
|
268 |
+
x1 = self.layer1(x)
|
269 |
+
x2 = self.layer2(x1)
|
270 |
+
x3 = self.layer3(x2)
|
271 |
+
x4 = self.layer4(x3)
|
272 |
+
|
273 |
+
xf = self.avgpool(x4)
|
274 |
+
xf = xf.view(xf.size(0), -1)
|
275 |
+
|
276 |
+
pred_pose = init_pose
|
277 |
+
pred_shape = init_shape
|
278 |
+
pred_cam = init_cam
|
279 |
+
for i in range(n_iter):
|
280 |
+
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
|
281 |
+
xc = self.fc1(xc)
|
282 |
+
xc = self.drop1(xc)
|
283 |
+
xc = self.fc2(xc)
|
284 |
+
xc = self.drop2(xc)
|
285 |
+
pred_pose = self.decpose(xc) + pred_pose
|
286 |
+
pred_shape = self.decshape(xc) + pred_shape
|
287 |
+
pred_cam = self.deccam(xc) + pred_cam
|
288 |
+
|
289 |
+
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
|
290 |
+
|
291 |
+
return pred_rotmat, pred_shape, pred_cam
|
292 |
+
|
293 |
+
|
294 |
+
def hmr(smpl_mean_params, pretrained=True, **kwargs):
|
295 |
+
""" Constructs an HMR model with ResNet50 backbone.
|
296 |
+
Args:
|
297 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
298 |
+
"""
|
299 |
+
model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
|
300 |
+
if pretrained:
|
301 |
+
resnet_imagenet = resnet.resnet50(pretrained=True)
|
302 |
+
model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
|
303 |
+
return model
|
lib / pymaf /models /maf_extractor.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py
|
2 |
+
|
3 |
+
from packaging import version
|
4 |
+
import torch
|
5 |
+
import scipy
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from lib.common.config import cfg
|
11 |
+
from lib.pymaf.utils.geometry import projection
|
12 |
+
from lib.pymaf.core.path_config import MESH_DOWNSAMPLEING
|
13 |
+
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class MAF_Extractor(nn.Module):
|
20 |
+
''' Mesh-aligned Feature Extrator
|
21 |
+
As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices.
|
22 |
+
The features extrated from spatial feature maps will go through a MLP for dimension reduction.
|
23 |
+
'''
|
24 |
+
|
25 |
+
def __init__(self, device=torch.device('cuda')):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.device = device
|
29 |
+
self.filters = []
|
30 |
+
self.num_views = 1
|
31 |
+
filter_channels = cfg.MODEL.PyMAF.MLP_DIM
|
32 |
+
self.last_op = nn.ReLU(True)
|
33 |
+
|
34 |
+
for l in range(0, len(filter_channels) - 1):
|
35 |
+
if 0 != l:
|
36 |
+
self.filters.append(
|
37 |
+
nn.Conv1d(filter_channels[l] + filter_channels[0],
|
38 |
+
filter_channels[l + 1], 1))
|
39 |
+
else:
|
40 |
+
self.filters.append(
|
41 |
+
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
|
42 |
+
|
43 |
+
self.add_module("conv%d" % l, self.filters[l])
|
44 |
+
|
45 |
+
self.im_feat = None
|
46 |
+
self.cam = None
|
47 |
+
|
48 |
+
# downsample SMPL mesh and assign part labels
|
49 |
+
# from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
|
50 |
+
smpl_mesh_graph = np.load(MESH_DOWNSAMPLEING,
|
51 |
+
allow_pickle=True,
|
52 |
+
encoding='latin1')
|
53 |
+
|
54 |
+
A = smpl_mesh_graph['A']
|
55 |
+
U = smpl_mesh_graph['U']
|
56 |
+
D = smpl_mesh_graph['D'] # shape: (2,)
|
57 |
+
|
58 |
+
# downsampling
|
59 |
+
ptD = []
|
60 |
+
for i in range(len(D)):
|
61 |
+
d = scipy.sparse.coo_matrix(D[i])
|
62 |
+
i = torch.LongTensor(np.array([d.row, d.col]))
|
63 |
+
v = torch.FloatTensor(d.data)
|
64 |
+
ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
|
65 |
+
|
66 |
+
# downsampling mapping from 6890 points to 431 points
|
67 |
+
# ptD[0].to_dense() - Size: [1723, 6890]
|
68 |
+
# ptD[1].to_dense() - Size: [431. 1723]
|
69 |
+
Dmap = torch.matmul(ptD[1].to_dense(),
|
70 |
+
ptD[0].to_dense()) # 6890 -> 431
|
71 |
+
self.register_buffer('Dmap', Dmap)
|
72 |
+
|
73 |
+
def reduce_dim(self, feature):
|
74 |
+
'''
|
75 |
+
Dimension reduction by multi-layer perceptrons
|
76 |
+
:param feature: list of [B, C_s, N] point-wise features before dimension reduction
|
77 |
+
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction
|
78 |
+
'''
|
79 |
+
y = feature
|
80 |
+
tmpy = feature
|
81 |
+
for i, f in enumerate(self.filters):
|
82 |
+
y = self._modules['conv' +
|
83 |
+
str(i)](y if i == 0 else torch.cat([y, tmpy], 1))
|
84 |
+
if i != len(self.filters) - 1:
|
85 |
+
y = F.leaky_relu(y)
|
86 |
+
if self.num_views > 1 and i == len(self.filters) // 2:
|
87 |
+
y = y.view(-1, self.num_views, y.shape[1],
|
88 |
+
y.shape[2]).mean(dim=1)
|
89 |
+
tmpy = feature.view(-1, self.num_views, feature.shape[1],
|
90 |
+
feature.shape[2]).mean(dim=1)
|
91 |
+
|
92 |
+
y = self.last_op(y)
|
93 |
+
|
94 |
+
y = y.view(y.shape[0], -1)
|
95 |
+
return y
|
96 |
+
|
97 |
+
def sampling(self, points, im_feat=None, z_feat=None):
|
98 |
+
'''
|
99 |
+
Given 2D points, sample the point-wise features for each point,
|
100 |
+
the dimension of point-wise features will be reduced from C_s to C_p by MLP.
|
101 |
+
Image features should be pre-computed before this call.
|
102 |
+
:param points: [B, N, 2] image coordinates of points
|
103 |
+
:im_feat: [B, C_s, H_s, W_s] spatial feature maps
|
104 |
+
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction
|
105 |
+
'''
|
106 |
+
if im_feat is None:
|
107 |
+
im_feat = self.im_feat
|
108 |
+
|
109 |
+
batch_size = im_feat.shape[0]
|
110 |
+
|
111 |
+
if version.parse(torch.__version__) >= version.parse('1.3.0'):
|
112 |
+
# Default grid_sample behavior has changed to align_corners=False since 1.3.0.
|
113 |
+
point_feat = torch.nn.functional.grid_sample(
|
114 |
+
im_feat, points.unsqueeze(2), align_corners=True)[..., 0]
|
115 |
+
else:
|
116 |
+
point_feat = torch.nn.functional.grid_sample(
|
117 |
+
im_feat, points.unsqueeze(2))[..., 0]
|
118 |
+
|
119 |
+
mesh_align_feat = self.reduce_dim(point_feat)
|
120 |
+
return mesh_align_feat
|
121 |
+
|
122 |
+
def forward(self, p, s_feat=None, cam=None, **kwargs):
|
123 |
+
''' Returns mesh-aligned features for the 3D mesh points.
|
124 |
+
Args:
|
125 |
+
p (tensor): [B, N_m, 3] mesh vertices
|
126 |
+
s_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps
|
127 |
+
cam (tensor): [B, 3] camera
|
128 |
+
Return:
|
129 |
+
mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features
|
130 |
+
'''
|
131 |
+
if cam is None:
|
132 |
+
cam = self.cam
|
133 |
+
p_proj_2d = projection(p, cam, retain_z=False)
|
134 |
+
mesh_align_feat = self.sampling(p_proj_2d, s_feat)
|
135 |
+
return mesh_align_feat
|
lib / pymaf /models /res_module.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py
|
2 |
+
|
3 |
+
from __future__ import absolute_import
|
4 |
+
from __future__ import division
|
5 |
+
from __future__ import print_function
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from collections import OrderedDict
|
11 |
+
import os
|
12 |
+
from lib.pymaf.core.cfgs import cfg
|
13 |
+
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
BN_MOMENTUM = 0.1
|
19 |
+
|
20 |
+
|
21 |
+
def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1):
|
22 |
+
"""3x3 convolution with padding"""
|
23 |
+
return nn.Conv2d(in_planes * groups,
|
24 |
+
out_planes * groups,
|
25 |
+
kernel_size=3,
|
26 |
+
stride=stride,
|
27 |
+
padding=1,
|
28 |
+
bias=bias,
|
29 |
+
groups=groups)
|
30 |
+
|
31 |
+
|
32 |
+
class BasicBlock(nn.Module):
|
33 |
+
expansion = 1
|
34 |
+
|
35 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
|
36 |
+
super().__init__()
|
37 |
+
self.conv1 = conv3x3(inplanes, planes, stride, groups=groups)
|
38 |
+
self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
|
39 |
+
self.relu = nn.ReLU(inplace=True)
|
40 |
+
self.conv2 = conv3x3(planes, planes, groups=groups)
|
41 |
+
self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
|
42 |
+
self.downsample = downsample
|
43 |
+
self.stride = stride
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
residual = x
|
47 |
+
|
48 |
+
out = self.conv1(x)
|
49 |
+
out = self.bn1(out)
|
50 |
+
out = self.relu(out)
|
51 |
+
|
52 |
+
out = self.conv2(out)
|
53 |
+
out = self.bn2(out)
|
54 |
+
|
55 |
+
if self.downsample is not None:
|
56 |
+
residual = self.downsample(x)
|
57 |
+
|
58 |
+
out += residual
|
59 |
+
out = self.relu(out)
|
60 |
+
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class Bottleneck(nn.Module):
|
65 |
+
expansion = 4
|
66 |
+
|
67 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
|
68 |
+
super().__init__()
|
69 |
+
self.conv1 = nn.Conv2d(inplanes * groups,
|
70 |
+
planes * groups,
|
71 |
+
kernel_size=1,
|
72 |
+
bias=False,
|
73 |
+
groups=groups)
|
74 |
+
self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
|
75 |
+
self.conv2 = nn.Conv2d(planes * groups,
|
76 |
+
planes * groups,
|
77 |
+
kernel_size=3,
|
78 |
+
stride=stride,
|
79 |
+
padding=1,
|
80 |
+
bias=False,
|
81 |
+
groups=groups)
|
82 |
+
self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
|
83 |
+
self.conv3 = nn.Conv2d(planes * groups,
|
84 |
+
planes * self.expansion * groups,
|
85 |
+
kernel_size=1,
|
86 |
+
bias=False,
|
87 |
+
groups=groups)
|
88 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups,
|
89 |
+
momentum=BN_MOMENTUM)
|
90 |
+
self.relu = nn.ReLU(inplace=True)
|
91 |
+
self.downsample = downsample
|
92 |
+
self.stride = stride
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
residual = x
|
96 |
+
|
97 |
+
out = self.conv1(x)
|
98 |
+
out = self.bn1(out)
|
99 |
+
out = self.relu(out)
|
100 |
+
|
101 |
+
out = self.conv2(out)
|
102 |
+
out = self.bn2(out)
|
103 |
+
out = self.relu(out)
|
104 |
+
|
105 |
+
out = self.conv3(out)
|
106 |
+
out = self.bn3(out)
|
107 |
+
|
108 |
+
if self.downsample is not None:
|
109 |
+
residual = self.downsample(x)
|
110 |
+
|
111 |
+
out += residual
|
112 |
+
out = self.relu(out)
|
113 |
+
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
resnet_spec = {
|
118 |
+
18: (BasicBlock, [2, 2, 2, 2]),
|
119 |
+
34: (BasicBlock, [3, 4, 6, 3]),
|
120 |
+
50: (Bottleneck, [3, 4, 6, 3]),
|
121 |
+
101: (Bottleneck, [3, 4, 23, 3]),
|
122 |
+
152: (Bottleneck, [3, 8, 36, 3])
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
class IUV_predict_layer(nn.Module):
|
127 |
+
def __init__(self,
|
128 |
+
feat_dim=256,
|
129 |
+
final_cov_k=3,
|
130 |
+
part_out_dim=25,
|
131 |
+
with_uv=True):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.with_uv = with_uv
|
135 |
+
if self.with_uv:
|
136 |
+
self.predict_u = nn.Conv2d(in_channels=feat_dim,
|
137 |
+
out_channels=25,
|
138 |
+
kernel_size=final_cov_k,
|
139 |
+
stride=1,
|
140 |
+
padding=1 if final_cov_k == 3 else 0)
|
141 |
+
|
142 |
+
self.predict_v = nn.Conv2d(in_channels=feat_dim,
|
143 |
+
out_channels=25,
|
144 |
+
kernel_size=final_cov_k,
|
145 |
+
stride=1,
|
146 |
+
padding=1 if final_cov_k == 3 else 0)
|
147 |
+
|
148 |
+
self.predict_ann_index = nn.Conv2d(
|
149 |
+
in_channels=feat_dim,
|
150 |
+
out_channels=15,
|
151 |
+
kernel_size=final_cov_k,
|
152 |
+
stride=1,
|
153 |
+
padding=1 if final_cov_k == 3 else 0)
|
154 |
+
|
155 |
+
self.predict_uv_index = nn.Conv2d(in_channels=feat_dim,
|
156 |
+
out_channels=25,
|
157 |
+
kernel_size=final_cov_k,
|
158 |
+
stride=1,
|
159 |
+
padding=1 if final_cov_k == 3 else 0)
|
160 |
+
|
161 |
+
self.inplanes = feat_dim
|
162 |
+
|
163 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
164 |
+
downsample = None
|
165 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
166 |
+
downsample = nn.Sequential(
|
167 |
+
nn.Conv2d(self.inplanes,
|
168 |
+
planes * block.expansion,
|
169 |
+
kernel_size=1,
|
170 |
+
stride=stride,
|
171 |
+
bias=False),
|
172 |
+
nn.BatchNorm2d(planes * block.expansion),
|
173 |
+
)
|
174 |
+
|
175 |
+
layers = []
|
176 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
177 |
+
self.inplanes = planes * block.expansion
|
178 |
+
for i in range(1, blocks):
|
179 |
+
layers.append(block(self.inplanes, planes))
|
180 |
+
|
181 |
+
return nn.Sequential(*layers)
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
return_dict = {}
|
185 |
+
|
186 |
+
predict_uv_index = self.predict_uv_index(x)
|
187 |
+
predict_ann_index = self.predict_ann_index(x)
|
188 |
+
|
189 |
+
return_dict['predict_uv_index'] = predict_uv_index
|
190 |
+
return_dict['predict_ann_index'] = predict_ann_index
|
191 |
+
|
192 |
+
if self.with_uv:
|
193 |
+
predict_u = self.predict_u(x)
|
194 |
+
predict_v = self.predict_v(x)
|
195 |
+
return_dict['predict_u'] = predict_u
|
196 |
+
return_dict['predict_v'] = predict_v
|
197 |
+
else:
|
198 |
+
return_dict['predict_u'] = None
|
199 |
+
return_dict['predict_v'] = None
|
200 |
+
# return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
|
201 |
+
# return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
|
202 |
+
|
203 |
+
return return_dict
|
204 |
+
|
205 |
+
|
206 |
+
class SmplResNet(nn.Module):
|
207 |
+
def __init__(self,
|
208 |
+
resnet_nums,
|
209 |
+
in_channels=3,
|
210 |
+
num_classes=229,
|
211 |
+
last_stride=2,
|
212 |
+
n_extra_feat=0,
|
213 |
+
truncate=0,
|
214 |
+
**kwargs):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
self.inplanes = 64
|
218 |
+
self.truncate = truncate
|
219 |
+
# extra = cfg.MODEL.EXTRA
|
220 |
+
# self.deconv_with_bias = extra.DECONV_WITH_BIAS
|
221 |
+
block, layers = resnet_spec[resnet_nums]
|
222 |
+
|
223 |
+
self.conv1 = nn.Conv2d(in_channels,
|
224 |
+
64,
|
225 |
+
kernel_size=7,
|
226 |
+
stride=2,
|
227 |
+
padding=3,
|
228 |
+
bias=False)
|
229 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
230 |
+
self.relu = nn.ReLU(inplace=True)
|
231 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
232 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
233 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
234 |
+
self.layer3 = self._make_layer(block, 256, layers[2],
|
235 |
+
stride=2) if truncate < 2 else None
|
236 |
+
self.layer4 = self._make_layer(
|
237 |
+
block, 512, layers[3],
|
238 |
+
stride=last_stride) if truncate < 1 else None
|
239 |
+
|
240 |
+
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
|
241 |
+
|
242 |
+
self.num_classes = num_classes
|
243 |
+
if num_classes > 0:
|
244 |
+
self.final_layer = nn.Linear(512 * block.expansion, num_classes)
|
245 |
+
nn.init.xavier_uniform_(self.final_layer.weight, gain=0.01)
|
246 |
+
|
247 |
+
self.n_extra_feat = n_extra_feat
|
248 |
+
if n_extra_feat > 0:
|
249 |
+
self.trans_conv = nn.Sequential(
|
250 |
+
nn.Conv2d(n_extra_feat + 512 * block.expansion,
|
251 |
+
512 * block.expansion,
|
252 |
+
kernel_size=1,
|
253 |
+
bias=False),
|
254 |
+
nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM),
|
255 |
+
nn.ReLU(True))
|
256 |
+
|
257 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
258 |
+
downsample = None
|
259 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
260 |
+
downsample = nn.Sequential(
|
261 |
+
nn.Conv2d(self.inplanes,
|
262 |
+
planes * block.expansion,
|
263 |
+
kernel_size=1,
|
264 |
+
stride=stride,
|
265 |
+
bias=False),
|
266 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
267 |
+
)
|
268 |
+
|
269 |
+
layers = []
|
270 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
271 |
+
self.inplanes = planes * block.expansion
|
272 |
+
for i in range(1, blocks):
|
273 |
+
layers.append(block(self.inplanes, planes))
|
274 |
+
|
275 |
+
return nn.Sequential(*layers)
|
276 |
+
|
277 |
+
def forward(self, x, infeat=None):
|
278 |
+
x = self.conv1(x)
|
279 |
+
x = self.bn1(x)
|
280 |
+
x = self.relu(x)
|
281 |
+
x = self.maxpool(x)
|
282 |
+
|
283 |
+
x1 = self.layer1(x)
|
284 |
+
x2 = self.layer2(x1)
|
285 |
+
x3 = self.layer3(x2) if self.truncate < 2 else x2
|
286 |
+
x4 = self.layer4(x3) if self.truncate < 1 else x3
|
287 |
+
|
288 |
+
if infeat is not None:
|
289 |
+
x4 = self.trans_conv(torch.cat([infeat, x4], 1))
|
290 |
+
|
291 |
+
if self.num_classes > 0:
|
292 |
+
xp = self.avg_pooling(x4)
|
293 |
+
cls = self.final_layer(xp.view(xp.size(0), -1))
|
294 |
+
if not cfg.DANET.USE_MEAN_PARA:
|
295 |
+
# for non-negative scale
|
296 |
+
scale = F.relu(cls[:, 0]).unsqueeze(1)
|
297 |
+
cls = torch.cat((scale, cls[:, 1:]), dim=1)
|
298 |
+
else:
|
299 |
+
cls = None
|
300 |
+
|
301 |
+
return cls, {'x4': x4}
|
302 |
+
|
303 |
+
def init_weights(self, pretrained=''):
|
304 |
+
if os.path.isfile(pretrained):
|
305 |
+
logger.info('=> loading pretrained model {}'.format(pretrained))
|
306 |
+
# self.load_state_dict(pretrained_state_dict, strict=False)
|
307 |
+
checkpoint = torch.load(pretrained)
|
308 |
+
if isinstance(checkpoint, OrderedDict):
|
309 |
+
# state_dict = checkpoint
|
310 |
+
state_dict_old = self.state_dict()
|
311 |
+
for key in state_dict_old.keys():
|
312 |
+
if key in checkpoint.keys():
|
313 |
+
if state_dict_old[key].shape != checkpoint[key].shape:
|
314 |
+
del checkpoint[key]
|
315 |
+
state_dict = checkpoint
|
316 |
+
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
317 |
+
state_dict_old = checkpoint['state_dict']
|
318 |
+
state_dict = OrderedDict()
|
319 |
+
# delete 'module.' because it is saved from DataParallel module
|
320 |
+
for key in state_dict_old.keys():
|
321 |
+
if key.startswith('module.'):
|
322 |
+
# state_dict[key[7:]] = state_dict[key]
|
323 |
+
# state_dict.pop(key)
|
324 |
+
state_dict[key[7:]] = state_dict_old[key]
|
325 |
+
else:
|
326 |
+
state_dict[key] = state_dict_old[key]
|
327 |
+
else:
|
328 |
+
raise RuntimeError(
|
329 |
+
'No state_dict found in checkpoint file {}'.format(
|
330 |
+
pretrained))
|
331 |
+
self.load_state_dict(state_dict, strict=False)
|
332 |
+
else:
|
333 |
+
logger.error('=> imagenet pretrained model dose not exist')
|
334 |
+
logger.error('=> please download it first')
|
335 |
+
raise ValueError('imagenet pretrained model does not exist')
|
336 |
+
|
337 |
+
|
338 |
+
class LimbResLayers(nn.Module):
|
339 |
+
def __init__(self,
|
340 |
+
resnet_nums,
|
341 |
+
inplanes,
|
342 |
+
outplanes=None,
|
343 |
+
groups=1,
|
344 |
+
**kwargs):
|
345 |
+
super().__init__()
|
346 |
+
|
347 |
+
self.inplanes = inplanes
|
348 |
+
block, layers = resnet_spec[resnet_nums]
|
349 |
+
self.outplanes = 512 if outplanes == None else outplanes
|
350 |
+
self.layer4 = self._make_layer(block,
|
351 |
+
self.outplanes,
|
352 |
+
layers[3],
|
353 |
+
stride=2,
|
354 |
+
groups=groups)
|
355 |
+
|
356 |
+
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
|
357 |
+
|
358 |
+
def _make_layer(self, block, planes, blocks, stride=1, groups=1):
|
359 |
+
downsample = None
|
360 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
361 |
+
downsample = nn.Sequential(
|
362 |
+
nn.Conv2d(self.inplanes * groups,
|
363 |
+
planes * block.expansion * groups,
|
364 |
+
kernel_size=1,
|
365 |
+
stride=stride,
|
366 |
+
bias=False,
|
367 |
+
groups=groups),
|
368 |
+
nn.BatchNorm2d(planes * block.expansion * groups,
|
369 |
+
momentum=BN_MOMENTUM),
|
370 |
+
)
|
371 |
+
|
372 |
+
layers = []
|
373 |
+
layers.append(
|
374 |
+
block(self.inplanes, planes, stride, downsample, groups=groups))
|
375 |
+
self.inplanes = planes * block.expansion
|
376 |
+
for i in range(1, blocks):
|
377 |
+
layers.append(block(self.inplanes, planes, groups=groups))
|
378 |
+
|
379 |
+
return nn.Sequential(*layers)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
x = self.layer4(x)
|
383 |
+
x = self.avg_pooling(x)
|
384 |
+
|
385 |
+
return x
|
lib / pymaf /utils / __init__.py
ADDED
File without changes
|
lib / pymaf /utils / geometry.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.nn import functional as F
|
4 |
+
"""
|
5 |
+
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
|
6 |
+
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
def batch_rodrigues(theta):
|
11 |
+
"""Convert axis-angle representation to rotation matrix.
|
12 |
+
Args:
|
13 |
+
theta: size = [B, 3]
|
14 |
+
Returns:
|
15 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
16 |
+
"""
|
17 |
+
l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
18 |
+
angle = torch.unsqueeze(l1norm, -1)
|
19 |
+
normalized = torch.div(theta, angle)
|
20 |
+
angle = angle * 0.5
|
21 |
+
v_cos = torch.cos(angle)
|
22 |
+
v_sin = torch.sin(angle)
|
23 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
24 |
+
return quat_to_rotmat(quat)
|
25 |
+
|
26 |
+
|
27 |
+
def quat_to_rotmat(quat):
|
28 |
+
"""Convert quaternion coefficients to rotation matrix.
|
29 |
+
Args:
|
30 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
31 |
+
Returns:
|
32 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
33 |
+
"""
|
34 |
+
norm_quat = quat
|
35 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
36 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
|
37 |
+
2], norm_quat[:,
|
38 |
+
3]
|
39 |
+
|
40 |
+
B = quat.size(0)
|
41 |
+
|
42 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
43 |
+
wx, wy, wz = w * x, w * y, w * z
|
44 |
+
xy, xz, yz = x * y, x * z, y * z
|
45 |
+
|
46 |
+
rotMat = torch.stack([
|
47 |
+
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
|
48 |
+
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
|
49 |
+
w2 - x2 - y2 + z2
|
50 |
+
],
|
51 |
+
dim=1).view(B, 3, 3)
|
52 |
+
return rotMat
|
53 |
+
|
54 |
+
|
55 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
56 |
+
"""
|
57 |
+
This function is borrowed from https://github.com/kornia/kornia
|
58 |
+
Convert 3x4 rotation matrix to Rodrigues vector
|
59 |
+
Args:
|
60 |
+
rotation_matrix (Tensor): rotation matrix.
|
61 |
+
Returns:
|
62 |
+
Tensor: Rodrigues vector transformation.
|
63 |
+
Shape:
|
64 |
+
- Input: :math:`(N, 3, 4)`
|
65 |
+
- Output: :math:`(N, 3)`
|
66 |
+
Example:
|
67 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
68 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
69 |
+
"""
|
70 |
+
if rotation_matrix.shape[1:] == (3, 3):
|
71 |
+
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
72 |
+
hom = torch.tensor([0, 0, 1],
|
73 |
+
dtype=torch.float32,
|
74 |
+
device=rotation_matrix.device).reshape(
|
75 |
+
1, 3, 1).expand(rot_mat.shape[0], -1, -1)
|
76 |
+
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
77 |
+
|
78 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
79 |
+
aa = quaternion_to_angle_axis(quaternion)
|
80 |
+
aa[torch.isnan(aa)] = 0.0
|
81 |
+
return aa
|
82 |
+
|
83 |
+
|
84 |
+
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
85 |
+
"""
|
86 |
+
This function is borrowed from https://github.com/kornia/kornia
|
87 |
+
Convert quaternion vector to angle axis of rotation.
|
88 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
89 |
+
Args:
|
90 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
91 |
+
Return:
|
92 |
+
torch.Tensor: tensor with angle axis of rotation.
|
93 |
+
Shape:
|
94 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
95 |
+
- Output: :math:`(*, 3)`
|
96 |
+
Example:
|
97 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
98 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
99 |
+
"""
|
100 |
+
if not torch.is_tensor(quaternion):
|
101 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
102 |
+
type(quaternion)))
|
103 |
+
|
104 |
+
if not quaternion.shape[-1] == 4:
|
105 |
+
raise ValueError(
|
106 |
+
"Input must be a tensor of shape Nx4 or 4. Got {}".format(
|
107 |
+
quaternion.shape))
|
108 |
+
# unpack input and compute conversion
|
109 |
+
q1: torch.Tensor = quaternion[..., 1]
|
110 |
+
q2: torch.Tensor = quaternion[..., 2]
|
111 |
+
q3: torch.Tensor = quaternion[..., 3]
|
112 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
113 |
+
|
114 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
115 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
116 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
117 |
+
cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta),
|
118 |
+
torch.atan2(sin_theta, cos_theta))
|
119 |
+
|
120 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
121 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
122 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
123 |
+
|
124 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
125 |
+
angle_axis[..., 0] += q1 * k
|
126 |
+
angle_axis[..., 1] += q2 * k
|
127 |
+
angle_axis[..., 2] += q3 * k
|
128 |
+
return angle_axis
|
129 |
+
|
130 |
+
|
131 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
132 |
+
"""
|
133 |
+
This function is borrowed from https://github.com/kornia/kornia
|
134 |
+
Convert 3x4 rotation matrix to 4d quaternion vector
|
135 |
+
This algorithm is based on algorithm described in
|
136 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
137 |
+
Args:
|
138 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
139 |
+
Return:
|
140 |
+
Tensor: the rotation in quaternion
|
141 |
+
Shape:
|
142 |
+
- Input: :math:`(N, 3, 4)`
|
143 |
+
- Output: :math:`(N, 4)`
|
144 |
+
Example:
|
145 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
146 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
147 |
+
"""
|
148 |
+
if not torch.is_tensor(rotation_matrix):
|
149 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
150 |
+
type(rotation_matrix)))
|
151 |
+
|
152 |
+
if len(rotation_matrix.shape) > 3:
|
153 |
+
raise ValueError(
|
154 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
155 |
+
rotation_matrix.shape))
|
156 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
157 |
+
raise ValueError(
|
158 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
159 |
+
rotation_matrix.shape))
|
160 |
+
|
161 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
162 |
+
|
163 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
164 |
+
|
165 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
166 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
167 |
+
|
168 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
169 |
+
q0 = torch.stack([
|
170 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0,
|
171 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
|
172 |
+
], -1)
|
173 |
+
t0_rep = t0.repeat(4, 1).t()
|
174 |
+
|
175 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
176 |
+
q1 = torch.stack([
|
177 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
178 |
+
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
|
179 |
+
], -1)
|
180 |
+
t1_rep = t1.repeat(4, 1).t()
|
181 |
+
|
182 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
183 |
+
q2 = torch.stack([
|
184 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
185 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
|
186 |
+
], -1)
|
187 |
+
t2_rep = t2.repeat(4, 1).t()
|
188 |
+
|
189 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
190 |
+
q3 = torch.stack([
|
191 |
+
t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
192 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
|
193 |
+
], -1)
|
194 |
+
t3_rep = t3.repeat(4, 1).t()
|
195 |
+
|
196 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
197 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
198 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
199 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
200 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
201 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
202 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
203 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
204 |
+
|
205 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
206 |
+
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
|
207 |
+
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
|
208 |
+
q *= 0.5
|
209 |
+
return q
|
210 |
+
|
211 |
+
|
212 |
+
def rot6d_to_rotmat(x):
|
213 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
214 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
215 |
+
Input:
|
216 |
+
(B,6) Batch of 6-D rotation representations
|
217 |
+
Output:
|
218 |
+
(B,3,3) Batch of corresponding rotation matrices
|
219 |
+
"""
|
220 |
+
x = x.view(-1, 3, 2)
|
221 |
+
a1 = x[:, :, 0]
|
222 |
+
a2 = x[:, :, 1]
|
223 |
+
b1 = F.normalize(a1)
|
224 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
225 |
+
b3 = torch.cross(b1, b2)
|
226 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
227 |
+
|
228 |
+
|
229 |
+
def projection(pred_joints, pred_camera, retain_z=False):
|
230 |
+
pred_cam_t = torch.stack([
|
231 |
+
pred_camera[:, 1], pred_camera[:, 2], 2 * 5000. /
|
232 |
+
(224. * pred_camera[:, 0] + 1e-9)
|
233 |
+
],
|
234 |
+
dim=-1)
|
235 |
+
batch_size = pred_joints.shape[0]
|
236 |
+
camera_center = torch.zeros(batch_size, 2)
|
237 |
+
pred_keypoints_2d = perspective_projection(
|
238 |
+
pred_joints,
|
239 |
+
rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1,
|
240 |
+
-1).to(pred_joints.device),
|
241 |
+
translation=pred_cam_t,
|
242 |
+
focal_length=5000.,
|
243 |
+
camera_center=camera_center,
|
244 |
+
retain_z=retain_z)
|
245 |
+
# Normalize keypoints to [-1,1]
|
246 |
+
pred_keypoints_2d = pred_keypoints_2d / (224. / 2.)
|
247 |
+
return pred_keypoints_2d
|
248 |
+
|
249 |
+
|
250 |
+
def perspective_projection(points,
|
251 |
+
rotation,
|
252 |
+
translation,
|
253 |
+
focal_length,
|
254 |
+
camera_center,
|
255 |
+
retain_z=False):
|
256 |
+
"""
|
257 |
+
This function computes the perspective projection of a set of points.
|
258 |
+
Input:
|
259 |
+
points (bs, N, 3): 3D points
|
260 |
+
rotation (bs, 3, 3): Camera rotation
|
261 |
+
translation (bs, 3): Camera translation
|
262 |
+
focal_length (bs,) or scalar: Focal length
|
263 |
+
camera_center (bs, 2): Camera center
|
264 |
+
"""
|
265 |
+
batch_size = points.shape[0]
|
266 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device)
|
267 |
+
K[:, 0, 0] = focal_length
|
268 |
+
K[:, 1, 1] = focal_length
|
269 |
+
K[:, 2, 2] = 1.
|
270 |
+
K[:, :-1, -1] = camera_center
|
271 |
+
|
272 |
+
# Transform points
|
273 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
274 |
+
points = points + translation.unsqueeze(1)
|
275 |
+
|
276 |
+
# Apply perspective distortion
|
277 |
+
projected_points = points / points[:, :, -1].unsqueeze(-1)
|
278 |
+
|
279 |
+
# Apply camera intrinsics
|
280 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
281 |
+
|
282 |
+
if retain_z:
|
283 |
+
return projected_points
|
284 |
+
else:
|
285 |
+
return projected_points[:, :, :-1]
|
286 |
+
|
287 |
+
|
288 |
+
def estimate_translation_np(S,
|
289 |
+
joints_2d,
|
290 |
+
joints_conf,
|
291 |
+
focal_length=5000,
|
292 |
+
img_size=224):
|
293 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
294 |
+
Input:
|
295 |
+
S: (25, 3) 3D joint locations
|
296 |
+
joints: (25, 3) 2D joint locations and confidence
|
297 |
+
Returns:
|
298 |
+
(3,) camera translation vector
|
299 |
+
"""
|
300 |
+
|
301 |
+
num_joints = S.shape[0]
|
302 |
+
# focal length
|
303 |
+
f = np.array([focal_length, focal_length])
|
304 |
+
# optical center
|
305 |
+
center = np.array([img_size / 2., img_size / 2.])
|
306 |
+
|
307 |
+
# transformations
|
308 |
+
Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
|
309 |
+
XY = np.reshape(S[:, 0:2], -1)
|
310 |
+
O = np.tile(center, num_joints)
|
311 |
+
F = np.tile(f, num_joints)
|
312 |
+
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
313 |
+
|
314 |
+
# least squares
|
315 |
+
Q = np.array([
|
316 |
+
F * np.tile(np.array([1, 0]), num_joints),
|
317 |
+
F * np.tile(np.array([0, 1]), num_joints),
|
318 |
+
O - np.reshape(joints_2d, -1)
|
319 |
+
]).T
|
320 |
+
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
321 |
+
|
322 |
+
# weighted least squares
|
323 |
+
W = np.diagflat(weight2)
|
324 |
+
Q = np.dot(W, Q)
|
325 |
+
c = np.dot(W, c)
|
326 |
+
|
327 |
+
# square matrix
|
328 |
+
A = np.dot(Q.T, Q)
|
329 |
+
b = np.dot(Q.T, c)
|
330 |
+
|
331 |
+
# solution
|
332 |
+
trans = np.linalg.solve(A, b)
|
333 |
+
|
334 |
+
return trans
|
335 |
+
|
336 |
+
|
337 |
+
def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
|
338 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
339 |
+
Input:
|
340 |
+
S: (B, 49, 3) 3D joint locations
|
341 |
+
joints: (B, 49, 3) 2D joint locations and confidence
|
342 |
+
Returns:
|
343 |
+
(B, 3) camera translation vectors
|
344 |
+
"""
|
345 |
+
|
346 |
+
device = S.device
|
347 |
+
# Use only joints 25:49 (GT joints)
|
348 |
+
S = S[:, 25:, :].cpu().numpy()
|
349 |
+
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
|
350 |
+
joints_conf = joints_2d[:, :, -1]
|
351 |
+
joints_2d = joints_2d[:, :, :-1]
|
352 |
+
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
|
353 |
+
# Find the translation for each example in the batch
|
354 |
+
for i in range(S.shape[0]):
|
355 |
+
S_i = S[i]
|
356 |
+
joints_i = joints_2d[i]
|
357 |
+
conf_i = joints_conf[i]
|
358 |
+
trans[i] = estimate_translation_np(S_i,
|
359 |
+
joints_i,
|
360 |
+
conf_i,
|
361 |
+
focal_length=focal_length,
|
362 |
+
img_size=img_size)
|
363 |
+
return torch.from_numpy(trans).to(device)
|
364 |
+
|
365 |
+
|
366 |
+
def Rot_y(angle, category='torch', prepend_dim=True, device=None):
|
367 |
+
'''Rotate around y-axis by angle
|
368 |
+
Args:
|
369 |
+
category: 'torch' or 'numpy'
|
370 |
+
prepend_dim: prepend an extra dimension
|
371 |
+
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
372 |
+
'''
|
373 |
+
m = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
|
374 |
+
[-np.sin(angle), 0., np.cos(angle)]])
|
375 |
+
if category == 'torch':
|
376 |
+
if prepend_dim:
|
377 |
+
return torch.tensor(m, dtype=torch.float,
|
378 |
+
device=device).unsqueeze(0)
|
379 |
+
else:
|
380 |
+
return torch.tensor(m, dtype=torch.float, device=device)
|
381 |
+
elif category == 'numpy':
|
382 |
+
if prepend_dim:
|
383 |
+
return np.expand_dims(m, 0)
|
384 |
+
else:
|
385 |
+
return m
|
386 |
+
else:
|
387 |
+
raise ValueError("category must be 'torch' or 'numpy'")
|
388 |
+
|
389 |
+
|
390 |
+
def Rot_x(angle, category='torch', prepend_dim=True, device=None):
|
391 |
+
'''Rotate around x-axis by angle
|
392 |
+
Args:
|
393 |
+
category: 'torch' or 'numpy'
|
394 |
+
prepend_dim: prepend an extra dimension
|
395 |
+
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
396 |
+
'''
|
397 |
+
m = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)],
|
398 |
+
[0., np.sin(angle), np.cos(angle)]])
|
399 |
+
if category == 'torch':
|
400 |
+
if prepend_dim:
|
401 |
+
return torch.tensor(m, dtype=torch.float,
|
402 |
+
device=device).unsqueeze(0)
|
403 |
+
else:
|
404 |
+
return torch.tensor(m, dtype=torch.float, device=device)
|
405 |
+
elif category == 'numpy':
|
406 |
+
if prepend_dim:
|
407 |
+
return np.expand_dims(m, 0)
|
408 |
+
else:
|
409 |
+
return m
|
410 |
+
else:
|
411 |
+
raise ValueError("category must be 'torch' or 'numpy'")
|
412 |
+
|
413 |
+
|
414 |
+
def Rot_z(angle, category='torch', prepend_dim=True, device=None):
|
415 |
+
'''Rotate around z-axis by angle
|
416 |
+
Args:
|
417 |
+
category: 'torch' or 'numpy'
|
418 |
+
prepend_dim: prepend an extra dimension
|
419 |
+
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
420 |
+
'''
|
421 |
+
m = np.array([[np.cos(angle), -np.sin(angle), 0.],
|
422 |
+
[np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]])
|
423 |
+
if category == 'torch':
|
424 |
+
if prepend_dim:
|
425 |
+
return torch.tensor(m, dtype=torch.float,
|
426 |
+
device=device).unsqueeze(0)
|
427 |
+
else:
|
428 |
+
return torch.tensor(m, dtype=torch.float, device=device)
|
429 |
+
elif category == 'numpy':
|
430 |
+
if prepend_dim:
|
431 |
+
return np.expand_dims(m, 0)
|
432 |
+
else:
|
433 |
+
return m
|
434 |
+
else:
|
435 |
+
raise ValueError("category must be 'torch' or 'numpy'")
|
lib / pymaf /utils / imutils.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains functions that are used to perform data augmentation.
|
3 |
+
"""
|
4 |
+
import cv2
|
5 |
+
import io
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from rembg import remove
|
10 |
+
from rembg.session_factory import new_session
|
11 |
+
from torchvision.models import detection
|
12 |
+
|
13 |
+
from lib.pymaf.core import constants
|
14 |
+
from lib.pymaf.utils.streamer import aug_matrix
|
15 |
+
from lib.common.cloth_extraction import load_segmentation
|
16 |
+
from torchvision import transforms
|
17 |
+
|
18 |
+
|
19 |
+
def load_img(img_file):
|
20 |
+
|
21 |
+
img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
|
22 |
+
if len(img.shape) == 2:
|
23 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
24 |
+
|
25 |
+
if not img_file.endswith("png"):
|
26 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
27 |
+
else:
|
28 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
|
29 |
+
|
30 |
+
return img
|
31 |
+
|
32 |
+
|
33 |
+
def get_bbox(img, det):
|
34 |
+
|
35 |
+
input = np.float32(img)
|
36 |
+
input = (input / 255.0 -
|
37 |
+
(0.5, 0.5, 0.5)) / (0.5, 0.5, 0.5) # TO [-1.0, 1.0]
|
38 |
+
input = input.transpose(2, 0, 1) # TO [3 x H x W]
|
39 |
+
bboxes, probs = det(torch.from_numpy(input).float().unsqueeze(0))
|
40 |
+
|
41 |
+
probs = probs.unsqueeze(3)
|
42 |
+
bboxes = (bboxes * probs).sum(dim=1, keepdim=True) / probs.sum(
|
43 |
+
dim=1, keepdim=True)
|
44 |
+
bbox = bboxes[0, 0, 0].cpu().numpy()
|
45 |
+
|
46 |
+
return bbox
|
47 |
+
# Michael Black is
|
48 |
+
|
49 |
+
|
50 |
+
def get_transformer(input_res):
|
51 |
+
|
52 |
+
image_to_tensor = transforms.Compose([
|
53 |
+
transforms.Resize(input_res),
|
54 |
+
transforms.ToTensor(),
|
55 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
56 |
+
])
|
57 |
+
|
58 |
+
mask_to_tensor = transforms.Compose([
|
59 |
+
transforms.Resize(input_res),
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Normalize((0.0, ), (1.0, ))
|
62 |
+
])
|
63 |
+
|
64 |
+
image_to_pymaf_tensor = transforms.Compose([
|
65 |
+
transforms.Resize(size=224),
|
66 |
+
transforms.Normalize(mean=constants.IMG_NORM_MEAN,
|
67 |
+
std=constants.IMG_NORM_STD)
|
68 |
+
])
|
69 |
+
|
70 |
+
image_to_pixie_tensor = transforms.Compose([
|
71 |
+
transforms.Resize(224)
|
72 |
+
])
|
73 |
+
|
74 |
+
def image_to_hybrik_tensor(img):
|
75 |
+
# mean
|
76 |
+
img[0].add_(-0.406)
|
77 |
+
img[1].add_(-0.457)
|
78 |
+
img[2].add_(-0.480)
|
79 |
+
|
80 |
+
# std
|
81 |
+
img[0].div_(0.225)
|
82 |
+
img[1].div_(0.224)
|
83 |
+
img[2].div_(0.229)
|
84 |
+
return img
|
85 |
+
|
86 |
+
return [image_to_tensor, mask_to_tensor, image_to_pymaf_tensor, image_to_pixie_tensor, image_to_hybrik_tensor]
|
87 |
+
|
88 |
+
|
89 |
+
def process_image(img_file, hps_type, input_res=512, device=None, seg_path=None):
|
90 |
+
"""Read image, do preprocessing and possibly crop it according to the bounding box.
|
91 |
+
If there are bounding box annotations, use them to crop the image.
|
92 |
+
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
|
93 |
+
"""
|
94 |
+
|
95 |
+
[image_to_tensor, mask_to_tensor, image_to_pymaf_tensor,
|
96 |
+
image_to_pixie_tensor, image_to_hybrik_tensor] = get_transformer(input_res)
|
97 |
+
|
98 |
+
img_ori = load_img(img_file)
|
99 |
+
|
100 |
+
in_height, in_width, _ = img_ori.shape
|
101 |
+
M = aug_matrix(in_width, in_height, input_res*2, input_res*2)
|
102 |
+
|
103 |
+
# from rectangle to square
|
104 |
+
img_for_crop = cv2.warpAffine(img_ori, M[0:2, :],
|
105 |
+
(input_res*2, input_res*2), flags=cv2.INTER_CUBIC)
|
106 |
+
|
107 |
+
# detection for bbox
|
108 |
+
detector = detection.maskrcnn_resnet50_fpn(pretrained=True)
|
109 |
+
detector.eval()
|
110 |
+
predictions = detector(
|
111 |
+
[torch.from_numpy(img_for_crop).permute(2, 0, 1) / 255.])[0]
|
112 |
+
human_ids = torch.where(
|
113 |
+
predictions["scores"] == predictions["scores"][predictions['labels'] == 1].max())
|
114 |
+
bbox = predictions["boxes"][human_ids, :].flatten().detach().cpu().numpy()
|
115 |
+
|
116 |
+
width = bbox[2] - bbox[0]
|
117 |
+
height = bbox[3] - bbox[1]
|
118 |
+
center = np.array([(bbox[0] + bbox[2]) / 2.0,
|
119 |
+
(bbox[1] + bbox[3]) / 2.0])
|
120 |
+
|
121 |
+
scale = max(height, width) / 180
|
122 |
+
|
123 |
+
if hps_type == 'hybrik':
|
124 |
+
img_np = crop_for_hybrik(img_for_crop, center,
|
125 |
+
np.array([scale * 180, scale * 180]))
|
126 |
+
else:
|
127 |
+
img_np, cropping_parameters = crop(
|
128 |
+
img_for_crop, center, scale, (input_res, input_res))
|
129 |
+
|
130 |
+
img_pil = Image.fromarray(remove(img_np, post_process_mask=True, session=new_session("u2net")))
|
131 |
+
|
132 |
+
# for icon
|
133 |
+
img_rgb = image_to_tensor(img_pil.convert("RGB"))
|
134 |
+
img_mask = torch.tensor(1.0) - (mask_to_tensor(img_pil.split()[-1]) <
|
135 |
+
torch.tensor(0.5)).float()
|
136 |
+
img_tensor = img_rgb * img_mask
|
137 |
+
|
138 |
+
# for hps
|
139 |
+
img_hps = img_np.astype(np.float32) / 255.
|
140 |
+
img_hps = torch.from_numpy(img_hps).permute(2, 0, 1)
|
141 |
+
|
142 |
+
if hps_type == 'bev':
|
143 |
+
img_hps = img_np[:, :, [2, 1, 0]]
|
144 |
+
elif hps_type == 'hybrik':
|
145 |
+
img_hps = image_to_hybrik_tensor(img_hps).unsqueeze(0).to(device)
|
146 |
+
elif hps_type != 'pixie':
|
147 |
+
img_hps = image_to_pymaf_tensor(img_hps).unsqueeze(0).to(device)
|
148 |
+
else:
|
149 |
+
img_hps = image_to_pixie_tensor(img_hps).unsqueeze(0).to(device)
|
150 |
+
|
151 |
+
# uncrop params
|
152 |
+
uncrop_param = {'center': center,
|
153 |
+
'scale': scale,
|
154 |
+
'ori_shape': img_ori.shape,
|
155 |
+
'box_shape': img_np.shape,
|
156 |
+
'crop_shape': img_for_crop.shape,
|
157 |
+
'M': M}
|
158 |
+
|
159 |
+
if not (seg_path is None):
|
160 |
+
segmentations = load_segmentation(seg_path, (in_height, in_width))
|
161 |
+
seg_coord_normalized = []
|
162 |
+
for seg in segmentations:
|
163 |
+
coord_normalized = []
|
164 |
+
for xy in seg['coordinates']:
|
165 |
+
xy_h = np.vstack((xy[:, 0], xy[:, 1], np.ones(len(xy)))).T
|
166 |
+
warped_indeces = M[0:2, :] @ xy_h[:, :, None]
|
167 |
+
warped_indeces = np.array(warped_indeces).astype(int)
|
168 |
+
warped_indeces.resize((warped_indeces.shape[:2]))
|
169 |
+
|
170 |
+
# cropped_indeces = crop_segmentation(warped_indeces, center, scale, (input_res, input_res), img_np.shape)
|
171 |
+
cropped_indeces = crop_segmentation(
|
172 |
+
warped_indeces, (input_res, input_res), cropping_parameters)
|
173 |
+
|
174 |
+
indices = np.vstack(
|
175 |
+
(cropped_indeces[:, 0], cropped_indeces[:, 1])).T
|
176 |
+
|
177 |
+
# Convert to NDC coordinates
|
178 |
+
seg_cropped_normalized = 2*(indices / input_res) - 1
|
179 |
+
# Don't know why we need to divide by 50 but it works ¯\_(ツ)_/¯ (probably some scaling factor somewhere)
|
180 |
+
# Divide only by 45 on the horizontal axis to take the curve of the human body into account
|
181 |
+
seg_cropped_normalized[:, 0] = (
|
182 |
+
1/40) * seg_cropped_normalized[:, 0]
|
183 |
+
seg_cropped_normalized[:, 1] = (
|
184 |
+
1/50) * seg_cropped_normalized[:, 1]
|
185 |
+
coord_normalized.append(seg_cropped_normalized)
|
186 |
+
|
187 |
+
seg['coord_normalized'] = coord_normalized
|
188 |
+
seg_coord_normalized.append(seg)
|
189 |
+
|
190 |
+
return img_tensor, img_hps, img_ori, img_mask, uncrop_param, seg_coord_normalized
|
191 |
+
|
192 |
+
return img_tensor, img_hps, img_ori, img_mask, uncrop_param
|
193 |
+
|
194 |
+
|
195 |
+
def get_transform(center, scale, res):
|
196 |
+
"""Generate transformation matrix."""
|
197 |
+
h = 200 * scale
|
198 |
+
t = np.zeros((3, 3))
|
199 |
+
t[0, 0] = float(res[1]) / h
|
200 |
+
t[1, 1] = float(res[0]) / h
|
201 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
202 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
203 |
+
t[2, 2] = 1
|
204 |
+
|
205 |
+
return t
|
206 |
+
|
207 |
+
|
208 |
+
def transform(pt, center, scale, res, invert=0):
|
209 |
+
"""Transform pixel location to different reference."""
|
210 |
+
t = get_transform(center, scale, res)
|
211 |
+
if invert:
|
212 |
+
t = np.linalg.inv(t)
|
213 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
214 |
+
new_pt = np.dot(t, new_pt)
|
215 |
+
return np.around(new_pt[:2]).astype(np.int16)
|
216 |
+
|
217 |
+
|
218 |
+
def crop(img, center, scale, res):
|
219 |
+
"""Crop image according to the supplied bounding box."""
|
220 |
+
|
221 |
+
# Upper left point
|
222 |
+
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
223 |
+
|
224 |
+
# Bottom right point
|
225 |
+
br = np.array(transform(res, center, scale, res, invert=1))
|
226 |
+
|
227 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
228 |
+
if len(img.shape) > 2:
|
229 |
+
new_shape += [img.shape[2]]
|
230 |
+
new_img = np.zeros(new_shape)
|
231 |
+
|
232 |
+
# Range to fill new array
|
233 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
234 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
235 |
+
|
236 |
+
# Range to sample from original image
|
237 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
238 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
239 |
+
|
240 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]
|
241 |
+
] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
242 |
+
if len(img.shape) == 2:
|
243 |
+
new_img = np.array(Image.fromarray(new_img).resize(res))
|
244 |
+
else:
|
245 |
+
new_img = np.array(Image.fromarray(
|
246 |
+
new_img.astype(np.uint8)).resize(res))
|
247 |
+
|
248 |
+
return new_img, (old_x, new_x, old_y, new_y, new_shape)
|
249 |
+
|
250 |
+
|
251 |
+
def crop_segmentation(org_coord, res, cropping_parameters):
|
252 |
+
old_x, new_x, old_y, new_y, new_shape = cropping_parameters
|
253 |
+
|
254 |
+
new_coord = np.zeros((org_coord.shape))
|
255 |
+
new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0])
|
256 |
+
new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0])
|
257 |
+
|
258 |
+
new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1])
|
259 |
+
new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0])
|
260 |
+
|
261 |
+
return new_coord
|
262 |
+
|
263 |
+
|
264 |
+
def crop_for_hybrik(img, center, scale):
|
265 |
+
inp_h, inp_w = (256, 256)
|
266 |
+
trans = get_affine_transform(center, scale, 0, [inp_w, inp_h])
|
267 |
+
new_img = cv2.warpAffine(
|
268 |
+
img, trans, (int(inp_w), int(inp_h)), flags=cv2.INTER_LINEAR)
|
269 |
+
return new_img
|
270 |
+
|
271 |
+
|
272 |
+
def get_affine_transform(center,
|
273 |
+
scale,
|
274 |
+
rot,
|
275 |
+
output_size,
|
276 |
+
shift=np.array([0, 0], dtype=np.float32),
|
277 |
+
inv=0):
|
278 |
+
|
279 |
+
def get_dir(src_point, rot_rad):
|
280 |
+
"""Rotate the point by `rot_rad` degree."""
|
281 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
282 |
+
|
283 |
+
src_result = [0, 0]
|
284 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
285 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
286 |
+
|
287 |
+
return src_result
|
288 |
+
|
289 |
+
def get_3rd_point(a, b):
|
290 |
+
"""Return vector c that perpendicular to (a - b)."""
|
291 |
+
direct = a - b
|
292 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
293 |
+
|
294 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
295 |
+
scale = np.array([scale, scale])
|
296 |
+
|
297 |
+
scale_tmp = scale
|
298 |
+
src_w = scale_tmp[0]
|
299 |
+
dst_w = output_size[0]
|
300 |
+
dst_h = output_size[1]
|
301 |
+
|
302 |
+
rot_rad = np.pi * rot / 180
|
303 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
304 |
+
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
305 |
+
|
306 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
307 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
308 |
+
src[0, :] = center + scale_tmp * shift
|
309 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
310 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
311 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
312 |
+
|
313 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
314 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
315 |
+
|
316 |
+
if inv:
|
317 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
318 |
+
else:
|
319 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
320 |
+
|
321 |
+
return trans
|
322 |
+
|
323 |
+
|
324 |
+
def corner_align(ul, br):
|
325 |
+
|
326 |
+
if ul[1]-ul[0] != br[1]-br[0]:
|
327 |
+
ul[1] = ul[0]+br[1]-br[0]
|
328 |
+
|
329 |
+
return ul, br
|
330 |
+
|
331 |
+
|
332 |
+
def uncrop(img, center, scale, orig_shape):
|
333 |
+
"""'Undo' the image cropping/resizing.
|
334 |
+
This function is used when evaluating mask/part segmentation.
|
335 |
+
"""
|
336 |
+
|
337 |
+
res = img.shape[:2]
|
338 |
+
|
339 |
+
# Upper left point
|
340 |
+
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
341 |
+
# Bottom right point
|
342 |
+
br = np.array(transform(res, center, scale, res, invert=1))
|
343 |
+
|
344 |
+
# quick fix
|
345 |
+
ul, br = corner_align(ul, br)
|
346 |
+
|
347 |
+
# size of cropped image
|
348 |
+
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
349 |
+
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
350 |
+
|
351 |
+
# Range to fill new array
|
352 |
+
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
353 |
+
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
354 |
+
|
355 |
+
# Range to sample from original image
|
356 |
+
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
357 |
+
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
358 |
+
|
359 |
+
img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape))
|
360 |
+
|
361 |
+
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]
|
362 |
+
] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
363 |
+
|
364 |
+
return new_img
|
365 |
+
|
366 |
+
|
367 |
+
def rot_aa(aa, rot):
|
368 |
+
"""Rotate axis angle parameters."""
|
369 |
+
# pose parameters
|
370 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
371 |
+
[np.sin(np.deg2rad(-rot)),
|
372 |
+
np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]])
|
373 |
+
# find the rotation of the body in camera frame
|
374 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
375 |
+
# apply the global rotation to the global orientation
|
376 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
377 |
+
aa = (resrot.T)[0]
|
378 |
+
return aa
|
379 |
+
|
380 |
+
|
381 |
+
def flip_img(img):
|
382 |
+
"""Flip rgb images or masks.
|
383 |
+
channels come last, e.g. (256,256,3).
|
384 |
+
"""
|
385 |
+
img = np.fliplr(img)
|
386 |
+
return img
|
387 |
+
|
388 |
+
|
389 |
+
def flip_kp(kp, is_smpl=False):
|
390 |
+
"""Flip keypoints."""
|
391 |
+
if len(kp) == 24:
|
392 |
+
if is_smpl:
|
393 |
+
flipped_parts = constants.SMPL_JOINTS_FLIP_PERM
|
394 |
+
else:
|
395 |
+
flipped_parts = constants.J24_FLIP_PERM
|
396 |
+
elif len(kp) == 49:
|
397 |
+
if is_smpl:
|
398 |
+
flipped_parts = constants.SMPL_J49_FLIP_PERM
|
399 |
+
else:
|
400 |
+
flipped_parts = constants.J49_FLIP_PERM
|
401 |
+
kp = kp[flipped_parts]
|
402 |
+
kp[:, 0] = -kp[:, 0]
|
403 |
+
return kp
|
404 |
+
|
405 |
+
|
406 |
+
def flip_pose(pose):
|
407 |
+
"""Flip pose.
|
408 |
+
The flipping is based on SMPL parameters.
|
409 |
+
"""
|
410 |
+
flipped_parts = constants.SMPL_POSE_FLIP_PERM
|
411 |
+
pose = pose[flipped_parts]
|
412 |
+
# we also negate the second and the third dimension of the axis-angle
|
413 |
+
pose[1::3] = -pose[1::3]
|
414 |
+
pose[2::3] = -pose[2::3]
|
415 |
+
return pose
|
416 |
+
|
417 |
+
|
418 |
+
def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
|
419 |
+
# Normalize keypoints between -1, 1
|
420 |
+
if not inv:
|
421 |
+
ratio = 1.0 / crop_size
|
422 |
+
kp_2d = 2.0 * kp_2d * ratio - 1.0
|
423 |
+
else:
|
424 |
+
ratio = 1.0 / crop_size
|
425 |
+
kp_2d = (kp_2d + 1.0) / (2 * ratio)
|
426 |
+
|
427 |
+
return kp_2d
|
428 |
+
|
429 |
+
|
430 |
+
def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
|
431 |
+
'''
|
432 |
+
param joints: [num_joints, 3]
|
433 |
+
param joints_vis: [num_joints, 3]
|
434 |
+
return: target, target_weight(1: visible, 0: invisible)
|
435 |
+
'''
|
436 |
+
num_joints = joints.shape[0]
|
437 |
+
device = joints.device
|
438 |
+
cur_device = torch.device(device.type, device.index)
|
439 |
+
if not hasattr(heatmap_size, '__len__'):
|
440 |
+
# width height
|
441 |
+
heatmap_size = [heatmap_size, heatmap_size]
|
442 |
+
assert len(heatmap_size) == 2
|
443 |
+
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
444 |
+
if joints_vis is not None:
|
445 |
+
target_weight[:, 0] = joints_vis[:, 0]
|
446 |
+
target = torch.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
|
447 |
+
dtype=torch.float32,
|
448 |
+
device=cur_device)
|
449 |
+
|
450 |
+
tmp_size = sigma * 3
|
451 |
+
|
452 |
+
for joint_id in range(num_joints):
|
453 |
+
mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
|
454 |
+
mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
|
455 |
+
# Check that any part of the gaussian is in-bounds
|
456 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
457 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
458 |
+
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
|
459 |
+
or br[0] < 0 or br[1] < 0:
|
460 |
+
# If not, just return the image as is
|
461 |
+
target_weight[joint_id] = 0
|
462 |
+
continue
|
463 |
+
|
464 |
+
# # Generate gaussian
|
465 |
+
size = 2 * tmp_size + 1
|
466 |
+
# x = np.arange(0, size, 1, np.float32)
|
467 |
+
# y = x[:, np.newaxis]
|
468 |
+
# x0 = y0 = size // 2
|
469 |
+
# # The gaussian is not normalized, we want the center value to equal 1
|
470 |
+
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
471 |
+
# g = torch.from_numpy(g.astype(np.float32))
|
472 |
+
|
473 |
+
x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
|
474 |
+
y = x.unsqueeze(-1)
|
475 |
+
x0 = y0 = size // 2
|
476 |
+
# The gaussian is not normalized, we want the center value to equal 1
|
477 |
+
g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
|
478 |
+
|
479 |
+
# Usable gaussian range
|
480 |
+
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
|
481 |
+
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
|
482 |
+
# Image range
|
483 |
+
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
|
484 |
+
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
|
485 |
+
|
486 |
+
v = target_weight[joint_id]
|
487 |
+
if v > 0.5:
|
488 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
489 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
490 |
+
|
491 |
+
return target, target_weight
|
lib / pymaf /utils / streamer.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import imageio
|
5 |
+
|
6 |
+
|
7 |
+
def aug_matrix(w1, h1, w2, h2):
|
8 |
+
dx = (w2 - w1) / 2.0
|
9 |
+
dy = (h2 - h1) / 2.0
|
10 |
+
|
11 |
+
matrix_trans = np.array([[1.0, 0, dx],
|
12 |
+
[0, 1.0, dy],
|
13 |
+
[0, 0, 1.0]])
|
14 |
+
|
15 |
+
scale = np.min([float(w2)/w1, float(h2)/h1])
|
16 |
+
|
17 |
+
M = get_affine_matrix(
|
18 |
+
center=(w2 / 2.0, h2 / 2.0),
|
19 |
+
translate=(0, 0),
|
20 |
+
scale=scale)
|
21 |
+
|
22 |
+
M = np.array(M + [0., 0., 1.]).reshape(3, 3)
|
23 |
+
M = M.dot(matrix_trans)
|
24 |
+
|
25 |
+
return M
|
26 |
+
|
27 |
+
|
28 |
+
def get_affine_matrix(center, translate, scale):
|
29 |
+
cx, cy = center
|
30 |
+
tx, ty = translate
|
31 |
+
|
32 |
+
M = [1, 0, 0,
|
33 |
+
0, 1, 0]
|
34 |
+
M = [x * scale for x in M]
|
35 |
+
|
36 |
+
# Apply translation and of center translation: RSS * C^-1
|
37 |
+
M[2] += M[0] * (-cx) + M[1] * (-cy)
|
38 |
+
M[5] += M[3] * (-cx) + M[4] * (-cy)
|
39 |
+
|
40 |
+
# Apply center translation: T * C * RSS * C^-1
|
41 |
+
M[2] += cx + tx
|
42 |
+
M[5] += cy + ty
|
43 |
+
return M
|
44 |
+
|
45 |
+
|
46 |
+
class BaseStreamer():
|
47 |
+
"""This streamer will return images at 512x512 size.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
width=512, height=512, pad=True,
|
52 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
53 |
+
**kwargs):
|
54 |
+
self.width = width
|
55 |
+
self.height = height
|
56 |
+
self.pad = pad
|
57 |
+
self.mean = np.array(mean)
|
58 |
+
self.std = np.array(std)
|
59 |
+
|
60 |
+
self.loader = self.create_loader()
|
61 |
+
|
62 |
+
def create_loader(self):
|
63 |
+
raise NotImplementedError
|
64 |
+
yield np.zeros((600, 400, 3)) # in RGB (0, 255)
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
image = next(self.loader)
|
68 |
+
in_height, in_width, _ = image.shape
|
69 |
+
M = aug_matrix(in_width, in_height, self.width, self.height, self.pad)
|
70 |
+
image = cv2.warpAffine(
|
71 |
+
image, M[0:2, :], (self.width, self.height), flags=cv2.INTER_CUBIC)
|
72 |
+
|
73 |
+
input = np.float32(image)
|
74 |
+
input = (input / 255.0 - self.mean) / self.std # TO [-1.0, 1.0]
|
75 |
+
input = input.transpose(2, 0, 1) # TO [3 x H x W]
|
76 |
+
return torch.from_numpy(input).float()
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
raise NotImplementedError
|
80 |
+
|
81 |
+
|
82 |
+
class CaptureStreamer(BaseStreamer):
|
83 |
+
"""This streamer takes webcam as input.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, id=0, width=512, height=512, pad=True, **kwargs):
|
87 |
+
super().__init__(width, height, pad, **kwargs)
|
88 |
+
self.capture = cv2.VideoCapture(id)
|
89 |
+
|
90 |
+
def create_loader(self):
|
91 |
+
while True:
|
92 |
+
_, image = self.capture.read()
|
93 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB
|
94 |
+
yield image
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return 100_000_000
|
98 |
+
|
99 |
+
def __del__(self):
|
100 |
+
self.capture.release()
|
101 |
+
|
102 |
+
|
103 |
+
class VideoListStreamer(BaseStreamer):
|
104 |
+
"""This streamer takes a list of video files as input.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, files, width=512, height=512, pad=True, **kwargs):
|
108 |
+
super().__init__(width, height, pad, **kwargs)
|
109 |
+
self.files = files
|
110 |
+
self.captures = [imageio.get_reader(f) for f in files]
|
111 |
+
self.nframes = sum([int(cap._meta["fps"] * cap._meta["duration"])
|
112 |
+
for cap in self.captures])
|
113 |
+
|
114 |
+
def create_loader(self):
|
115 |
+
for capture in self.captures:
|
116 |
+
for image in capture: # RGB
|
117 |
+
yield image
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
return self.nframes
|
121 |
+
|
122 |
+
def __del__(self):
|
123 |
+
for capture in self.captures:
|
124 |
+
capture.close()
|
125 |
+
|
126 |
+
|
127 |
+
class ImageListStreamer(BaseStreamer):
|
128 |
+
"""This streamer takes a list of image files as input.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, files, width=512, height=512, pad=True, **kwargs):
|
132 |
+
super().__init__(width, height, pad, **kwargs)
|
133 |
+
self.files = files
|
134 |
+
|
135 |
+
def create_loader(self):
|
136 |
+
for f in self.files:
|
137 |
+
image = cv2.imread(f, cv2.IMREAD_UNCHANGED)[:, :, 0:3]
|
138 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB
|
139 |
+
yield image
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.files)
|
lib / pymaf /utils /transforms.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# Licensed under the MIT License.
|
4 |
+
# Written by Bin Xiao ([email protected])
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from __future__ import absolute_import
|
8 |
+
from __future__ import division
|
9 |
+
from __future__ import print_function
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
def transform_preds(coords, center, scale, output_size):
|
16 |
+
target_coords = np.zeros(coords.shape)
|
17 |
+
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
|
18 |
+
for p in range(coords.shape[0]):
|
19 |
+
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
20 |
+
return target_coords
|
21 |
+
|
22 |
+
|
23 |
+
def get_affine_transform(center,
|
24 |
+
scale,
|
25 |
+
rot,
|
26 |
+
output_size,
|
27 |
+
shift=np.array([0, 0], dtype=np.float32),
|
28 |
+
inv=0):
|
29 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
30 |
+
# print(scale)
|
31 |
+
scale = np.array([scale, scale])
|
32 |
+
|
33 |
+
scale_tmp = scale * 200.0
|
34 |
+
src_w = scale_tmp[0]
|
35 |
+
dst_w = output_size[0]
|
36 |
+
dst_h = output_size[1]
|
37 |
+
|
38 |
+
rot_rad = np.pi * rot / 180
|
39 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
40 |
+
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
41 |
+
|
42 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
43 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
44 |
+
src[0, :] = center + scale_tmp * shift
|
45 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
46 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
47 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
48 |
+
|
49 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
50 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
51 |
+
|
52 |
+
if inv:
|
53 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
54 |
+
else:
|
55 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
56 |
+
|
57 |
+
return trans
|
58 |
+
|
59 |
+
|
60 |
+
def affine_transform(pt, t):
|
61 |
+
new_pt = np.array([pt[0], pt[1], 1.]).T
|
62 |
+
new_pt = np.dot(t, new_pt)
|
63 |
+
return new_pt[:2]
|
64 |
+
|
65 |
+
|
66 |
+
def get_3rd_point(a, b):
|
67 |
+
direct = a - b
|
68 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
69 |
+
|
70 |
+
|
71 |
+
def get_dir(src_point, rot_rad):
|
72 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
73 |
+
|
74 |
+
src_result = [0, 0]
|
75 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
76 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
77 |
+
|
78 |
+
return src_result
|
lib / renderer / __init__.py
ADDED
File without changes
|
lib / renderer / camera.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from .glm import ortho
|
22 |
+
|
23 |
+
|
24 |
+
class Camera:
|
25 |
+
def __init__(self, width=1600, height=1200):
|
26 |
+
# Focal Length
|
27 |
+
# equivalent 50mm
|
28 |
+
focal = np.sqrt(width * width + height * height)
|
29 |
+
self.focal_x = focal
|
30 |
+
self.focal_y = focal
|
31 |
+
# Principal Point Offset
|
32 |
+
self.principal_x = width / 2
|
33 |
+
self.principal_y = height / 2
|
34 |
+
# Axis Skew
|
35 |
+
self.skew = 0
|
36 |
+
# Image Size
|
37 |
+
self.width = width
|
38 |
+
self.height = height
|
39 |
+
|
40 |
+
self.near = 1
|
41 |
+
self.far = 10
|
42 |
+
|
43 |
+
# Camera Center
|
44 |
+
self.center = np.array([0, 0, 1.6])
|
45 |
+
self.direction = np.array([0, 0, -1])
|
46 |
+
self.right = np.array([1, 0, 0])
|
47 |
+
self.up = np.array([0, 1, 0])
|
48 |
+
|
49 |
+
self.ortho_ratio = None
|
50 |
+
|
51 |
+
def sanity_check(self):
|
52 |
+
self.center = self.center.reshape([-1])
|
53 |
+
self.direction = self.direction.reshape([-1])
|
54 |
+
self.right = self.right.reshape([-1])
|
55 |
+
self.up = self.up.reshape([-1])
|
56 |
+
|
57 |
+
assert len(self.center) == 3
|
58 |
+
assert len(self.direction) == 3
|
59 |
+
assert len(self.right) == 3
|
60 |
+
assert len(self.up) == 3
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def normalize_vector(v):
|
64 |
+
v_norm = np.linalg.norm(v)
|
65 |
+
return v if v_norm == 0 else v / v_norm
|
66 |
+
|
67 |
+
def get_real_z_value(self, z):
|
68 |
+
z_near = self.near
|
69 |
+
z_far = self.far
|
70 |
+
z_n = 2.0 * z - 1.0
|
71 |
+
z_e = 2.0 * z_near * z_far / (z_far + z_near - z_n * (z_far - z_near))
|
72 |
+
return z_e
|
73 |
+
|
74 |
+
def get_rotation_matrix(self):
|
75 |
+
rot_mat = np.eye(3)
|
76 |
+
s = self.right
|
77 |
+
s = self.normalize_vector(s)
|
78 |
+
rot_mat[0, :] = s
|
79 |
+
u = self.up
|
80 |
+
u = self.normalize_vector(u)
|
81 |
+
rot_mat[1, :] = -u
|
82 |
+
rot_mat[2, :] = self.normalize_vector(self.direction)
|
83 |
+
|
84 |
+
return rot_mat
|
85 |
+
|
86 |
+
def get_translation_vector(self):
|
87 |
+
rot_mat = self.get_rotation_matrix()
|
88 |
+
trans = -np.dot(rot_mat, self.center)
|
89 |
+
return trans
|
90 |
+
|
91 |
+
def get_intrinsic_matrix(self):
|
92 |
+
int_mat = np.eye(3)
|
93 |
+
|
94 |
+
int_mat[0, 0] = self.focal_x
|
95 |
+
int_mat[1, 1] = self.focal_y
|
96 |
+
int_mat[0, 1] = self.skew
|
97 |
+
int_mat[0, 2] = self.principal_x
|
98 |
+
int_mat[1, 2] = self.principal_y
|
99 |
+
|
100 |
+
return int_mat
|
101 |
+
|
102 |
+
def get_projection_matrix(self):
|
103 |
+
ext_mat = self.get_extrinsic_matrix()
|
104 |
+
int_mat = self.get_intrinsic_matrix()
|
105 |
+
|
106 |
+
return np.matmul(int_mat, ext_mat)
|
107 |
+
|
108 |
+
def get_extrinsic_matrix(self):
|
109 |
+
rot_mat = self.get_rotation_matrix()
|
110 |
+
int_mat = self.get_intrinsic_matrix()
|
111 |
+
trans = self.get_translation_vector()
|
112 |
+
|
113 |
+
extrinsic = np.eye(4)
|
114 |
+
extrinsic[:3, :3] = rot_mat
|
115 |
+
extrinsic[:3, 3] = trans
|
116 |
+
|
117 |
+
return extrinsic[:3, :]
|
118 |
+
|
119 |
+
def set_rotation_matrix(self, rot_mat):
|
120 |
+
self.direction = rot_mat[2, :]
|
121 |
+
self.up = -rot_mat[1, :]
|
122 |
+
self.right = rot_mat[0, :]
|
123 |
+
|
124 |
+
def set_intrinsic_matrix(self, int_mat):
|
125 |
+
self.focal_x = int_mat[0, 0]
|
126 |
+
self.focal_y = int_mat[1, 1]
|
127 |
+
self.skew = int_mat[0, 1]
|
128 |
+
self.principal_x = int_mat[0, 2]
|
129 |
+
self.principal_y = int_mat[1, 2]
|
130 |
+
|
131 |
+
def set_projection_matrix(self, proj_mat):
|
132 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
133 |
+
int_mat, rot_mat, camera_center_homo = res[0], res[1], res[2]
|
134 |
+
camera_center = camera_center_homo[0:3] / camera_center_homo[3]
|
135 |
+
camera_center = camera_center.reshape(-1)
|
136 |
+
int_mat = int_mat / int_mat[2][2]
|
137 |
+
|
138 |
+
self.set_intrinsic_matrix(int_mat)
|
139 |
+
self.set_rotation_matrix(rot_mat)
|
140 |
+
self.center = camera_center
|
141 |
+
|
142 |
+
self.sanity_check()
|
143 |
+
|
144 |
+
def get_gl_matrix(self):
|
145 |
+
z_near = self.near
|
146 |
+
z_far = self.far
|
147 |
+
rot_mat = self.get_rotation_matrix()
|
148 |
+
int_mat = self.get_intrinsic_matrix()
|
149 |
+
trans = self.get_translation_vector()
|
150 |
+
|
151 |
+
extrinsic = np.eye(4)
|
152 |
+
extrinsic[:3, :3] = rot_mat
|
153 |
+
extrinsic[:3, 3] = trans
|
154 |
+
axis_adj = np.eye(4)
|
155 |
+
axis_adj[2, 2] = -1
|
156 |
+
axis_adj[1, 1] = -1
|
157 |
+
model_view = np.matmul(axis_adj, extrinsic)
|
158 |
+
|
159 |
+
projective = np.zeros([4, 4])
|
160 |
+
projective[:2, :2] = int_mat[:2, :2]
|
161 |
+
projective[:2, 2:3] = -int_mat[:2, 2:3]
|
162 |
+
projective[3, 2] = -1
|
163 |
+
projective[2, 2] = (z_near + z_far)
|
164 |
+
projective[2, 3] = (z_near * z_far)
|
165 |
+
|
166 |
+
if self.ortho_ratio is None:
|
167 |
+
ndc = ortho(0, self.width, 0, self.height, z_near, z_far)
|
168 |
+
perspective = np.matmul(ndc, projective)
|
169 |
+
else:
|
170 |
+
perspective = ortho(-self.width * self.ortho_ratio / 2,
|
171 |
+
self.width * self.ortho_ratio / 2,
|
172 |
+
-self.height * self.ortho_ratio / 2,
|
173 |
+
self.height * self.ortho_ratio / 2, z_near,
|
174 |
+
z_far)
|
175 |
+
|
176 |
+
return perspective, model_view
|
177 |
+
|
178 |
+
|
179 |
+
def KRT_from_P(proj_mat, normalize_K=True):
|
180 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
181 |
+
K, Rot, camera_center_homog = res[0], res[1], res[2]
|
182 |
+
camera_center = camera_center_homog[0:3] / camera_center_homog[3]
|
183 |
+
trans = -Rot.dot(camera_center)
|
184 |
+
if normalize_K:
|
185 |
+
K = K / K[2][2]
|
186 |
+
return K, Rot, trans
|
187 |
+
|
188 |
+
|
189 |
+
def MVP_from_P(proj_mat, width, height, near=0.1, far=10000):
|
190 |
+
'''
|
191 |
+
Convert OpenCV camera calibration matrix to OpenGL projection and model view matrix
|
192 |
+
:param proj_mat: OpenCV camera projeciton matrix
|
193 |
+
:param width: Image width
|
194 |
+
:param height: Image height
|
195 |
+
:param near: Z near value
|
196 |
+
:param far: Z far value
|
197 |
+
:return: OpenGL projection matrix and model view matrix
|
198 |
+
'''
|
199 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
200 |
+
K, Rot, camera_center_homog = res[0], res[1], res[2]
|
201 |
+
camera_center = camera_center_homog[0:3] / camera_center_homog[3]
|
202 |
+
trans = -Rot.dot(camera_center)
|
203 |
+
K = K / K[2][2]
|
204 |
+
|
205 |
+
extrinsic = np.eye(4)
|
206 |
+
extrinsic[:3, :3] = Rot
|
207 |
+
extrinsic[:3, 3:4] = trans
|
208 |
+
axis_adj = np.eye(4)
|
209 |
+
axis_adj[2, 2] = -1
|
210 |
+
axis_adj[1, 1] = -1
|
211 |
+
model_view = np.matmul(axis_adj, extrinsic)
|
212 |
+
|
213 |
+
zFar = far
|
214 |
+
zNear = near
|
215 |
+
projective = np.zeros([4, 4])
|
216 |
+
projective[:2, :2] = K[:2, :2]
|
217 |
+
projective[:2, 2:3] = -K[:2, 2:3]
|
218 |
+
projective[3, 2] = -1
|
219 |
+
projective[2, 2] = (zNear + zFar)
|
220 |
+
projective[2, 3] = (zNear * zFar)
|
221 |
+
|
222 |
+
ndc = ortho(0, width, 0, height, zNear, zFar)
|
223 |
+
|
224 |
+
perspective = np.matmul(ndc, projective)
|
225 |
+
|
226 |
+
return perspective, model_view
|
lib / renderer / gl / __init__.py
ADDED
File without changes
|
lib / renderer / gl / data / color.fs
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330 core
|
2 |
+
|
3 |
+
layout (location = 0) out vec4 FragColor;
|
4 |
+
layout (location = 1) out vec4 FragNormal;
|
5 |
+
layout (location = 2) out vec4 FragDepth;
|
6 |
+
|
7 |
+
in vec3 Color;
|
8 |
+
in vec3 CamNormal;
|
9 |
+
in vec3 depth;
|
10 |
+
|
11 |
+
|
12 |
+
void main()
|
13 |
+
{
|
14 |
+
FragColor = vec4(Color,1.0);
|
15 |
+
|
16 |
+
vec3 cam_norm_normalized = normalize(CamNormal);
|
17 |
+
vec3 rgb = (cam_norm_normalized + 1.0) / 2.0;
|
18 |
+
FragNormal = vec4(rgb, 1.0);
|
19 |
+
FragDepth = vec4(depth.xyz, 1.0);
|
20 |
+
}
|
lib / renderer / gl / data /color.vs
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330 core
|
2 |
+
|
3 |
+
layout (location = 0) in vec3 a_Position;
|
4 |
+
layout (location = 1) in vec3 a_Color;
|
5 |
+
layout (location = 2) in vec3 a_Normal;
|
6 |
+
|
7 |
+
out vec3 CamNormal;
|
8 |
+
out vec3 CamPos;
|
9 |
+
out vec3 Color;
|
10 |
+
out vec3 depth;
|
11 |
+
|
12 |
+
|
13 |
+
uniform mat3 RotMat;
|
14 |
+
uniform mat4 NormMat;
|
15 |
+
uniform mat4 ModelMat;
|
16 |
+
uniform mat4 PerspMat;
|
17 |
+
|
18 |
+
void main()
|
19 |
+
{
|
20 |
+
vec3 a_Position = (NormMat * vec4(a_Position,1.0)).xyz;
|
21 |
+
gl_Position = PerspMat * ModelMat * vec4(RotMat * a_Position, 1.0);
|
22 |
+
Color = a_Color;
|
23 |
+
|
24 |
+
mat3 R = mat3(ModelMat) * RotMat;
|
25 |
+
CamNormal = (R * a_Normal);
|
26 |
+
|
27 |
+
depth = vec3(gl_Position.z / gl_Position.w);
|
28 |
+
|
29 |
+
}
|
lib / renderer / gl / data /normal.fs
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
out vec4 FragColor;
|
4 |
+
|
5 |
+
in vec3 CamNormal;
|
6 |
+
|
7 |
+
void main()
|
8 |
+
{
|
9 |
+
vec3 cam_norm_normalized = normalize(CamNormal);
|
10 |
+
vec3 rgb = (cam_norm_normalized + 1.0) / 2.0;
|
11 |
+
FragColor = vec4(rgb, 1.0);
|
12 |
+
}
|
lib / renderer / gl / data /normal.vs
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
layout (location = 0) in vec3 Position;
|
4 |
+
layout (location = 1) in vec3 Normal;
|
5 |
+
|
6 |
+
out vec3 CamNormal;
|
7 |
+
|
8 |
+
uniform mat4 ModelMat;
|
9 |
+
uniform mat4 PerspMat;
|
10 |
+
|
11 |
+
void main()
|
12 |
+
{
|
13 |
+
gl_Position = PerspMat * ModelMat * vec4(Position, 1.0);
|
14 |
+
CamNormal = (ModelMat * vec4(Normal, 0.0)).xyz;
|
15 |
+
}
|
lib / renderer / gl / data /prt.fs
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
uniform vec3 SHCoeffs[9];
|
4 |
+
uniform uint analytic;
|
5 |
+
|
6 |
+
uniform uint hasNormalMap;
|
7 |
+
uniform uint hasAlbedoMap;
|
8 |
+
|
9 |
+
uniform sampler2D AlbedoMap;
|
10 |
+
uniform sampler2D NormalMap;
|
11 |
+
|
12 |
+
in VertexData {
|
13 |
+
vec3 Position;
|
14 |
+
vec3 Depth;
|
15 |
+
vec3 ModelNormal;
|
16 |
+
vec2 Texcoord;
|
17 |
+
vec3 Tangent;
|
18 |
+
vec3 Bitangent;
|
19 |
+
vec3 PRT1;
|
20 |
+
vec3 PRT2;
|
21 |
+
vec3 PRT3;
|
22 |
+
vec3 Label;
|
23 |
+
} VertexIn;
|
24 |
+
|
25 |
+
layout (location = 0) out vec4 FragColor;
|
26 |
+
layout (location = 1) out vec4 FragNormal;
|
27 |
+
layout (location = 2) out vec4 FragPosition;
|
28 |
+
layout (location = 3) out vec4 FragAlbedo;
|
29 |
+
layout (location = 4) out vec4 FragShading;
|
30 |
+
layout (location = 5) out vec4 FragPRT1;
|
31 |
+
layout (location = 6) out vec4 FragPRT2;
|
32 |
+
// layout (location = 7) out vec4 FragPRT3;
|
33 |
+
layout (location = 7) out vec4 FragLabel;
|
34 |
+
|
35 |
+
|
36 |
+
vec4 gammaCorrection(vec4 vec, float g)
|
37 |
+
{
|
38 |
+
return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
|
39 |
+
}
|
40 |
+
|
41 |
+
vec3 gammaCorrection(vec3 vec, float g)
|
42 |
+
{
|
43 |
+
return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
|
44 |
+
}
|
45 |
+
|
46 |
+
void evaluateH(vec3 n, out float H[9])
|
47 |
+
{
|
48 |
+
float c1 = 0.429043, c2 = 0.511664,
|
49 |
+
c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
|
50 |
+
|
51 |
+
H[0] = c4;
|
52 |
+
H[1] = 2.0 * c2 * n[1];
|
53 |
+
H[2] = 2.0 * c2 * n[2];
|
54 |
+
H[3] = 2.0 * c2 * n[0];
|
55 |
+
H[4] = 2.0 * c1 * n[0] * n[1];
|
56 |
+
H[5] = 2.0 * c1 * n[1] * n[2];
|
57 |
+
H[6] = c3 * n[2] * n[2] - c5;
|
58 |
+
H[7] = 2.0 * c1 * n[2] * n[0];
|
59 |
+
H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
|
60 |
+
}
|
61 |
+
|
62 |
+
vec3 evaluateLightingModel(vec3 normal)
|
63 |
+
{
|
64 |
+
float H[9];
|
65 |
+
evaluateH(normal, H);
|
66 |
+
vec3 res = vec3(0.0);
|
67 |
+
for (int i = 0; i < 9; i++) {
|
68 |
+
res += H[i] * SHCoeffs[i];
|
69 |
+
}
|
70 |
+
return res;
|
71 |
+
}
|
72 |
+
|
73 |
+
// nC: coarse geometry normal, nH: fine normal from normal map
|
74 |
+
vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
|
75 |
+
{
|
76 |
+
float HC[9], HH[9];
|
77 |
+
evaluateH(nC, HC);
|
78 |
+
evaluateH(nH, HH);
|
79 |
+
|
80 |
+
vec3 res = vec3(0.0);
|
81 |
+
vec3 shadow = vec3(0.0);
|
82 |
+
vec3 unshadow = vec3(0.0);
|
83 |
+
for(int i = 0; i < 3; ++i){
|
84 |
+
for(int j = 0; j < 3; ++j){
|
85 |
+
int id = i*3+j;
|
86 |
+
res += HH[id]* SHCoeffs[id];
|
87 |
+
shadow += prt[i][j] * SHCoeffs[id];
|
88 |
+
unshadow += HC[id] * SHCoeffs[id];
|
89 |
+
}
|
90 |
+
}
|
91 |
+
vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
|
92 |
+
res = ratio * res;
|
93 |
+
|
94 |
+
return res;
|
95 |
+
}
|
96 |
+
|
97 |
+
vec3 evaluateLightingModelPRT(mat3 prt)
|
98 |
+
{
|
99 |
+
vec3 res = vec3(0.0);
|
100 |
+
for(int i = 0; i < 3; ++i){
|
101 |
+
for(int j = 0; j < 3; ++j){
|
102 |
+
res += prt[i][j] * SHCoeffs[i*3+j];
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
return res;
|
107 |
+
}
|
108 |
+
|
109 |
+
void main()
|
110 |
+
{
|
111 |
+
vec2 uv = VertexIn.Texcoord;
|
112 |
+
vec3 nC = normalize(VertexIn.ModelNormal);
|
113 |
+
vec3 nml = nC;
|
114 |
+
mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
|
115 |
+
|
116 |
+
if(hasAlbedoMap == uint(0))
|
117 |
+
FragAlbedo = vec4(1.0);
|
118 |
+
else
|
119 |
+
FragAlbedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
|
120 |
+
|
121 |
+
if(hasNormalMap == uint(0))
|
122 |
+
{
|
123 |
+
if(analytic == uint(0))
|
124 |
+
FragShading = vec4(evaluateLightingModelPRT(prt), 1.0f);
|
125 |
+
else
|
126 |
+
FragShading = vec4(evaluateLightingModel(nC), 1.0f);
|
127 |
+
}
|
128 |
+
else
|
129 |
+
{
|
130 |
+
vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
|
131 |
+
|
132 |
+
mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
|
133 |
+
vec3 nH = normalize(TBN * n_tan);
|
134 |
+
|
135 |
+
if(analytic == uint(0))
|
136 |
+
FragShading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
|
137 |
+
else
|
138 |
+
FragShading = vec4(evaluateLightingModel(nH), 1.0f);
|
139 |
+
|
140 |
+
nml = nH;
|
141 |
+
}
|
142 |
+
|
143 |
+
FragShading = gammaCorrection(FragShading, 2.2);
|
144 |
+
FragColor = clamp(FragAlbedo * FragShading, 0.0, 1.0);
|
145 |
+
FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
|
146 |
+
FragPosition = vec4(VertexIn.Depth.xyz, 1.0);
|
147 |
+
FragShading = vec4(clamp(0.5*FragShading.xyz, 0.0, 1.0),1.0);
|
148 |
+
// FragColor = gammaCorrection(clamp(FragAlbedo * FragShading, 0.0, 1.0),2.2);
|
149 |
+
// FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
|
150 |
+
// FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
|
151 |
+
// FragShading = vec4(gammaCorrection(clamp(0.5*FragShading.xyz, 0.0, 1.0),2.2),1.0);
|
152 |
+
// FragAlbedo = gammaCorrection(FragAlbedo,2.2);
|
153 |
+
FragPRT1 = vec4(VertexIn.PRT1,1.0);
|
154 |
+
FragPRT2 = vec4(VertexIn.PRT2,1.0);
|
155 |
+
// FragPRT3 = vec4(VertexIn.PRT3,1.0);
|
156 |
+
FragLabel = vec4(VertexIn.Label,1.0);
|
157 |
+
}
|