diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..1e8a20b6bd8a30656a0d54968fa8b6ee5461b5bf --- /dev/null +++ b/.gitattributes @@ -0,0 +1,52 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +checkpoints/BFM_Fitting/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text +checkpoints/BFM_Fitting/BFM09_model_info.mat filter=lfs diff=lfs merge=lfs -text +checkpoints/facevid2vid_00189-model.pth.tar filter=lfs diff=lfs merge=lfs -text +checkpoints/mapping_00229-model.pth.tar filter=lfs diff=lfs merge=lfs -text +checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text +examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text +examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bd6a181ac8d66af3c486ec82538d5f297645b990 --- /dev/null +++ b/.gitignore @@ -0,0 +1,155 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +results/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..5ddc6e3d8b246534a58f9612a88b309fa7e10795 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,59 @@ +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y --no-install-recommends \ + git \ + zip \ + unzip \ + git-lfs \ + wget \ + curl \ + # ffmpeg \ + ffmpeg \ + x264 \ + # python build dependencies \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + libncursesw5-dev \ + xz-utils \ + tk-dev \ + libxml2-dev \ + libxmlsec1-dev \ + libffi-dev \ + liblzma-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:${PATH} +WORKDIR ${HOME}/app + +RUN curl https://pyenv.run | bash +ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH} +ENV PYTHON_VERSION=3.10.9 +RUN pyenv install ${PYTHON_VERSION} && \ + pyenv global ${PYTHON_VERSION} && \ + pyenv rehash && \ + pip install --no-cache-dir -U pip setuptools wheel + +RUN pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1 +COPY --chown=1000 requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -U -r /tmp/requirements.txt + +COPY --chown=1000 . ${HOME}/app +RUN ls -a +ENV PYTHONPATH=${HOME}/app \ + PYTHONUNBUFFERED=1 \ + GRADIO_ALLOW_FLAGGING=never \ + GRADIO_NUM_PORTS=1 \ + GRADIO_SERVER_NAME=0.0.0.0 \ + GRADIO_THEME=huggingface \ + SYSTEM=spaces +CMD ["python", "app.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b2a615ac931ce1e81df51deb56c3df2414b59e63 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Tencent AI Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d6468baa29b635161882902a3efbdb98f3c71317 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +--- +title: SadTalker +emoji: 😭 +colorFrom: purple +colorTo: green +sdk: gradio +sdk_version: 3.23.0 +app_file: app.py +pinned: false +license: mit +duplicated_from: vinthony/SadTalker +--- + + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..44ffb9b5bfe538cbbabacf93652e5aaa45be50f5 --- /dev/null +++ b/app.py @@ -0,0 +1,111 @@ +import os, sys +import tempfile +import gradio as gr +from modules.text2speech import text2speech +from modules.sadtalker_test import SadTalker + +def get_driven_audio(audio): + if os.path.isfile(audio): + return audio + else: + save_path = tempfile.NamedTemporaryFile( + delete=False, + suffix=("." + "wav"), + ) + gen_audio = text2speech(audio, save_path.name) + return gen_audio, gen_audio + +def get_source_image(image): + return image + +def sadtalker_demo(result_dir='./tmp/'): + + sad_talker = SadTalker() + with gr.Blocks(analytics_enabled=False) as sadtalker_interface: + gr.Markdown("

😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

\ + Arxiv       \ + Homepage       \ + Github
") + + with gr.Row(): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="sadtalker_source_image"): + with gr.TabItem('Upload image'): + with gr.Row(): + source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256) + + with gr.Tabs(elem_id="sadtalker_driven_audio"): + with gr.TabItem('Upload audio(wav/mp3 only currently)'): + with gr.Column(variant='panel'): + driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath") + + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="sadtalker_checkbox"): + with gr.TabItem('Settings'): + with gr.Column(variant='panel'): + is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion)").style(container=True) + is_resize_mode = gr.Checkbox(label="Resize Mode (⚠️ Resize mode need manually crop the image firstly, can handle larger image crop)").style(container=True) + is_enhance_mode = gr.Checkbox(label="Enhance Mode (better face quality )").style(container=True) + submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary') + + with gr.Tabs(elem_id="sadtalker_genearted"): + gen_video = gr.Video(label="Generated video", format="mp4").style(width=256) + gen_text = gr.Textbox(visible=False) + + with gr.Row(): + examples = [ + [ + 'examples/source_image/art_10.png', + 'examples/driven_audio/deyu.wav', + True, + False, + False + ], + [ + 'examples/source_image/art_1.png', + 'examples/driven_audio/fayu.wav', + True, + True, + False + ], + [ + 'examples/source_image/art_9.png', + 'examples/driven_audio/itosinger1.wav', + True, + False, + True + ] + ] + gr.Examples(examples=examples, + inputs=[ + source_image, + driven_audio, + is_still_mode, + is_resize_mode, + is_enhance_mode, + gr.Textbox(value=result_dir, visible=False)], + outputs=[gen_video, gen_text], + fn=sad_talker.test, + cache_examples=os.getenv('SYSTEM') == 'spaces') + + submit.click( + fn=sad_talker.test, + inputs=[source_image, + driven_audio, + is_still_mode, + is_resize_mode, + is_enhance_mode, + gr.Textbox(value=result_dir, visible=False)], + outputs=[gen_video, gen_text] + ) + + return sadtalker_interface + + +if __name__ == "__main__": + + sadtalker_result_dir = os.path.join('./', 'results') + demo = sadtalker_demo(sadtalker_result_dir) + demo.launch() + + diff --git a/checkpoints/BFM_Fitting.zip b/checkpoints/BFM_Fitting.zip new file mode 100644 index 0000000000000000000000000000000000000000..895479e053ea9f18c12cf68217cd58543b1d2d84 --- /dev/null +++ b/checkpoints/BFM_Fitting.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:785f77f3de288568e76666cd419dcf40646d3f74eae6d4fa3b766c933087a9d8 +size 404051745 diff --git a/checkpoints/BFM_Fitting/01_MorphableModel.mat b/checkpoints/BFM_Fitting/01_MorphableModel.mat new file mode 100644 index 0000000000000000000000000000000000000000..f251485b55d35adac0ad4f1622a47d7a39a1502c --- /dev/null +++ b/checkpoints/BFM_Fitting/01_MorphableModel.mat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2 +size 240875364 diff --git a/checkpoints/BFM_Fitting/BFM09_model_info.mat b/checkpoints/BFM_Fitting/BFM09_model_info.mat new file mode 100644 index 0000000000000000000000000000000000000000..605b1aa60286236b4041d15fccdd978b9d89761d --- /dev/null +++ b/checkpoints/BFM_Fitting/BFM09_model_info.mat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db8d00544f0b0182f1b8430a3bb87662b3ff674eb33c84e6f52dbe2971adb81b +size 127170280 diff --git a/checkpoints/BFM_Fitting/BFM_exp_idx.mat b/checkpoints/BFM_Fitting/BFM_exp_idx.mat new file mode 100644 index 0000000000000000000000000000000000000000..1146e4e9c3bef303a497383aa7974c014fe945c7 Binary files /dev/null and b/checkpoints/BFM_Fitting/BFM_exp_idx.mat differ diff --git a/checkpoints/BFM_Fitting/BFM_front_idx.mat b/checkpoints/BFM_Fitting/BFM_front_idx.mat new file mode 100644 index 0000000000000000000000000000000000000000..b9d7b0953dd1dc5b1e28144610485409ac321f9b Binary files /dev/null and b/checkpoints/BFM_Fitting/BFM_front_idx.mat differ diff --git a/checkpoints/BFM_Fitting/Exp_Pca.bin b/checkpoints/BFM_Fitting/Exp_Pca.bin new file mode 100644 index 0000000000000000000000000000000000000000..3c1785e6abc52b13e54a573f9f3ebc099915b1e0 --- /dev/null +++ b/checkpoints/BFM_Fitting/Exp_Pca.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726 +size 51086404 diff --git a/checkpoints/BFM_Fitting/facemodel_info.mat b/checkpoints/BFM_Fitting/facemodel_info.mat new file mode 100644 index 0000000000000000000000000000000000000000..3e516ec7297fa3248098f49ecea10579f4831c0a Binary files /dev/null and b/checkpoints/BFM_Fitting/facemodel_info.mat differ diff --git a/checkpoints/BFM_Fitting/select_vertex_id.mat b/checkpoints/BFM_Fitting/select_vertex_id.mat new file mode 100644 index 0000000000000000000000000000000000000000..5b8b220093d93b133acc94ffed159f31a74854cd Binary files /dev/null and b/checkpoints/BFM_Fitting/select_vertex_id.mat differ diff --git a/checkpoints/BFM_Fitting/similarity_Lm3D_all.mat b/checkpoints/BFM_Fitting/similarity_Lm3D_all.mat new file mode 100644 index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d Binary files /dev/null and b/checkpoints/BFM_Fitting/similarity_Lm3D_all.mat differ diff --git a/checkpoints/BFM_Fitting/std_exp.txt b/checkpoints/BFM_Fitting/std_exp.txt new file mode 100644 index 0000000000000000000000000000000000000000..767b8de4ea1ca78b6f22b98ff2dee4fa345500bb --- /dev/null +++ b/checkpoints/BFM_Fitting/std_exp.txt @@ -0,0 +1 @@ +453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19 \ No newline at end of file diff --git a/checkpoints/auido2exp_00300-model.pth b/checkpoints/auido2exp_00300-model.pth new file mode 100644 index 0000000000000000000000000000000000000000..927072b6ea4aafde874e6cc0f51594f20e8dac17 --- /dev/null +++ b/checkpoints/auido2exp_00300-model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7608f0e6b477e50e03ca569ac5b04a841b9217f89d502862fc78fda4e46dec4 +size 34278319 diff --git a/checkpoints/auido2pose_00140-model.pth b/checkpoints/auido2pose_00140-model.pth new file mode 100644 index 0000000000000000000000000000000000000000..db44aee66c9246710511e59b552fc041aebe5d8a --- /dev/null +++ b/checkpoints/auido2pose_00140-model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fba6701852dc57efbed25b1e4276e4ff752941860d69fc4429f08a02326ebce +size 95916155 diff --git a/checkpoints/epoch_20.pth b/checkpoints/epoch_20.pth new file mode 100644 index 0000000000000000000000000000000000000000..97ebd6753f7ca4bcd39d3b82e7109b66a2dbc1fb --- /dev/null +++ b/checkpoints/epoch_20.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b +size 288860037 diff --git a/checkpoints/facevid2vid_00189-model.pth.tar b/checkpoints/facevid2vid_00189-model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..6c676eb119186e3ae866188f5b0ab2cff10473bc --- /dev/null +++ b/checkpoints/facevid2vid_00189-model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbad01d46f0510276dc4521322dde6824a873a4222cd0740c85762e7067ea71d +size 2112619148 diff --git a/checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip b/checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip new file mode 100644 index 0000000000000000000000000000000000000000..6bb44e2ec4c154607a919de8e3d5c5448e86b586 --- /dev/null +++ b/checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd938726adb1f15f361263cce2db9cb820c42585fa8796ec72ce19107f369a46 +size 96316515 diff --git a/checkpoints/hub/checkpoints/s3fd-619a316812.pth b/checkpoints/hub/checkpoints/s3fd-619a316812.pth new file mode 100644 index 0000000000000000000000000000000000000000..895538e7fb6df3ad6e0e80d6d48b3c4e60cd9e6c --- /dev/null +++ b/checkpoints/hub/checkpoints/s3fd-619a316812.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543 +size 89843225 diff --git a/checkpoints/mapping_00229-model.pth.tar b/checkpoints/mapping_00229-model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..6400233ae3fa5ff9426800ef761fd6c830bc0cd7 --- /dev/null +++ b/checkpoints/mapping_00229-model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1 +size 155521183 diff --git a/checkpoints/shape_predictor_68_face_landmarks.dat b/checkpoints/shape_predictor_68_face_landmarks.dat new file mode 100644 index 0000000000000000000000000000000000000000..1e5da4f9a556bec8582e6c55b89b3e6bfdd60021 --- /dev/null +++ b/checkpoints/shape_predictor_68_face_landmarks.dat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f +size 99693937 diff --git a/checkpoints/wav2lip.pth b/checkpoints/wav2lip.pth new file mode 100644 index 0000000000000000000000000000000000000000..c575a07ac4e62abfd60cb8681ebb6df241cf31e6 --- /dev/null +++ b/checkpoints/wav2lip.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78b681b68ad9fe6c6fb1debc6ff43ad05834a8af8a62ffc4167b7b34ef63c37 +size 435807851 diff --git a/config/auido2exp.yaml b/config/auido2exp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7369dbf350476e14a1d600507f1f8b7d8aa6ecd3 --- /dev/null +++ b/config/auido2exp.yaml @@ -0,0 +1,58 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt + TRAIN_BATCH_SIZE: 32 + EVAL_BATCH_SIZE: 32 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm + LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + NUM_REPEATS: 2 + T: 40 + + +MODEL: + FRAMEWORK: V2 + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 128 + SEQ_LEN: 32 + LATENT_SIZE: 256 + ENCODER_LAYER_SIZES: [192, 1024] + DECODER_LAYER_SIZES: [1024, 192] + + +TRAIN: + MAX_EPOCH: 300 + GENERATOR: + LR: 2.0e-5 + DISCRIMINATOR: + LR: 1.0e-5 + LOSS: + W_FEAT: 0 + W_COEFF_EXP: 2 + W_LM: 1.0e-2 + W_LM_MOUTH: 0 + W_REG: 0 + W_SYNC: 0 + W_COLOR: 0 + W_EXPRESSION: 0 + W_LIPREADING: 0.01 + W_LIPREADING_VV: 0 + W_EYE_BLINK: 4 + +TAG: + NAME: small_dataset + + diff --git a/config/auido2pose.yaml b/config/auido2pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc61f94d12f406f2d8d02545e55b61075051484d --- /dev/null +++ b/config/auido2pose.yaml @@ -0,0 +1,49 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt + TRAIN_BATCH_SIZE: 64 + EVAL_BATCH_SIZE: 1 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + + +MODEL: + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 6 + SEQ_LEN: 32 + LATENT_SIZE: 64 + ENCODER_LAYER_SIZES: [192, 128] + DECODER_LAYER_SIZES: [128, 192] + + +TRAIN: + MAX_EPOCH: 150 + GENERATOR: + LR: 1.0e-4 + DISCRIMINATOR: + LR: 1.0e-4 + LOSS: + LAMBDA_REG: 1 + LAMBDA_LANDMARKS: 0 + LAMBDA_VERTICES: 0 + LAMBDA_GAN_MOTION: 0.7 + LAMBDA_GAN_COEFF: 0 + LAMBDA_KL: 1 + +TAG: + NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder + + diff --git a/config/facerender.yaml b/config/facerender.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9494ef82dfa16b16b7aa0b848ebdd6b23e739e2a --- /dev/null +++ b/config/facerender.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 70 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/examples/driven_audio/RD_Radio31_000.wav b/examples/driven_audio/RD_Radio31_000.wav new file mode 100644 index 0000000000000000000000000000000000000000..3b04940a0bff7481179c29bfc47553d9c4224bcf Binary files /dev/null and b/examples/driven_audio/RD_Radio31_000.wav differ diff --git a/examples/driven_audio/RD_Radio34_002.wav b/examples/driven_audio/RD_Radio34_002.wav new file mode 100644 index 0000000000000000000000000000000000000000..6813e812a8d1c57cb2f02eee3fece68a0864d96e Binary files /dev/null and b/examples/driven_audio/RD_Radio34_002.wav differ diff --git a/examples/driven_audio/RD_Radio36_000.wav b/examples/driven_audio/RD_Radio36_000.wav new file mode 100644 index 0000000000000000000000000000000000000000..c73adfed5f142886940bc249904d77f9e54befda Binary files /dev/null and b/examples/driven_audio/RD_Radio36_000.wav differ diff --git a/examples/driven_audio/RD_Radio40_000.wav b/examples/driven_audio/RD_Radio40_000.wav new file mode 100644 index 0000000000000000000000000000000000000000..88ce964e1734210451e3a364f87f8661db388b74 Binary files /dev/null and b/examples/driven_audio/RD_Radio40_000.wav differ diff --git a/examples/driven_audio/chinese_news.wav b/examples/driven_audio/chinese_news.wav new file mode 100644 index 0000000000000000000000000000000000000000..9232795586cbcb926cca70f90691a9e281d32ab9 --- /dev/null +++ b/examples/driven_audio/chinese_news.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b0f4d313a1ca671bc4831d60bcf0c12225efbffe6c0e93e54fbfe9bcd4021cb +size 1536078 diff --git a/examples/driven_audio/chinese_poem1.wav b/examples/driven_audio/chinese_poem1.wav new file mode 100644 index 0000000000000000000000000000000000000000..17c0871100d454bcd95b4281ab6b153c04724fe5 Binary files /dev/null and b/examples/driven_audio/chinese_poem1.wav differ diff --git a/examples/driven_audio/chinese_poem2.wav b/examples/driven_audio/chinese_poem2.wav new file mode 100644 index 0000000000000000000000000000000000000000..e3b294eceff5c5ee43124b7cfa42e4a70196a45f Binary files /dev/null and b/examples/driven_audio/chinese_poem2.wav differ diff --git a/examples/driven_audio/deyu.wav b/examples/driven_audio/deyu.wav new file mode 100644 index 0000000000000000000000000000000000000000..438cd45b36be0d7cec6732d1ffa1c396141a563e --- /dev/null +++ b/examples/driven_audio/deyu.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba1839c57770a2ab0b593ce814344bfd4d750da02acc9be9e8cf5b9113a0f88a +size 2694784 diff --git a/examples/driven_audio/eluosi.wav b/examples/driven_audio/eluosi.wav new file mode 100644 index 0000000000000000000000000000000000000000..336e85fe5cb8d7110fbade7684cce4a33fdffb98 --- /dev/null +++ b/examples/driven_audio/eluosi.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4a3593815dc7b68c256672baa61934c9479efa770af2065fb0886f02713606e +size 1786672 diff --git a/examples/driven_audio/fayu.wav b/examples/driven_audio/fayu.wav new file mode 100644 index 0000000000000000000000000000000000000000..bf5cb6e65b2f959174facc80e13ce145226991cc --- /dev/null +++ b/examples/driven_audio/fayu.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16ebd13626ae4171030b4ea05cceef06078483c352e4b68d469fc2a52bfffceb +size 1940428 diff --git a/examples/driven_audio/imagine.wav b/examples/driven_audio/imagine.wav new file mode 100644 index 0000000000000000000000000000000000000000..c02a95b80b8e2b5c4353a4047239c361e9e3d01a --- /dev/null +++ b/examples/driven_audio/imagine.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2db410217e074d91ae6011e1c5dc0b94f02d05d381c50af8e54253eeacad17d2 +size 1618510 diff --git a/examples/driven_audio/itosinger1.wav b/examples/driven_audio/itosinger1.wav new file mode 100644 index 0000000000000000000000000000000000000000..4937dbb264e2fc24d4752baf8b802b0bac41be24 Binary files /dev/null and b/examples/driven_audio/itosinger1.wav differ diff --git a/examples/driven_audio/japanese.wav b/examples/driven_audio/japanese.wav new file mode 100644 index 0000000000000000000000000000000000000000..63db9ffc287a9186f144b635f87bf352ba30ff22 --- /dev/null +++ b/examples/driven_audio/japanese.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3db5426d0b158799e2be4f609b11f75bfbd4affffe18e9a1c8e6f241fcdedcfc +size 2622712 diff --git a/examples/source_image/art_0.png b/examples/source_image/art_0.png new file mode 100644 index 0000000000000000000000000000000000000000..d8d97645a4ecd9018bf2ad6d9094cf581f816f58 Binary files /dev/null and b/examples/source_image/art_0.png differ diff --git a/examples/source_image/art_1.png b/examples/source_image/art_1.png new file mode 100644 index 0000000000000000000000000000000000000000..4388abe026a5ba1f6c2e9f3a782564bb611f5781 Binary files /dev/null and b/examples/source_image/art_1.png differ diff --git a/examples/source_image/art_10.png b/examples/source_image/art_10.png new file mode 100644 index 0000000000000000000000000000000000000000..5f6568b30f063b09cef08c54df629dae7ff54360 Binary files /dev/null and b/examples/source_image/art_10.png differ diff --git a/examples/source_image/art_11.png b/examples/source_image/art_11.png new file mode 100644 index 0000000000000000000000000000000000000000..4caf17ca866fe54cc5c3af33fb0e93114da1bfb9 Binary files /dev/null and b/examples/source_image/art_11.png differ diff --git a/examples/source_image/art_12.png b/examples/source_image/art_12.png new file mode 100644 index 0000000000000000000000000000000000000000..e15306c30f09807f7df80504032cc39b1c265b6a Binary files /dev/null and b/examples/source_image/art_12.png differ diff --git a/examples/source_image/art_13.png b/examples/source_image/art_13.png new file mode 100644 index 0000000000000000000000000000000000000000..129374120f1f01580a9baa0f37d8bbbe904b2373 Binary files /dev/null and b/examples/source_image/art_13.png differ diff --git a/examples/source_image/art_14.png b/examples/source_image/art_14.png new file mode 100644 index 0000000000000000000000000000000000000000..0f0489bf7cebb41346f029421fdf41dc2e52519b Binary files /dev/null and b/examples/source_image/art_14.png differ diff --git a/examples/source_image/art_15.png b/examples/source_image/art_15.png new file mode 100644 index 0000000000000000000000000000000000000000..a0af242a4b3e962aef8ce5c10a5026646509bfc6 Binary files /dev/null and b/examples/source_image/art_15.png differ diff --git a/examples/source_image/art_16.png b/examples/source_image/art_16.png new file mode 100644 index 0000000000000000000000000000000000000000..afb659b641b564a3d850229c67d014483516af67 --- /dev/null +++ b/examples/source_image/art_16.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f6d350055eea3abe35ee3fe9df80dcd99d8edae66ef4fc20bf06168bf189f25 +size 1480263 diff --git a/examples/source_image/art_17.png b/examples/source_image/art_17.png new file mode 100644 index 0000000000000000000000000000000000000000..875a3e3c2e985efe7407b6c8fff99faa591b9811 --- /dev/null +++ b/examples/source_image/art_17.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05747bb45dcf271d9bb24344bd1bce0e0746d24ce4e13545b27ad40b50c3bfe7 +size 2092096 diff --git a/examples/source_image/art_18.png b/examples/source_image/art_18.png new file mode 100644 index 0000000000000000000000000000000000000000..96358e0e542f66d1f4fd92acd092124e738fc6fe Binary files /dev/null and b/examples/source_image/art_18.png differ diff --git a/examples/source_image/art_19.png b/examples/source_image/art_19.png new file mode 100644 index 0000000000000000000000000000000000000000..4f477a1ab58994e3cb4140b1a8ca59dcc428f387 Binary files /dev/null and b/examples/source_image/art_19.png differ diff --git a/examples/source_image/art_2.png b/examples/source_image/art_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9560673430d461ad94980731ee0b404fcda32084 Binary files /dev/null and b/examples/source_image/art_2.png differ diff --git a/examples/source_image/art_20.png b/examples/source_image/art_20.png new file mode 100644 index 0000000000000000000000000000000000000000..de1ea5c975dbed93ce80c1aa70f6298703acf70f Binary files /dev/null and b/examples/source_image/art_20.png differ diff --git a/examples/source_image/art_3.png b/examples/source_image/art_3.png new file mode 100644 index 0000000000000000000000000000000000000000..f2d3c117ed2d7074ec5427ebd1e68147e4476031 --- /dev/null +++ b/examples/source_image/art_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81be3a9cc605ab01cbf741330b406db5246e8bbbcb443ad43ffeca2ef161e005 +size 1353396 diff --git a/examples/source_image/art_4.png b/examples/source_image/art_4.png new file mode 100644 index 0000000000000000000000000000000000000000..ce5fda1d95dd1d6d497648fbfb95dc53380d367e --- /dev/null +++ b/examples/source_image/art_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab322220d8eab1bfefdaedea91ca5d08a34258c1ab1e585a9b1c85b32968f983 +size 3625669 diff --git a/examples/source_image/art_5.png b/examples/source_image/art_5.png new file mode 100644 index 0000000000000000000000000000000000000000..2726da0cb91b4ab9d54eef21efa653d2f8cda959 --- /dev/null +++ b/examples/source_image/art_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:199217b4c839ed849577aedcad32f2bce934628b9783ba4654a93756b25e7896 +size 1228028 diff --git a/examples/source_image/art_6.png b/examples/source_image/art_6.png new file mode 100644 index 0000000000000000000000000000000000000000..e9f6d8f272dc9bf971285667ecbe765ede41c967 Binary files /dev/null and b/examples/source_image/art_6.png differ diff --git a/examples/source_image/art_7.png b/examples/source_image/art_7.png new file mode 100644 index 0000000000000000000000000000000000000000..d8cc380aacb76a6ce9f5e41086bb1fb375a4e7db Binary files /dev/null and b/examples/source_image/art_7.png differ diff --git a/examples/source_image/art_8.png b/examples/source_image/art_8.png new file mode 100644 index 0000000000000000000000000000000000000000..169035fba5a1ab690564e661e2e5ea95a5a71e87 --- /dev/null +++ b/examples/source_image/art_8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d704497947c07ac16534299451fc0526acddf286c2ab4ceb48161ff6facc2af +size 3119298 diff --git a/examples/source_image/art_9.png b/examples/source_image/art_9.png new file mode 100644 index 0000000000000000000000000000000000000000..61a02dd4a57d382f215a73d635959ae45c208635 --- /dev/null +++ b/examples/source_image/art_9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90f84739e2aa2388efaf0fac2b57a82df279b213a8dab9faa7af8ae7468b4e80 +size 1262963 diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3582c98cf92b34c93fdf3df585aeb84b2c7d77f4 --- /dev/null +++ b/inference.py @@ -0,0 +1,134 @@ +import torch +from time import strftime +import os, sys, time +from argparse import ArgumentParser + +from src.utils.preprocess import CropAndExtract +from src.test_audio2coeff import Audio2Coeff +from src.facerender.animate import AnimateFromCoeff +from src.generate_batch import get_data +from src.generate_facerender_batch import get_facerender_data + +def main(args): + #torch.backends.cudnn.enabled = False + + pic_path = args.source_image + audio_path = args.driven_audio + save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S")) + os.makedirs(save_dir, exist_ok=True) + pose_style = args.pose_style + device = args.device + batch_size = args.batch_size + camera_yaw_list = args.camera_yaw + camera_pitch_list = args.camera_pitch + camera_roll_list = args.camera_roll + + current_code_path = sys.argv[0] + current_root_path = os.path.split(current_code_path)[0] + + os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir) + + path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat') + path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth') + dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting') + wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth') + + audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth') + audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') + + audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth') + audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') + + free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar') + mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar') + facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml') + + #init model + print(path_of_net_recon_model) + preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device) + + print(audio2pose_checkpoint) + print(audio2exp_checkpoint) + audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, + audio2exp_checkpoint, audio2exp_yaml_path, + wav2lip_checkpoint, device) + + print(free_view_checkpoint) + print(mapping_checkpoint) + animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, + facerender_yaml_path, device) + + #crop image and extract 3dmm from image + first_frame_dir = os.path.join(save_dir, 'first_frame_dir') + os.makedirs(first_frame_dir, exist_ok=True) + first_coeff_path, crop_pic_path = preprocess_model.generate(pic_path, first_frame_dir) + if first_coeff_path is None: + print("Can't get the coeffs of the input") + return + + #audio2ceoff + batch = get_data(first_coeff_path, audio_path, device) + coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style) + + # 3dface render + if args.face3dvis: + from src.face3d.visualize import gen_composed_video + gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4')) + + #coeff2video + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, + batch_size, camera_yaw_list, camera_pitch_list, camera_roll_list, + expression_scale=args.expression_scale, still_mode=args.still) + + animate_from_coeff.generate(data, save_dir, enhancer=args.enhancer) + video_name = data['video_name'] + + if args.enhancer is not None: + print(f'The generated video is named {video_name}_enhanced in {save_dir}') + else: + print(f'The generated video is named {video_name} in {save_dir}') + + return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4') + + +if __name__ == '__main__': + + parser = ArgumentParser() + parser.add_argument("--driven_audio", default='./examples/driven_audio/japanese.wav', help="path to driven audio") + parser.add_argument("--source_image", default='./examples/source_image/art_0.png', help="path to source image") + parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output") + parser.add_argument("--result_dir", default='./results', help="path to output") + parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)") + parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender") + parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender") + parser.add_argument('--camera_yaw', nargs='+', type=int, default=[0], help="the camera yaw degree") + parser.add_argument('--camera_pitch', nargs='+', type=int, default=[0], help="the camera pitch degree") + parser.add_argument('--camera_roll', nargs='+', type=int, default=[0], help="the camera roll degree") + parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [GFPGAN]") + parser.add_argument("--cpu", dest="cpu", action="store_true") + parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") + parser.add_argument("--still", action="store_true") + + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='not use') + parser.add_argument('--init_path', type=str, default=None, help='not Use') + parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # default renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + + args = parser.parse_args() + + if torch.cuda.is_available() and not args.cpu: + args.device = "cuda" + else: + args.device = "cpu" + + main(args) + diff --git a/modules/__pycache__/gfpgan_inference.cpython-38.pyc b/modules/__pycache__/gfpgan_inference.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f5c7d7953f80bd717ae9a965fbe7b7a6d0e3a3 Binary files /dev/null and b/modules/__pycache__/gfpgan_inference.cpython-38.pyc differ diff --git a/modules/__pycache__/gfpgan_inference.cpython-39.pyc b/modules/__pycache__/gfpgan_inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0ecebb67b685d1cd74fd4266ddb33d75f3c60b2 Binary files /dev/null and b/modules/__pycache__/gfpgan_inference.cpython-39.pyc differ diff --git a/modules/__pycache__/sadtalker_test.cpython-38.pyc b/modules/__pycache__/sadtalker_test.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c54ce9b8728a52636f9cb9f9c47616709d04cfe4 Binary files /dev/null and b/modules/__pycache__/sadtalker_test.cpython-38.pyc differ diff --git a/modules/__pycache__/sadtalker_test.cpython-39.pyc b/modules/__pycache__/sadtalker_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b73e50c9766b517d9d765d1e3e58b8a7153b8dd Binary files /dev/null and b/modules/__pycache__/sadtalker_test.cpython-39.pyc differ diff --git a/modules/__pycache__/text2speech.cpython-38.pyc b/modules/__pycache__/text2speech.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ad4127ce0050c2215bdb797974ad849d12a96c Binary files /dev/null and b/modules/__pycache__/text2speech.cpython-38.pyc differ diff --git a/modules/__pycache__/text2speech.cpython-39.pyc b/modules/__pycache__/text2speech.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74134d79b4e1171fed217853c1457430705b6616 Binary files /dev/null and b/modules/__pycache__/text2speech.cpython-39.pyc differ diff --git a/modules/gfpgan_inference.py b/modules/gfpgan_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e7dc80eac012906b797843aa6019c2c4a39b3b --- /dev/null +++ b/modules/gfpgan_inference.py @@ -0,0 +1,36 @@ +import os,sys + +def gfpgan(scale, origin_mp4_path): + current_code_path = sys.argv[0] + current_root_path = os.path.split(current_code_path)[0] + print(current_root_path) + gfpgan_code_path = current_root_path+'/repositories/GFPGAN/inference_gfpgan.py' + print(gfpgan_code_path) + + #video2pic + result_dir = os.path.split(origin_mp4_path)[0] + video_name = os.path.split(origin_mp4_path)[1] + video_name = video_name.split('.')[0] + print(video_name) + str_scale = str(scale).replace('.', '_') + output_mp4_path = os.path.join(result_dir, video_name+'##'+str_scale+'.mp4') + temp_output_mp4_path = os.path.join(result_dir, 'temp_'+video_name+'##'+str_scale+'.mp4') + + audio_name = video_name.split('##')[-1] + audio_path = os.path.join(result_dir, audio_name+'.wav') + temp_pic_dir1 = os.path.join(result_dir, video_name) + temp_pic_dir2 = os.path.join(result_dir, video_name+'##'+str_scale) + os.makedirs(temp_pic_dir1, exist_ok=True) + os.makedirs(temp_pic_dir2, exist_ok=True) + cmd1 = 'ffmpeg -i \"{}\" -start_number 0 \"{}\"/%06d.png -loglevel error -y'.format(origin_mp4_path, temp_pic_dir1) + os.system(cmd1) + cmd2 = f'python {gfpgan_code_path} -i {temp_pic_dir1} -o {temp_pic_dir2} -s {scale}' + os.system(cmd2) + cmd3 = f'ffmpeg -r 25 -f image2 -i {temp_pic_dir2}/%06d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {temp_output_mp4_path}' + os.system(cmd3) + cmd4 = f'ffmpeg -y -i {temp_output_mp4_path} -i {audio_path} -vcodec copy {output_mp4_path}' + os.system(cmd4) + #shutil.rmtree(temp_pic_dir1) + #shutil.rmtree(temp_pic_dir2) + + return output_mp4_path diff --git a/modules/sadtalker_test.py b/modules/sadtalker_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34d9699f71fcd6d8f413f9cc96926dd6ceff36b1 --- /dev/null +++ b/modules/sadtalker_test.py @@ -0,0 +1,118 @@ +import torch +import os, sys, shutil +from src.utils.preprocess import CropAndExtract +from src.test_audio2coeff import Audio2Coeff +from src.facerender.animate import AnimateFromCoeff +from src.generate_batch import get_data +from src.generate_facerender_batch import get_facerender_data +import uuid + +from pydub import AudioSegment + +def mp3_to_wav(mp3_filename,wav_filename,frame_rate): + mp3_file = AudioSegment.from_file(file=mp3_filename) + mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") + +from modules.text2speech import text2speech + +class SadTalker(): + + def __init__(self, checkpoint_path='checkpoints'): + + if torch.cuda.is_available() : + device = "cuda" + else: + device = "cpu" + + # current_code_path = sys.argv[0] + # modules_path = os.path.split(current_code_path)[0] + + current_root_path = './' + + os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints') + + path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat') + path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') + dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting') + wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') + + audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') + audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml') + + audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') + audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml') + + free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar') + mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar') + facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml') + + #init model + print(path_of_lm_croper) + self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device) + + print(audio2pose_checkpoint) + self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, + audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device) + print(free_view_checkpoint) + self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, + facerender_yaml_path, device) + self.device = device + + def test(self, source_image, driven_audio, still_mode, resize_mode, use_enhancer, result_dir='./'): + + time_tag = str(uuid.uuid4()) # strftime("%Y_%m_%d_%H.%M.%S") + save_dir = os.path.join(result_dir, time_tag) + os.makedirs(save_dir, exist_ok=True) + + input_dir = os.path.join(save_dir, 'input') + os.makedirs(input_dir, exist_ok=True) + + print(source_image) + pic_path = os.path.join(input_dir, os.path.basename(source_image)) + shutil.move(source_image, input_dir) + + if os.path.isfile(driven_audio): + audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) + + #### mp3 to wav + if '.mp3' in audio_path: + mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) + audio_path = audio_path.replace('.mp3', '.wav') + else: + shutil.move(driven_audio, input_dir) + else: + text2speech + + + os.makedirs(save_dir, exist_ok=True) + pose_style = 0 + #crop image and extract 3dmm from image + first_frame_dir = os.path.join(save_dir, 'first_frame_dir') + os.makedirs(first_frame_dir, exist_ok=True) + first_coeff_path, crop_pic_path, original_size = self.preprocess_model.generate(pic_path, first_frame_dir, crop_or_resize= 'resize' if resize_mode else 'crop') + if first_coeff_path is None: + raise AttributeError("No face is detected") + + #audio2ceoff + batch = get_data(first_coeff_path, audio_path, self.device) + coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style) + #coeff2video + batch_size = 4 + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode) + self.animate_from_coeff.generate(data, save_dir, enhancer='gfpgan' if use_enhancer else None, original_size=original_size) + video_name = data['video_name'] + print(f'The generated video is named {video_name} in {save_dir}') + + torch.cuda.empty_cache() + torch.cuda.synchronize() + + import gc; gc.collect() + + if use_enhancer: + return os.path.join(save_dir, video_name+'_enhanced.mp4'), os.path.join(save_dir, video_name+'_enhanced.mp4') + + else: + return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4') + + + \ No newline at end of file diff --git a/modules/text2speech.py b/modules/text2speech.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecaef36961494c8b2b1f5771a70b997efa04ffd --- /dev/null +++ b/modules/text2speech.py @@ -0,0 +1,12 @@ +import os + +def text2speech(txt, audio_path): + print(txt) + cmd = f'tts --text "{txt}" --out_path {audio_path}' + print(cmd) + try: + os.system(cmd) + return audio_path + except: + print("Error: Failed convert txt to audio") + return None \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e3bea4c2687a5c78894cefa88c1983f6a358a12 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +torch==1.12.1 +torchvision==0.13.1 +torchaudio==0.12.1 +numpy==1.23.4 +face_alignment==1.3.5 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +librosa==0.9.2 # +numba +resampy==0.3.1 +pydub==0.25.1 +scipy==1.5.3 +kornia==0.6.8 +tqdm +yacs==0.1.8 +pyyaml +joblib==1.1.0 +scikit-image==0.19.3 +basicsr==1.4.2 +facexlib==0.2.5 +dlib-bin +gfpgan \ No newline at end of file diff --git a/src/__pycache__/generate_batch.cpython-38.pyc b/src/__pycache__/generate_batch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c68dd09e49933b52115307195bf3aa446d924922 Binary files /dev/null and b/src/__pycache__/generate_batch.cpython-38.pyc differ diff --git a/src/__pycache__/generate_facerender_batch.cpython-38.pyc b/src/__pycache__/generate_facerender_batch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a30615ed3eaa5902a2fa553ed3ed17a9ae92a51 Binary files /dev/null and b/src/__pycache__/generate_facerender_batch.cpython-38.pyc differ diff --git a/src/__pycache__/test_audio2coeff.cpython-38.pyc b/src/__pycache__/test_audio2coeff.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2553cc97f50096d7c7005ad39274a8653cb6ad4 Binary files /dev/null and b/src/__pycache__/test_audio2coeff.cpython-38.pyc differ diff --git a/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..460563d74a990c40a3c5bd6f3209acca6d86b550 Binary files /dev/null and b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc differ diff --git a/src/audio2exp_models/__pycache__/networks.cpython-38.pyc b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..766660615f22f94c740dd420ccef83ed442c4fac Binary files /dev/null and b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc differ diff --git a/src/audio2exp_models/audio2exp.py b/src/audio2exp_models/audio2exp.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6e6b77b0ceb2089539caa440f7106c7b1e8aa2 --- /dev/null +++ b/src/audio2exp_models/audio2exp.py @@ -0,0 +1,40 @@ +from tqdm import tqdm +import torch +from torch import nn + + +class Audio2Exp(nn.Module): + def __init__(self, netG, cfg, device, prepare_training_loss=False): + super(Audio2Exp, self).__init__() + self.cfg = cfg + self.device = device + self.netG = netG.to(device) + + def test(self, batch): + + mel_input = batch['indiv_mels'] # bs T 1 80 16 + bs = mel_input.shape[0] + T = mel_input.shape[1] + + exp_coeff_pred = [] + + for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames + + current_mel_input = mel_input[:,i:i+10] + + ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 + ratio = batch['ratio_gt'][:, i:i+10] #bs T + + audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 + + curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 + + exp_coeff_pred += [curr_exp_coeff_pred] + + # BS x T x 64 + results_dict = { + 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) + } + return results_dict + + diff --git a/src/audio2exp_models/networks.py b/src/audio2exp_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..f052e18101f5446a527ae354b3621e7d0d4991cc --- /dev/null +++ b/src/audio2exp_models/networks.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + self.use_act = use_act + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + + if self.use_act: + return self.act(out) + else: + return out + +class SimpleWrapperV2(nn.Module): + def __init__(self) -> None: + super().__init__() + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0), + ) + + #### load the pre-trained audio_encoder + #self.audio_encoder = self.audio_encoder.to(device) + ''' + wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] + state_dict = self.audio_encoder.state_dict() + + for k,v in wav2lip_state_dict.items(): + if 'audio_encoder' in k: + print('init:', k) + state_dict[k.replace('module.audio_encoder.', '')] = v + self.audio_encoder.load_state_dict(state_dict) + ''' + + self.mapping1 = nn.Linear(512+64+1, 64) + #self.mapping2 = nn.Linear(30, 64) + #nn.init.constant_(self.mapping1.weight, 0.) + nn.init.constant_(self.mapping1.bias, 0.) + + def forward(self, x, ref, ratio): + x = self.audio_encoder(x).view(x.size(0), -1) + ref_reshape = ref.reshape(x.size(0), -1) + ratio = ratio.reshape(x.size(0), -1) + + y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) + out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial + return out diff --git a/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fa93168344012f0bdb77727b5b5669fac8a10b Binary files /dev/null and b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc differ diff --git a/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d9bdf072c5bd356cc312357646c6eae2b798d0 Binary files /dev/null and b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc differ diff --git a/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d9aaee3ad4caa8afc40f723d224eb5b25e8afcd Binary files /dev/null and b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc differ diff --git a/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ebfcd0dd3538cedeb7eba984f94d9763b392c6 Binary files /dev/null and b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc differ diff --git a/src/audio2pose_models/__pycache__/networks.cpython-38.pyc b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..239626089b91321b1c00cfba2dfe0a3ba1ccb0b9 Binary files /dev/null and b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc differ diff --git a/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6b40591fd932ddb2cf686b72afd08c90de1a44 Binary files /dev/null and b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc differ diff --git a/src/audio2pose_models/audio2pose.py b/src/audio2pose_models/audio2pose.py new file mode 100644 index 0000000000000000000000000000000000000000..3a37179e221340662a817628df3d01ae9e34404f --- /dev/null +++ b/src/audio2pose_models/audio2pose.py @@ -0,0 +1,94 @@ +import torch +from torch import nn +from src.audio2pose_models.cvae import CVAE +from src.audio2pose_models.discriminator import PoseSequenceDiscriminator +from src.audio2pose_models.audio_encoder import AudioEncoder + +class Audio2Pose(nn.Module): + def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): + super().__init__() + self.cfg = cfg + self.seq_len = cfg.MODEL.CVAE.SEQ_LEN + self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE + self.device = device + + self.audio_encoder = AudioEncoder(wav2lip_checkpoint) + self.audio_encoder.eval() + for param in self.audio_encoder.parameters(): + param.requires_grad = False + + self.netG = CVAE(cfg) + self.netD_motion = PoseSequenceDiscriminator(cfg) + + self.gan_criterion = nn.MSELoss() + self.reg_criterion = nn.L1Loss(reduction='none') + self.pair_criterion = nn.PairwiseDistance() + self.cosine_loss = nn.CosineSimilarity(dim=1) + + def forward(self, x): + + batch = {} + coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 + batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6 + batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6 + batch['class'] = x['class'].squeeze(0).cuda() # bs + indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 + + # forward + audio_emb_list = [] + audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG(batch) + + pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 + pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6 + pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6 + + batch['pose_pred'] = pose_pred + batch['pose_gt'] = pose_gt + + return batch + + def test(self, x): + + batch = {} + ref = x['ref'] #bs 1 70 + batch['ref'] = x['ref'][:,0,-6:] + batch['class'] = x['class'] + bs = ref.shape[0] + + indiv_mels= x['indiv_mels'] # bs T 1 80 16 + indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame + num_frames = x['num_frames'] + num_frames = int(num_frames) - 1 + + # + div = num_frames//self.seq_len + re = num_frames%self.seq_len + audio_emb_list = [] + pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, + device=batch['ref'].device)] + + for i in range(div): + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 + + if re != 0: + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) + + pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) + batch['pose_motion_pred'] = pose_motion_pred + + pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 + + batch['pose_pred'] = pose_pred + return batch diff --git a/src/audio2pose_models/audio_encoder.py b/src/audio2pose_models/audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce036df119f86ef28c3ac8d6c834264571c309a --- /dev/null +++ b/src/audio2pose_models/audio_encoder.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class AudioEncoder(nn.Module): + def __init__(self, wav2lip_checkpoint): + super(AudioEncoder, self).__init__() + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + #### load the pre-trained audio_encoder\ + wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict'] + state_dict = self.audio_encoder.state_dict() + + for k,v in wav2lip_state_dict.items(): + if 'audio_encoder' in k: + state_dict[k.replace('module.audio_encoder.', '')] = v + self.audio_encoder.load_state_dict(state_dict) + + + def forward(self, audio_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + dim = audio_embedding.shape[1] + audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) + + return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 diff --git a/src/audio2pose_models/cvae.py b/src/audio2pose_models/cvae.py new file mode 100644 index 0000000000000000000000000000000000000000..d017ce865a03bae40dfe066dbcd82e29839d89dc --- /dev/null +++ b/src/audio2pose_models/cvae.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn +from src.audio2pose_models.res_unet import ResUnet + +def class2onehot(idx, class_num): + + assert torch.max(idx).item() < class_num + onehot = torch.zeros(idx.size(0), class_num).to(idx.device) + onehot.scatter_(1, idx, 1) + return onehot + +class CVAE(nn.Module): + def __init__(self, cfg): + super().__init__() + encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES + decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES + latent_size = cfg.MODEL.CVAE.LATENT_SIZE + num_classes = cfg.DATASET.NUM_CLASSES + audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE + audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE + seq_len = cfg.MODEL.CVAE.SEQ_LEN + + self.latent_size = latent_size + + self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, batch): + batch = self.encoder(batch) + mu = batch['mu'] + logvar = batch['logvar'] + z = self.reparameterize(mu, logvar) + batch['z'] = z + return self.decoder(batch) + + def test(self, batch): + ''' + class_id = batch['class'] + z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) + batch['z'] = z + ''' + return self.decoder(batch) + +class ENCODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + + self.linear_means = nn.Linear(layer_sizes[-1], latent_size) + self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + class_id = batch['class'] + pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 + ref = batch['ref'] #bs 6 + bs = pose_motion_gt.shape[0] + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + + #pose encode + pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 + pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 + + #audio mapping + print(audio_in.shape) + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + audio_out = audio_out.reshape(bs, -1) + + class_bias = self.classbias[class_id] #bs latent_size + x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size + x_out = self.MLP(x_in) + + mu = self.linear_means(x_out) + logvar = self.linear_means(x_out) #bs latent_size + + batch.update({'mu':mu, 'logvar':logvar}) + return batch + +class DECODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + input_size = latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + if i+1 < len(layer_sizes): + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + else: + self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) + + self.pose_linear = nn.Linear(6, 6) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + + z = batch['z'] #bs latent_size + bs = z.shape[0] + class_id = batch['class'] + ref = batch['ref'] #bs 6 + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + #print('audio_in: ', audio_in[:, :, :10]) + + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + #print('audio_out: ', audio_out[:, :, :10]) + audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size + class_bias = self.classbias[class_id] #bs latent_size + + z = z + class_bias + x_in = torch.cat([ref, z, audio_out], dim=-1) + x_out = self.MLP(x_in) # bs layer_sizes[-1] + x_out = x_out.reshape((bs, self.seq_len, -1)) + + #print('x_out: ', x_out) + + pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 + + pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 + + batch.update({'pose_motion_pred':pose_motion_pred}) + return batch diff --git a/src/audio2pose_models/discriminator.py b/src/audio2pose_models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..339c38e4812ff38a810f0f3a1c01812f6d5d78db --- /dev/null +++ b/src/audio2pose_models/discriminator.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class ConvNormRelu(nn.Module): + def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, + kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): + super().__init__() + if kernel_size is None: + if downsample: + kernel_size, stride, padding = 4, 2, 1 + else: + kernel_size, stride, padding = 3, 1, 1 + + if conv_type == '2d': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm2d(out_channels) + else: + raise NotImplementedError + elif conv_type == '1d': + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm1d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm1d(out_channels) + else: + raise NotImplementedError + nn.init.kaiming_normal_(self.conv.weight) + + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if isinstance(self.norm, nn.InstanceNorm1d): + x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] + else: + x = self.norm(x) + x = self.act(x) + return x + + +class PoseSequenceDiscriminator(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU + + self.seq = nn.Sequential( + ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 + ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 + ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 + nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 + ) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) + x = self.seq(x) + x = x.squeeze(1) + return x \ No newline at end of file diff --git a/src/audio2pose_models/networks.py b/src/audio2pose_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa0b1390e7b4bb0e16057ac94d2fe84f48421af --- /dev/null +++ b/src/audio2pose_models/networks.py @@ -0,0 +1,140 @@ +import torch.nn as nn +import torch + + +class ResidualConv(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding): + super(ResidualConv, self).__init__() + + self.conv_block = nn.Sequential( + nn.BatchNorm2d(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, stride=stride, padding=padding + ), + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), + nn.BatchNorm2d(output_dim), + ) + + def forward(self, x): + + return self.conv_block(x) + self.conv_skip(x) + + +class Upsample(nn.Module): + def __init__(self, input_dim, output_dim, kernel, stride): + super(Upsample, self).__init__() + + self.upsample = nn.ConvTranspose2d( + input_dim, output_dim, kernel_size=kernel, stride=stride + ) + + def forward(self, x): + return self.upsample(x) + + +class Squeeze_Excite_Block(nn.Module): + def __init__(self, channel, reduction=16): + super(Squeeze_Excite_Block, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class ASPP(nn.Module): + def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): + super(ASPP, self).__init__() + + self.aspp_block1 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block2 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block3 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + + self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) + self._init_weights() + + def forward(self, x): + x1 = self.aspp_block1(x) + x2 = self.aspp_block2(x) + x3 = self.aspp_block3(x) + out = torch.cat([x1, x2, x3], dim=1) + return self.output(out) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Upsample_(nn.Module): + def __init__(self, scale=2): + super(Upsample_, self).__init__() + + self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) + + def forward(self, x): + return self.upsample(x) + + +class AttentionBlock(nn.Module): + def __init__(self, input_encoder, input_decoder, output_dim): + super(AttentionBlock, self).__init__() + + self.conv_encoder = nn.Sequential( + nn.BatchNorm2d(input_encoder), + nn.ReLU(), + nn.Conv2d(input_encoder, output_dim, 3, padding=1), + nn.MaxPool2d(2, 2), + ) + + self.conv_decoder = nn.Sequential( + nn.BatchNorm2d(input_decoder), + nn.ReLU(), + nn.Conv2d(input_decoder, output_dim, 3, padding=1), + ) + + self.conv_attn = nn.Sequential( + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, 1, 1), + ) + + def forward(self, x1, x2): + out = self.conv_encoder(x1) + self.conv_decoder(x2) + out = self.conv_attn(out) + return out * x2 \ No newline at end of file diff --git a/src/audio2pose_models/res_unet.py b/src/audio2pose_models/res_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f2611e1d1a9bf233507427b34928fca60e094224 --- /dev/null +++ b/src/audio2pose_models/res_unet.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from src.audio2pose_models.networks import ResidualConv, Upsample + + +class ResUnet(nn.Module): + def __init__(self, channel=1, filters=[32, 64, 128, 256]): + super(ResUnet, self).__init__() + + self.input_layer = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), + nn.BatchNorm2d(filters[0]), + nn.ReLU(), + nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), + ) + self.input_skip = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) + ) + + self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) + self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) + + self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) + + self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) + self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) + + self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) + self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) + + self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) + self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) + + self.output_layer = nn.Sequential( + nn.Conv2d(filters[0], 1, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + # Encode + x1 = self.input_layer(x) + self.input_skip(x) + x2 = self.residual_conv_1(x1) + x3 = self.residual_conv_2(x2) + # Bridge + x4 = self.bridge(x3) + + # Decode + x4 = self.upsample_1(x4) + x5 = torch.cat([x4, x3], dim=1) + + x6 = self.up_residual_conv1(x5) + + x6 = self.upsample_2(x6) + x7 = torch.cat([x6, x2], dim=1) + + x8 = self.up_residual_conv2(x7) + + x8 = self.upsample_3(x8) + x9 = torch.cat([x8, x1], dim=1) + + x10 = self.up_residual_conv3(x9) + + output = self.output_layer(x10) + + return output \ No newline at end of file diff --git a/src/config/auido2exp.yaml b/src/config/auido2exp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7369dbf350476e14a1d600507f1f8b7d8aa6ecd3 --- /dev/null +++ b/src/config/auido2exp.yaml @@ -0,0 +1,58 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt + TRAIN_BATCH_SIZE: 32 + EVAL_BATCH_SIZE: 32 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm + LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + NUM_REPEATS: 2 + T: 40 + + +MODEL: + FRAMEWORK: V2 + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 128 + SEQ_LEN: 32 + LATENT_SIZE: 256 + ENCODER_LAYER_SIZES: [192, 1024] + DECODER_LAYER_SIZES: [1024, 192] + + +TRAIN: + MAX_EPOCH: 300 + GENERATOR: + LR: 2.0e-5 + DISCRIMINATOR: + LR: 1.0e-5 + LOSS: + W_FEAT: 0 + W_COEFF_EXP: 2 + W_LM: 1.0e-2 + W_LM_MOUTH: 0 + W_REG: 0 + W_SYNC: 0 + W_COLOR: 0 + W_EXPRESSION: 0 + W_LIPREADING: 0.01 + W_LIPREADING_VV: 0 + W_EYE_BLINK: 4 + +TAG: + NAME: small_dataset + + diff --git a/src/config/auido2pose.yaml b/src/config/auido2pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc61f94d12f406f2d8d02545e55b61075051484d --- /dev/null +++ b/src/config/auido2pose.yaml @@ -0,0 +1,49 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt + TRAIN_BATCH_SIZE: 64 + EVAL_BATCH_SIZE: 1 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + + +MODEL: + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 6 + SEQ_LEN: 32 + LATENT_SIZE: 64 + ENCODER_LAYER_SIZES: [192, 128] + DECODER_LAYER_SIZES: [128, 192] + + +TRAIN: + MAX_EPOCH: 150 + GENERATOR: + LR: 1.0e-4 + DISCRIMINATOR: + LR: 1.0e-4 + LOSS: + LAMBDA_REG: 1 + LAMBDA_LANDMARKS: 0 + LAMBDA_VERTICES: 0 + LAMBDA_GAN_MOTION: 0.7 + LAMBDA_GAN_COEFF: 0 + LAMBDA_KL: 1 + +TAG: + NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder + + diff --git a/src/config/facerender.yaml b/src/config/facerender.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9494ef82dfa16b16b7aa0b848ebdd6b23e739e2a --- /dev/null +++ b/src/config/facerender.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 70 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc b/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0469c877400338fae921f4aedf1159b03abbb101 Binary files /dev/null and b/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc differ diff --git a/src/face3d/__pycache__/visualize.cpython-38.pyc b/src/face3d/__pycache__/visualize.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a666447a57777ba5a4c6ed6642f234b79c45d372 Binary files /dev/null and b/src/face3d/__pycache__/visualize.cpython-38.pyc differ diff --git a/src/face3d/data/__init__.py b/src/face3d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9761c518a1b07c5996165869742af0a52c82bc --- /dev/null +++ b/src/face3d/data/__init__.py @@ -0,0 +1,116 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import numpy as np +import importlib +import torch.utils.data +from face3d.data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt, rank=0): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt, rank=rank) + dataset = data_loader.load_data() + return dataset + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt, rank=0): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + self.sampler = None + print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) + if opt.use_ddp and opt.isTrain: + world_size = opt.world_size + self.sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=world_size, + rank=rank, + shuffle=not opt.serial_batches + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + sampler=self.sampler, + num_workers=int(opt.num_threads / world_size), + batch_size=int(opt.batch_size / world_size), + drop_last=True) + else: + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=(not opt.serial_batches) and opt.isTrain, + num_workers=int(opt.num_threads), + drop_last=True + ) + + def set_epoch(self, epoch): + self.dataset.current_epoch = epoch + if self.sampler is not None: + self.sampler.set_epoch(epoch) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/src/face3d/data/base_dataset.py b/src/face3d/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd57d082d519f512d7114b4f867b6695fb7de06 --- /dev/null +++ b/src/face3d/data/base_dataset.py @@ -0,0 +1,125 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + # self.root = opt.dataroot + self.current_epoch = 0 + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_transform(grayscale=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + transform_list += [transforms.ToTensor()] + return transforms.Compose(transform_list) + +def get_affine_mat(opt, size): + shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False + w, h = size + + if 'shift' in opt.preprocess: + shift_pixs = int(opt.shift_pixs) + shift_x = random.randint(-shift_pixs, shift_pixs) + shift_y = random.randint(-shift_pixs, shift_pixs) + if 'scale' in opt.preprocess: + scale = 1 + opt.scale_delta * (2 * random.random() - 1) + if 'rot' in opt.preprocess: + rot_angle = opt.rot_angle * (2 * random.random() - 1) + rot_rad = -rot_angle * np.pi/180 + if 'flip' in opt.preprocess: + flip = random.random() > 0.5 + + shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) + flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) + shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) + rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) + scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) + shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) + + affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin + affine_inv = np.linalg.inv(affine) + return affine, affine_inv, flip + +def apply_img_affine(img, affine_inv, method=Image.BICUBIC): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) + +def apply_lm_affine(landmark, affine, flip, size): + _, h = size + lm = landmark.copy() + lm[:, 1] = h - 1 - lm[:, 1] + lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) + lm = lm @ np.transpose(affine) + lm[:, :2] = lm[:, :2] / lm[:, 2:] + lm = lm[:, :2] + lm[:, 1] = h - 1 - lm[:, 1] + if flip: + lm_ = lm.copy() + lm_[:17] = lm[16::-1] + lm_[17:22] = lm[26:21:-1] + lm_[22:27] = lm[21:16:-1] + lm_[31:36] = lm[35:30:-1] + lm_[36:40] = lm[45:41:-1] + lm_[40:42] = lm[47:45:-1] + lm_[42:46] = lm[39:35:-1] + lm_[46:48] = lm[41:39:-1] + lm_[48:55] = lm[54:47:-1] + lm_[55:60] = lm[59:54:-1] + lm_[60:65] = lm[64:59:-1] + lm_[65:68] = lm[67:64:-1] + lm = lm_ + return lm diff --git a/src/face3d/data/flist_dataset.py b/src/face3d/data/flist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b6945c80aa756074a5d3c02b9443b15ddcfc57 --- /dev/null +++ b/src/face3d/data/flist_dataset.py @@ -0,0 +1,125 @@ +"""This script defines the custom dataset for Deep3DFaceRecon_pytorch +""" + +import os.path +from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util +import numpy as np +import json +import torch +from scipy.io import loadmat, savemat +import pickle +from util.preprocess import align_img, estimate_norm +from util.load_mats import load_lm3d + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + +def jason_flist_reader(flist): + with open(flist, 'r') as fp: + info = json.load(fp) + return info + +def parse_label(label): + return torch.tensor(np.array(label).astype(np.float32)) + + +class FlistDataset(BaseDataset): + """ + It requires one directories to host training images '/path/to/data/train' + You can train the model with the dataset flag '--dataroot /path/to/data'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + msk_names = default_flist_reader(opt.flist) + self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] + + self.size = len(self.msk_paths) + self.opt = opt + + self.name = 'train' if opt.isTrain else 'val' + if '_' in opt.flist: + self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] + + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + img (tensor) -- an image in the input domain + msk (tensor) -- its corresponding attention mask + lm (tensor) -- its corresponding 3d landmarks + im_paths (str) -- image paths + aug_flag (bool) -- a flag used to tell whether its raw or augmented + """ + msk_path = self.msk_paths[index % self.size] # make sure index is within then range + img_path = msk_path.replace('mask/', '') + lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' + + raw_img = Image.open(img_path).convert('RGB') + raw_msk = Image.open(msk_path).convert('RGB') + raw_lm = np.loadtxt(lm_path).astype(np.float32) + + _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) + + aug_flag = self.opt.use_aug and self.opt.isTrain + if aug_flag: + img, lm, msk = self._augmentation(img, lm, self.opt, msk) + + _, H = img.size + M = estimate_norm(lm, H) + transform = get_transform() + img_tensor = transform(img) + msk_tensor = transform(msk)[:1, ...] + lm_tensor = parse_label(lm) + M_tensor = parse_label(M) + + + return {'imgs': img_tensor, + 'lms': lm_tensor, + 'msks': msk_tensor, + 'M': M_tensor, + 'im_paths': img_path, + 'aug_flag': aug_flag, + 'dataset': self.name} + + def _augmentation(self, img, lm, opt, msk=None): + affine, affine_inv, flip = get_affine_mat(opt, img.size) + img = apply_img_affine(img, affine_inv) + lm = apply_lm_affine(lm, affine, flip, img.size) + if msk is not None: + msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) + return img, lm, msk + + + + + def __len__(self): + """Return the total number of images in the dataset. + """ + return self.size diff --git a/src/face3d/data/image_folder.py b/src/face3d/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..efadc2ecbe2fb4b53b78230aba25ec505eff0e55 --- /dev/null +++ b/src/face3d/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" +import numpy as np +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/src/face3d/data/template_dataset.py b/src/face3d/data/template_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdf16be2a8a834b204c45d88c86857b37b9bd25 --- /dev/null +++ b/src/face3d/data/template_dataset.py @@ -0,0 +1,75 @@ +"""Dataset class template + +This module provides a template for users to implement custom datasets. +You can specify '--dataset_mode template' to use this dataset. +The class name should be consistent with both the filename and its dataset_mode option. +The filename should be _dataset.py +The class name should be Dataset.py +You need to implement the following functions: + -- : Add dataset-specific options and rewrite default values for existing options. + -- <__init__>: Initialize this dataset class. + -- <__getitem__>: Return a data point and its metadata information. + -- <__len__>: Return the number of images. +""" +from data.base_dataset import BaseDataset, get_transform +# from data.image_folder import make_dataset +# from PIL import Image + + +class TemplateDataset(BaseDataset): + """A template dataset class for you to implement custom datasets.""" + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') + parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values + return parser + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + # save the option and dataset root + BaseDataset.__init__(self, opt) + # get the image paths of your dataset; + self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root + # define the default transform function. You can use ; You can also define your custom transform function + self.transform = get_transform(opt) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index -- a random integer for data indexing + + Returns: + a dictionary of data with their names. It usually contains the data itself and its metadata information. + + Step 1: get a random image path: e.g., path = self.image_paths[index] + Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). + Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) + Step 4: return a data point as a dictionary. + """ + path = 'temp' # needs to be a string + data_A = None # needs to be a tensor + data_B = None # needs to be a tensor + return {'data_A': data_A, 'data_B': data_B, 'path': path} + + def __len__(self): + """Return the total number of images.""" + return len(self.image_paths) diff --git a/src/face3d/extract_kp_videos.py b/src/face3d/extract_kp_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..f12e9ec3488d99a29620b744beaa46814b66db8f --- /dev/null +++ b/src/face3d/extract_kp_videos.py @@ -0,0 +1,107 @@ +import os +import cv2 +import time +import glob +import argparse +import face_alignment +import numpy as np +from PIL import Image +from tqdm import tqdm +from itertools import cycle + +from torch.multiprocessing import Pool, Process, set_start_method + +class KeypointExtractor(): + def __init__(self, device): + self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + current_kp = self.extract_keypoint(image) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + while True: + try: + keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/src/face3d/models/__init__.py b/src/face3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7986c7ad2ec48f404adf81fea5aa06aaf1eeb4 --- /dev/null +++ b/src/face3d/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from src.face3d.models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "face3d.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/src/face3d/models/__pycache__/__init__.cpython-38.pyc b/src/face3d/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..886f0b184346c5530d0bf8d6f4b2300079511225 Binary files /dev/null and b/src/face3d/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/face3d/models/__pycache__/base_model.cpython-38.pyc b/src/face3d/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42691ec8e26c5c38baf6bd0172dff8110754da1 Binary files /dev/null and b/src/face3d/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/src/face3d/models/__pycache__/bfm.cpython-38.pyc b/src/face3d/models/__pycache__/bfm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088a48bf9f0cabeb667c11c21000f0254c63ec81 Binary files /dev/null and b/src/face3d/models/__pycache__/bfm.cpython-38.pyc differ diff --git a/src/face3d/models/__pycache__/facerecon_model.cpython-38.pyc b/src/face3d/models/__pycache__/facerecon_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e8de7975dee1099cb3e7698227df4e4062f86ee Binary files /dev/null and b/src/face3d/models/__pycache__/facerecon_model.cpython-38.pyc differ diff --git a/src/face3d/models/__pycache__/losses.cpython-38.pyc b/src/face3d/models/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffbf94d1f1e09d5ba0653c588b0cfaeb3df7b920 Binary files /dev/null and b/src/face3d/models/__pycache__/losses.cpython-38.pyc differ diff --git a/src/face3d/models/__pycache__/networks.cpython-38.pyc b/src/face3d/models/__pycache__/networks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a97b5cd3309786e87448c4478ae2d19a18e096b Binary files /dev/null and b/src/face3d/models/__pycache__/networks.cpython-38.pyc differ diff --git a/src/face3d/models/arcface_torch/README.md b/src/face3d/models/arcface_torch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ee63a861229b68873561fa39bfa7c9a8b53b947 --- /dev/null +++ b/src/face3d/models/arcface_torch/README.md @@ -0,0 +1,164 @@ +# Distributed Arcface Training in Pytorch + +This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions +identity on a single server. + +## Requirements + +- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- `pip install -r requirements.txt`. +- Download the dataset + from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) + . + +## How to Training + +To train a model, run `train.py` with the path to the configs: + +### 1. Single node, 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +``` + +### 2. Multiple nodes, each node 8 GPUs: + +Node 0: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +Node 1: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +### 3.Training resnet2060 with 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py +``` + +## Model Zoo + +- The models are available for non-commercial research purposes only. +- All models can be found in here. +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) + +### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) + +ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face +recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. +As the result, we can evaluate the FAIR performance for different algorithms. + +For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The +globalised multi-racial testset contains 242,143 identities and 1,624,305 images. + +For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). +Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. +There are totally 13,928 positive pairs and 96,983,824 negative pairs. + +| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | +| :---: | :--- | :--- | :--- |:--- |:--- | +| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | +| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | +| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | +| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | +| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | +| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | +| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | +| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | +| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | +| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | + +### Performance on IJB-C and Verification Datasets + +| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | +| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | +| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| +| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| +| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| +| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| +| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| +| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| +| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| +| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| +| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| + +[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) + + +## [Speed Benchmark](docs/speed_benchmark.md) + +**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of +classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same +accuracy with several times faster training performance and smaller GPU memory. +Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a +sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a +sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, +we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed +training and mixed precision training. + +![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) + +More details see +[speed_benchmark.md](docs/speed_benchmark.md) in docs. + +### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) + +`-` means training failed because of gpu memory limitations. + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|1400000 | **1672** | 3043 | 4738 | +|5500000 | **-** | **1389** | 3975 | +|8000000 | **-** | **-** | 3565 | +|16000000 | **-** | **-** | 2679 | +|29000000 | **-** | **-** | **1855** | + +### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|1400000 | 32252 | 11178 | 6056 | +|5500000 | **-** | 32188 | 9854 | +|8000000 | **-** | **-** | 12310 | +|16000000 | **-** | **-** | 19950 | +|29000000 | **-** | **-** | 32324 | + +## Evaluation ICCV2021-MFR and IJB-C + +More details see [eval.md](docs/eval.md) in docs. + +## Test + +We tested many versions of PyTorch. Please create an issue if you are having trouble. + +- [x] torch 1.6.0 +- [x] torch 1.7.1 +- [x] torch 1.8.0 +- [x] torch 1.9.0 + +## Citation + +``` +@inproceedings{deng2019arcface, + title={Arcface: Additive angular margin loss for deep face recognition}, + author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={4690--4699}, + year={2019} +} +@inproceedings{an2020partical_fc, + title={Partial FC: Training 10 Million Identities on a Single Machine}, + author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and + Zhang, Debing and Fu Ying}, + booktitle={Arxiv 2010.05222}, + year={2020} +} +``` diff --git a/src/face3d/models/arcface_torch/backbones/__init__.py b/src/face3d/models/arcface_torch/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55bd4c5d1889a1a998b52eb56793bbc1eef1b691 --- /dev/null +++ b/src/face3d/models/arcface_torch/backbones/__init__.py @@ -0,0 +1,25 @@ +from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 +from .mobilefacenet import get_mbf + + +def get_model(name, **kwargs): + # resnet + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + elif name == "mbf": + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf(fp16=fp16, num_features=num_features) + else: + raise ValueError() \ No newline at end of file diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-36.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49397797cf06eaa01ef1327d25f0c145a511994 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-36.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-37.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82f8ed2b49d5c718fe15c47d620156600f776765 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f6ad3ed4af3cc3d3cfa9067e345cdffb058638 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1291676de1f08eaba633f000d015eab672e0036 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-36.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6be617e2ecf266f566e6e5d4972465fcd0379ac5 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-36.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-37.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a085d7cb2aa24dabc85966931e3aa9db54310e3 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-37.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f59247d26d9210b5fd2960df842753a903a90b3d Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8a633135905cc3c5fe7673c6d6ab584e0692ce7 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-36.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d9748f002ee2f953efa2391054329b6d32f9016 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-36.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-37.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50b9f06989f4ca4f6f5bd7a1fdf1952f2035e974 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-37.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8edc64d28aa3e3fb8c26ba795d04a8ef35b1540 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24ebbc749bfa90340e389e2c88bd1f8218c3e338 Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc differ diff --git a/src/face3d/models/arcface_torch/backbones/iresnet.py b/src/face3d/models/arcface_torch/backbones/iresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d3b9c240c24687d432197f976ee01fbf423216 --- /dev/null +++ b/src/face3d/models/arcface_torch/backbones/iresnet.py @@ -0,0 +1,187 @@ +import torch +from torch import nn + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) + diff --git a/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/src/face3d/models/arcface_torch/backbones/iresnet2060.py new file mode 100644 index 0000000000000000000000000000000000000000..21d1122144d207637d2444cba1f68fe630c89f31 --- /dev/null +++ b/src/face3d/models/arcface_torch/backbones/iresnet2060.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +assert torch.__version__ >= "1.8.1" +from torch.utils.checkpoint import checkpoint_sequential + +__all__ = ['iresnet2060'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def checkpoint(self, func, num_seg, x): + if self.training: + return checkpoint_sequential(func, num_seg, x) + else: + return func(x) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.checkpoint(self.layer2, 20, x) + x = self.checkpoint(self.layer3, 100, x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet2060(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py new file mode 100644 index 0000000000000000000000000000000000000000..87731491d76f9ff61cc70e57bb3f18c54fae308c --- /dev/null +++ b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py @@ -0,0 +1,130 @@ +''' +Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py +Original author cavalleria +''' + +import torch.nn as nn +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module +import torch + + +class Flatten(Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ConvBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(ConvBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), + BatchNorm2d(num_features=out_c), + PReLU(num_parameters=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), + BatchNorm2d(num_features=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class DepthWise(Module): + def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): + super(DepthWise, self).__init__() + self.residual = residual + self.layers = nn.Sequential( + ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), + ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), + LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + ) + + def forward(self, x): + short_cut = None + if self.residual: + short_cut = x + x = self.layers(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) + self.layers = Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class GDC(Module): + def __init__(self, embedding_size): + super(GDC, self).__init__() + self.layers = nn.Sequential( + LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), + Flatten(), + Linear(512, embedding_size, bias=False), + BatchNorm1d(embedding_size)) + + def forward(self, x): + return self.layers(x) + + +class MobileFaceNet(Module): + def __init__(self, fp16=False, num_features=512): + super(MobileFaceNet, self).__init__() + scale = 2 + self.fp16 = fp16 + self.layers = nn.Sequential( + ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), + ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), + DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), + Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), + Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), + Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ) + self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.features = GDC(num_features) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.layers(x) + x = self.conv_sep(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def get_mbf(fp16, num_features): + return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/src/face3d/models/arcface_torch/configs/3millions.py b/src/face3d/models/arcface_torch/configs/3millions.py new file mode 100644 index 0000000000000000000000000000000000000000..c9edc2f1414e35f93abfd3dfe11a61f1f406580e --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/3millions.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/src/face3d/models/arcface_torch/configs/3millions_pfc.py new file mode 100644 index 0000000000000000000000000000000000000000..77caafdbb300d8109d5bfdb844f131710ef81f20 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/3millions_pfc.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/src/face3d/models/arcface_torch/configs/__init__.py b/src/face3d/models/arcface_torch/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/face3d/models/arcface_torch/configs/base.py b/src/face3d/models/arcface_torch/configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..78e4b36a9142b649ec39a8c59331bb2557f2ad57 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/base.py @@ -0,0 +1,56 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = "ms1mv3_arcface_r50" + +config.dataset = "ms1m-retinaface-t1" +config.embedding_size = 512 +config.sample_rate = 1 +config.fp16 = False +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +if config.dataset == "emore": + config.rec = "/train_tmp/faces_emore" + config.num_classes = 85742 + config.num_image = 5822653 + config.num_epoch = 16 + config.warmup_epoch = -1 + config.decay_epoch = [8, 14, ] + config.val_targets = ["lfw", ] + +elif config.dataset == "ms1m-retinaface-t1": + config.rec = "/train_tmp/ms1m-retinaface-t1" + config.num_classes = 93431 + config.num_image = 5179510 + config.num_epoch = 25 + config.warmup_epoch = -1 + config.decay_epoch = [11, 17, 22] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "glint360k": + config.rec = "/train_tmp/glint360k" + config.num_classes = 360232 + config.num_image = 17091657 + config.num_epoch = 20 + config.warmup_epoch = -1 + config.decay_epoch = [8, 12, 15, 18] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "webface": + config.rec = "/train_tmp/faces_webface_112x112" + config.num_classes = 10572 + config.num_image = "forget" + config.num_epoch = 34 + config.warmup_epoch = -1 + config.decay_epoch = [20, 28, 32] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..46ae777cc97af41a531cba4e5d1ff31f2efcb468 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/src/face3d/models/arcface_torch/configs/glint360k_r100.py new file mode 100644 index 0000000000000000000000000000000000000000..93d0701c0094517cec147c382b005e8063938548 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/glint360k_r100.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/src/face3d/models/arcface_torch/configs/glint360k_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8db34cd547e8e667103c93585296e47a894e97 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/glint360k_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/src/face3d/models/arcface_torch/configs/glint360k_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..fda2701758a839a7161d09c25f0ca3d26033baff --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/glint360k_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/src/face3d/models/arcface_torch/configs/glint360k_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..37e7922f1f63284e356dcc45a5f979f9c105f25e --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/glint360k_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a00d6305eeda5a94788017afc1cda0d4a4cd2a --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 20, 25] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4e0d31f1aedf4590628d394e1606920fefb5c9 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py new file mode 100644 index 0000000000000000000000000000000000000000..23ad81e082c4b6390b67b164d0ceb84bb0635684 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r2060" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 64 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..5f78337a3d1f9eb6e9145eb5093618796c6842d2 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..08ba55dbbea6df0afffddbb3d1ed173efad99604 --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/src/face3d/models/arcface_torch/configs/speed.py b/src/face3d/models/arcface_torch/configs/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..45e95237da65e44f35a172c25ac6dc4e313e4eae --- /dev/null +++ b/src/face3d/models/arcface_torch/configs/speed.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 100 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/src/face3d/models/arcface_torch/dataset.py b/src/face3d/models/arcface_torch/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..96bbb8bb6da99122f350bc8e1a6390245840e32b --- /dev/null +++ b/src/face3d/models/arcface_torch/dataset.py @@ -0,0 +1,124 @@ +import numbers +import os +import queue as Queue +import threading + +import mxnet as mx +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + self.transform = transforms.Compose( + [transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + def __getitem__(self, index): + idx = self.imgidx[index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + if self.transform is not None: + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.imgidx) + + +class SyntheticDataset(Dataset): + def __init__(self, local_rank): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.label = 1 + + def __getitem__(self, index): + return self.img, self.label + + def __len__(self): + return 1000000 diff --git a/src/face3d/models/arcface_torch/docs/eval.md b/src/face3d/models/arcface_torch/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..dd1d9e257367b6422680966198646c45e5a2671d --- /dev/null +++ b/src/face3d/models/arcface_torch/docs/eval.md @@ -0,0 +1,31 @@ +## Eval on ICCV2021-MFR + +coming soon. + + +## Eval IJBC +You can eval ijbc with pytorch or onnx. + + +1. Eval IJBC With Onnx +```shell +CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 +``` + +2. Eval IJBC With Pytorch +```shell +CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ +--model-prefix ms1mv3_arcface_r50/backbone.pth \ +--image-path IJB_release/IJBC \ +--result-dir ms1mv3_arcface_r50 \ +--batch-size 128 \ +--job ms1mv3_arcface_r50 \ +--target IJBC \ +--network iresnet50 +``` + +## Inference + +```shell +python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 +``` diff --git a/src/face3d/models/arcface_torch/docs/install.md b/src/face3d/models/arcface_torch/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..6314a40441285e9236438e468caf8b71a407531a --- /dev/null +++ b/src/face3d/models/arcface_torch/docs/install.md @@ -0,0 +1,51 @@ +## v1.8.0 +### Linux and Windows +```shell +# CUDA 11.0 +pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 + +# CPU only +pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +``` + + +## v1.7.1 +### Linux and Windows +```shell +# CUDA 11.0 +pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 + +# CUDA 10.1 +pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html +``` + + +## v1.6.0 + +### Linux and Windows +```shell +# CUDA 10.2 +pip install torch==1.6.0 torchvision==0.7.0 + +# CUDA 10.1 +pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +``` \ No newline at end of file diff --git a/src/face3d/models/arcface_torch/docs/modelzoo.md b/src/face3d/models/arcface_torch/docs/modelzoo.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/src/face3d/models/arcface_torch/docs/speed_benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea --- /dev/null +++ b/src/face3d/models/arcface_torch/docs/speed_benchmark.md @@ -0,0 +1,93 @@ +## Test Training Speed + +- Test Commands + +You need to use the following two commands to test the Partial FC training performance. +The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, +batch size is 1024. +```shell +# Model Parallel +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions +# Partial FC 0.1 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc +``` + +- GPU Memory + +``` +# (Model Parallel) gpustat -i +[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB +[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB +[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB +[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB +[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB +[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB +[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB +[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB + +# (Partial FC 0.1) gpustat -i +[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· +[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· +[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· +[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· +[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· +[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· +[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· +[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· +``` + +- Training Speed + +```python +# (Model Parallel) trainging.log +Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 + +# (Partial FC 0.1) trainging.log +Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 +``` + +In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, +and the training speed is 2.5 times faster than the model parallel. + + +## Speed Benchmark + +1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|250000 | 4047 | 4521 | 4976 | +|500000 | 3087 | 4013 | 4900 | +|1000000 | 2090 | 3449 | 4803 | +|1400000 | 1672 | 3043 | 4738 | +|2000000 | - | 2593 | 4626 | +|4000000 | - | 1748 | 4208 | +|5500000 | - | 1389 | 3975 | +|8000000 | - | - | 3565 | +|16000000 | - | - | 2679 | +|29000000 | - | - | 1855 | + +2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|250000 | 9940 | 5826 | 5004 | +|500000 | 14220 | 7114 | 5202 | +|1000000 | 23708 | 9966 | 5620 | +|1400000 | 32252 | 11178 | 6056 | +|2000000 | - | 13978 | 6472 | +|4000000 | - | 23238 | 8284 | +|5500000 | - | 32188 | 9854 | +|8000000 | - | - | 12310 | +|16000000 | - | - | 19950 | +|29000000 | - | - | 32324 | diff --git a/src/face3d/models/arcface_torch/eval/__init__.py b/src/face3d/models/arcface_torch/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/face3d/models/arcface_torch/eval/verification.py b/src/face3d/models/arcface_torch/eval/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..253343b83dbf9d1bd154d14ec068e098bf0968db --- /dev/null +++ b/src/face3d/models/arcface_torch/eval/verification.py @@ -0,0 +1,407 @@ +"""Helper for evaluation on the Labeled Faces in the Wild dataset +""" + +# MIT License +# +# Copyright (c) 2016 David Sandberg +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import datetime +import os +import pickle + +import mxnet as mx +import numpy as np +import sklearn +import torch +from mxnet import ndarray as nd +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold + + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list + + +def dumpR(data_set, + backbone, + batch_size, + name='', + data_extra=None, + label_shape=None): + print('dump verification embedding..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + + _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) + time0 = datetime.datetime.now() + if data_extra is None: + db = mx.io.DataBatch(data=(_data,), label=(_label,)) + else: + db = mx.io.DataBatch(data=(_data, _data_extra), + label=(_label,)) + model.forward(db, is_train=False) + net_out = model.get_outputs() + _embeddings = net_out[0].asnumpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + actual_issame = np.asarray(issame_list) + outname = os.path.join('temp.bin') + with open(outname, 'wb') as f: + pickle.dump((embeddings, issame_list), + f, + protocol=pickle.HIGHEST_PROTOCOL) + + +# if __name__ == '__main__': +# +# parser = argparse.ArgumentParser(description='do verification') +# # general +# parser.add_argument('--data-dir', default='', help='') +# parser.add_argument('--model', +# default='../model/softmax,50', +# help='path to load model.') +# parser.add_argument('--target', +# default='lfw,cfp_ff,cfp_fp,agedb_30', +# help='test targets.') +# parser.add_argument('--gpu', default=0, type=int, help='gpu id') +# parser.add_argument('--batch-size', default=32, type=int, help='') +# parser.add_argument('--max', default='', type=str, help='') +# parser.add_argument('--mode', default=0, type=int, help='') +# parser.add_argument('--nfolds', default=10, type=int, help='') +# args = parser.parse_args() +# image_size = [112, 112] +# print('image_size', image_size) +# ctx = mx.gpu(args.gpu) +# nets = [] +# vec = args.model.split(',') +# prefix = args.model.split(',')[0] +# epochs = [] +# if len(vec) == 1: +# pdir = os.path.dirname(prefix) +# for fname in os.listdir(pdir): +# if not fname.endswith('.params'): +# continue +# _file = os.path.join(pdir, fname) +# if _file.startswith(prefix): +# epoch = int(fname.split('.')[0].split('-')[1]) +# epochs.append(epoch) +# epochs = sorted(epochs, reverse=True) +# if len(args.max) > 0: +# _max = [int(x) for x in args.max.split(',')] +# assert len(_max) == 2 +# if len(epochs) > _max[1]: +# epochs = epochs[_max[0]:_max[1]] +# +# else: +# epochs = [int(x) for x in vec[1].split('|')] +# print('model number', len(epochs)) +# time0 = datetime.datetime.now() +# for epoch in epochs: +# print('loading', prefix, epoch) +# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) +# all_layers = sym.get_internals() +# sym = all_layers['fc1_output'] +# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) +# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], +# image_size[1]))]) +# model.set_params(arg_params, aux_params) +# nets.append(model) +# time_now = datetime.datetime.now() +# diff = time_now - time0 +# print('model loading time', diff.total_seconds()) +# +# ver_list = [] +# ver_name_list = [] +# for name in args.target.split(','): +# path = os.path.join(args.data_dir, name + ".bin") +# if os.path.exists(path): +# print('loading.. ', name) +# data_set = load_bin(path, image_size) +# ver_list.append(data_set) +# ver_name_list.append(name) +# +# if args.mode == 0: +# for i in range(len(ver_list)): +# results = [] +# for model in nets: +# acc1, std1, acc2, std2, xnorm, embeddings_list = test( +# ver_list[i], model, args.batch_size, args.nfolds) +# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) +# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) +# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) +# results.append(acc2) +# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) +# elif args.mode == 1: +# raise ValueError +# else: +# model = nets[0] +# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/src/face3d/models/arcface_torch/eval_ijbc.py b/src/face3d/models/arcface_torch/eval_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5a650d486d18eb02d6f60d448fc3b315261f5d --- /dev/null +++ b/src/face3d/models/arcface_torch/eval_ijbc.py @@ -0,0 +1,483 @@ +# coding: utf-8 + +import os +import pickle + +import matplotlib +import pandas as pd + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import timeit +import sklearn +import argparse +import cv2 +import numpy as np +import torch +from skimage import transform as trans +from backbones import get_model +from sklearn.metrics import roc_curve, auc + +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from pathlib import Path + +import sys +import warnings + +sys.path.insert(0, "../") +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser(description='do ijb test') +# general +parser.add_argument('--model-prefix', default='', help='path to load model.') +parser.add_argument('--image-path', default='', type=str, help='') +parser.add_argument('--result-dir', default='.', type=str, help='') +parser.add_argument('--batch-size', default=128, type=int, help='') +parser.add_argument('--network', default='iresnet50', type=str, help='') +parser.add_argument('--job', default='insightface', type=str, help='job name') +parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') +args = parser.parse_args() + +target = args.target +model_path = args.model_prefix +image_path = args.image_path +result_dir = args.result_dir +gpu_id = None +use_norm_score = True # if Ture, TestMode(N1) +use_detector_score = True # if Ture, TestMode(D1) +use_flip_test = True # if Ture, TestMode(F1) +job = args.job +batch_size = args.batch_size + + +class Embedding(object): + def __init__(self, prefix, data_shape, batch_size=1): + image_size = (112, 112) + self.image_size = image_size + weight = torch.load(prefix) + resnet = get_model(args.network, dropout=0, fp16=False).cuda() + resnet.load_state_dict(weight) + model = torch.nn.DataParallel(resnet) + self.model = model + self.model.eval() + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + src[:, 0] += 8.0 + self.src = src + self.batch_size = batch_size + self.data_shape = data_shape + + def get(self, rimg, landmark): + + assert landmark.shape[0] == 68 or landmark.shape[0] == 5 + assert landmark.shape[1] == 2 + if landmark.shape[0] == 68: + landmark5 = np.zeros((5, 2), dtype=np.float32) + landmark5[0] = (landmark[36] + landmark[39]) / 2 + landmark5[1] = (landmark[42] + landmark[45]) / 2 + landmark5[2] = landmark[30] + landmark5[3] = landmark[48] + landmark5[4] = landmark[54] + else: + landmark5 = landmark + tform = trans.SimilarityTransform() + tform.estimate(landmark5, self.src) + M = tform.params[0:2, :] + img = cv2.warpAffine(rimg, + M, (self.image_size[1], self.image_size[0]), + borderValue=0.0) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_flip = np.fliplr(img) + img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB + img_flip = np.transpose(img_flip, (2, 0, 1)) + input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) + input_blob[0] = img + input_blob[1] = img_flip + return input_blob + + @torch.no_grad() + def forward_db(self, batch_data): + imgs = torch.Tensor(batch_data).cuda() + imgs.div_(255).sub_(0.5).div_(0.5) + feat = self.model(imgs) + feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) + return feat.cpu().numpy() + + +# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] +def divideIntoNstrand(listTemp, n): + twoList = [[] for i in range(n)] + for i, e in enumerate(listTemp): + twoList[i % n].append(e) + return twoList + + +def read_template_media_list(path): + # ijb_meta = np.loadtxt(path, dtype=str) + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +# In[ ]: + + +def read_template_pair_list(path): + # pairs = np.loadtxt(path, dtype=str) + pairs = pd.read_csv(path, sep=' ', header=None).values + # print(pairs.shape) + # print(pairs[:, 0].astype(np.int)) + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +# In[ ]: + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# In[ ]: + + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, batch_size) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + + batch_data = np.empty((2 * rare_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, rare_size) + for img_index, each_line in enumerate(files[len(files) - rare_size:]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + batch_data[2 * img_index][:] = input_blob[0] + batch_data[2 * img_index + 1][:] = input_blob[1] + if (img_index + 1) % rare_size == 0: + print('batch', batch) + img_feats[len(files) - + rare_size:][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 + # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) + return img_feats, faceness_scores + + +# In[ ]: + + +def image2template_feature(img_feats=None, templates=None, medias=None): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + for count_template, uqt in enumerate(unique_templates): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, + return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +# In[ ]: + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +# In[ ]: +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def read_score(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# # Step1: Load Meta Data + +# In[ ]: + +assert target == 'IJBC' or target == 'IJBB' + +# ============================================================= +# load image and template relationships for template feature embedding +# tid --> template id, mid --> media id +# format: +# image_name tid mid +# ============================================================= +start = timeit.default_timer() +templates, medias = read_template_media_list( + os.path.join('%s/meta' % image_path, + '%s_face_tid_mid.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: + +# ============================================================= +# load template pairs for template-to-template verification +# tid : template id, label : 1/0 +# format: +# tid_1 tid_2 label +# ============================================================= +start = timeit.default_timer() +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 2: Get Image Features + +# In[ ]: + +# ============================================================= +# load image features +# format: +# img_feats: [image_num x feats_dim] (227630, 512) +# ============================================================= +start = timeit.default_timer() +img_path = '%s/loose_crop' % image_path +img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) +img_list = open(img_list_path) +files = img_list.readlines() +# files_list = divideIntoNstrand(files, rank_size) +files_list = files + +# img_feats +# for i in range(rank_size): +img_feats, faceness_scores = get_image_feature(img_path, files_list, + model_path, 0, gpu_id) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) +print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], + img_feats.shape[1])) + +# # Step3: Get Template Features + +# In[ ]: + +# ============================================================= +# compute template features from image features. +# ============================================================= +start = timeit.default_timer() +# ========================================================== +# Norm feature before aggregation into template feature? +# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). +# ========================================================== +# 1. FaceScore (Feature Norm) +# 2. FaceScore (Detector) + +if use_flip_test: + # concat --- F1 + # img_input_feats = img_feats + # add --- F2 + img_input_feats = img_feats[:, 0:img_feats.shape[1] // + 2] + img_feats[:, img_feats.shape[1] // 2:] +else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + +if use_norm_score: + img_input_feats = img_input_feats +else: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt( + np.sum(img_input_feats ** 2, -1, keepdims=True)) + +if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] +else: + img_input_feats = img_input_feats + +template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 4: Get Template Similarity Scores + +# In[ ]: + +# ============================================================= +# compute verification scores between template pairs. +# ============================================================= +start = timeit.default_timer() +score = verification(template_norm_feats, unique_templates, p1, p2) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: +save_path = os.path.join(result_dir, args.job) +# save_path = result_dir + '/%s_result' % target + +if not os.path.exists(save_path): + os.makedirs(save_path) + +score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) +np.save(score_save_file, score) + +# # Step 5: Get ROC Curves and TPR@FPR Table + +# In[ ]: + +files = [score_save_file] +methods = [] +scores = [] +for file in files: + methods.append(Path(file).stem) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) +print(tpr_fpr_table) diff --git a/src/face3d/models/arcface_torch/inference.py b/src/face3d/models/arcface_torch/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5156e8d649954837e397c2ff15ec29995e7502 --- /dev/null +++ b/src/face3d/models/arcface_torch/inference.py @@ -0,0 +1,35 @@ +import argparse + +import cv2 +import numpy as np +import torch + +from backbones import get_model + + +@torch.no_grad() +def inference(weight, name, img): + if img is None: + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + else: + img = cv2.imread(img) + img = cv2.resize(img, (112, 112)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + net = get_model(name, fp16=False) + net.load_state_dict(torch.load(weight)) + net.eval() + feat = net(img).numpy() + print(feat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('--network', type=str, default='r50', help='backbone network') + parser.add_argument('--weight', type=str, default='') + parser.add_argument('--img', type=str, default=None) + args = parser.parse_args() + inference(args.weight, args.network, args.img) diff --git a/src/face3d/models/arcface_torch/losses.py b/src/face3d/models/arcface_torch/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..87aeaa107af4d53f5a6132b3739d5cafdcded7fc --- /dev/null +++ b/src/face3d/models/arcface_torch/losses.py @@ -0,0 +1,42 @@ +import torch +from torch import nn + + +def get_loss(name): + if name == "cosface": + return CosFace() + elif name == "arcface": + return ArcFace() + else: + raise ValueError() + + +class CosFace(nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine[index] -= m_hot + ret = cosine * self.s + return ret + + +class ArcFace(nn.Module): + def __init__(self, s=64.0, m=0.5): + super(ArcFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine: torch.Tensor, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine.acos_() + cosine[index] += m_hot + cosine.cos_().mul_(self.s) + return cosine diff --git a/src/face3d/models/arcface_torch/onnx_helper.py b/src/face3d/models/arcface_torch/onnx_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ca922ca6d410655029e459cf8fd1c323d276c34c --- /dev/null +++ b/src/face3d/models/arcface_torch/onnx_helper.py @@ -0,0 +1,250 @@ +from __future__ import division +import datetime +import os +import os.path as osp +import glob +import numpy as np +import cv2 +import sys +import onnxruntime +import onnx +import argparse +from onnx import numpy_helper +from insightface.data import get_image + +class ArcFaceORT: + def __init__(self, model_path, cpu=False): + self.model_path = model_path + # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" + self.providers = ['CPUExecutionProvider'] if cpu else None + + #input_size is (w,h), return error message, return None if success + def check(self, track='cfat', test_img = None): + #default is cfat + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=15 + if track.startswith('ms1m'): + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=10 + elif track.startswith('glint'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=20 + elif track.startswith('cfat'): + max_model_size_mb = 1024 + max_feat_dim = 512 + max_time_cost = 15 + elif track.startswith('unconstrained'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=30 + else: + return "track not found" + + if not os.path.exists(self.model_path): + return "model_path not exists" + if not os.path.isdir(self.model_path): + return "model_path should be directory" + onnx_files = [] + for _file in os.listdir(self.model_path): + if _file.endswith('.onnx'): + onnx_files.append(osp.join(self.model_path, _file)) + if len(onnx_files)==0: + return "do not have onnx files" + self.model_file = sorted(onnx_files)[-1] + print('use onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('input-shape:', input_shape) + if len(input_shape)!=4: + return "length of input_shape should be 4" + if not isinstance(input_shape[0], str): + #return "input_shape[0] should be str to support batch-inference" + print('reset input-shape[0] to None') + model = onnx.load(self.model_file) + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') + onnx.save(model, new_model_file) + self.model_file = new_model_file + print('use new onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('new-input-shape:', input_shape) + + self.image_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + outputs = session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + #print(o.name, o.shape) + if len(output_names)!=1: + return "number of output nodes should be 1" + self.session = session + self.input_name = input_name + self.output_names = output_names + #print(self.output_names) + model = onnx.load(self.model_file) + graph = model.graph + if len(graph.node)<8: + return "too small onnx graph" + + input_size = (112,112) + self.crop = None + if track=='cfat': + crop_file = osp.join(self.model_path, 'crop.txt') + if osp.exists(crop_file): + lines = open(crop_file,'r').readlines() + if len(lines)!=6: + return "crop.txt should contain 6 lines" + lines = [int(x) for x in lines] + self.crop = lines[:4] + input_size = tuple(lines[4:6]) + if input_size!=self.image_size: + return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) + + self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) + if self.model_size_mb > max_model_size_mb: + return "max model size exceed, given %.3f-MB"%self.model_size_mb + + input_mean = None + input_std = None + if track=='cfat': + pn_file = osp.join(self.model_path, 'pixel_norm.txt') + if osp.exists(pn_file): + lines = open(pn_file,'r').readlines() + if len(lines)!=2: + return "pixel_norm.txt should contain 2 lines" + input_mean = float(lines[0]) + input_std = float(lines[1]) + if input_mean is not None or input_std is not None: + if input_mean is None or input_std is None: + return "please set input_mean and input_std simultaneously" + else: + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): + find_mul = True + if find_sub and find_mul: + print("find sub and mul") + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + for initn in graph.initializer: + weight_array = numpy_helper.to_array(initn) + dt = weight_array.dtype + if dt.itemsize<4: + return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) + if test_img is None: + test_img = get_image('Tom_Hanks_54745') + test_img = cv2.resize(test_img, self.image_size) + else: + test_img = cv2.resize(test_img, self.image_size) + feat, cost = self.benchmark(test_img) + batch_result = self.check_batch(test_img) + batch_result_sum = float(np.sum(batch_result)) + if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: + print(batch_result) + print(batch_result_sum) + return "batch result output contains NaN!" + + if len(feat.shape) < 2: + return "the shape of the feature must be two, but get {}".format(str(feat.shape)) + + if feat.shape[1] > max_feat_dim: + return "max feat dim exceed, given %d"%feat.shape[1] + self.feat_dim = feat.shape[1] + cost_ms = cost*1000 + if cost_ms>max_time_cost: + return "max time cost exceed, given %.4f"%cost_ms + self.cost_ms = cost_ms + print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) + return None + + def check_batch(self, img): + if not isinstance(img, list): + imgs = [img, ] * 32 + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] + if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: + nimg = cv2.resize(nimg, self.image_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages( + images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, + mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + + def meta_info(self): + return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} + + + def forward(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.image_size + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + return net_out + + def benchmark(self, img): + input_size = self.image_size + if self.crop is not None: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + img = nimg + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + costs = [] + for _ in range(50): + ta = datetime.datetime.now() + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + tb = datetime.datetime.now() + cost = (tb-ta).total_seconds() + costs.append(cost) + costs = sorted(costs) + cost = costs[5] + return net_out, cost + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + # general + parser.add_argument('workdir', help='submitted work dir', type=str) + parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') + args = parser.parse_args() + handler = ArcFaceORT(args.workdir) + err = handler.check(args.track) + print('err:', err) diff --git a/src/face3d/models/arcface_torch/onnx_ijbc.py b/src/face3d/models/arcface_torch/onnx_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..05b50bfad4b4cf38903b89f596263a8e29a50d3e --- /dev/null +++ b/src/face3d/models/arcface_torch/onnx_ijbc.py @@ -0,0 +1,267 @@ +import argparse +import os +import pickle +import timeit + +import cv2 +import mxnet as mx +import numpy as np +import pandas as pd +import prettytable +import skimage.transform +from sklearn.metrics import roc_curve +from sklearn.preprocessing import normalize + +from onnx_helper import ArcFaceORT + +SRC = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]] + , dtype=np.float32) +SRC[:, 0] += 8.0 + + +class AlignedDataSet(mx.gluon.data.Dataset): + def __init__(self, root, lines, align=True): + self.lines = lines + self.root = root + self.align = align + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + each_line = self.lines[idx] + name_lmk_score = each_line.strip().split(' ') + name = os.path.join(self.root, name_lmk_score[0]) + img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) + landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) + st = skimage.transform.SimilarityTransform() + st.estimate(landmark5, SRC) + img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) + img_1 = np.expand_dims(img, 0) + img_2 = np.expand_dims(np.fliplr(img), 0) + output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) + output = np.transpose(output, (0, 3, 1, 2)) + output = mx.nd.array(output) + return output + + +def extract(model_root, dataset): + model = ArcFaceORT(model_path=model_root) + model.check() + feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) + + def batchify_fn(data): + return mx.nd.concat(*data, dim=0) + + data_loader = mx.gluon.data.DataLoader( + dataset, 128, last_batch='keep', num_workers=4, + thread_pool=True, prefetch=16, batchify_fn=batchify_fn) + num_iter = 0 + for batch in data_loader: + batch = batch.asnumpy() + batch = (batch - model.input_mean) / model.input_std + feat = model.session.run(model.output_names, {model.input_name: batch})[0] + feat = np.reshape(feat, (-1, model.feat_dim * 2)) + feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat + num_iter += 1 + if num_iter % 50 == 0: + print(num_iter) + return feat_mat + + +def read_template_media_list(path): + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +def image2template_feature(img_feats=None, + templates=None, + medias=None): + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + for count_template, uqt in enumerate(unique_templates): + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] + media_norm_feats = np.array(media_norm_feats) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + template_norm_feats = normalize(template_feats) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) + total_pairs = np.array(range(len(p1))) + batchsize = 100000 + sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def main(args): + use_norm_score = True # if Ture, TestMode(N1) + use_detector_score = True # if Ture, TestMode(D1) + use_flip_test = True # if Ture, TestMode(F1) + assert args.target == 'IJBC' or args.target == 'IJBB' + + start = timeit.default_timer() + templates, medias = read_template_media_list( + os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % args.image_path, + '%s_template_pair_label.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + img_path = '%s/loose_crop' % args.image_path + img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) + img_list = open(img_list_path) + files = img_list.readlines() + dataset = AlignedDataSet(root=img_path, lines=files, align=True) + img_feats = extract(args.model_root, dataset) + + faceness_scores = [] + for each_line in files: + name_lmk_score = each_line.split() + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) + start = timeit.default_timer() + + if use_flip_test: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] + else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + + if use_norm_score: + img_input_feats = img_input_feats + else: + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + + if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + else: + img_input_feats = img_input_feats + + template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + score = verification(template_norm_feats, unique_templates, p1, p2) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) + if not os.path.exists(save_path): + os.makedirs(save_path) + score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) + np.save(score_save_file, score) + files = [score_save_file] + methods = [] + scores = [] + for file in files: + methods.append(os.path.basename(file)) + scores.append(np.load(file)) + methods = np.array(methods) + scores = dict(zip(methods, scores)) + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) + for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, args.target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) + print(tpr_fpr_table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='do ijb test') + # general + parser.add_argument('--model-root', default='', help='path to load model.') + parser.add_argument('--image-path', default='', type=str, help='') + parser.add_argument('--result-dir', default='.', type=str, help='') + parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') + main(parser.parse_args()) diff --git a/src/face3d/models/arcface_torch/partial_fc.py b/src/face3d/models/arcface_torch/partial_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..17e2d25715d10ba446c957e1d2528b0687ed71d5 --- /dev/null +++ b/src/face3d/models/arcface_torch/partial_fc.py @@ -0,0 +1,222 @@ +import logging +import os + +import torch +import torch.distributed as dist +from torch.nn import Module +from torch.nn.functional import normalize, linear +from torch.nn.parameter import Parameter + + +class PartialFC(Module): + """ + Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, + Partial FC: Training 10 Million Identities on a Single Machine + See the original paper: + https://arxiv.org/abs/2010.05222 + """ + + @torch.no_grad() + def __init__(self, rank, local_rank, world_size, batch_size, resume, + margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): + """ + rank: int + Unique process(GPU) ID from 0 to world_size - 1. + local_rank: int + Unique process(GPU) ID within the server from 0 to 7. + world_size: int + Number of GPU. + batch_size: int + Batch size on current rank(GPU). + resume: bool + Select whether to restore the weight of softmax. + margin_softmax: callable + A function of margin softmax, eg: cosface, arcface. + num_classes: int + The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, + required. + sample_rate: float + The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling + can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. + embedding_size: int + The feature dimension, default is 512. + prefix: str + Path for save checkpoint, default is './'. + """ + super(PartialFC, self).__init__() + # + self.num_classes: int = num_classes + self.rank: int = rank + self.local_rank: int = local_rank + self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) + self.world_size: int = world_size + self.batch_size: int = batch_size + self.margin_softmax: callable = margin_softmax + self.sample_rate: float = sample_rate + self.embedding_size: int = embedding_size + self.prefix: str = prefix + self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) + self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) + self.num_sample: int = int(self.sample_rate * self.num_local) + + self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) + self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) + + if resume: + try: + self.weight: torch.Tensor = torch.load(self.weight_name) + self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) + if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: + raise IndexError + logging.info("softmax weight resume successfully!") + logging.info("softmax weight mom resume successfully!") + except (FileNotFoundError, KeyError, IndexError): + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init!") + logging.info("softmax weight mom init!") + else: + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init successfully!") + logging.info("softmax weight mom init successfully!") + self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) + + self.index = None + if int(self.sample_rate) == 1: + self.update = lambda: 0 + self.sub_weight = Parameter(self.weight) + self.sub_weight_mom = self.weight_mom + else: + self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) + + def save_params(self): + """ Save softmax weight for each rank on prefix + """ + torch.save(self.weight.data, self.weight_name) + torch.save(self.weight_mom, self.weight_mom_name) + + @torch.no_grad() + def sample(self, total_label): + """ + Sample all positive class centers in each rank, and random select neg class centers to filling a fixed + `num_sample`. + + total_label: tensor + Label after all gather, which cross all GPUs. + """ + index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) + total_label[~index_positive] = -1 + total_label[index_positive] -= self.class_start + if int(self.sample_rate) != 1: + positive = torch.unique(total_label[index_positive], sorted=True) + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local], device=self.device) + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1] + index = index.sort()[0] + else: + index = positive + self.index = index + total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) + self.sub_weight = Parameter(self.weight[index]) + self.sub_weight_mom = self.weight_mom[index] + + def forward(self, total_features, norm_weight): + """ Partial fc forward, `logits = X * sample(W)` + """ + torch.cuda.current_stream().wait_stream(self.stream) + logits = linear(total_features, norm_weight) + return logits + + @torch.no_grad() + def update(self): + """ Set updated weight and weight_mom to memory bank. + """ + self.weight_mom[self.index] = self.sub_weight_mom + self.weight[self.index] = self.sub_weight + + def prepare(self, label, optimizer): + """ + get sampled class centers for cal softmax. + + label: tensor + Label tensor on each rank. + optimizer: opt + Optimizer for partial fc, which need to get weight mom. + """ + with torch.cuda.stream(self.stream): + total_label = torch.zeros( + size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) + dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) + self.sample(total_label) + optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) + optimizer.param_groups[-1]['params'][0] = self.sub_weight + optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom + norm_weight = normalize(self.sub_weight) + return total_label, norm_weight + + def forward_backward(self, label, features, optimizer): + """ + Partial fc forward and backward with model parallel + + label: tensor + Label tensor on each rank(GPU) + features: tensor + Features tensor on each rank(GPU) + optimizer: optimizer + Optimizer for partial fc + + Returns: + -------- + x_grad: tensor + The gradient of features. + loss_v: tensor + Loss value for cross entropy. + """ + total_label, norm_weight = self.prepare(label, optimizer) + total_features = torch.zeros( + size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) + dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) + total_features.requires_grad = True + + logits = self.forward(total_features, norm_weight) + logits = self.margin_softmax(logits, total_label) + + with torch.no_grad(): + max_fc = torch.max(logits, dim=1, keepdim=True)[0] + dist.all_reduce(max_fc, dist.ReduceOp.MAX) + + # calculate exp(logits) and all-reduce + logits_exp = torch.exp(logits - max_fc) + logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) + dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) + + # calculate prob + logits_exp.div_(logits_sum_exp) + + # get one-hot + grad = logits_exp + index = torch.where(total_label != -1)[0] + one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) + one_hot.scatter_(1, total_label[index, None], 1) + + # calculate loss + loss = torch.zeros(grad.size()[0], 1, device=grad.device) + loss[index] = grad[index].gather(1, total_label[index, None]) + dist.all_reduce(loss, dist.ReduceOp.SUM) + loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + + # calculate grad + grad[index] -= one_hot + grad.div_(self.batch_size * self.world_size) + + logits.backward(grad) + if total_features.grad is not None: + total_features.grad.detach_() + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) + # feature gradient all-reduce + dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size + # backward backbone + return x_grad, loss_v diff --git a/src/face3d/models/arcface_torch/requirement.txt b/src/face3d/models/arcface_torch/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..f72c1b3ba814ae1e0bc1c1f56402026978b9e870 --- /dev/null +++ b/src/face3d/models/arcface_torch/requirement.txt @@ -0,0 +1,5 @@ +tensorboard +easydict +mxnet +onnx +sklearn diff --git a/src/face3d/models/arcface_torch/run.sh b/src/face3d/models/arcface_torch/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..61af4b4950eb11334e55362e3e3c5e2796979a01 --- /dev/null +++ b/src/face3d/models/arcface_torch/run.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/src/face3d/models/arcface_torch/torch2onnx.py b/src/face3d/models/arcface_torch/torch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..fc26ab82e552331bc8d75b34e81000418f4d38ec --- /dev/null +++ b/src/face3d/models/arcface_torch/torch2onnx.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import torch + + +def convert_onnx(net, path_module, output, opset=11, simplify=False): + assert isinstance(net, torch.nn.Module) + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = img.astype(np.float) + img = (img / 255. - 0.5) / 0.5 # torch style norm + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + + weight = torch.load(path_module) + net.load_state_dict(weight) + net.eval() + torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) + model = onnx.load(output) + graph = model.graph + graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "backbone.pth") + assert os.path.exists(input_file) + model_name = os.path.basename(os.path.dirname(input_file)).lower() + params = model_name.split("_") + if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + if args.network is None: + args.network = params[2] + assert args.network is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0) + + output_path = args.output + if output_path is None: + output_path = os.path.join(os.path.dirname(__file__), 'onnx') + if not os.path.exists(output_path): + os.makedirs(output_path) + assert os.path.isdir(output_path) + output_file = os.path.join(output_path, "%s.onnx" % model_name) + convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/src/face3d/models/arcface_torch/train.py b/src/face3d/models/arcface_torch/train.py new file mode 100644 index 0000000000000000000000000000000000000000..55eca2d0ad9463415970e09bccab8b722e496704 --- /dev/null +++ b/src/face3d/models/arcface_torch/train.py @@ -0,0 +1,141 @@ +import argparse +import logging +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.data.distributed +from torch.nn.utils import clip_grad_norm_ + +import losses +from backbones import get_model +from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX +from partial_fc import PartialFC +from utils.utils_amp import MaxClipGradScaler +from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_config import get_config +from utils.utils_logging import AverageMeter, init_logging + + +def main(args): + cfg = get_config(args.config) + try: + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group('nccl') + except KeyError: + world_size = 1 + rank = 0 + dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) + + local_rank = args.local_rank + torch.cuda.set_device(local_rank) + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + if cfg.rec == "synthetic": + train_set = SyntheticDataset(local_rank=local_rank) + else: + train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) + train_loader = DataLoaderX( + local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, + sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) + backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) + + if cfg.resume: + try: + backbone_pth = os.path.join(cfg.output, "backbone.pth") + backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) + if rank == 0: + logging.info("backbone resume successfully!") + except (FileNotFoundError, KeyError, IndexError, RuntimeError): + if rank == 0: + logging.info("resume fail, backbone init successfully!") + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank]) + backbone.train() + margin_softmax = losses.get_loss(cfg.loss) + module_partial_fc = PartialFC( + rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, + batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, + sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) + + opt_backbone = torch.optim.SGD( + params=[{'params': backbone.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + opt_pfc = torch.optim.SGD( + params=[{'params': module_partial_fc.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + + num_image = len(train_set) + total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch + cfg.total_step = num_image // total_batch_size * cfg.num_epoch + + def lr_step_func(current_step): + cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] + if current_step < cfg.warmup_step: + return current_step / cfg.warmup_step + else: + return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) + + scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_backbone, lr_lambda=lr_step_func) + scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_pfc, lr_lambda=lr_step_func) + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + val_target = cfg.val_targets + callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) + callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) + callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) + + loss = AverageMeter() + start_epoch = 0 + global_step = 0 + grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None + for epoch in range(start_epoch, cfg.num_epoch): + train_sampler.set_epoch(epoch) + for step, (img, label) in enumerate(train_loader): + global_step += 1 + features = F.normalize(backbone(img)) + x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) + if cfg.fp16: + features.backward(grad_amp.scale(x_grad)) + grad_amp.unscale_(opt_backbone) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + grad_amp.step(opt_backbone) + grad_amp.update() + else: + features.backward(x_grad) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + opt_backbone.step() + + opt_pfc.step() + module_partial_fc.update() + opt_backbone.zero_grad() + opt_pfc.zero_grad() + loss.update(loss_v, 1) + callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) + callback_verification(global_step, backbone) + scheduler_backbone.step() + scheduler_pfc.step() + callback_checkpoint(global_step, backbone, module_partial_fc) + dist.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('config', type=str, help='py config file') + parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + main(parser.parse_args()) diff --git a/src/face3d/models/arcface_torch/utils/__init__.py b/src/face3d/models/arcface_torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/face3d/models/arcface_torch/utils/plot.py b/src/face3d/models/arcface_torch/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc588e5c01ca550b69c385aeb3fd139c59fb88a --- /dev/null +++ b/src/face3d/models/arcface_torch/utils/plot.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +image_path = "/data/anxiang/IJB_release/IJBC" +files = [ + "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" +] + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file.split('/')[-2]) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/src/face3d/models/arcface_torch/utils/utils_amp.py b/src/face3d/models/arcface_torch/utils/utils_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac2a03f4212faa129faed447a8f4519c0a00a8b --- /dev/null +++ b/src/face3d/models/arcface_torch/utils/utils_amp.py @@ -0,0 +1,88 @@ +from typing import Dict, List + +import torch + +if torch.__version__ < '1.9': + Iterable = torch._six.container_abcs.Iterable +else: + import collections + + Iterable = collections.abc.Iterable +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) diff --git a/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/src/face3d/models/arcface_torch/utils/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2f56cba47c57de102710ff56eaac591e59f4da --- /dev/null +++ b/src/face3d/models/arcface_torch/utils/utils_callbacks.py @@ -0,0 +1,117 @@ +import logging +import os +import time +from typing import List + +import torch + +from eval import verification +from utils.utils_logging import AverageMeter + + +class CallBackVerification(object): + def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): + self.frequent: int = frequent + self.rank: int = rank + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + self.frequent: int = frequent + self.rank: int = rank + self.time_start = time.time() + self.total_step: int = total_step + self.batch_size: int = batch_size + self.world_size: int = world_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, + global_step: int, + loss: AverageMeter, + epoch: int, + fp16: bool, + learning_rate: float, + grad_scaler: torch.cuda.amp.GradScaler): + if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + time_now = (time.time() - self.time_start) / 3600 + time_total = time_now / ((global_step + 1) / self.total_step) + time_for_end = time_total - time_now + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + self.writer.add_scalar('learning_rate', learning_rate, global_step) + self.writer.add_scalar('loss', loss.avg, global_step) + if fp16: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, + grad_scaler.get_scale(), time_for_end + ) + else: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end + ) + logging.info(msg) + loss.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() + + +class CallBackModelCheckpoint(object): + def __init__(self, rank, output="./"): + self.rank: int = rank + self.output: str = output + + def __call__(self, global_step, backbone, partial_fc, ): + if global_step > 100 and self.rank == 0: + path_module = os.path.join(self.output, "backbone.pth") + torch.save(backbone.module.state_dict(), path_module) + logging.info("Pytorch Model Saved in '{}'".format(path_module)) + + if global_step > 100 and partial_fc is not None: + partial_fc.save_params() diff --git a/src/face3d/models/arcface_torch/utils/utils_config.py b/src/face3d/models/arcface_torch/utils/utils_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02eaf70fc0140aca7925f621c29a496f491cae --- /dev/null +++ b/src/face3d/models/arcface_torch/utils/utils_config.py @@ -0,0 +1,16 @@ +import importlib +import os.path as osp + + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + config = importlib.import_module("configs.base") + cfg = config.config + config = importlib.import_module("configs.%s" % temp_module_name) + job_cfg = config.config + cfg.update(job_cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + return cfg \ No newline at end of file diff --git a/src/face3d/models/arcface_torch/utils/utils_logging.py b/src/face3d/models/arcface_torch/utils/utils_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c787b6aae7cd037a4718df44d672b8ffa9e5c249 --- /dev/null +++ b/src/face3d/models/arcface_torch/utils/utils_logging.py @@ -0,0 +1,41 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(rank, models_root): + if rank == 0: + log_root = logging.getLogger() + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/src/face3d/models/arcface_torch/utils/utils_os.py b/src/face3d/models/arcface_torch/utils/utils_os.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/face3d/models/base_model.py b/src/face3d/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe64a7f739ad8f8cfbf3073a2bf49e1468127fd --- /dev/null +++ b/src/face3d/models/base_model.py @@ -0,0 +1,316 @@ +"""This script defines the base network model for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.isTrain = False + self.device = torch.device('cpu') + self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.parallel_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def dict_grad_hook_factory(add_func=lambda x: x): + saved_dict = dict() + + def hook_gen(name): + def grad_hook(grad): + saved_vals = add_func(grad) + saved_dict[name] = saved_vals + return grad_hook + return hook_gen, saved_dict + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + load_suffix = opt.epoch + self.load_networks(load_suffix) + + + # self.print_networks(opt.verbose) + + def parallelize(self, convert_sync_batchnorm=True): + if not self.opt.use_ddp: + for name in self.parallel_names: + if isinstance(name, str): + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + else: + for name in self.model_names: + if isinstance(name, str): + module = getattr(self, name) + if convert_sync_batchnorm: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) + + # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. + for name in self.parallel_names: + if isinstance(name, str) and name not in self.model_names: + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + + # put state_dict of optimizer to gpu device + if self.opt.phase != 'test': + if self.opt.continue_train: + for optim in self.optimizers: + for state in optim.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + def data_dependent_initialize(self, data): + pass + + def train(self): + """Make models train mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train() + + def eval(self): + """Make models eval mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self, name='A'): + """ Return image paths that are used to load current data""" + return self.image_paths if name =='A' else self.image_paths_B + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name)[:, :3, ...] + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir) + + save_filename = 'epoch_%s.pth' % (epoch) + save_path = os.path.join(self.save_dir, save_filename) + + save_dict = {} + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, + torch.nn.parallel.DistributedDataParallel): + net = net.module + save_dict[name] = net.state_dict() + + + for i, optim in enumerate(self.optimizers): + save_dict['opt_%02d'%i] = optim.state_dict() + + for i, sched in enumerate(self.schedulers): + save_dict['sched_%02d'%i] = sched.state_dict() + + torch.save(save_dict, save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if self.opt.isTrain and self.opt.pretrained_name is not None: + load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + load_dir = self.save_dir + load_filename = 'epoch_%s.pth' % (epoch) + load_path = os.path.join(load_dir, load_filename) + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % load_path) + + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + net.load_state_dict(state_dict[name]) + + if self.opt.phase != 'test': + if self.opt.continue_train: + print('loading the optim from %s' % load_path) + for i, optim in enumerate(self.optimizers): + optim.load_state_dict(state_dict['opt_%02d'%i]) + + try: + print('loading the sched from %s' % load_path) + for i, sched in enumerate(self.schedulers): + sched.load_state_dict(state_dict['sched_%02d'%i]) + except: + print('Failed to load schedulers, set schedulers according to epoch count manually') + for i, sched in enumerate(self.schedulers): + sched.last_epoch = self.opt.epoch_count - 1 + + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def generate_visuals_for_evaluation(self, data, mode): + return {} diff --git a/src/face3d/models/bfm.py b/src/face3d/models/bfm.py new file mode 100644 index 0000000000000000000000000000000000000000..a75db682f02dd1979d4a7de1d11dd3aa5cdf5279 --- /dev/null +++ b/src/face3d/models/bfm.py @@ -0,0 +1,331 @@ +"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import loadmat +from src.face3d.util.load_mats import transferBFM09 +import os + +def perspective_projection(focal, center): + # return p.T (N, 3) @ (3, 3) + return np.array([ + focal, 0, center, + 0, focal, center, + 0, 0, 1 + ]).reshape([3, 3]).astype(np.float32).transpose() + +class SH: + def __init__(self): + self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] + self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] + + + +class ParametricFaceModel: + def __init__(self, + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ + 0.8, 0, 0, 0, 0, 0, 0, 0, 0 + ]), + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): + + if not os.path.isfile(os.path.join(bfm_folder, default_name)): + transferBFM09(bfm_folder) + + model = loadmat(os.path.join(bfm_folder, default_name)) + # mean face shape. [3*N,1] + self.mean_shape = model['meanshape'].astype(np.float32) + # identity basis. [3*N,80] + self.id_base = model['idBase'].astype(np.float32) + # expression basis. [3*N,64] + self.exp_base = model['exBase'].astype(np.float32) + # mean face texture. [3*N,1] (0-255) + self.mean_tex = model['meantex'].astype(np.float32) + # texture basis. [3*N,80] + self.tex_base = model['texBase'].astype(np.float32) + # face indices for each vertex that lies in. starts from 0. [N,8] + self.point_buf = model['point_buf'].astype(np.int64) - 1 + # vertex indices for each face. starts from 0. [F,3] + self.face_buf = model['tri'].astype(np.int64) - 1 + # vertex indices for 68 landmarks. starts from 0. [68,1] + self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 + + if is_train: + # vertex indices for small face region to compute photometric error. starts from 0. + self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 + # vertex indices for each face from small face region. starts from 0. [f,3] + self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 + # vertex indices for pre-defined skin region to compute reflectance loss + self.skin_mask = np.squeeze(model['skinmask']) + + if recenter: + mean_shape = self.mean_shape.reshape([-1, 3]) + mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) + self.mean_shape = mean_shape.reshape([-1, 1]) + + self.persc_proj = perspective_projection(focal, center) + self.device = 'cpu' + self.camera_distance = camera_distance + self.SH = SH() + self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) + + + def to(self, device): + self.device = device + for key, value in self.__dict__.items(): + if type(value).__module__ == np.__name__: + setattr(self, key, torch.tensor(value).to(device)) + + + def compute_shape(self, id_coeff, exp_coeff): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) + + Parameters: + id_coeff -- torch.tensor, size (B, 80), identity coeffs + exp_coeff -- torch.tensor, size (B, 64), expression coeffs + """ + batch_size = id_coeff.shape[0] + id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) + exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) + face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) + return face_shape.reshape([batch_size, -1, 3]) + + + def compute_texture(self, tex_coeff, normalize=True): + """ + Return: + face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) + + Parameters: + tex_coeff -- torch.tensor, size (B, 80) + """ + batch_size = tex_coeff.shape[0] + face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex + if normalize: + face_texture = face_texture / 255. + return face_texture.reshape([batch_size, -1, 3]) + + + def compute_norm(self, face_shape): + """ + Return: + vertex_norm -- torch.tensor, size (B, N, 3) + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + + v1 = face_shape[:, self.face_buf[:, 0]] + v2 = face_shape[:, self.face_buf[:, 1]] + v3 = face_shape[:, self.face_buf[:, 2]] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = torch.cross(e1, e2, dim=-1) + face_norm = F.normalize(face_norm, dim=-1, p=2) + face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) + + vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) + vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) + return vertex_norm + + + def compute_color(self, face_texture, face_norm, gamma): + """ + Return: + face_color -- torch.tensor, size (B, N, 3), range (0, 1.) + + Parameters: + face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) + face_norm -- torch.tensor, size (B, N, 3), rotated face normal + gamma -- torch.tensor, size (B, 27), SH coeffs + """ + batch_size = gamma.shape[0] + v_num = face_texture.shape[1] + a, c = self.SH.a, self.SH.c + gamma = gamma.reshape([batch_size, 3, 9]) + gamma = gamma + self.init_lit + gamma = gamma.permute(0, 2, 1) + Y = torch.cat([ + a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), + -a[1] * c[1] * face_norm[..., 1:2], + a[1] * c[1] * face_norm[..., 2:], + -a[1] * c[1] * face_norm[..., :1], + a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], + -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], + 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), + -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], + 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) + ], dim=-1) + r = Y @ gamma[..., :1] + g = Y @ gamma[..., 1:2] + b = Y @ gamma[..., 2:] + face_color = torch.cat([r, g, b], dim=-1) * face_texture + return face_color + + + def compute_rotation(self, angles): + """ + Return: + rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat + + Parameters: + angles -- torch.tensor, size (B, 3), radian + """ + + batch_size = angles.shape[0] + ones = torch.ones([batch_size, 1]).to(self.device) + zeros = torch.zeros([batch_size, 1]).to(self.device) + x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([batch_size, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) + + + def to_camera(self, face_shape): + face_shape[..., -1] = self.camera_distance - face_shape[..., -1] + return face_shape + + def to_image(self, face_shape): + """ + Return: + face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + # to image_plane + face_proj = face_shape @ self.persc_proj + face_proj = face_proj[..., :2] / face_proj[..., 2:] + + return face_proj + + + def transform(self, face_shape, rot, trans): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + rot -- torch.tensor, size (B, 3, 3) + trans -- torch.tensor, size (B, 3) + """ + return face_shape @ rot + trans.unsqueeze(1) + + + def get_landmarks(self, face_proj): + """ + Return: + face_lms -- torch.tensor, size (B, 68, 2) + + Parameters: + face_proj -- torch.tensor, size (B, N, 2) + """ + return face_proj[:, self.keypoints] + + def split_coeff(self, coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + def compute_for_render(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + rotation = self.compute_rotation(coef_dict['angle']) + + + face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape_transformed) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + def compute_for_render_woRotation(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + #rotation = self.compute_rotation(coef_dict['angle']) + + + #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm # @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/src/face3d/models/facerecon_model.py b/src/face3d/models/facerecon_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7de8ca6eebc50ff1ed52c5ba37d31b43f977b5e1 --- /dev/null +++ b/src/face3d/models/facerecon_model.py @@ -0,0 +1,220 @@ +"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +from src.face3d.models.base_model import BaseModel +from src.face3d.models import networks +from src.face3d.models.bfm import ParametricFaceModel +from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss +from src.face3d.util import util +from src.face3d.util.nvdiffrast import MeshRenderer +# from src.face3d.util.preprocess import estimate_norm_torch + +import trimesh +from scipy.io import savemat + +class FaceReconModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=False): + """ Configures options specific for CUT model + """ + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') + parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth') + parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + + if is_train: + # training parameters + parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') + parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') + parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') + parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') + + + # augmentation parameters + parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') + parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') + parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') + + # loss weights + parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') + parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') + parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') + parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') + parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') + parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') + parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') + parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') + parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') + + opt, _ = parser.parse_known_args() + parser.set_defaults( + focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. + ) + if is_train: + parser.set_defaults( + use_crop_face=True, use_predef_M=False + ) + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + + self.visual_names = ['output_vis'] + self.model_names = ['net_recon'] + self.parallel_names = self.model_names + ['renderer'] + + self.facemodel = ParametricFaceModel( + bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, + is_train=self.isTrain, default_name=opt.bfm_model + ) + + fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi + self.renderer = MeshRenderer( + rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) + ) + + if self.isTrain: + self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] + + self.net_recog = networks.define_net_recog( + net_recog=opt.net_recog, pretrained_path=opt.net_recog_path + ) + # loss func name: (compute_%s_loss) % loss_name + self.compute_feat_loss = perceptual_loss + self.comupte_color_loss = photo_loss + self.compute_lm_loss = landmark_loss + self.compute_reg_loss = reg_loss + self.compute_reflc_loss = reflectance_loss + + self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) + self.optimizers = [self.optimizer] + self.parallel_names += ['net_recog'] + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + self.input_img = input['imgs'].to(self.device) + self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.trans_m = input['M'].to(self.device) if 'M' in input else None + self.image_paths = input['im_paths'] if 'im_paths' in input else None + + def forward(self, output_coeff, device): + self.facemodel.to(device) + self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ + self.facemodel.compute_for_render(output_coeff) + self.pred_mask, _, self.pred_face = self.renderer( + self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) + + self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) + + + def compute_losses(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + assert self.net_recog.training == False + trans_m = self.trans_m + if not self.opt.use_predef_M: + trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) + + pred_feat = self.net_recog(self.pred_face, trans_m) + gt_feat = self.net_recog(self.input_img, self.trans_m) + self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) + + face_mask = self.pred_mask + if self.opt.use_crop_face: + face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) + + face_mask = face_mask.detach() + self.loss_color = self.opt.w_color * self.comupte_color_loss( + self.pred_face, self.input_img, self.atten_mask * face_mask) + + loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) + self.loss_reg = self.opt.w_reg * loss_reg + self.loss_gamma = self.opt.w_gamma * loss_gamma + + self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) + + self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) + + self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + + self.loss_lm + self.loss_reflc + + + def optimize_parameters(self, isTrain=True): + self.forward() + self.compute_losses() + """Update network weights; it will be called in every training iteration.""" + if isTrain: + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() + + def compute_visuals(self): + with torch.no_grad(): + input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() + output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img + output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() + + if self.gt_lm is not None: + gt_lm_numpy = self.gt_lm.cpu().numpy() + pred_lm_numpy = self.pred_lm.detach().cpu().numpy() + output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') + output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') + + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw, output_vis_numpy), axis=-2) + else: + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw), axis=-2) + + self.output_vis = torch.tensor( + output_vis_numpy / 255., dtype=torch.float32 + ).permute(0, 3, 1, 2).to(self.device) + + def save_mesh(self, name): + + recon_shape = self.pred_vertex # get reconstructed shape + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape = recon_shape.cpu().numpy()[0] + recon_color = self.pred_color + recon_color = recon_color.cpu().numpy()[0] + tri = self.facemodel.face_buf.cpu().numpy() + mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) + mesh.export(name) + + def save_coeff(self,name): + + pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_lm = self.pred_lm.cpu().numpy() + pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_coeffs['lm68'] = pred_lm + savemat(name,pred_coeffs) + + + diff --git a/src/face3d/models/losses.py b/src/face3d/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..09d6a85870af1ef2b857e4a3fdd4b2f7fc991317 --- /dev/null +++ b/src/face3d/models/losses.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +from kornia.geometry import warp_affine +import torch.nn.functional as F + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +### perceptual level loss +class PerceptualLoss(nn.Module): + def __init__(self, recog_net, input_size=112): + super(PerceptualLoss, self).__init__() + self.recog_net = recog_net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + def forward(imageA, imageB, M): + """ + 1 - cosine distance + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order + imageB --same as imageA + """ + + imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) + imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) + + # freeze bn + self.recog_net.eval() + + id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) + id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +def perceptual_loss(id_featureA, id_featureB): + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +### image level loss +def photo_loss(imageA, imageB, mask, eps=1e-6): + """ + l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order + imageB --same as imageA + """ + loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask + loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) + return loss + +def landmark_loss(predict_lm, gt_lm, weight=None): + """ + weighted mse loss + Parameters: + predict_lm --torch.tensor (B, 68, 2) + gt_lm --torch.tensor (B, 68, 2) + weight --numpy.array (1, 68) + """ + if not weight: + weight = np.ones([68]) + weight[28:31] = 20 + weight[-8:] = 20 + weight = np.expand_dims(weight, 0) + weight = torch.tensor(weight).to(predict_lm.device) + loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight + loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) + return loss + + +### regulization +def reg_loss(coeffs_dict, opt=None): + """ + l2 norm without the sqrt, from yu's implementation (mse) + tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss + Parameters: + coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans + + """ + # coefficient regularization to ensure plausible 3d faces + if opt: + w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex + else: + w_id, w_exp, w_tex = 1, 1, 1, 1 + creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ + w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ + w_tex * torch.sum(coeffs_dict['tex'] ** 2) + creg_loss = creg_loss / coeffs_dict['id'].shape[0] + + # gamma regularization to ensure a nearly-monochromatic light + gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) + gamma_mean = torch.mean(gamma, dim=1, keepdims=True) + gamma_loss = torch.mean((gamma - gamma_mean) ** 2) + + return creg_loss, gamma_loss + +def reflectance_loss(texture, mask): + """ + minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo + Parameters: + texture --torch.tensor, (B, N, 3) + mask --torch.tensor, (N), 1 or 0 + + """ + mask = mask.reshape([1, mask.shape[0], 1]) + texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) + loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) + return loss + diff --git a/src/face3d/models/networks.py b/src/face3d/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..ead9cdcb8720b845c233de79dc8a8d1668492108 --- /dev/null +++ b/src/face3d/models/networks.py @@ -0,0 +1,521 @@ +"""This script defines deep neural networks for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch.nn.functional as F +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch +from torch import Tensor +import torch.nn as nn +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from .arcface_torch.backbones import get_model +from kornia.geometry import warp_affine + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_net_recon(net_recon, use_last_fc=False, init_path=None): + return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + +def define_net_recog(net_recog, pretrained_path=None): + net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) + net.eval() + return net + +class ReconNetWrapper(nn.Module): + fc_dim=257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): + super(ReconNetWrapper, self).__init__() + self.use_last_fc = use_last_fc + if net_recon not in func_dict: + return NotImplementedError('network [%s] is not implemented', net_recon) + func, last_dim = func_dict[net_recon] + backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) + if init_path and os.path.isfile(init_path): + state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) + backbone.load_state_dict(state_dict) + print("loading init net_recon %s from %s" %(net_recon, init_path)) + self.backbone = backbone + if not use_last_fc: + self.final_layers = nn.ModuleList([ + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz + ]) + for m in self.final_layers: + nn.init.constant_(m.weight, 0.) + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + x = self.backbone(x) + if not self.use_last_fc: + output = [] + for layer in self.final_layers: + output.append(layer(x)) + x = torch.flatten(torch.cat(output, dim=1), 1) + return x + + +class RecogNetWrapper(nn.Module): + def __init__(self, net_recog, pretrained_path=None, input_size=112): + super(RecogNetWrapper, self).__init__() + net = get_model(name=net_recog, fp16=False) + if pretrained_path: + state_dict = torch.load(pretrained_path, map_location='cpu') + net.load_state_dict(state_dict) + print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + for param in net.parameters(): + param.requires_grad = False + self.net = net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + + def forward(self, image, M): + image = self.preprocess(resize_n_crop(image, M, self.input_size)) + id_feature = F.normalize(self.net(image), dim=-1, p=2) + return id_feature + + +# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + use_last_fc: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.use_last_fc = use_last_fc + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if self.use_last_fc: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if self.use_last_fc: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +func_dict = { + 'resnet18': (resnet18, 512), + 'resnet50': (resnet50, 2048) +} diff --git a/src/face3d/models/template_model.py b/src/face3d/models/template_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dac7b33d5889777eb63c9882a3b9fa094dcab293 --- /dev/null +++ b/src/face3d/models/template_model.py @@ -0,0 +1,100 @@ +"""Model class template + +This module provides a template for users to implement custom models. +You can specify '--model template' to use this model. +The class name should be consistent with both the filename and its model option. +The filename should be _dataset.py +The class name should be Dataset.py +It implements a simple image-to-image translation baseline based on regression loss. +Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: + min_ ||netG(data_A) - data_B||_1 +You need to implement the following functions: + : Add model-specific options and rewrite default values for existing options. + <__init__>: Initialize this model class. + : Unpack input data and perform data pre-processing. + : Run forward pass. This will be called by both and . + : Update network weights; it will be called in every training iteration. +""" +import numpy as np +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/src/face3d/options/__init__.py b/src/face3d/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90 --- /dev/null +++ b/src/face3d/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/src/face3d/options/base_options.py b/src/face3d/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f921d5a43434ae802a55a0fa3889c4b7ab9f6d --- /dev/null +++ b/src/face3d/options/base_options.py @@ -0,0 +1,169 @@ +"""This script contains base options for Deep3DFaceRecon_pytorch +""" + +import argparse +import os +from util import util +import numpy as np +import torch +import face3d.models as models +import face3d.data as data + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self, cmd_line=None): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + self.cmd_line = None + if cmd_line is not None: + self.cmd_line = cmd_line.split() + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') + parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') + parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') + parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') + parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') + parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') + parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') + + # model parameters + parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') + + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + if self.cmd_line is None: + opt, _ = parser.parse_known_args() + else: + opt, _ = parser.parse_known_args(self.cmd_line) + + # set cuda visible devices + os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + if self.cmd_line is None: + opt, _ = parser.parse_known_args() # parse again with new defaults + else: + opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults + + # modify dataset-related parser options + if opt.dataset_mode: + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + if self.cmd_line is None: + return parser.parse_args() + else: + return parser.parse_args(self.cmd_line) + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + try: + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + except PermissionError as error: + print("permission error {}".format(error)) + pass + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + gpu_ids.append(id) + opt.world_size = len(gpu_ids) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(gpu_ids[0]) + if opt.world_size == 1: + opt.use_ddp = False + + if opt.phase != 'test': + # set continue_train automatically + if opt.pretrained_name is None: + model_dir = os.path.join(opt.checkpoints_dir, opt.name) + else: + model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) + if os.path.isdir(model_dir): + model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] + if os.path.isdir(model_dir) and len(model_pths) != 0: + opt.continue_train= True + + # update the latest epoch count + if opt.continue_train: + if opt.epoch == 'latest': + epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] + if len(epoch_counts) != 0: + opt.epoch_count = max(epoch_counts) + 1 + else: + opt.epoch_count = int(opt.epoch) + 1 + + + self.print_options(opt) + self.opt = opt + return self.opt diff --git a/src/face3d/options/inference_options.py b/src/face3d/options/inference_options.py new file mode 100644 index 0000000000000000000000000000000000000000..c453965959ab4cfb31acbc424f994db68c3d4df5 --- /dev/null +++ b/src/face3d/options/inference_options.py @@ -0,0 +1,23 @@ +from face3d.options.base_options import BaseOptions + + +class InferenceOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') + parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') + parser.add_argument('--save_split_files', action='store_true', help='save split files or not') + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/src/face3d/options/test_options.py b/src/face3d/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff3ad142779850d1d5a1640bc00f70d34d4a862 --- /dev/null +++ b/src/face3d/options/test_options.py @@ -0,0 +1,21 @@ +"""This script contains the test options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/src/face3d/options/train_options.py b/src/face3d/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1337bfdd5f372b5c686a91b394a2aadbe5741f44 --- /dev/null +++ b/src/face3d/options/train_options.py @@ -0,0 +1,53 @@ +"""This script contains the training options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions +from util import util + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # dataset parameters + # for train + parser.add_argument('--data_root', type=str, default='./', help='dataset root') + parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') + parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') + + # for val + parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') + parser.add_argument('--batch_size_val', type=int, default=32) + + + # visualization parameters + parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') + + # training parameters + parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') + + self.isTrain = True + return parser diff --git a/src/face3d/util/BBRegressorParam_r.mat b/src/face3d/util/BBRegressorParam_r.mat new file mode 100644 index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084 Binary files /dev/null and b/src/face3d/util/BBRegressorParam_r.mat differ diff --git a/src/face3d/util/__init__.py b/src/face3d/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04eecb58b62f8c9d11d17606c6241d278a48b9b9 --- /dev/null +++ b/src/face3d/util/__init__.py @@ -0,0 +1,3 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" +from src.face3d.util import * + diff --git a/src/face3d/util/__pycache__/__init__.cpython-38.pyc b/src/face3d/util/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22771f3169f2da9a37c1bd619a0e5d05003492b9 Binary files /dev/null and b/src/face3d/util/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/face3d/util/__pycache__/load_mats.cpython-38.pyc b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a48b59ca078ef709825d54c069f518c15103c4e Binary files /dev/null and b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc differ diff --git a/src/face3d/util/__pycache__/nvdiffrast.cpython-38.pyc b/src/face3d/util/__pycache__/nvdiffrast.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ac5cc3eb7c6fd3141005a9cd53f604c49036717 Binary files /dev/null and b/src/face3d/util/__pycache__/nvdiffrast.cpython-38.pyc differ diff --git a/src/face3d/util/__pycache__/preprocess.cpython-38.pyc b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7900dafbd8b74629c391eb8972f615650d4461df Binary files /dev/null and b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc differ diff --git a/src/face3d/util/__pycache__/util.cpython-38.pyc b/src/face3d/util/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56d6f9217276ff22306a567df4861f802e61a82a Binary files /dev/null and b/src/face3d/util/__pycache__/util.cpython-38.pyc differ diff --git a/src/face3d/util/detect_lm68.py b/src/face3d/util/detect_lm68.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e40997289e17405e1fb6c408d21adce7b626ce --- /dev/null +++ b/src/face3d/util/detect_lm68.py @@ -0,0 +1,106 @@ +import os +import cv2 +import numpy as np +from scipy.io import loadmat +import tensorflow as tf +from util.preprocess import align_for_lm +from shutil import move + +mean_face = np.loadtxt('util/test_mean_face.txt') +mean_face = mean_face.reshape([68, 2]) + +def save_label(labels, save_path): + np.savetxt(save_path, labels) + +def draw_landmarks(img, landmark, save_name): + landmark = landmark + lm_img = np.zeros([img.shape[0], img.shape[1], 3]) + lm_img[:] = img.astype(np.float32) + landmark = np.round(landmark).astype(np.int32) + + for i in range(len(landmark)): + for j in range(-1, 1): + for k in range(-1, 1): + if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ + img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ + landmark[i, 0]+k > 0 and \ + landmark[i, 0]+k < img.shape[1]: + lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, + :] = np.array([0, 0, 255]) + lm_img = lm_img.astype(np.uint8) + + cv2.imwrite(save_name, lm_img) + + +def load_data(img_name, txt_name): + return cv2.imread(img_name), np.loadtxt(txt_name) + +# create tensorflow graph for landmark detector +def load_lm_graph(graph_filename): + with tf.gfile.GFile(graph_filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='net') + img_224 = graph.get_tensor_by_name('net/input_imgs:0') + output_lm = graph.get_tensor_by_name('net/lm:0') + lm_sess = tf.Session(graph=graph) + + return lm_sess,img_224,output_lm + +# landmark detection +def detect_68p(img_path,sess,input_op,output_op): + print('detecting landmarks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + vis_path = os.path.join(img_path, 'vis') + remove_path = os.path.join(img_path, 'remove') + save_path = os.path.join(img_path, 'landmarks') + if not os.path.isdir(vis_path): + os.makedirs(vis_path) + if not os.path.isdir(remove_path): + os.makedirs(remove_path) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + txt_name = '.'.join(name.split('.')[:-1]) + '.txt' + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + + # if an image does not have detected 5 facial landmarks, remove it from the training list + if not os.path.isfile(full_txt_name): + move(full_image_name, os.path.join(remove_path, name)) + continue + + # load data + img, five_points = load_data(full_image_name, full_txt_name) + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + + # if the alignment fails, remove corresponding image from the training list + if scale == 0: + move(full_txt_name, os.path.join( + remove_path, txt_name)) + move(full_image_name, os.path.join(remove_path, name)) + continue + + # detect landmarks + input_img = np.reshape( + input_img, [1, 224, 224, 3]).astype(np.float32) + landmark = sess.run( + output_op, feed_dict={input_op: input_img}) + + # transform back to original image coordinate + landmark = landmark.reshape([68, 2]) + mean_face + landmark[:, 1] = 223 - landmark[:, 1] + landmark = landmark / scale + landmark[:, 0] = landmark[:, 0] + bbox[0] + landmark[:, 1] = landmark[:, 1] + bbox[1] + landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] + + if i % 100 == 0: + draw_landmarks(img, landmark, os.path.join(vis_path, name)) + save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/src/face3d/util/generate_list.py b/src/face3d/util/generate_list.py new file mode 100644 index 0000000000000000000000000000000000000000..943d906781063c3584a7e5b5c784f8aac0694985 --- /dev/null +++ b/src/face3d/util/generate_list.py @@ -0,0 +1,34 @@ +"""This script is to generate training list files for Deep3DFaceRecon_pytorch +""" + +import os + +# save path to training data +def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): + save_path = os.path.join(save_folder, mode) + if not os.path.isdir(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in lms_list]) + + with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in imgs_list]) + + with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in msks_list]) + +# check if the path is valid +def check_list(rlms_list, rimgs_list, rmsks_list): + lms_list, imgs_list, msks_list = [], [], [] + for i in range(len(rlms_list)): + flag = 'false' + lm_path = rlms_list[i] + im_path = rimgs_list[i] + msk_path = rmsks_list[i] + if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): + flag = 'true' + lms_list.append(rlms_list[i]) + imgs_list.append(rimgs_list[i]) + msks_list.append(rmsks_list[i]) + print(i, rlms_list[i], flag) + return lms_list, imgs_list, msks_list diff --git a/src/face3d/util/html.py b/src/face3d/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68 --- /dev/null +++ b/src/face3d/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/src/face3d/util/load_mats.py b/src/face3d/util/load_mats.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a6fcc71de1d7dad8b0f81c67dc1c213764ff0b --- /dev/null +++ b/src/face3d/util/load_mats.py @@ -0,0 +1,120 @@ +"""This script is to load 3D face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from PIL import Image +from scipy.io import loadmat, savemat +from array import array +import os.path as osp + +# load expression basis +def LoadExpBasis(bfm_folder='BFM'): + n_vertex = 53215 + Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin, 1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin, 3*n_vertex) + expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) + Expbin.close() + + expPC = np.array(expPC) + expPC = np.reshape(expPC, [exp_dim[0], -1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) + + return expPC, expEV + + +# transfer original BFM09 to our face model +def transferBFM09(bfm_folder='BFM'): + print('Transfer BFM09 to BFM_model_front......') + original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC, expEV = LoadExpBasis(bfm_folder) + + # transfer BFM09 to our face model + + idBase = shapePC*np.reshape(shapeEV, [-1, 199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:, :80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV, [-1, 79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:, :64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV, [-1, 199]) + texBase = texBase[:, :80] # use only first 80 basis + + # our face model is cropped along face landmarks and contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) + index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) + + index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) + index_shape = index_shape['trimIndex'].astype( + np.int32) - 1 # starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + idBase = np.reshape(idBase, [-1, 3, 80]) + idBase = idBase[index_shape, :, :] + idBase = np.reshape(idBase, [-1, 80]) + + texBase = np.reshape(texBase, [-1, 3, 80]) + texBase = texBase[index_shape, :, :] + texBase = np.reshape(texBase, [-1, 80]) + + exBase = np.reshape(exBase, [-1, 3, 64]) + exBase = exBase[index_exp, :, :] + exBase = np.reshape(exBase, [-1, 64]) + + meanshape = np.reshape(shapeMU, [-1, 3])/1e5 + meanshape = meanshape[index_shape, :] + meanshape = np.reshape(meanshape, [1, -1]) + + meantex = np.reshape(texMU, [-1, 3]) + meantex = meantex[index_shape, :] + meantex = np.reshape(meantex, [1, -1]) + + # other info contains triangles, region used for computing photometric loss, + # region used for skin texture regularization, and 68 landmarks index etc. + other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) + frontmask2_idx = other_info['frontmask2_idx'] + skinmask = other_info['skinmask'] + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + tri_mask2 = other_info['tri_mask2'] + + # save our face model + savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, + 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) + + +# load landmarks for standard face, which is used for image preprocessing +def load_lm3d(bfm_folder): + + Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( + Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) + Lm3D = Lm3D[[1, 2, 0, 3, 4], :] + + return Lm3D + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/src/face3d/util/nvdiffrast.py b/src/face3d/util/nvdiffrast.py new file mode 100644 index 0000000000000000000000000000000000000000..f3245859c650afbfe841a66b74cddefaf28820d9 --- /dev/null +++ b/src/face3d/util/nvdiffrast.py @@ -0,0 +1,126 @@ +"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch + Attention, antialiasing step is missing in current version. +""" +import pytorch3d.ops +import torch +import torch.nn.functional as F +import kornia +from kornia.geometry.camera import pixel2cam +import numpy as np +from typing import List +from scipy.io import loadmat +from torch import nn + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + DirectionalLights, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, +) + +# def ndc_projection(x=0.1, n=1.0, f=50.0): +# return np.array([[n/x, 0, 0, 0], +# [ 0, n/-x, 0, 0], +# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], +# [ 0, 0, -1, 0]]).astype(np.float32) + +class MeshRenderer(nn.Module): + def __init__(self, + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224): + super(MeshRenderer, self).__init__() + + # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear + # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( + # torch.diag(torch.tensor([1., -1, -1, 1]))) + self.rasterize_size = rasterize_size + self.fov = rasterize_fov + self.znear = znear + self.zfar = zfar + + self.rasterizer = None + + def forward(self, vertex, tri, feat=None): + """ + Return: + mask -- torch.tensor, size (B, 1, H, W) + depth -- torch.tensor, size (B, 1, H, W) + features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None + + Parameters: + vertex -- torch.tensor, size (B, N, 3) + tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles + feat(optional) -- torch.tensor, size (B, N ,C), features + """ + device = vertex.device + rsize = int(self.rasterize_size) + # ndc_proj = self.ndc_proj.to(device) + # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v + if vertex.shape[-1] == 3: + vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) + vertex[..., 0] = -vertex[..., 0] + + + # vertex_ndc = vertex @ ndc_proj.t() + if self.rasterizer is None: + self.rasterizer = MeshRasterizer() + print("create rasterizer on device cuda:%d"%device.index) + + # ranges = None + # if isinstance(tri, List) or len(tri.shape) == 3: + # vum = vertex_ndc.shape[1] + # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) + # fstartidx = torch.cumsum(fnum, dim=0) - fnum + # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() + # for i in range(tri.shape[0]): + # tri[i] = tri[i] + i*vum + # vertex_ndc = torch.cat(vertex_ndc, dim=0) + # tri = torch.cat(tri, dim=0) + + # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] + tri = tri.type(torch.int32).contiguous() + + # rasterize + cameras = FoVPerspectiveCameras( + device=device, + fov=self.fov, + znear=self.znear, + zfar=self.zfar, + ) + + raster_settings = RasterizationSettings( + image_size=rsize + ) + + # print(vertex.shape, tri.shape) + mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) + + fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) + rast_out = fragments.pix_to_face.squeeze(-1) + depth = fragments.zbuf + + # render depth + depth = depth.permute(0, 3, 1, 2) + mask = (rast_out > 0).float().unsqueeze(1) + depth = mask * depth + + + image = None + if feat is not None: + attributes = feat.reshape(-1,3)[mesh.faces_packed()] + image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, + fragments.bary_coords, + attributes) + # print(image.shape) + image = image.squeeze(-2).permute(0, 3, 1, 2) + image = mask * image + + return mask, depth, image + diff --git a/src/face3d/util/preprocess.py b/src/face3d/util/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b77a3a4058c208e5ba8cb1cfbb563954a5f7a3e2 --- /dev/null +++ b/src/face3d/util/preprocess.py @@ -0,0 +1,103 @@ +"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from scipy.io import loadmat +from PIL import Image +import cv2 +import os +from skimage import transform as trans +import torch +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# calculating least square problem for image alignment +def POS(xp, x): + npts = xp.shape[1] + + A = np.zeros([2*npts, 8]) + + A[0:2*npts-1:2, 0:3] = x.transpose() + A[0:2*npts-1:2, 3] = 1 + + A[1:2*npts:2, 4:7] = x.transpose() + A[1:2*npts:2, 7] = 1 + + b = np.reshape(xp.transpose(), [2*npts, 1]) + + k, _, _, _ = np.linalg.lstsq(A, b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx, sTy], axis=0) + + return t, s + +# resize and crop images for face reconstruction +def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): + w0, h0 = img.size + w = (w0*s).astype(np.int32) + h = (h0*s).astype(np.int32) + left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) + below = up + target_size + + img = img.resize((w, h), resample=Image.BICUBIC) + img = img.crop((left, up, right, below)) + + if mask is not None: + mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.crop((left, up, right, below)) + + lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - + t[1] + h0/2], axis=1)*s + lm = lm - np.reshape( + np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) + + return img, lm, mask + +# utils for face reconstruction +def extract_5p(lm): + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( + lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) + lm5p = lm5p[[1, 2, 0, 3, 4], :] + return lm5p + +# utils for face reconstruction +def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): + """ + Return: + transparams --numpy.array (raw_W, raw_H, scale, tx, ty) + img_new --PIL.Image (target_size, target_size, 3) + lm_new --numpy.array (68, 2), y direction is opposite to v direction + mask_new --PIL.Image (target_size, target_size) + + Parameters: + img --PIL.Image (raw_H, raw_W, 3) + lm --numpy.array (68, 2), y direction is opposite to v direction + lm3D --numpy.array (5, 3) + mask --PIL.Image (raw_H, raw_W, 3) + """ + + w0, h0 = img.size + if lm.shape[0] != 5: + lm5p = extract_5p(lm) + else: + lm5p = lm + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face + t, s = POS(lm5p.transpose(), lm3D.transpose()) + s = rescale_factor/s + + # processing the image + img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) + trans_params = np.array([w0, h0, s, t[0], t[1]]) + + return trans_params, img_new, lm_new, mask_new diff --git a/src/face3d/util/skin_mask.py b/src/face3d/util/skin_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a74e4c3b40d13b0258b83a12f56321a85bb179 --- /dev/null +++ b/src/face3d/util/skin_mask.py @@ -0,0 +1,125 @@ +"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch +""" + +import math +import numpy as np +import os +import cv2 + +class GMM: + def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + + self.factor = [0]*num + for i in range(self.num): + self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 + + def likelihood(self, data): + assert(data.shape[1] == self.dim) + N = data.shape[0] + lh = np.zeros(N) + + for i in range(self.num): + data_ = data - self.mu[i] + + tmp = np.matmul(data_,self.cov_inv[i]) * data_ + tmp = np.sum(tmp,axis=1) + power = -0.5 * tmp + + p = np.array([math.exp(power[j]) for j in range(N)]) + p = p/self.factor[i] + lh += p*self.w[i] + + return lh + + +def _rgb2ycbcr(rgb): + m = np.array([[65.481, 128.553, 24.966], + [-37.797, -74.203, 112], + [112, -93.786, -18.214]]) + shape = rgb.shape + rgb = rgb.reshape((shape[0] * shape[1], 3)) + ycbcr = np.dot(rgb, m.transpose() / 255.) + ycbcr[:, 0] += 16. + ycbcr[:, 1:] += 128. + return ycbcr.reshape(shape) + + +def _bgr2ycbcr(bgr): + rgb = bgr[..., ::-1] + return _rgb2ycbcr(rgb) + + +gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] +gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] +gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] + +gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) + +gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] +gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] +gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] + +gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) + +prior_skin = 0.8 +prior_nonskin = 1 - prior_skin + + +# calculate skin attention mask +def skinmask(imbgr): + im = _bgr2ycbcr(imbgr) + + data = im.reshape((-1,3)) + + lh_skin = gmm_skin.likelihood(data) + lh_nonskin = gmm_nonskin.likelihood(data) + + tmp1 = prior_skin * lh_skin + tmp2 = prior_nonskin * lh_nonskin + post_skin = tmp1 / (tmp1+tmp2) # posterior probability + + post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + + post_skin = np.round(post_skin*255) + post_skin = post_skin.astype(np.uint8) + post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + + return post_skin + + +def get_skin_mask(img_path): + print('generating skin masks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + save_path = os.path.join(img_path, 'mask') + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + img = cv2.imread(full_image_name).astype(np.float32) + skin_img = skinmask(img) + cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/src/face3d/util/test_mean_face.txt b/src/face3d/util/test_mean_face.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d --- /dev/null +++ b/src/face3d/util/test_mean_face.txt @@ -0,0 +1,136 @@ +-5.228591537475585938e+01 +2.078247070312500000e-01 +-5.064269638061523438e+01 +-1.315765380859375000e+01 +-4.952939224243164062e+01 +-2.592591094970703125e+01 +-4.793047332763671875e+01 +-3.832135772705078125e+01 +-4.512159729003906250e+01 +-5.059623336791992188e+01 +-3.917720794677734375e+01 +-6.043736648559570312e+01 +-2.929953765869140625e+01 +-6.861183166503906250e+01 +-1.719801330566406250e+01 +-7.572736358642578125e+01 +-1.961936950683593750e+00 +-7.862001037597656250e+01 +1.467941284179687500e+01 +-7.607844543457031250e+01 +2.744073486328125000e+01 +-6.915261840820312500e+01 +3.855677795410156250e+01 +-5.950350570678710938e+01 +4.478240966796875000e+01 +-4.867547225952148438e+01 +4.714337158203125000e+01 +-3.800830078125000000e+01 +4.940315246582031250e+01 +-2.496297454833984375e+01 +5.117234802246093750e+01 +-1.241538238525390625e+01 +5.190507507324218750e+01 +8.244247436523437500e-01 +-4.150688934326171875e+01 +2.386329650878906250e+01 +-3.570307159423828125e+01 +3.017010498046875000e+01 +-2.790358734130859375e+01 +3.212951660156250000e+01 +-1.941773223876953125e+01 +3.156523132324218750e+01 +-1.138106536865234375e+01 +2.841992187500000000e+01 +5.993263244628906250e+00 +2.895182800292968750e+01 +1.343590545654296875e+01 +3.189880371093750000e+01 +2.203153991699218750e+01 +3.302221679687500000e+01 +2.992478942871093750e+01 +3.099150085449218750e+01 +3.628388977050781250e+01 +2.765748596191406250e+01 +-1.933914184570312500e+00 +1.405374145507812500e+01 +-2.153038024902343750e+00 +5.772636413574218750e+00 +-2.270050048828125000e+00 +-2.121643066406250000e+00 +-2.218330383300781250e+00 +-1.068978118896484375e+01 +-1.187252044677734375e+01 +-1.997912597656250000e+01 +-6.879402160644531250e+00 +-2.143579864501953125e+01 +-1.227821350097656250e+00 +-2.193494415283203125e+01 +4.623237609863281250e+00 +-2.152721405029296875e+01 +9.721397399902343750e+00 +-1.953671264648437500e+01 +-3.648714447021484375e+01 +9.811126708984375000e+00 +-3.130242919921875000e+01 +1.422447967529296875e+01 +-2.212834930419921875e+01 +1.493019866943359375e+01 +-1.500880432128906250e+01 +1.073588562011718750e+01 +-2.095037078857421875e+01 +9.054298400878906250e+00 +-3.050099182128906250e+01 +8.704177856445312500e+00 +1.173237609863281250e+01 +1.054329681396484375e+01 +1.856353759765625000e+01 +1.535009765625000000e+01 +2.893331909179687500e+01 +1.451992797851562500e+01 +3.452944946289062500e+01 +1.065280151367187500e+01 +2.875990295410156250e+01 +8.654792785644531250e+00 +1.942100524902343750e+01 +9.422447204589843750e+00 +-2.204488372802734375e+01 +-3.983994293212890625e+01 +-1.324458312988281250e+01 +-3.467377471923828125e+01 +-6.749649047851562500e+00 +-3.092894744873046875e+01 +-9.183349609375000000e-01 +-3.196458435058593750e+01 +4.220649719238281250e+00 +-3.090406036376953125e+01 +1.089889526367187500e+01 +-3.497008514404296875e+01 +1.874589538574218750e+01 +-4.065438079833984375e+01 +1.124106597900390625e+01 +-4.438417816162109375e+01 +5.181709289550781250e+00 +-4.649170684814453125e+01 +-1.158607482910156250e+00 +-4.680406951904296875e+01 +-7.918922424316406250e+00 +-4.671575164794921875e+01 +-1.452505493164062500e+01 +-4.416526031494140625e+01 +-2.005007171630859375e+01 +-3.997841644287109375e+01 +-1.054919433593750000e+01 +-3.849683380126953125e+01 +-1.051826477050781250e+00 +-3.794863128662109375e+01 +6.412681579589843750e+00 +-3.804645538330078125e+01 +1.627674865722656250e+01 +-4.039697265625000000e+01 +6.373878479003906250e+00 +-4.087213897705078125e+01 +-8.551712036132812500e-01 +-4.157129669189453125e+01 +-1.014953613281250000e+01 +-4.128469085693359375e+01 diff --git a/src/face3d/util/util.py b/src/face3d/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0d689ca138fc0fbf5bec794511ea0f9e638f9ea9 --- /dev/null +++ b/src/face3d/util/util.py @@ -0,0 +1,208 @@ +"""This script contains basic utilities for Deep3DFaceRecon_pytorch +""" +from __future__ import print_function +import numpy as np +import torch +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +import torchvision + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + +def genvalconf(train_opt, **kwargs): + conf = Namespace(**vars(train_opt)) + attr_dict = train_opt.__dict__ + for key, value in attr_dict.items(): + if 'val' in key and key.split('_')[0] in attr_dict: + setattr(conf, key.split('_')[0], value) + + for key in kwargs: + setattr(conf, key, kwargs[key]) + + return conf + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array, range(0, 1) + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def correct_resize_label(t, size): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i, :1] + one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) + one_np = one_np[:, :, 0] + one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) + resized_t = torch.from_numpy(np.array(one_image)).long() + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i + 1] + one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + +def draw_landmarks(img, landmark, color='r', step=2): + """ + Return: + img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) + + + Parameters: + img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) + landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction + color -- str, 'r' or 'b' (red or blue) + """ + if color =='r': + c = np.array([255., 0, 0]) + else: + c = np.array([0, 0, 255.]) + + _, H, W, _ = img.shape + img, landmark = img.copy(), landmark.copy() + landmark[..., 1] = H - 1 - landmark[..., 1] + landmark = np.round(landmark).astype(np.int32) + for i in range(landmark.shape[1]): + x, y = landmark[:, i, 0], landmark[:, i, 1] + for j in range(-step, step): + for k in range(-step, step): + u = np.clip(x + j, 0, W - 1) + v = np.clip(y + k, 0, H - 1) + for m in range(landmark.shape[0]): + img[m, v[m], u[m]] = c + return img diff --git a/src/face3d/util/visualizer.py b/src/face3d/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4023a6d4086acba9bc88e079f625194d324d7c9e --- /dev/null +++ b/src/face3d/util/visualizer.py @@ -0,0 +1,227 @@ +"""This script defines the visualizer for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from torch.utils.tensorboard import SummaryWriter + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.use_html = opt.isTrain and not opt.no_html + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) + self.win_size = opt.display_winsize + self.name = opt.name + self.saved = False + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + + def display_current_results(self, visuals, total_iters, epoch, save_result): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + for label, image in visuals.items(): + self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, total_iters, losses): + # G_loss_collection = {} + # D_loss_collection = {} + # for name, value in losses.items(): + # if 'G' in name or 'NCE' in name or 'idt' in name: + # G_loss_collection[name] = value + # else: + # D_loss_collection[name] = value + # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) + # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) + for name, value in losses.items(): + self.writer.add_scalar(name, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message + + +class MyVisualizer: + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the optio + self.name = opt.name + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') + + if opt.phase != 'test': + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + + def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, + add_image=True): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + dataset (str) - - 'train' or 'val' or 'test' + """ + # if (not add_image) and (not save_results): return + + for label, image in visuals.items(): + for i in range(image.shape[0]): + image_numpy = util.tensor2im(image[i]) + if add_image: + self.writer.add_image(label + '%s_%02d'%(dataset, i + count), + image_numpy, total_iters, dataformats='HWC') + + if save_results: + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + if name is not None: + img_path = os.path.join(save_path, '%s.png' % name) + else: + img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) + util.save_image(image_numpy, img_path) + + + def plot_current_losses(self, total_iters, losses, dataset='train'): + for name, value in losses.items(): + self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( + dataset, epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/src/face3d/visualize.py b/src/face3d/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..23a1110806a0ddf37d4aa549c023d1c3f7114e3e --- /dev/null +++ b/src/face3d/visualize.py @@ -0,0 +1,48 @@ +# check the sync of 3dmm feature and the audio +import cv2 +import numpy as np +from src.face3d.models.bfm import ParametricFaceModel +from src.face3d.models.facerecon_model import FaceReconModel +import torch +import subprocess, platform +import scipy.io as scio +from tqdm import tqdm + +# draft +def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): + + coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] + + coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] + + coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 + + coeff_full[:, 80:144] = coeff_pred[:, 0:64] + coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation + coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation + + tmp_video_path = '/tmp/face3dtmp.mp4' + + facemodel = FaceReconModel(args) + + video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) + + for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): + cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) + + facemodel.forward(cur_coeff_full, device) + + predicted_landmark = facemodel.pred_lm # TODO. + predicted_landmark = predicted_landmark.cpu().numpy().squeeze() + + rendered_img = facemodel.pred_face + rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) + out_img = rendered_img[:, :, :3].astype(np.uint8) + + video.write(np.uint8(out_img[:,:,::-1])) + + video.release() + + command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) + subprocess.call(command, shell=platform.system() != 'Windows') + diff --git a/src/facerender/__pycache__/animate.cpython-38.pyc b/src/facerender/__pycache__/animate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fb3d0ee467093c0cb318003c52eb4c78f11cc9 Binary files /dev/null and b/src/facerender/__pycache__/animate.cpython-38.pyc differ diff --git a/src/facerender/animate.py b/src/facerender/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..be2d62ebaeffe06a8dee1e268d832690b1937320 --- /dev/null +++ b/src/facerender/animate.py @@ -0,0 +1,182 @@ +import os +import cv2 +import yaml +import numpy as np +import warnings +from skimage import img_as_ubyte +warnings.filterwarnings('ignore') + +import imageio +import torch + +from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from src.facerender.modules.mapping import MappingNet +from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator +from src.facerender.modules.make_animation import make_animation + +from pydub import AudioSegment +from src.utils.face_enhancer import enhancer as face_enhancer + + +class AnimateFromCoeff(): + + def __init__(self, free_view_checkpoint, mapping_checkpoint, + config_path, device): + + with open(config_path) as f: + config = yaml.safe_load(f) + + generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) + mapping = MappingNet(**config['model_params']['mapping_params']) + + + generator.to(device) + kp_extractor.to(device) + mapping.to(device) + for param in generator.parameters(): + param.requires_grad = False + for param in kp_extractor.parameters(): + param.requires_grad = False + for param in mapping.parameters(): + param.requires_grad = False + + if free_view_checkpoint is not None: + self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + if mapping_checkpoint is not None: + self.load_cpk_mapping(mapping_checkpoint, mapping=mapping) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + self.kp_extractor = kp_extractor + self.generator = generator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.mapping.eval() + + self.device = device + + def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, + optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if mapping is not None: + mapping.load_state_dict(checkpoint['mapping']) + if discriminator is not None: + discriminator.load_state_dict(checkpoint['discriminator']) + if optimizer_mapping is not None: + optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) + if optimizer_discriminator is not None: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + + return checkpoint['epoch'] + + def generate(self, x, video_save_dir, enhancer=None, original_size=None): + + source_image=x['source_image'].type(torch.FloatTensor) + source_semantics=x['source_semantics'].type(torch.FloatTensor) + target_semantics=x['target_semantics_list'].type(torch.FloatTensor) + yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) + pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) + roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) + source_image=source_image.to(self.device) + source_semantics=source_semantics.to(self.device) + target_semantics=target_semantics.to(self.device) + yaw_c_seq = x['yaw_c_seq'].to(self.device) + pitch_c_seq = x['pitch_c_seq'].to(self.device) + roll_c_seq = x['roll_c_seq'].to(self.device) + + frame_num = x['frame_num'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, self.mapping, + yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True,) + + predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) + predictions_video = predictions_video[:frame_num] + + video = [] + for idx in range(predictions_video.shape[0]): + image = predictions_video[idx] + image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) + video.append(image) + result = img_as_ubyte(video) + + ### the generated video is 256x256, so we keep the aspect ratio, + if original_size: + result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ] + + video_name = x['video_name'] + '.mp4' + path = os.path.join(video_save_dir, 'temp_'+video_name) + imageio.mimsave(path, result, fps=float(25)) + + if enhancer: + video_name_enhancer = x['video_name'] + '_enhanced.mp4' + av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) + enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) + enhanced_images = face_enhancer(result, method=enhancer) + + if original_size: + enhanced_images = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in enhanced_images ] + + imageio.mimsave(enhanced_path, enhanced_images, fps=float(25)) + + av_path = os.path.join(video_save_dir, video_name) + audio_path = x['audio_path'] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') + start_time = 0 + sound = AudioSegment.from_mp3(audio_path) + frames = frame_num + end_time = start_time + frames*1/25*1000 + word1=sound.set_frame_rate(16000) + word = word1[start_time:end_time] + word.export(new_audio_path, format="wav") + + cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (path, new_audio_path, av_path) + os.system(cmd) + + if enhancer: + cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (enhanced_path, new_audio_path, av_path_enhancer) + os.system(cmd) + os.remove(enhanced_path) + + os.remove(path) + os.remove(new_audio_path) + diff --git a/src/facerender/modules/__pycache__/animate_model.cpython-38.pyc b/src/facerender/modules/__pycache__/animate_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ecb83e033911eb82d582e097c513ea0fd4cb69a Binary files /dev/null and b/src/facerender/modules/__pycache__/animate_model.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/animate_model.cpython-39.pyc b/src/facerender/modules/__pycache__/animate_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e9a594ddff05d41ed7fea66e42b37558869332a Binary files /dev/null and b/src/facerender/modules/__pycache__/animate_model.cpython-39.pyc differ diff --git a/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5178c3763bc9f6fcff3a8a410deff7d3c30060db Binary files /dev/null and b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/dense_motion.cpython-39.pyc b/src/facerender/modules/__pycache__/dense_motion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a6cec5db6525ef350d0fcd52efe814b0d3f1e6d Binary files /dev/null and b/src/facerender/modules/__pycache__/dense_motion.cpython-39.pyc differ diff --git a/src/facerender/modules/__pycache__/generator.cpython-38.pyc b/src/facerender/modules/__pycache__/generator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d132f05d36e505f21c864d4c95931472ba58051 Binary files /dev/null and b/src/facerender/modules/__pycache__/generator.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/generator.cpython-39.pyc b/src/facerender/modules/__pycache__/generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9587fe99d8905d8ac99d60025ed1a8d5bacf1b Binary files /dev/null and b/src/facerender/modules/__pycache__/generator.cpython-39.pyc differ diff --git a/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccc5d4543365bfc022a06a72d6ed9d388249279a Binary files /dev/null and b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/keypoint_detector.cpython-39.pyc b/src/facerender/modules/__pycache__/keypoint_detector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e609a2ce2bea049dcc08e711684347032da88e1a Binary files /dev/null and b/src/facerender/modules/__pycache__/keypoint_detector.cpython-39.pyc differ diff --git a/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b54bcc293d742f70db165849b9764666b0f9a8b Binary files /dev/null and b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/mapping.cpython-38.pyc b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1a2baa2bfab28fe7e3904f94a644633124b56c Binary files /dev/null and b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/mapping5.cpython-38.pyc b/src/facerender/modules/__pycache__/mapping5.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae35fb77f8552d2aa9cb263cba6ca9d37bbee9a7 Binary files /dev/null and b/src/facerender/modules/__pycache__/mapping5.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/mapping5.cpython-39.pyc b/src/facerender/modules/__pycache__/mapping5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6b6db40007f95fca648909a638810273b2c050 Binary files /dev/null and b/src/facerender/modules/__pycache__/mapping5.cpython-39.pyc differ diff --git a/src/facerender/modules/__pycache__/util.cpython-38.pyc b/src/facerender/modules/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1c92955be38c880c52cc70b8051fd8ef4fa63a Binary files /dev/null and b/src/facerender/modules/__pycache__/util.cpython-38.pyc differ diff --git a/src/facerender/modules/__pycache__/util.cpython-39.pyc b/src/facerender/modules/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8764b93cb4e5964b831caf9ff376b70105f3dc5d Binary files /dev/null and b/src/facerender/modules/__pycache__/util.cpython-39.pyc differ diff --git a/src/facerender/modules/dense_motion.py b/src/facerender/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..30c13060be8e82979771514b4ec51e5de23f49fa --- /dev/null +++ b/src/facerender/modules/dense_motion.py @@ -0,0 +1,117 @@ +from torch import nn +import torch.nn.functional as F +import torch +from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian + +from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, + estimate_occlusion_map=False): + super(DenseMotionNetwork, self).__init__() + # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) + + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) + self.norm = BatchNorm3d(compress, affine=True) + + if estimate_occlusion_map: + # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + self.num_kp = num_kp + + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) + coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) + + # if 'jacobian' in kp_driving: + if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: + jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + + driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + #adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 + + # sparse_motions = driving_to_source + + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) + heatmap = gaussian_driving - gaussian_source + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + + feature = self.compress(feature) + feature = self.norm(feature) + feature = F.relu(feature) + + out_dict = dict() + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) + + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) + + input_ = torch.cat([heatmap, deformed_feature], dim=2) + input_ = input_.view(bs, -1, d, h, w) + + # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) + + prediction = self.hourglass(input_) + + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.occlusion: + bs, c, d, h, w = prediction.shape + prediction = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/src/facerender/modules/discriminator.py b/src/facerender/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..d4459b07cb075c9f9d345f9b3dffc02cd859313b --- /dev/null +++ b/src/facerender/modules/discriminator.py @@ -0,0 +1,90 @@ +from torch import nn +import torch.nn.functional as F +from facerender.modules.util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, x): + feature_maps = [] + out = x + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key]) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/src/facerender/modules/generator.py b/src/facerender/modules/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5f8d26b18a1fa5cb1d8cbe9d1fa2413bf39f01 --- /dev/null +++ b/src/facerender/modules/generator.py @@ -0,0 +1,251 @@ +import torch +from torch import nn +import torch.nn.functional as F +from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock +from src.facerender.modules.dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator follows NVIDIA architecture. + """ + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.resblocks_2d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) + out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image + + # Decoding part + out = self.resblocks_2d(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict + + +class SPADEDecoder(nn.Module): + def __init__(self): + super().__init__() + ic = 256 + oc = 64 + norm_G = 'spadespectralinstance' + label_nc = 256 + + self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) + self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) + self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) + self.up = nn.Upsample(scale_factor=2) + + def forward(self, feature): + seg = feature + x = self.fc(feature) + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + x = self.up(x) + x = self.up_0(x, seg) # 256, 128, 128 + x = self.up(x) + x = self.up_1(x, seg) # 64, 256, 256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + # x = torch.tanh(x) + x = F.sigmoid(x) + + return x + + +class OcclusionAwareSPADEGenerator(nn.Module): + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareSPADEGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + self.decoder = SPADEDecoder() + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # Decoding part + out = self.decoder(out) + + output_dict["prediction"] = out + + return output_dict \ No newline at end of file diff --git a/src/facerender/modules/keypoint_detector.py b/src/facerender/modules/keypoint_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..62a38a962b2f1a4326aac771aced353ec5e22a96 --- /dev/null +++ b/src/facerender/modules/keypoint_detector.py @@ -0,0 +1,179 @@ +from torch import nn +import torch +import torch.nn.functional as F + +from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck + + +class KPDetector(nn.Module): + """ + Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): + super(KPDetector, self).__init__() + + self.predictor = KPHourglass(block_expansion, in_features=image_channel, + max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) + + # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) + self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) + self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) + ''' + initial as: + [[1 0 0] + [0 1 0] + [0 0 1]] + ''' + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + value = (heatmap * grid).sum(dim=(2, 3, 4)) + kp = {'value': value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], + final_shape[3], final_shape[4]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) + out['jacobian'] = jacobian + + return out + + +class HEEstimator(nn.Module): + """ + Estimating head pose and expression. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): + super(HEEstimator, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) + self.norm1 = BatchNorm2d(block_expansion, affine=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) + self.norm2 = BatchNorm2d(256, affine=True) + + self.block1 = nn.Sequential() + for i in range(3): + self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) + + self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) + self.norm3 = BatchNorm2d(512, affine=True) + self.block2 = ResBottleneck(in_features=512, stride=2) + + self.block3 = nn.Sequential() + for i in range(3): + self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) + + self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) + self.norm4 = BatchNorm2d(1024, affine=True) + self.block4 = ResBottleneck(in_features=1024, stride=2) + + self.block5 = nn.Sequential() + for i in range(5): + self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) + + self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) + self.norm5 = BatchNorm2d(2048, affine=True) + self.block6 = ResBottleneck(in_features=2048, stride=2) + + self.block7 = nn.Sequential() + for i in range(2): + self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) + + self.fc_roll = nn.Linear(2048, num_bins) + self.fc_pitch = nn.Linear(2048, num_bins) + self.fc_yaw = nn.Linear(2048, num_bins) + + self.fc_t = nn.Linear(2048, 3) + + self.fc_exp = nn.Linear(2048, 3*num_kp) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.maxpool(out) + + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + + out = self.block1(out) + + out = self.conv3(out) + out = self.norm3(out) + out = F.relu(out) + out = self.block2(out) + + out = self.block3(out) + + out = self.conv4(out) + out = self.norm4(out) + out = F.relu(out) + out = self.block4(out) + + out = self.block5(out) + + out = self.conv5(out) + out = self.norm5(out) + out = F.relu(out) + out = self.block6(out) + + out = self.block7(out) + + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.shape[0], -1) + + yaw = self.fc_roll(out) + pitch = self.fc_pitch(out) + roll = self.fc_yaw(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} + diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2382d82d26043145184b339103aac64abdaa62 --- /dev/null +++ b/src/facerender/modules/make_animation.py @@ -0,0 +1,160 @@ +from scipy.spatial import ConvexHull +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm + +def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, + use_relative_movement=False, use_relative_jacobian=False): + if adapt_movement_scale: + source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) + kp_value_diff *= adapt_movement_scale + kp_new['value'] = kp_value_diff + kp_source['value'] + + if use_relative_jacobian: + jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) + kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) + + return kp_new + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + return degree + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), + torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), + torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) + + return rot_mat + +def keypoint_transformation(kp_canonical, he): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + + if 'yaw_c' in he: + yaw = yaw + he['yaw_c'] + if 'pitch_c' in he: + pitch = pitch + he['pitch_c'] + if 'roll_c' in he: + roll = roll + he['roll_c'] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + t, exp = he['t'], he['exp'] + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + + +def make_animation(source_image, source_semantics, target_semantics, + generator, kp_detector, mapping, + yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, + use_exp=True): + with torch.no_grad(): + predictions = [] + + kp_canonical = kp_detector(source_image) + he_source = mapping(source_semantics) + kp_source = keypoint_transformation(kp_canonical, he_source) + + for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): + target_semantics_frame = target_semantics[:, frame_idx] + he_driving = mapping(target_semantics_frame) + if not use_exp: + he_driving['exp'] = he_driving['exp']*0 + if yaw_c_seq is not None: + he_driving['yaw_c'] = yaw_c_seq[:, frame_idx] + if pitch_c_seq is not None: + he_driving['pitch_c'] = pitch_c_seq[:, frame_idx] + if roll_c_seq is not None: + he_driving['roll_c'] = roll_c_seq[:, frame_idx] + + kp_driving = keypoint_transformation(kp_canonical, he_driving) + + #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, + #kp_driving_initial=kp_driving_initial) + kp_norm = kp_driving + out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) + predictions.append(out['prediction']) + predictions_ts = torch.stack(predictions, dim=1) + return predictions_ts + +class AnimateModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, generator, kp_extractor, mapping): + super(AnimateModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.mapping.eval() + + def forward(self, x): + + source_image = x['source_image'] + source_semantics = x['source_semantics'] + target_semantics = x['target_semantics'] + yaw_c_seq = x['yaw_c_seq'] + pitch_c_seq = x['pitch_c_seq'] + roll_c_seq = x['roll_c_seq'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, + self.mapping, use_exp = True, + yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) + + return predictions_video \ No newline at end of file diff --git a/src/facerender/modules/mapping.py b/src/facerender/modules/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3a1c2d1770996080c08e9daafb346f05d7bcdd --- /dev/null +++ b/src/facerender/modules/mapping.py @@ -0,0 +1,47 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MappingNet(nn.Module): + def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): + super( MappingNet, self).__init__() + + self.layer = layer + nonlinearity = nn.LeakyReLU(0.1) + + self.first = nn.Sequential( + torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) + + for i in range(layer): + net = nn.Sequential(nonlinearity, + torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) + setattr(self, 'encoder' + str(i), net) + + self.pooling = nn.AdaptiveAvgPool1d(1) + self.output_nc = descriptor_nc + + self.fc_roll = nn.Linear(descriptor_nc, num_bins) + self.fc_pitch = nn.Linear(descriptor_nc, num_bins) + self.fc_yaw = nn.Linear(descriptor_nc, num_bins) + self.fc_t = nn.Linear(descriptor_nc, 3) + self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) + + def forward(self, input_3dmm): + out = self.first(input_3dmm) + for i in range(self.layer): + model = getattr(self, 'encoder' + str(i)) + out = model(out) + out[:,:,3:-3] + out = self.pooling(out) + out = out.view(out.shape[0], -1) + #print('out:', out.shape) + + yaw = self.fc_yaw(out) + pitch = self.fc_pitch(out) + roll = self.fc_roll(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file diff --git a/src/facerender/modules/util.py b/src/facerender/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b916deefbb8b957ad6ab3cd7403c28513e5ae18e --- /dev/null +++ b/src/facerender/modules/util.py @@ -0,0 +1,564 @@ +from torch import nn + +import torch.nn.functional as F +import torch + +from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + +import torch.nn.utils.spectral_norm as spectral_norm + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp['value'] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + +def make_coordinate_grid_2d(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +def make_coordinate_grid(spatial_size, type): + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ResBottleneck(nn.Module): + def __init__(self, in_features, stride): + super(ResBottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) + self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) + self.norm1 = BatchNorm2d(in_features//4, affine=True) + self.norm2 = BatchNorm2d(in_features//4, affine=True) + self.norm3 = BatchNorm2d(in_features, affine=True) + + self.stride = stride + if self.stride != 1: + self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) + self.norm4 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv3(out) + out = self.norm3(out) + if self.stride != 1: + x = self.skip(x) + x = self.norm4(x) + out += x + out = F.relu(out) + return out + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm3d(in_features, affine=True) + self.norm2 = BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + + def forward(self, x): + # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + # self.out_filters = block_expansion + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + # for up_block in self.up_blocks[:-1]: + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + # out = self.up_blocks[-1](out) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class KPHourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): + super(KPHourglass, self).__init__() + + self.down_blocks = nn.Sequential() + for i in range(num_blocks): + self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + + in_filters = min(max_features, block_expansion * (2 ** num_blocks)) + self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) + + self.up_blocks = nn.Sequential() + for i in range(num_blocks): + in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) + out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) + self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.reshape_depth = reshape_depth + self.out_filters = out_filters + + def forward(self, x): + out = self.down_blocks(x) + out = self.conv(out) + bs, c, h, w = out.shape + out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) + out = self.up_blocks(out) + + return out + + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + +class audio2image(nn.Module): + def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): + super().__init__() + # Attributes + self.generator = generator + self.kp_extractor = kp_extractor + self.he_estimator_video = he_estimator_video + self.he_estimator_audio = he_estimator_audio + self.train_params = train_params + + def headpose_pred_to_degree(self, pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + + return degree + + def get_rotation_matrix(self, yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), + torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), + torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), + -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), + torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), + torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) + + return rot_mat + + def keypoint_transformation(self, kp_canonical, he): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] + t, exp = he['t'], he['exp'] + + yaw = self.headpose_pred_to_degree(yaw) + pitch = self.headpose_pred_to_degree(pitch) + roll = self.headpose_pred_to_degree(roll) + + rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + def forward(self, source_image, target_audio): + pose_source = self.he_estimator_video(source_image) + pose_generated = self.he_estimator_audio(target_audio) + kp_canonical = self.kp_extractor(source_image) + kp_source = self.keypoint_transformation(kp_canonical, pose_source) + kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) + generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) + return generated \ No newline at end of file diff --git a/src/facerender/sync_batchnorm/__init__.py b/src/facerender/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/src/facerender/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8327a281a1c119814499648bdec814cf753ba0ba Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-36.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-37.pyc b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e9c9671abd49037eb51d66e7bb6046177433a27 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d5fdb5ff0e14c08894b394b8c1cae7e1f324c4 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-39.pyc b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c0d18c3cec16bbeccbc825186b14c60550563a1 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a89a661e425c0b49c5d616759928e701eab005 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7658dccf719cd85ac0c6e6f6b190ffe6f32c5ed Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20a4560fc425087d5d63c70cc08fd12c2d8a7ea1 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-39.pyc b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c07e4d0f03cd52a105f009d16f079559a5f97e Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7602415a703e1bd2b6008a9bf6dde9778d4349ae Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-36.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-37.pyc b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce98838a834f854dbbc7a8d2f4f1295802e97f3 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-37.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb7252b8ad1b6aec2f5566979db0494f71a63d91 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-39.pyc b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b84f093a8aef9c2b92f0beead2318296163c9e1f Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-39.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a53e2cdf5b5c2d0f7fc9f6c928fe116d629a6c8 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-36.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-37.pyc b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b91c03d671fb5a9334bd4791f6e1f55d397f2e62 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-37.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30c9811579d75333db1b60fe4622f682013f719b Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc differ diff --git a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-39.pyc b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561b184da4d393c548f7eb0b3076c765d4bf3745 Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-39.pyc differ diff --git a/src/facerender/sync_batchnorm/batchnorm.py b/src/facerender/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4e763f0366dffa10320116413f8c7181a8aeb1 --- /dev/null +++ b/src/facerender/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/src/facerender/sync_batchnorm/comm.py b/src/facerender/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/src/facerender/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/src/facerender/sync_batchnorm/replicate.py b/src/facerender/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/src/facerender/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/src/facerender/sync_batchnorm/unittest.py b/src/facerender/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/src/facerender/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/src/generate_batch.py b/src/generate_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9e19b6aa4c19c13caf0a208e1189cd6c19f796 --- /dev/null +++ b/src/generate_batch.py @@ -0,0 +1,94 @@ +import os + +from tqdm import tqdm +import torch +import numpy as np +import random +import scipy.io as scio +import src.utils.audio as audio + +def crop_pad_audio(wav, audio_length): + if len(wav) > audio_length: + wav = wav[:audio_length] + elif len(wav) < audio_length: + wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) + return wav + +def parse_audio_length(audio_length, sr, fps): + bit_per_frames = sr / fps + + num_frames = int(audio_length / bit_per_frames) + audio_length = int(num_frames * bit_per_frames) + + return audio_length, num_frames + +def generate_blink_seq(num_frames): + ratio = np.zeros((num_frames,1)) + frame_id = 0 + while frame_id in range(num_frames): + start = 80 + if frame_id+start+9<=num_frames - 1: + ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] + frame_id = frame_id+start+9 + else: + break + return ratio + +def generate_blink_seq_randomly(num_frames): + ratio = np.zeros((num_frames,1)) + if num_frames<=20: + return ratio + frame_id = 0 + while frame_id in range(num_frames): + start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) + if frame_id+start+5<=num_frames - 1: + ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] + frame_id = frame_id+start+5 + else: + break + return ratio + +def get_data(first_coeff_path, audio_path, device): + + syncnet_mel_step_size = 16 + fps = 25 + + pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + + source_semantics_path = first_coeff_path + source_semantics_dict = scio.loadmat(source_semantics_path) + ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 + + wav = audio.load_wav(audio_path, 16000) + wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) + wav = crop_pad_audio(wav, wav_length) + orig_mel = audio.melspectrogram(wav).T + spec = orig_mel.copy() # nframes 80 + indiv_mels = [] + + for i in tqdm(range(num_frames), 'mel:'): + start_frame_num = i-2 + start_idx = int(80. * (start_frame_num / float(fps))) + end_idx = start_idx + syncnet_mel_step_size + seq = list(range(start_idx, end_idx)) + seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] + m = spec[seq, :] + indiv_mels.append(m.T) + indiv_mels = np.asarray(indiv_mels) # T 80 16 + ratio = generate_blink_seq_randomly(num_frames) # T + + indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16 + ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T + ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70 + + indiv_mels = indiv_mels.to(device) + ratio = ratio.to(device) + ref_coeff = ref_coeff.to(device) + + return {'indiv_mels': indiv_mels, + 'ref': ref_coeff, + 'num_frames': num_frames, + 'ratio_gt': ratio, + 'audio_name': audio_name, 'pic_name': pic_name} + diff --git a/src/generate_facerender_batch.py b/src/generate_facerender_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..fc737cffc8e960828fb6e59ab1c22e7541a307f9 --- /dev/null +++ b/src/generate_facerender_batch.py @@ -0,0 +1,128 @@ +import os +import numpy as np +from PIL import Image +from skimage import io, img_as_float32, transform +import torch +import scipy.io as scio + +def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, + batch_size, camera_yaw_list=[0], camera_pitch_list=[0], camera_roll_list=[0], + expression_scale=1.0, still_mode = False): + + semantic_radius = 13 + video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0] + txt_path = os.path.splitext(coeff_path)[0] + + data={} + + img1 = Image.open(pic_path) + source_image = np.array(img1) + source_image = img_as_float32(source_image) + source_image = transform.resize(source_image, (256, 256, 3)) + source_image = source_image.transpose((2, 0, 1)) + source_image_ts = torch.FloatTensor(source_image).unsqueeze(0) + source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1) + data['source_image'] = source_image_ts + + source_semantics_dict = scio.loadmat(first_coeff_path) + source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 + source_semantics_new = transform_semantic_1(source_semantics, semantic_radius) + source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0) + source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1) + data['source_semantics'] = source_semantics_ts + + # target + generated_dict = scio.loadmat(coeff_path) + generated_3dmm = generated_dict['coeff_3dmm'] + generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale + + if still_mode: + generated_3dmm[:, 64:] = np.repeat(source_semantics[:, 64:], generated_3dmm.shape[0], axis=0) + + with open(txt_path+'.txt', 'w') as f: + for coeff in generated_3dmm: + for i in coeff: + f.write(str(i)[:7] + ' '+'\t') + f.write('\n') + + target_semantics_list = [] + frame_num = generated_3dmm.shape[0] + data['frame_num'] = frame_num + for frame_idx in range(frame_num): + target_semantics = transform_semantic_target(generated_3dmm, frame_idx, semantic_radius) + target_semantics_list.append(target_semantics) + + remainder = frame_num%batch_size + if remainder!=0: + for _ in range(batch_size-remainder): + target_semantics_list.append(target_semantics) + + target_semantics_np = np.array(target_semantics_list) #frame_num 70 semantic_radius*2+1 + target_semantics_np = target_semantics_np.reshape(batch_size, -1, target_semantics_np.shape[-2], target_semantics_np.shape[-1]) + data['target_semantics_list'] = torch.FloatTensor(target_semantics_np) + data['video_name'] = video_name + data['audio_path'] = audio_path + + yaw_c_seq = gen_camera_pose(camera_yaw_list, frame_num, batch_size) + pitch_c_seq = gen_camera_pose(camera_pitch_list, frame_num, batch_size) + roll_c_seq = gen_camera_pose(camera_roll_list, frame_num, batch_size) + + data['yaw_c_seq'] = torch.FloatTensor(yaw_c_seq) + data['pitch_c_seq'] = torch.FloatTensor(pitch_c_seq) + data['roll_c_seq'] = torch.FloatTensor(roll_c_seq) + return data + +def transform_semantic_1(semantic, semantic_radius): + semantic_list = [semantic for i in range(0, semantic_radius*2+1)] + coeff_3dmm = np.concatenate(semantic_list, 0) + return coeff_3dmm.transpose(1,0) + +def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius): + num_frames = coeff_3dmm.shape[0] + seq = list(range(frame_index- semantic_radius, frame_index+ semantic_radius+1)) + index = [ min(max(item, 0), num_frames-1) for item in seq ] + coeff_3dmm_g = coeff_3dmm[index, :] + return coeff_3dmm_g.transpose(1,0) + +def gen_camera_pose(camera_degree_list, frame_num, batch_size): + + new_degree_list = [] + if len(camera_degree_list) == 1: + for _ in range(frame_num): + new_degree_list.append(camera_degree_list[0]) + remainder = frame_num%batch_size + if remainder!=0: + for _ in range(batch_size-remainder): + new_degree_list.append(new_degree_list[-1]) + new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) + return new_degree_np + + degree_sum = 0. + for i, degree in enumerate(camera_degree_list[1:]): + degree_sum += abs(degree-camera_degree_list[i]) + + degree_per_frame = degree_sum/(frame_num-1) + for i, degree in enumerate(camera_degree_list[1:]): + degree_last = camera_degree_list[i] + degree_step = degree_per_frame * abs(degree-degree_last)/(degree-degree_last) + new_degree_list = new_degree_list + list(np.arange(degree_last, degree, degree_step)) + if len(new_degree_list) > frame_num: + new_degree_list = new_degree_list[:frame_num] + elif len(new_degree_list) < frame_num: + for _ in range(frame_num-len(new_degree_list)): + new_degree_list.append(new_degree_list[-1]) + print(len(new_degree_list)) + print(frame_num) + + remainder = frame_num%batch_size + if remainder!=0: + for _ in range(batch_size-remainder): + new_degree_list.append(new_degree_list[-1]) + new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) + return new_degree_np + + + + + + diff --git a/src/gradio_demo.py b/src/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..4f78c97349652e23cf463c49527191fcec795564 --- /dev/null +++ b/src/gradio_demo.py @@ -0,0 +1,113 @@ +import torch, uuid +from time import gmtime, strftime +import os, sys, shutil +from src.utils.preprocess import CropAndExtract +from src.test_audio2coeff import Audio2Coeff +from src.facerender.animate import AnimateFromCoeff +from src.generate_batch import get_data +from src.generate_facerender_batch import get_facerender_data +from src.utils.text2speech import text2speech + +from pydub import AudioSegment + +def mp3_to_wav(mp3_filename,wav_filename,frame_rate): + mp3_file = AudioSegment.from_file(file=mp3_filename) + mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") + + +class SadTalker(): + + def __init__(self, checkpoint_path='checkpoints', config_path='src/config'): + + if torch.cuda.is_available() : + device = "cuda" + else: + device = "cpu" + + os.environ['TORCH_HOME']= checkpoint_path + + path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat') + path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth') + dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting') + wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth') + + audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose_00140-model.pth') + audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml') + + audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp_00300-model.pth') + audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml') + + free_view_checkpoint = os.path.join( checkpoint_path, 'facevid2vid_00189-model.pth.tar') + mapping_checkpoint = os.path.join( checkpoint_path, 'mapping_00229-model.pth.tar') + facerender_yaml_path = os.path.join( config_path, 'facerender.yaml') + + #init model + print(path_of_lm_croper) + self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device) + + print(audio2pose_checkpoint) + self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, + audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device) + print(free_view_checkpoint) + self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, + facerender_yaml_path, device) + self.device = device + + def test(self, source_image, driven_audio, still_mode, use_enhancer, result_dir='./'): + + time_tag = str(uuid.uuid4()) + save_dir = os.path.join(result_dir, time_tag) + os.makedirs(save_dir, exist_ok=True) + + input_dir = os.path.join(save_dir, 'input') + os.makedirs(input_dir, exist_ok=True) + + print(source_image) + pic_path = os.path.join(input_dir, os.path.basename(source_image)) + shutil.move(source_image, input_dir) + + if os.path.isfile(driven_audio): + audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) + + #### mp3 to wav + if '.mp3' in audio_path: + mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) + audio_path = audio_path.replace('.mp3', '.wav') + else: + shutil.move(driven_audio, input_dir) + else: + text2speech + + + os.makedirs(save_dir, exist_ok=True) + pose_style = 0 + #crop image and extract 3dmm from image + first_frame_dir = os.path.join(save_dir, 'first_frame_dir') + os.makedirs(first_frame_dir, exist_ok=True) + first_coeff_path, crop_pic_path, original_size = self.preprocess_model.generate(pic_path, first_frame_dir) + + if first_coeff_path is None: + raise AttributeError("No face is detected") + + #audio2ceoff + batch = get_data(first_coeff_path, audio_path, self.device) # longer audio? + coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style) + #coeff2video + batch_size = 4 + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode) + self.animate_from_coeff.generate(data, save_dir, enhancer='gfpgan' if use_enhancer else None, original_size=original_size) + video_name = data['video_name'] + print(f'The generated video is named {video_name} in {save_dir}') + + torch.cuda.empty_cache() + torch.cuda.synchronize() + import gc; gc.collect() + + if use_enhancer: + return os.path.join(save_dir, video_name+'_enhanced.mp4'), os.path.join(save_dir, video_name+'_enhanced.mp4') + + else: + return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4') + + + \ No newline at end of file diff --git a/src/test_audio2coeff.py b/src/test_audio2coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..3db6be3af59b0319c50106d9a92c903118f28410 --- /dev/null +++ b/src/test_audio2coeff.py @@ -0,0 +1,87 @@ +import os +import torch +import numpy as np +from scipy.io import savemat +from yacs.config import CfgNode as CN +from scipy.signal import savgol_filter + +from src.audio2pose_models.audio2pose import Audio2Pose +from src.audio2exp_models.networks import SimpleWrapperV2 +from src.audio2exp_models.audio2exp import Audio2Exp + +def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if model is not None: + model.load_state_dict(checkpoint['model']) + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + return checkpoint['epoch'] + +class Audio2Coeff(): + + def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path, + audio2exp_checkpoint, audio2exp_yaml_path, + wav2lip_checkpoint, device): + #load config + fcfg_pose = open(audio2pose_yaml_path) + cfg_pose = CN.load_cfg(fcfg_pose) + cfg_pose.freeze() + fcfg_exp = open(audio2exp_yaml_path) + cfg_exp = CN.load_cfg(fcfg_exp) + cfg_exp.freeze() + + # load audio2pose_model + self.audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint, device=device) + self.audio2pose_model = self.audio2pose_model.to(device) + self.audio2pose_model.eval() + for param in self.audio2pose_model.parameters(): + param.requires_grad = False + try: + load_cpk(audio2pose_checkpoint, model=self.audio2pose_model, device=device) + except: + raise Exception("Failed in loading audio2pose_checkpoint") + + # load audio2exp_model + netG = SimpleWrapperV2() + netG = netG.to(device) + for param in netG.parameters(): + netG.requires_grad = False + netG.eval() + try: + load_cpk(audio2exp_checkpoint, model=netG, device=device) + except: + raise Exception("Failed in loading audio2exp_checkpoint") + self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) + self.audio2exp_model = self.audio2exp_model.to(device) + for param in self.audio2exp_model.parameters(): + param.requires_grad = False + self.audio2exp_model.eval() + + self.device = device + + def generate(self, batch, coeff_save_dir, pose_style): + + with torch.no_grad(): + #test + results_dict_exp= self.audio2exp_model.test(batch) + exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 + + #for class_id in range(1): + #class_id = 0#(i+10)%45 + #class_id = random.randint(0,46) #46 styles can be selected + batch['class'] = torch.LongTensor([pose_style]).to(self.device) + results_dict_pose = self.audio2pose_model.test(batch) + pose_pred = results_dict_pose['pose_pred'] #bs T 6 + + pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) + coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 + + coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() + + savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), + {'coeff_3dmm': coeffs_pred_numpy}) + + return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) + + diff --git a/src/utils/__pycache__/audio.cpython-38.pyc b/src/utils/__pycache__/audio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9037ed6e9b29bf1f5ba29b25ed9c067103bb361 Binary files /dev/null and b/src/utils/__pycache__/audio.cpython-38.pyc differ diff --git a/src/utils/__pycache__/croper.cpython-38.pyc b/src/utils/__pycache__/croper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..addfae662741dd661426427e2f29d506c399adba Binary files /dev/null and b/src/utils/__pycache__/croper.cpython-38.pyc differ diff --git a/src/utils/__pycache__/face_enhancer.cpython-38.pyc b/src/utils/__pycache__/face_enhancer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51b465795f49c49c741a7fb510d02564337deb28 Binary files /dev/null and b/src/utils/__pycache__/face_enhancer.cpython-38.pyc differ diff --git a/src/utils/__pycache__/hparams.cpython-38.pyc b/src/utils/__pycache__/hparams.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29278c1421204d040aa03f77ed43e18f9b60dad8 Binary files /dev/null and b/src/utils/__pycache__/hparams.cpython-38.pyc differ diff --git a/src/utils/__pycache__/preprocess.cpython-38.pyc b/src/utils/__pycache__/preprocess.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5e0b7f2a4c29050bfbb30405816311acd3060f0 Binary files /dev/null and b/src/utils/__pycache__/preprocess.cpython-38.pyc differ diff --git a/src/utils/audio.py b/src/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..89433eb4c681112804fbed72b157700f553739a8 --- /dev/null +++ b/src/utils/audio.py @@ -0,0 +1,136 @@ +import librosa +import librosa.filters +import numpy as np +# import tensorflow as tf +from scipy import signal +from scipy.io import wavfile +from src.utils.hparams import hparams as hp + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/src/utils/croper.py b/src/utils/croper.py new file mode 100644 index 0000000000000000000000000000000000000000..e68d280ee4bd83db2089c226af5d4be714fcca9d --- /dev/null +++ b/src/utils/croper.py @@ -0,0 +1,295 @@ +import os +import cv2 +import time +import glob +import argparse +import scipy +import numpy as np +from PIL import Image +from tqdm import tqdm +from itertools import cycle + +from torch.multiprocessing import Pool, Process, set_start_method + + +""" +brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) +author: lzhbrian (https://lzhbrian.me) +date: 2020.1.5 +note: code is heavily borrowed from + https://github.com/NVlabs/ffhq-dataset + http://dlib.net/face_landmark_detection.py.html +requirements: + apt install cmake + conda install Pillow numpy scipy + pip install dlib + # download face landmark model from: + # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +""" + +import numpy as np +from PIL import Image +import dlib + + +class Croper: + def __init__(self, path_of_lm): + # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 + self.predictor = dlib.shape_predictor(path_of_lm) + + def get_landmark(self, img_np): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + dets = detector(img_np, 1) + # print("Number of faces detected: {}".format(len(dets))) + # for k, d in enumerate(dets): + if len(dets) == 0: + return None + d = dets[0] + # Get the landmarks/parts for the face in box d. + shape = self.predictor(img_np, d) + # print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + # lm is a shape=(68,2) np.array + return lm + + def align_face(self, img, lm, output_size=1024): + """ + :param filepath: str + :return: PIL Image + """ + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # 双眼差与双嘴差相加 + x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 + qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 + + # Shrink. + # 如果计算出的四边形太大了,就按比例缩小它 + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + # img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + # if enable_padding and max(pad) > border - 4: + # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # h, w, _ = img.shape + # y, x, _ = np.ogrid[:h, :w, :1] + # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + # blur = qsize * 0.02 + # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + # quad += pad[:2] + + # Transform. + quad = (quad + 0.5).flatten() + lx = max(min(quad[0], quad[2]), 0) + ly = max(min(quad[1], quad[7]), 0) + rx = min(max(quad[4], quad[6]), img.size[0]) + ry = min(max(quad[3], quad[5]), img.size[0]) + # img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), + # Image.BILINEAR) + # if output_size < transform_size: + # img = img.resize((output_size, output_size), Image.ANTIALIAS) + + # Save aligned image. + return crop, [lx, ly, rx, ry] + + # def crop(self, img_np_list): + # for _i in range(len(img_np_list)): + # img_np = img_np_list[_i] + # lm = self.get_landmark(img_np) + # if lm is None: + # return None + # crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=512) + # clx, cly, crx, cry = crop + # lx, ly, rx, ry = quad + # lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + + # _inp = img_np_list[_i] + # _inp = _inp[cly:cry, clx:crx] + # _inp = _inp[ly:ry, lx:rx] + # img_np_list[_i] = _inp + # return img_np_list + + def crop(self, img_np_list, xsize=512): # first frame for all video + img_np = img_np_list[0] + lm = self.get_landmark(img_np) + if lm is None: + return None + crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + for _i in range(len(img_np_list)): + _inp = img_np_list[_i] + _inp = _inp[cly:cry, clx:crx] + # cv2.imwrite('test1.jpg', _inp) + _inp = _inp[ly:ry, lx:rx] + # cv2.imwrite('test2.jpg', _inp) + img_np_list[_i] = _inp + return img_np_list, crop, quad + + +def read_video(filename, uplimit=100): + frames = [] + cap = cv2.VideoCapture(filename) + cnt = 0 + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.resize(frame, (512, 512)) + frames.append(frame) + else: + break + cnt += 1 + if cnt >= uplimit: + break + cap.release() + assert len(frames) > 0, f'{filename}: video with no frames!' + return frames + + +def create_video(video_name, frames, fps=25, video_format='.mp4', resize_ratio=1): + # video_name = os.path.dirname(image_folder) + video_format + # img_list = glob.glob1(image_folder, 'frame*') + # img_list.sort() + # frame = cv2.imread(os.path.join(image_folder, img_list[0])) + # frame = cv2.resize(frame, (0, 0), fx=resize_ratio, fy=resize_ratio) + # height, width, layers = frames[0].shape + height, width, layers = 512, 512, 3 + if video_format == '.mp4': + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + elif video_format == '.avi': + fourcc = cv2.VideoWriter_fourcc(*'XVID') + video = cv2.VideoWriter(video_name, fourcc, fps, (width, height)) + for _frame in frames: + _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR) + video.write(_frame) + +def create_images(video_name, frames): + height, width, layers = 512, 512, 3 + images_dir = video_name.split('.')[0] + os.makedirs(images_dir, exist_ok=True) + for i, _frame in enumerate(frames): + _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR) + _frame_path = os.path.join(images_dir, str(i)+'.jpg') + cv2.imwrite(_frame_path, _frame) + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + croper = Croper() + + frames = read_video(filename, uplimit=opt.uplimit) + name = filename.split('/')[-1] # .split('.')[0] + name = os.path.join(opt.output_dir, name) + + frames = croper.crop(frames) + if frames is None: + print(f'{name}: detect no face. should removed') + return + # create_video(name, frames) + create_images(name, frames) + + +def get_data_path(video_dir): + eg_video_files = ['/apdcephfs/share_1290939/quincheng/datasets/HDTF/backup_fps25/WDA_KatieHill_000.mp4'] + # filenames = list() + # VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + # VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + # extensions = VIDEO_EXTENSIONS + # for ext in extensions: + # filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}')) + # print('Total number of videos:', len(filenames)) + return eg_video_files + + +def get_wra_data_path(video_dir): + if opt.option == 'video': + videos_path = sorted(glob.glob(f'{video_dir}/*.mp4')) + elif opt.option == 'image': + videos_path = sorted(glob.glob(f'{video_dir}/*/')) + else: + raise NotImplementedError + print('Example videos: ', videos_path[:2]) + return videos_path + + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=8) + parser.add_argument('--uplimit', type=int, default=500) + parser.add_argument('--option', type=str, default='video') + + root = '/apdcephfs/share_1290939/quincheng/datasets/HDTF' + cmd = f'--input_dir {root}/backup_fps25_first20s_sync/ ' \ + f'--output_dir {root}/crop512_stylegan_firstframe_sync/ ' \ + '--device_ids 0 ' \ + '--workers 8 ' \ + '--option video ' \ + '--uplimit 500 ' + opt = parser.parse_args(cmd.split()) + # filenames = get_data_path(opt.input_dir) + filenames = get_wra_data_path(opt.input_dir) + os.makedirs(opt.output_dir, exist_ok=True) + print(f'Video numbers: {len(filenames)}') + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/src/utils/face_enhancer.py b/src/utils/face_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..6192649d7141f2cd05f1302f7c954bfb8fa612fa --- /dev/null +++ b/src/utils/face_enhancer.py @@ -0,0 +1,60 @@ +import os +from basicsr.utils import imwrite + +from gfpgan import GFPGANer + +from tqdm import tqdm + +def enhancer(images, method='gfpgan'): + + # ------------------------ set up GFPGAN restorer ------------------------ + if method == 'gfpgan': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.4' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' + elif method == 'RestoreFormer': + arch = 'RestoreFormer' + channel_multiplier = 2 + model_name = 'RestoreFormer' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' + elif method == 'codeformer': + arch = 'CodeFormer' + channel_multiplier = 2 + model_name = 'CodeFormer' + url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + else: + raise ValueError(f'Wrong model version {method}.') + + # determine model paths + model_path = os.path.join('experiments/pretrained_models', model_name + '.pth') + + if not os.path.isfile(model_path): + model_path = os.path.join('checkpoints', model_name + '.pth') + + if not os.path.isfile(model_path): + # download pre-trained models from url + model_path = url + + restorer = GFPGANer( + model_path=model_path, + upscale=2, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + # ------------------------ restore ------------------------ + restored_img = [] + for idx in tqdm(range(len(images)), 'Face Enhancer:'): + + # restore faces and background if necessary + cropped_faces, restored_faces, _ = restorer.enhance( + images[idx], + has_aligned=True, + only_center_face=False, + paste_back=True, + weight=0.5) + + restored_img += restored_faces + + return restored_img \ No newline at end of file diff --git a/src/utils/hparams.py b/src/utils/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..743c5c7d5a5a9e686f1ccd6fb3c2fb5cb382d62b --- /dev/null +++ b/src/utils/hparams.py @@ -0,0 +1,160 @@ +from glob import glob +import os + +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=20, + checkpoint_interval=3000, + eval_interval=3000, + writer_interval=300, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=1000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + + +# Default hyperparameters +hparamsdebug = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=2, + initial_learning_rate=1e-3, + nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=0, + checkpoint_interval=10000, + eval_interval=10, + writer_interval=5, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] + return "Hyperparameters:\n" + "\n".join(hp) diff --git a/src/utils/preprocess.py b/src/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3dad8d4a49080a3300f672965a11a8a2054fa2 --- /dev/null +++ b/src/utils/preprocess.py @@ -0,0 +1,152 @@ +import numpy as np +import cv2, os, sys, torch +from tqdm import tqdm +from PIL import Image + +# 3dmm extraction +from src.face3d.util.preprocess import align_img +from src.face3d.util.load_mats import load_lm3d +from src.face3d.models import networks +from src.face3d.extract_kp_videos import KeypointExtractor + +from scipy.io import loadmat, savemat +from src.utils.croper import Croper + +import warnings +warnings.filterwarnings("ignore") + +def split_coeff(coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + + +class CropAndExtract(): + def __init__(self, path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device): + + self.croper = Croper(path_of_lm_croper) + self.kp_extractor = KeypointExtractor(device) + self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) + checkpoint = torch.load(path_of_net_recon_model, map_location=torch.device(device)) + self.net_recon.load_state_dict(checkpoint['net_recon']) + self.net_recon.eval() + self.lm3d_std = load_lm3d(dir_of_BFM_fitting) + self.device = device + + def generate(self, input_path, save_dir, crop_or_resize='crop'): + + pic_size = 256 + pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] + + landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') + coeff_path = os.path.join(save_dir, pic_name+'.mat') + png_path = os.path.join(save_dir, pic_name+'.png') + + #load input + if not os.path.isfile(input_path): + raise ValueError('input_path must be a valid path to video/image file') + elif input_path.split('.')[1] in ['jpg', 'png', 'jpeg']: + # loader for first frame + full_frames = [cv2.imread(input_path)] + fps = 25 + else: + # loader for videos + video_stream = cv2.VideoCapture(input_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + full_frames.append(frame) + break + x_full_frames = [cv2.cvtColor(full_frames[0], cv2.COLOR_BGR2RGB) ] + + if crop_or_resize.lower() == 'crop': # default crop + x_full_frames, crop, quad = self.croper.crop(x_full_frames, xsize=pic_size) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + original_size = (ox2 - ox1, oy2 - oy1) + else: + oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] + original_size = (ox2 - ox1, oy2 - oy1) + + frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] + if len(frames_pil) == 0: + print('No face is detected in the input file') + return None, None + + # save crop info + for frame in frames_pil: + cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) + + # 2. get the landmark according to the detected face. + if not os.path.isfile(landmarks_path): + lm = self.kp_extractor.extract_keypoint(frames_pil, landmarks_path) + else: + print(' Using saved landmarks.') + lm = np.loadtxt(landmarks_path).astype(np.float32) + lm = lm.reshape([len(x_full_frames), -1, 2]) + + if not os.path.isfile(coeff_path): + # load 3dmm paramter generator from Deep3DFaceRecon_pytorch + video_coeffs, full_coeffs = [], [] + for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): + frame = frames_pil[idx] + W,H = frame.size + lm1 = lm[idx].reshape([-1, 2]) + + if np.mean(lm1) == -1: + lm1 = (self.lm3d_std[:, :2]+1)/2. + lm1 = np.concatenate( + [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 + ) + else: + lm1[:, -1] = H - 1 - lm1[:, -1] + + trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) + + trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32) + im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) + + with torch.no_grad(): + full_coeff = self.net_recon(im_t) + coeffs = split_coeff(full_coeff) + + pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} + + pred_coeff = np.concatenate([ + pred_coeff['exp'], + pred_coeff['angle'], + pred_coeff['trans'], + trans_params[2:][None], + ], 1) + video_coeffs.append(pred_coeff) + full_coeffs.append(full_coeff.cpu().numpy()) + + semantic_npy = np.array(video_coeffs)[:,0] + + savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]}) + + return coeff_path, png_path, original_size \ No newline at end of file diff --git a/src/utils/text2speech.py b/src/utils/text2speech.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecaef36961494c8b2b1f5771a70b997efa04ffd --- /dev/null +++ b/src/utils/text2speech.py @@ -0,0 +1,12 @@ +import os + +def text2speech(txt, audio_path): + print(txt) + cmd = f'tts --text "{txt}" --out_path {audio_path}' + print(cmd) + try: + os.system(cmd) + return audio_path + except: + print("Error: Failed convert txt to audio") + return None \ No newline at end of file