Spaces:
Sleeping
Sleeping
update repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +2 -0
- LICENSE.txt +97 -0
- dnnlib/__init__.py +9 -0
- dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
- dnnlib/__pycache__/util.cpython-38.pyc +0 -0
- dnnlib/util.py +473 -0
- encoder4editing/LICENSE +21 -0
- encoder4editing/configs/__init__.py +0 -0
- encoder4editing/configs/data_configs.py +41 -0
- encoder4editing/configs/paths_config.py +28 -0
- encoder4editing/configs/transforms_config.py +62 -0
- encoder4editing/criteria/__init__.py +0 -0
- encoder4editing/criteria/id_loss.py +47 -0
- encoder4editing/criteria/lpips/__init__.py +0 -0
- encoder4editing/criteria/lpips/lpips.py +35 -0
- encoder4editing/criteria/lpips/networks.py +96 -0
- encoder4editing/criteria/lpips/utils.py +30 -0
- encoder4editing/criteria/moco_loss.py +71 -0
- encoder4editing/criteria/w_norm.py +14 -0
- encoder4editing/datasets/__init__.py +0 -0
- encoder4editing/datasets/gt_res_dataset.py +32 -0
- encoder4editing/datasets/images_dataset.py +33 -0
- encoder4editing/datasets/inference_dataset.py +25 -0
- encoder4editing/editings/ganspace.py +22 -0
- encoder4editing/editings/ganspace_pca/cars_pca.pt +3 -0
- encoder4editing/editings/ganspace_pca/ffhq_pca.pt +3 -0
- encoder4editing/editings/interfacegan_directions/age.pt +3 -0
- encoder4editing/editings/interfacegan_directions/pose.pt +3 -0
- encoder4editing/editings/interfacegan_directions/smile.pt +3 -0
- encoder4editing/editings/latent_editor.py +45 -0
- encoder4editing/editings/sefa.py +46 -0
- encoder4editing/environment/e4e_env.yaml +73 -0
- encoder4editing/infer.py +134 -0
- encoder4editing/metrics/LEC.py +134 -0
- encoder4editing/models/__init__.py +0 -0
- encoder4editing/models/discriminator.py +20 -0
- encoder4editing/models/encoders/__init__.py +0 -0
- encoder4editing/models/encoders/helpers.py +140 -0
- encoder4editing/models/encoders/model_irse.py +84 -0
- encoder4editing/models/encoders/psp_encoders.py +235 -0
- encoder4editing/models/latent_codes_pool.py +55 -0
- encoder4editing/models/psp.py +100 -0
- encoder4editing/models/stylegan2/__init__.py +0 -0
- encoder4editing/models/stylegan2/model.py +673 -0
- encoder4editing/models/stylegan2/op/__init__.py +2 -0
- encoder4editing/models/stylegan2/op/fused_act.py +85 -0
- encoder4editing/models/stylegan2/op/fused_bias_act.cpp +21 -0
- encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
- encoder4editing/models/stylegan2/op/upfirdn2d.cpp +23 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.gitattributes
CHANGED
|
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.pth* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
=======================================================================
|
| 8 |
+
|
| 9 |
+
1. Definitions
|
| 10 |
+
|
| 11 |
+
"Licensor" means any person or entity that distributes its Work.
|
| 12 |
+
|
| 13 |
+
"Software" means the original work of authorship made available under
|
| 14 |
+
this License.
|
| 15 |
+
|
| 16 |
+
"Work" means the Software and any additions to or derivative works of
|
| 17 |
+
the Software that are made available under this License.
|
| 18 |
+
|
| 19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
| 20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
| 21 |
+
provided, however, that for the purposes of this License, derivative
|
| 22 |
+
works shall not include works that remain separable from, or merely
|
| 23 |
+
link (or bind by name) to the interfaces of, the Work.
|
| 24 |
+
|
| 25 |
+
Works, including the Software, are "made available" under this License
|
| 26 |
+
by including in or with the Work either (a) a copyright notice
|
| 27 |
+
referencing the applicability of this License to the Work, or (b) a
|
| 28 |
+
copy of this License.
|
| 29 |
+
|
| 30 |
+
2. License Grants
|
| 31 |
+
|
| 32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
| 33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
| 34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
| 35 |
+
prepare derivative works of, publicly display, publicly perform,
|
| 36 |
+
sublicense and distribute its Work and any resulting derivative
|
| 37 |
+
works in any form.
|
| 38 |
+
|
| 39 |
+
3. Limitations
|
| 40 |
+
|
| 41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
| 42 |
+
if (a) you do so under this License, (b) you include a complete
|
| 43 |
+
copy of this License with your distribution, and (c) you retain
|
| 44 |
+
without modification any copyright, patent, trademark, or
|
| 45 |
+
attribution notices that are present in the Work.
|
| 46 |
+
|
| 47 |
+
3.2 Derivative Works. You may specify that additional or different
|
| 48 |
+
terms apply to the use, reproduction, and distribution of your
|
| 49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
| 50 |
+
provide that the use limitation in Section 3.3 applies to your
|
| 51 |
+
derivative works, and (b) you identify the specific derivative
|
| 52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
| 53 |
+
this License (including the redistribution requirements in Section
|
| 54 |
+
3.1) will continue to apply to the Work itself.
|
| 55 |
+
|
| 56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
| 57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
| 58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
| 59 |
+
derivative works commercially. As used herein, "non-commercially"
|
| 60 |
+
means for research or evaluation purposes only.
|
| 61 |
+
|
| 62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
| 63 |
+
against any Licensor (including any claim, cross-claim or
|
| 64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
| 65 |
+
are infringed by any Work, then your rights under this License from
|
| 66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
| 67 |
+
immediately.
|
| 68 |
+
|
| 69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
| 70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
| 71 |
+
as necessary to reproduce the notices described in this License.
|
| 72 |
+
|
| 73 |
+
3.6 Termination. If you violate any term of this License, then your
|
| 74 |
+
rights under this License (including the grant in Section 2.1) will
|
| 75 |
+
terminate immediately.
|
| 76 |
+
|
| 77 |
+
4. Disclaimer of Warranty.
|
| 78 |
+
|
| 79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
| 80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
| 81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
| 82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
| 83 |
+
THIS LICENSE.
|
| 84 |
+
|
| 85 |
+
5. Limitation of Liability.
|
| 86 |
+
|
| 87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
| 88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
| 89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
| 90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
| 91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
| 92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
| 93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
| 94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
| 95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
| 96 |
+
|
| 97 |
+
=======================================================================
|
dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
dnnlib/__pycache__/util.cpython-38.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
dnnlib/util.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Miscellaneous utility classes and functions."""
|
| 10 |
+
|
| 11 |
+
import ctypes
|
| 12 |
+
import fnmatch
|
| 13 |
+
import importlib
|
| 14 |
+
import inspect
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import sys
|
| 19 |
+
import types
|
| 20 |
+
import io
|
| 21 |
+
import pickle
|
| 22 |
+
import re
|
| 23 |
+
import requests
|
| 24 |
+
import html
|
| 25 |
+
import hashlib
|
| 26 |
+
import glob
|
| 27 |
+
import tempfile
|
| 28 |
+
import urllib
|
| 29 |
+
import urllib.request
|
| 30 |
+
import uuid
|
| 31 |
+
|
| 32 |
+
from distutils.util import strtobool
|
| 33 |
+
from typing import Any, List, Tuple, Union
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class EasyDict(dict):
|
| 37 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 38 |
+
|
| 39 |
+
def __getattr__(self, name: str) -> Any:
|
| 40 |
+
try:
|
| 41 |
+
return self[name]
|
| 42 |
+
except KeyError:
|
| 43 |
+
raise AttributeError(name)
|
| 44 |
+
|
| 45 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 46 |
+
self[name] = value
|
| 47 |
+
|
| 48 |
+
def __delattr__(self, name: str) -> None:
|
| 49 |
+
del self[name]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Logger(object):
|
| 53 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
| 56 |
+
self.file = None
|
| 57 |
+
|
| 58 |
+
if file_name is not None:
|
| 59 |
+
self.file = open(file_name, file_mode)
|
| 60 |
+
|
| 61 |
+
self.should_flush = should_flush
|
| 62 |
+
self.stdout = sys.stdout
|
| 63 |
+
self.stderr = sys.stderr
|
| 64 |
+
|
| 65 |
+
sys.stdout = self
|
| 66 |
+
sys.stderr = self
|
| 67 |
+
|
| 68 |
+
def __enter__(self) -> "Logger":
|
| 69 |
+
return self
|
| 70 |
+
|
| 71 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 72 |
+
self.close()
|
| 73 |
+
|
| 74 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 75 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 76 |
+
if isinstance(text, bytes):
|
| 77 |
+
text = text.decode()
|
| 78 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
if self.file is not None:
|
| 82 |
+
self.file.write(text)
|
| 83 |
+
|
| 84 |
+
self.stdout.write(text)
|
| 85 |
+
|
| 86 |
+
if self.should_flush:
|
| 87 |
+
self.flush()
|
| 88 |
+
|
| 89 |
+
def flush(self) -> None:
|
| 90 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 91 |
+
if self.file is not None:
|
| 92 |
+
self.file.flush()
|
| 93 |
+
|
| 94 |
+
self.stdout.flush()
|
| 95 |
+
|
| 96 |
+
def close(self) -> None:
|
| 97 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 98 |
+
self.flush()
|
| 99 |
+
|
| 100 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 101 |
+
if sys.stdout is self:
|
| 102 |
+
sys.stdout = self.stdout
|
| 103 |
+
if sys.stderr is self:
|
| 104 |
+
sys.stderr = self.stderr
|
| 105 |
+
|
| 106 |
+
if self.file is not None:
|
| 107 |
+
self.file.close()
|
| 108 |
+
self.file = None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Cache directories
|
| 112 |
+
# ------------------------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
_dnnlib_cache_dir = None
|
| 115 |
+
|
| 116 |
+
def set_cache_dir(path: str) -> None:
|
| 117 |
+
global _dnnlib_cache_dir
|
| 118 |
+
_dnnlib_cache_dir = path
|
| 119 |
+
|
| 120 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 121 |
+
if _dnnlib_cache_dir is not None:
|
| 122 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 123 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 124 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 125 |
+
if 'HOME' in os.environ:
|
| 126 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 127 |
+
if 'USERPROFILE' in os.environ:
|
| 128 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 129 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 130 |
+
|
| 131 |
+
# Small util functions
|
| 132 |
+
# ------------------------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 136 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 137 |
+
s = int(np.rint(seconds))
|
| 138 |
+
|
| 139 |
+
if s < 60:
|
| 140 |
+
return "{0}s".format(s)
|
| 141 |
+
elif s < 60 * 60:
|
| 142 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 143 |
+
elif s < 24 * 60 * 60:
|
| 144 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 145 |
+
else:
|
| 146 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def ask_yes_no(question: str) -> bool:
|
| 150 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 151 |
+
while True:
|
| 152 |
+
try:
|
| 153 |
+
print("{0} [y/n]".format(question))
|
| 154 |
+
return strtobool(input().lower())
|
| 155 |
+
except ValueError:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def tuple_product(t: Tuple) -> Any:
|
| 160 |
+
"""Calculate the product of the tuple elements."""
|
| 161 |
+
result = 1
|
| 162 |
+
|
| 163 |
+
for v in t:
|
| 164 |
+
result *= v
|
| 165 |
+
|
| 166 |
+
return result
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
_str_to_ctype = {
|
| 170 |
+
"uint8": ctypes.c_ubyte,
|
| 171 |
+
"uint16": ctypes.c_uint16,
|
| 172 |
+
"uint32": ctypes.c_uint32,
|
| 173 |
+
"uint64": ctypes.c_uint64,
|
| 174 |
+
"int8": ctypes.c_byte,
|
| 175 |
+
"int16": ctypes.c_int16,
|
| 176 |
+
"int32": ctypes.c_int32,
|
| 177 |
+
"int64": ctypes.c_int64,
|
| 178 |
+
"float32": ctypes.c_float,
|
| 179 |
+
"float64": ctypes.c_double
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 184 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 185 |
+
type_str = None
|
| 186 |
+
|
| 187 |
+
if isinstance(type_obj, str):
|
| 188 |
+
type_str = type_obj
|
| 189 |
+
elif hasattr(type_obj, "__name__"):
|
| 190 |
+
type_str = type_obj.__name__
|
| 191 |
+
elif hasattr(type_obj, "name"):
|
| 192 |
+
type_str = type_obj.name
|
| 193 |
+
else:
|
| 194 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 195 |
+
|
| 196 |
+
assert type_str in _str_to_ctype.keys()
|
| 197 |
+
|
| 198 |
+
my_dtype = np.dtype(type_str)
|
| 199 |
+
my_ctype = _str_to_ctype[type_str]
|
| 200 |
+
|
| 201 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 202 |
+
|
| 203 |
+
return my_dtype, my_ctype
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def is_pickleable(obj: Any) -> bool:
|
| 207 |
+
try:
|
| 208 |
+
with io.BytesIO() as stream:
|
| 209 |
+
pickle.dump(obj, stream)
|
| 210 |
+
return True
|
| 211 |
+
except:
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 216 |
+
# ------------------------------------------------------------------------------------------
|
| 217 |
+
|
| 218 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 219 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 220 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 221 |
+
|
| 222 |
+
# allow convenience shorthands, substitute them by full names
|
| 223 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 224 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 225 |
+
|
| 226 |
+
# list alternatives for (module_name, local_obj_name)
|
| 227 |
+
parts = obj_name.split(".")
|
| 228 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 229 |
+
|
| 230 |
+
# try each alternative in turn
|
| 231 |
+
for module_name, local_obj_name in name_pairs:
|
| 232 |
+
try:
|
| 233 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 234 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 235 |
+
return module, local_obj_name
|
| 236 |
+
except:
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
# maybe some of the modules themselves contain errors?
|
| 240 |
+
for module_name, _local_obj_name in name_pairs:
|
| 241 |
+
try:
|
| 242 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 243 |
+
except ImportError:
|
| 244 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 245 |
+
raise
|
| 246 |
+
|
| 247 |
+
# maybe the requested attribute is missing?
|
| 248 |
+
for module_name, local_obj_name in name_pairs:
|
| 249 |
+
try:
|
| 250 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 251 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 252 |
+
except ImportError:
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
# we are out of luck, but we have no idea why
|
| 256 |
+
raise ImportError(obj_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 260 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 261 |
+
if obj_name == '':
|
| 262 |
+
return module
|
| 263 |
+
obj = module
|
| 264 |
+
for part in obj_name.split("."):
|
| 265 |
+
obj = getattr(obj, part)
|
| 266 |
+
return obj
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def get_obj_by_name(name: str) -> Any:
|
| 270 |
+
"""Finds the python object with the given name."""
|
| 271 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 272 |
+
return get_obj_from_module(module, obj_name)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 276 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 277 |
+
assert func_name is not None
|
| 278 |
+
func_obj = get_obj_by_name(func_name)
|
| 279 |
+
assert callable(func_obj)
|
| 280 |
+
return func_obj(*args, **kwargs)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 284 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 285 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 289 |
+
"""Get the directory path of the module containing the given object name."""
|
| 290 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 291 |
+
return os.path.dirname(inspect.getfile(module))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 295 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 296 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 300 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 301 |
+
assert is_top_level_function(obj)
|
| 302 |
+
module = obj.__module__
|
| 303 |
+
if module == '__main__':
|
| 304 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 305 |
+
return module + "." + obj.__name__
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# File system helpers
|
| 309 |
+
# ------------------------------------------------------------------------------------------
|
| 310 |
+
|
| 311 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 312 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 313 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 314 |
+
assert os.path.isdir(dir_path)
|
| 315 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 316 |
+
|
| 317 |
+
if ignores is None:
|
| 318 |
+
ignores = []
|
| 319 |
+
|
| 320 |
+
result = []
|
| 321 |
+
|
| 322 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 323 |
+
for ignore_ in ignores:
|
| 324 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 325 |
+
|
| 326 |
+
# dirs need to be edited in-place
|
| 327 |
+
for d in dirs_to_remove:
|
| 328 |
+
dirs.remove(d)
|
| 329 |
+
|
| 330 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 331 |
+
|
| 332 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 333 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 334 |
+
|
| 335 |
+
if add_base_to_relative:
|
| 336 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 337 |
+
|
| 338 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 339 |
+
result += zip(absolute_paths, relative_paths)
|
| 340 |
+
|
| 341 |
+
return result
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 345 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 346 |
+
Will create all necessary directories."""
|
| 347 |
+
for file in files:
|
| 348 |
+
target_dir_name = os.path.dirname(file[1])
|
| 349 |
+
|
| 350 |
+
# will create all intermediate-level directories
|
| 351 |
+
if not os.path.exists(target_dir_name):
|
| 352 |
+
os.makedirs(target_dir_name)
|
| 353 |
+
|
| 354 |
+
shutil.copyfile(file[0], file[1])
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# URL helpers
|
| 358 |
+
# ------------------------------------------------------------------------------------------
|
| 359 |
+
|
| 360 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 361 |
+
"""Determine whether the given object is a valid URL string."""
|
| 362 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 363 |
+
return False
|
| 364 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 365 |
+
return True
|
| 366 |
+
try:
|
| 367 |
+
res = requests.compat.urlparse(obj)
|
| 368 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 369 |
+
return False
|
| 370 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 371 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 372 |
+
return False
|
| 373 |
+
except:
|
| 374 |
+
return False
|
| 375 |
+
return True
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 379 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 380 |
+
assert num_attempts >= 1
|
| 381 |
+
assert not (return_filename and (not cache))
|
| 382 |
+
|
| 383 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 384 |
+
if not re.match('^[a-z]+://', url):
|
| 385 |
+
return url if return_filename else open(url, "rb")
|
| 386 |
+
|
| 387 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 388 |
+
# arise on Windows:
|
| 389 |
+
#
|
| 390 |
+
# file:///c:/foo.txt
|
| 391 |
+
#
|
| 392 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 393 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 394 |
+
#
|
| 395 |
+
# If you touch this code path, you should test it on both Linux and
|
| 396 |
+
# Windows.
|
| 397 |
+
#
|
| 398 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 399 |
+
# but that converts forward slashes to backslashes and this causes
|
| 400 |
+
# its own set of problems.
|
| 401 |
+
if url.startswith('file://'):
|
| 402 |
+
filename = urllib.parse.urlparse(url).path
|
| 403 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 404 |
+
filename = filename[1:]
|
| 405 |
+
return filename if return_filename else open(filename, "rb")
|
| 406 |
+
|
| 407 |
+
assert is_url(url)
|
| 408 |
+
|
| 409 |
+
# Lookup from cache.
|
| 410 |
+
if cache_dir is None:
|
| 411 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 412 |
+
|
| 413 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 414 |
+
if cache:
|
| 415 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 416 |
+
if len(cache_files) == 1:
|
| 417 |
+
filename = cache_files[0]
|
| 418 |
+
return filename if return_filename else open(filename, "rb")
|
| 419 |
+
|
| 420 |
+
# Download.
|
| 421 |
+
url_name = None
|
| 422 |
+
url_data = None
|
| 423 |
+
with requests.Session() as session:
|
| 424 |
+
if verbose:
|
| 425 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 426 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 427 |
+
try:
|
| 428 |
+
with session.get(url) as res:
|
| 429 |
+
res.raise_for_status()
|
| 430 |
+
if len(res.content) == 0:
|
| 431 |
+
raise IOError("No data received")
|
| 432 |
+
|
| 433 |
+
if len(res.content) < 8192:
|
| 434 |
+
content_str = res.content.decode("utf-8")
|
| 435 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 436 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 437 |
+
if len(links) == 1:
|
| 438 |
+
url = requests.compat.urljoin(url, links[0])
|
| 439 |
+
raise IOError("Google Drive virus checker nag")
|
| 440 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 441 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 442 |
+
|
| 443 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 444 |
+
url_name = match[1] if match else url
|
| 445 |
+
url_data = res.content
|
| 446 |
+
if verbose:
|
| 447 |
+
print(" done")
|
| 448 |
+
break
|
| 449 |
+
except KeyboardInterrupt:
|
| 450 |
+
raise
|
| 451 |
+
except:
|
| 452 |
+
if not attempts_left:
|
| 453 |
+
if verbose:
|
| 454 |
+
print(" failed")
|
| 455 |
+
raise
|
| 456 |
+
if verbose:
|
| 457 |
+
print(".", end="", flush=True)
|
| 458 |
+
|
| 459 |
+
# Save to cache.
|
| 460 |
+
if cache:
|
| 461 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 462 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 463 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 464 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 465 |
+
with open(temp_file, "wb") as f:
|
| 466 |
+
f.write(url_data)
|
| 467 |
+
os.replace(temp_file, cache_file) # atomic
|
| 468 |
+
if return_filename:
|
| 469 |
+
return cache_file
|
| 470 |
+
|
| 471 |
+
# Return data as file object.
|
| 472 |
+
assert not return_filename
|
| 473 |
+
return io.BytesIO(url_data)
|
encoder4editing/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 omertov
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
encoder4editing/configs/__init__.py
ADDED
|
File without changes
|
encoder4editing/configs/data_configs.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from configs import transforms_config
|
| 2 |
+
from configs.paths_config import dataset_paths
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
DATASETS = {
|
| 6 |
+
'ffhq_encode': {
|
| 7 |
+
'transforms': transforms_config.EncodeTransforms,
|
| 8 |
+
'train_source_root': dataset_paths['ffhq'],
|
| 9 |
+
'train_target_root': dataset_paths['ffhq'],
|
| 10 |
+
'test_source_root': dataset_paths['celeba_test'],
|
| 11 |
+
'test_target_root': dataset_paths['celeba_test'],
|
| 12 |
+
},
|
| 13 |
+
'cars_encode': {
|
| 14 |
+
'transforms': transforms_config.CarsEncodeTransforms,
|
| 15 |
+
'train_source_root': dataset_paths['cars_train'],
|
| 16 |
+
'train_target_root': dataset_paths['cars_train'],
|
| 17 |
+
'test_source_root': dataset_paths['cars_test'],
|
| 18 |
+
'test_target_root': dataset_paths['cars_test'],
|
| 19 |
+
},
|
| 20 |
+
'horse_encode': {
|
| 21 |
+
'transforms': transforms_config.EncodeTransforms,
|
| 22 |
+
'train_source_root': dataset_paths['horse_train'],
|
| 23 |
+
'train_target_root': dataset_paths['horse_train'],
|
| 24 |
+
'test_source_root': dataset_paths['horse_test'],
|
| 25 |
+
'test_target_root': dataset_paths['horse_test'],
|
| 26 |
+
},
|
| 27 |
+
'church_encode': {
|
| 28 |
+
'transforms': transforms_config.EncodeTransforms,
|
| 29 |
+
'train_source_root': dataset_paths['church_train'],
|
| 30 |
+
'train_target_root': dataset_paths['church_train'],
|
| 31 |
+
'test_source_root': dataset_paths['church_test'],
|
| 32 |
+
'test_target_root': dataset_paths['church_test'],
|
| 33 |
+
},
|
| 34 |
+
'cats_encode': {
|
| 35 |
+
'transforms': transforms_config.EncodeTransforms,
|
| 36 |
+
'train_source_root': dataset_paths['cats_train'],
|
| 37 |
+
'train_target_root': dataset_paths['cats_train'],
|
| 38 |
+
'test_source_root': dataset_paths['cats_test'],
|
| 39 |
+
'test_target_root': dataset_paths['cats_test'],
|
| 40 |
+
}
|
| 41 |
+
}
|
encoder4editing/configs/paths_config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_paths = {
|
| 2 |
+
# Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
|
| 3 |
+
'ffhq': '',
|
| 4 |
+
'celeba_test': '',
|
| 5 |
+
|
| 6 |
+
# Cars Dataset (In the paper: Stanford cars)
|
| 7 |
+
'cars_train': '',
|
| 8 |
+
'cars_test': '',
|
| 9 |
+
|
| 10 |
+
# Horse Dataset (In the paper: LSUN Horse)
|
| 11 |
+
'horse_train': '',
|
| 12 |
+
'horse_test': '',
|
| 13 |
+
|
| 14 |
+
# Church Dataset (In the paper: LSUN Church)
|
| 15 |
+
'church_train': '',
|
| 16 |
+
'church_test': '',
|
| 17 |
+
|
| 18 |
+
# Cats Dataset (In the paper: LSUN Cat)
|
| 19 |
+
'cats_train': '',
|
| 20 |
+
'cats_test': ''
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
model_paths = {
|
| 24 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
| 25 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
| 26 |
+
'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
|
| 27 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
|
| 28 |
+
}
|
encoder4editing/configs/transforms_config.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TransformsConfig(object):
|
| 6 |
+
|
| 7 |
+
def __init__(self, opts):
|
| 8 |
+
self.opts = opts
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def get_transforms(self):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EncodeTransforms(TransformsConfig):
|
| 16 |
+
|
| 17 |
+
def __init__(self, opts):
|
| 18 |
+
super(EncodeTransforms, self).__init__(opts)
|
| 19 |
+
|
| 20 |
+
def get_transforms(self):
|
| 21 |
+
transforms_dict = {
|
| 22 |
+
'transform_gt_train': transforms.Compose([
|
| 23 |
+
transforms.Resize((256, 256)),
|
| 24 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 25 |
+
transforms.ToTensor(),
|
| 26 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
| 27 |
+
'transform_source': None,
|
| 28 |
+
'transform_test': transforms.Compose([
|
| 29 |
+
transforms.Resize((256, 256)),
|
| 30 |
+
transforms.ToTensor(),
|
| 31 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
| 32 |
+
'transform_inference': transforms.Compose([
|
| 33 |
+
transforms.Resize((256, 256)),
|
| 34 |
+
transforms.ToTensor(),
|
| 35 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
| 36 |
+
}
|
| 37 |
+
return transforms_dict
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CarsEncodeTransforms(TransformsConfig):
|
| 41 |
+
|
| 42 |
+
def __init__(self, opts):
|
| 43 |
+
super(CarsEncodeTransforms, self).__init__(opts)
|
| 44 |
+
|
| 45 |
+
def get_transforms(self):
|
| 46 |
+
transforms_dict = {
|
| 47 |
+
'transform_gt_train': transforms.Compose([
|
| 48 |
+
transforms.Resize((192, 256)),
|
| 49 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 50 |
+
transforms.ToTensor(),
|
| 51 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
| 52 |
+
'transform_source': None,
|
| 53 |
+
'transform_test': transforms.Compose([
|
| 54 |
+
transforms.Resize((192, 256)),
|
| 55 |
+
transforms.ToTensor(),
|
| 56 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
| 57 |
+
'transform_inference': transforms.Compose([
|
| 58 |
+
transforms.Resize((192, 256)),
|
| 59 |
+
transforms.ToTensor(),
|
| 60 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
| 61 |
+
}
|
| 62 |
+
return transforms_dict
|
encoder4editing/criteria/__init__.py
ADDED
|
File without changes
|
encoder4editing/criteria/id_loss.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from configs.paths_config import model_paths
|
| 4 |
+
from models.encoders.model_irse import Backbone
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class IDLoss(nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super(IDLoss, self).__init__()
|
| 10 |
+
print('Loading ResNet ArcFace')
|
| 11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
| 12 |
+
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
|
| 13 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 14 |
+
self.facenet.eval()
|
| 15 |
+
for module in [self.facenet, self.face_pool]:
|
| 16 |
+
for param in module.parameters():
|
| 17 |
+
param.requires_grad = False
|
| 18 |
+
|
| 19 |
+
def extract_feats(self, x):
|
| 20 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
| 21 |
+
x = self.face_pool(x)
|
| 22 |
+
x_feats = self.facenet(x)
|
| 23 |
+
return x_feats
|
| 24 |
+
|
| 25 |
+
def forward(self, y_hat, y, x):
|
| 26 |
+
n_samples = x.shape[0]
|
| 27 |
+
x_feats = self.extract_feats(x)
|
| 28 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
| 29 |
+
y_hat_feats = self.extract_feats(y_hat)
|
| 30 |
+
y_feats = y_feats.detach()
|
| 31 |
+
loss = 0
|
| 32 |
+
sim_improvement = 0
|
| 33 |
+
id_logs = []
|
| 34 |
+
count = 0
|
| 35 |
+
for i in range(n_samples):
|
| 36 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
| 37 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
| 38 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
| 39 |
+
id_logs.append({'diff_target': float(diff_target),
|
| 40 |
+
'diff_input': float(diff_input),
|
| 41 |
+
'diff_views': float(diff_views)})
|
| 42 |
+
loss += 1 - diff_target
|
| 43 |
+
id_diff = float(diff_target) - float(diff_views)
|
| 44 |
+
sim_improvement += id_diff
|
| 45 |
+
count += 1
|
| 46 |
+
|
| 47 |
+
return loss / count, sim_improvement / count, id_logs
|
encoder4editing/criteria/lpips/__init__.py
ADDED
|
File without changes
|
encoder4editing/criteria/lpips/lpips.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
| 5 |
+
from criteria.lpips.utils import get_state_dict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LPIPS(nn.Module):
|
| 9 |
+
r"""Creates a criterion that measures
|
| 10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
| 11 |
+
Arguments:
|
| 12 |
+
net_type (str): the network type to compare the features:
|
| 13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
| 14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
| 17 |
+
|
| 18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
| 19 |
+
|
| 20 |
+
super(LPIPS, self).__init__()
|
| 21 |
+
|
| 22 |
+
# pretrained network
|
| 23 |
+
self.net = get_network(net_type).to("cuda")
|
| 24 |
+
|
| 25 |
+
# linear layers
|
| 26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
| 27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
| 31 |
+
|
| 32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
| 33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
| 34 |
+
|
| 35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
encoder4editing/criteria/lpips/networks.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence
|
| 2 |
+
|
| 3 |
+
from itertools import chain
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision import models
|
| 8 |
+
|
| 9 |
+
from criteria.lpips.utils import normalize_activation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_network(net_type: str):
|
| 13 |
+
if net_type == 'alex':
|
| 14 |
+
return AlexNet()
|
| 15 |
+
elif net_type == 'squeeze':
|
| 16 |
+
return SqueezeNet()
|
| 17 |
+
elif net_type == 'vgg':
|
| 18 |
+
return VGG16()
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LinLayers(nn.ModuleList):
|
| 24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
| 25 |
+
super(LinLayers, self).__init__([
|
| 26 |
+
nn.Sequential(
|
| 27 |
+
nn.Identity(),
|
| 28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
| 29 |
+
) for nc in n_channels_list
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
for param in self.parameters():
|
| 33 |
+
param.requires_grad = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseNet(nn.Module):
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super(BaseNet, self).__init__()
|
| 39 |
+
|
| 40 |
+
# register buffer
|
| 41 |
+
self.register_buffer(
|
| 42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
| 43 |
+
self.register_buffer(
|
| 44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
| 45 |
+
|
| 46 |
+
def set_requires_grad(self, state: bool):
|
| 47 |
+
for param in chain(self.parameters(), self.buffers()):
|
| 48 |
+
param.requires_grad = state
|
| 49 |
+
|
| 50 |
+
def z_score(self, x: torch.Tensor):
|
| 51 |
+
return (x - self.mean) / self.std
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor):
|
| 54 |
+
x = self.z_score(x)
|
| 55 |
+
|
| 56 |
+
output = []
|
| 57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
| 58 |
+
x = layer(x)
|
| 59 |
+
if i in self.target_layers:
|
| 60 |
+
output.append(normalize_activation(x))
|
| 61 |
+
if len(output) == len(self.target_layers):
|
| 62 |
+
break
|
| 63 |
+
return output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SqueezeNet(BaseNet):
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super(SqueezeNet, self).__init__()
|
| 69 |
+
|
| 70 |
+
self.layers = models.squeezenet1_1(True).features
|
| 71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
| 72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
| 73 |
+
|
| 74 |
+
self.set_requires_grad(False)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AlexNet(BaseNet):
|
| 78 |
+
def __init__(self):
|
| 79 |
+
super(AlexNet, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.layers = models.alexnet(True).features
|
| 82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
| 83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
| 84 |
+
|
| 85 |
+
self.set_requires_grad(False)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class VGG16(BaseNet):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super(VGG16, self).__init__()
|
| 91 |
+
|
| 92 |
+
self.layers = models.vgg16(True).features
|
| 93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
| 94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
| 95 |
+
|
| 96 |
+
self.set_requires_grad(False)
|
encoder4editing/criteria/lpips/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def normalize_activation(x, eps=1e-10):
|
| 7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
| 8 |
+
return x / (norm_factor + eps)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
| 12 |
+
# build url
|
| 13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
| 14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
| 15 |
+
|
| 16 |
+
# download
|
| 17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
| 18 |
+
url, progress=True,
|
| 19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# rename keys
|
| 23 |
+
new_state_dict = OrderedDict()
|
| 24 |
+
for key, val in old_state_dict.items():
|
| 25 |
+
new_key = key
|
| 26 |
+
new_key = new_key.replace('lin', '')
|
| 27 |
+
new_key = new_key.replace('model.', '')
|
| 28 |
+
new_state_dict[new_key] = val
|
| 29 |
+
|
| 30 |
+
return new_state_dict
|
encoder4editing/criteria/moco_loss.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from configs.paths_config import model_paths
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MocoLoss(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, opts):
|
| 11 |
+
super(MocoLoss, self).__init__()
|
| 12 |
+
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
| 13 |
+
self.model = self.__load_model()
|
| 14 |
+
self.model.eval()
|
| 15 |
+
for param in self.model.parameters():
|
| 16 |
+
param.requires_grad = False
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def __load_model():
|
| 20 |
+
import torchvision.models as models
|
| 21 |
+
model = models.__dict__["resnet50"]()
|
| 22 |
+
# freeze all layers but the last fc
|
| 23 |
+
for name, param in model.named_parameters():
|
| 24 |
+
if name not in ['fc.weight', 'fc.bias']:
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
| 27 |
+
state_dict = checkpoint['state_dict']
|
| 28 |
+
# rename moco pre-trained keys
|
| 29 |
+
for k in list(state_dict.keys()):
|
| 30 |
+
# retain only encoder_q up to before the embedding layer
|
| 31 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
| 32 |
+
# remove prefix
|
| 33 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
| 34 |
+
# delete renamed or unused k
|
| 35 |
+
del state_dict[k]
|
| 36 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 37 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
| 38 |
+
# remove output layer
|
| 39 |
+
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
def extract_feats(self, x):
|
| 43 |
+
x = F.interpolate(x, size=224)
|
| 44 |
+
x_feats = self.model(x)
|
| 45 |
+
x_feats = nn.functional.normalize(x_feats, dim=1)
|
| 46 |
+
x_feats = x_feats.squeeze()
|
| 47 |
+
return x_feats
|
| 48 |
+
|
| 49 |
+
def forward(self, y_hat, y, x):
|
| 50 |
+
n_samples = x.shape[0]
|
| 51 |
+
x_feats = self.extract_feats(x)
|
| 52 |
+
y_feats = self.extract_feats(y)
|
| 53 |
+
y_hat_feats = self.extract_feats(y_hat)
|
| 54 |
+
y_feats = y_feats.detach()
|
| 55 |
+
loss = 0
|
| 56 |
+
sim_improvement = 0
|
| 57 |
+
sim_logs = []
|
| 58 |
+
count = 0
|
| 59 |
+
for i in range(n_samples):
|
| 60 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
| 61 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
| 62 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
| 63 |
+
sim_logs.append({'diff_target': float(diff_target),
|
| 64 |
+
'diff_input': float(diff_input),
|
| 65 |
+
'diff_views': float(diff_views)})
|
| 66 |
+
loss += 1 - diff_target
|
| 67 |
+
sim_diff = float(diff_target) - float(diff_views)
|
| 68 |
+
sim_improvement += sim_diff
|
| 69 |
+
count += 1
|
| 70 |
+
|
| 71 |
+
return loss / count, sim_improvement / count, sim_logs
|
encoder4editing/criteria/w_norm.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class WNormLoss(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, start_from_latent_avg=True):
|
| 8 |
+
super(WNormLoss, self).__init__()
|
| 9 |
+
self.start_from_latent_avg = start_from_latent_avg
|
| 10 |
+
|
| 11 |
+
def forward(self, latent, latent_avg=None):
|
| 12 |
+
if self.start_from_latent_avg:
|
| 13 |
+
latent = latent - latent_avg
|
| 14 |
+
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
encoder4editing/datasets/__init__.py
ADDED
|
File without changes
|
encoder4editing/datasets/gt_res_dataset.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
import os
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
class GTResDataset(Dataset):
|
| 9 |
+
|
| 10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
| 11 |
+
self.pairs = []
|
| 12 |
+
for f in os.listdir(root_path):
|
| 13 |
+
image_path = os.path.join(root_path, f)
|
| 14 |
+
gt_path = os.path.join(gt_dir, f)
|
| 15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
| 16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
| 17 |
+
self.transform = transform
|
| 18 |
+
self.transform_train = transform_train
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.pairs)
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, index):
|
| 24 |
+
from_path, to_path, _ = self.pairs[index]
|
| 25 |
+
from_im = Image.open(from_path).convert('RGB')
|
| 26 |
+
to_im = Image.open(to_path).convert('RGB')
|
| 27 |
+
|
| 28 |
+
if self.transform:
|
| 29 |
+
to_im = self.transform(to_im)
|
| 30 |
+
from_im = self.transform(from_im)
|
| 31 |
+
|
| 32 |
+
return from_im, to_im
|
encoder4editing/datasets/images_dataset.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from utils import data_utils
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ImagesDataset(Dataset):
|
| 7 |
+
|
| 8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
| 9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
| 10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
| 11 |
+
self.source_transform = source_transform
|
| 12 |
+
self.target_transform = target_transform
|
| 13 |
+
self.opts = opts
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.source_paths)
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, index):
|
| 19 |
+
from_path = self.source_paths[index]
|
| 20 |
+
from_im = Image.open(from_path)
|
| 21 |
+
from_im = from_im.convert('RGB')
|
| 22 |
+
|
| 23 |
+
to_path = self.target_paths[index]
|
| 24 |
+
to_im = Image.open(to_path).convert('RGB')
|
| 25 |
+
if self.target_transform:
|
| 26 |
+
to_im = self.target_transform(to_im)
|
| 27 |
+
|
| 28 |
+
if self.source_transform:
|
| 29 |
+
from_im = self.source_transform(from_im)
|
| 30 |
+
else:
|
| 31 |
+
from_im = to_im
|
| 32 |
+
|
| 33 |
+
return from_im, to_im
|
encoder4editing/datasets/inference_dataset.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from utils import data_utils
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class InferenceDataset(Dataset):
|
| 7 |
+
|
| 8 |
+
def __init__(self, root, opts, transform=None, preprocess=None):
|
| 9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
| 10 |
+
self.transform = transform
|
| 11 |
+
self.preprocess = preprocess
|
| 12 |
+
self.opts = opts
|
| 13 |
+
|
| 14 |
+
def __len__(self):
|
| 15 |
+
return len(self.paths)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, index):
|
| 18 |
+
from_path = self.paths[index]
|
| 19 |
+
if self.preprocess is not None:
|
| 20 |
+
from_im = self.preprocess(from_path)
|
| 21 |
+
else:
|
| 22 |
+
from_im = Image.open(from_path).convert('RGB')
|
| 23 |
+
if self.transform:
|
| 24 |
+
from_im = self.transform(from_im)
|
| 25 |
+
return from_im
|
encoder4editing/editings/ganspace.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def edit(latents, pca, edit_directions):
|
| 5 |
+
edit_latents = []
|
| 6 |
+
for latent in latents:
|
| 7 |
+
for pca_idx, start, end, strength in edit_directions:
|
| 8 |
+
delta = get_delta(pca, latent, pca_idx, strength)
|
| 9 |
+
delta_padded = torch.zeros(latent.shape).to('cuda')
|
| 10 |
+
delta_padded[start:end] += delta.repeat(end - start, 1)
|
| 11 |
+
edit_latents.append(latent + delta_padded)
|
| 12 |
+
return torch.stack(edit_latents)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_delta(pca, latent, idx, strength):
|
| 16 |
+
# pca: ganspace checkpoint. latent: (16, 512) w+
|
| 17 |
+
w_centered = latent - pca['mean'].to('cuda')
|
| 18 |
+
lat_comp = pca['comp'].to('cuda')
|
| 19 |
+
lat_std = pca['std'].to('cuda')
|
| 20 |
+
w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
|
| 21 |
+
delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
|
| 22 |
+
return delta
|
encoder4editing/editings/ganspace_pca/cars_pca.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
|
| 3 |
+
size 167562
|
encoder4editing/editings/ganspace_pca/ffhq_pca.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
|
| 3 |
+
size 167562
|
encoder4editing/editings/interfacegan_directions/age.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
|
| 3 |
+
size 2808
|
encoder4editing/editings/interfacegan_directions/pose.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
|
| 3 |
+
size 37624
|
encoder4editing/editings/interfacegan_directions/smile.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
|
| 3 |
+
size 2808
|
encoder4editing/editings/latent_editor.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(".")
|
| 4 |
+
sys.path.append("..")
|
| 5 |
+
from editings import ganspace, sefa
|
| 6 |
+
from utils.common import tensor2im
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LatentEditor(object):
|
| 10 |
+
def __init__(self, stylegan_generator, is_cars=False):
|
| 11 |
+
self.generator = stylegan_generator
|
| 12 |
+
self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
|
| 13 |
+
|
| 14 |
+
def apply_ganspace(self, latent, ganspace_pca, edit_directions):
|
| 15 |
+
edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
|
| 16 |
+
return self._latents_to_image(edit_latents)
|
| 17 |
+
|
| 18 |
+
def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
|
| 19 |
+
edit_latents = []
|
| 20 |
+
if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
|
| 21 |
+
for f in range(*factor_range):
|
| 22 |
+
edit_latent = latent + f * direction
|
| 23 |
+
edit_latents.append(edit_latent)
|
| 24 |
+
edit_latents = torch.cat(edit_latents)
|
| 25 |
+
else:
|
| 26 |
+
edit_latents = latent + factor * direction
|
| 27 |
+
return self._latents_to_image(edit_latents)
|
| 28 |
+
|
| 29 |
+
def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
|
| 30 |
+
edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
|
| 31 |
+
return self._latents_to_image(edit_latents)
|
| 32 |
+
|
| 33 |
+
# Currently, in order to apply StyleFlow editings, one should run inference,
|
| 34 |
+
# save the latent codes and load them form the official StyleFlow repository.
|
| 35 |
+
# def apply_styleflow(self):
|
| 36 |
+
# pass
|
| 37 |
+
|
| 38 |
+
def _latents_to_image(self, latents):
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
|
| 41 |
+
if self.is_cars:
|
| 42 |
+
images = images[:, :, 64:448, :] # 512x512 -> 384x512
|
| 43 |
+
horizontal_concat_image = torch.cat(list(images), 2)
|
| 44 |
+
final_image = tensor2im(horizontal_concat_image)
|
| 45 |
+
return final_image
|
encoder4editing/editings/sefa.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
|
| 7 |
+
|
| 8 |
+
layers, boundaries, values = factorize_weight(generator, indices)
|
| 9 |
+
codes = latents.detach().cpu().numpy() # (1,18,512)
|
| 10 |
+
|
| 11 |
+
# Generate visualization pages.
|
| 12 |
+
distances = np.linspace(start_distance, end_distance, step)
|
| 13 |
+
num_sam = num_samples
|
| 14 |
+
num_sem = semantics
|
| 15 |
+
|
| 16 |
+
edited_latents = []
|
| 17 |
+
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
|
| 18 |
+
boundary = boundaries[sem_id:sem_id + 1]
|
| 19 |
+
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
|
| 20 |
+
code = codes[sam_id:sam_id + 1]
|
| 21 |
+
for col_id, d in enumerate(distances, start=1):
|
| 22 |
+
temp_code = code.copy()
|
| 23 |
+
temp_code[:, layers, :] += boundary * d
|
| 24 |
+
edited_latents.append(torch.from_numpy(temp_code).float().cuda())
|
| 25 |
+
return torch.cat(edited_latents)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def factorize_weight(g_ema, layers='all'):
|
| 29 |
+
|
| 30 |
+
weights = []
|
| 31 |
+
if layers == 'all' or 0 in layers:
|
| 32 |
+
weight = g_ema.conv1.conv.modulation.weight.T
|
| 33 |
+
weights.append(weight.cpu().detach().numpy())
|
| 34 |
+
|
| 35 |
+
if layers == 'all':
|
| 36 |
+
layers = list(range(g_ema.num_layers - 1))
|
| 37 |
+
else:
|
| 38 |
+
layers = [l - 1 for l in layers if l != 0]
|
| 39 |
+
|
| 40 |
+
for idx in layers:
|
| 41 |
+
weight = g_ema.convs[idx].conv.modulation.weight.T
|
| 42 |
+
weights.append(weight.cpu().detach().numpy())
|
| 43 |
+
weight = np.concatenate(weights, axis=1).astype(np.float32)
|
| 44 |
+
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
|
| 45 |
+
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
|
| 46 |
+
return layers, eigen_vectors.T, eigen_values
|
encoder4editing/environment/e4e_env.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: e4e_env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- ca-certificates=2020.4.5.1=hecc5488_0
|
| 8 |
+
- certifi=2020.4.5.1=py36h9f0ad1d_0
|
| 9 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 10 |
+
- libffi=3.2.1=hd88cf55_4
|
| 11 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 12 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 13 |
+
- ncurses=6.2=he6710b0_1
|
| 14 |
+
- ninja=1.10.0=hc9558a2_0
|
| 15 |
+
- openssl=1.1.1g=h516909a_0
|
| 16 |
+
- pip=20.0.2=py36_3
|
| 17 |
+
- python=3.6.7=h0371630_0
|
| 18 |
+
- python_abi=3.6=1_cp36m
|
| 19 |
+
- readline=7.0=h7b6447c_5
|
| 20 |
+
- setuptools=46.4.0=py36_0
|
| 21 |
+
- sqlite=3.31.1=h62c20be_1
|
| 22 |
+
- tk=8.6.8=hbc83047_0
|
| 23 |
+
- wheel=0.34.2=py36_0
|
| 24 |
+
- xz=5.2.5=h7b6447c_0
|
| 25 |
+
- zlib=1.2.11=h7b6447c_3
|
| 26 |
+
- pip:
|
| 27 |
+
- absl-py==0.9.0
|
| 28 |
+
- cachetools==4.1.0
|
| 29 |
+
- chardet==3.0.4
|
| 30 |
+
- cycler==0.10.0
|
| 31 |
+
- decorator==4.4.2
|
| 32 |
+
- future==0.18.2
|
| 33 |
+
- google-auth==1.15.0
|
| 34 |
+
- google-auth-oauthlib==0.4.1
|
| 35 |
+
- grpcio==1.29.0
|
| 36 |
+
- idna==2.9
|
| 37 |
+
- imageio==2.8.0
|
| 38 |
+
- importlib-metadata==1.6.0
|
| 39 |
+
- kiwisolver==1.2.0
|
| 40 |
+
- markdown==3.2.2
|
| 41 |
+
- matplotlib==3.2.1
|
| 42 |
+
- mxnet==1.6.0
|
| 43 |
+
- networkx==2.4
|
| 44 |
+
- numpy==1.18.4
|
| 45 |
+
- oauthlib==3.1.0
|
| 46 |
+
- opencv-python==4.2.0.34
|
| 47 |
+
- pillow==7.1.2
|
| 48 |
+
- protobuf==3.12.1
|
| 49 |
+
- pyasn1==0.4.8
|
| 50 |
+
- pyasn1-modules==0.2.8
|
| 51 |
+
- pyparsing==2.4.7
|
| 52 |
+
- python-dateutil==2.8.1
|
| 53 |
+
- pytorch-lightning==0.7.1
|
| 54 |
+
- pywavelets==1.1.1
|
| 55 |
+
- requests==2.23.0
|
| 56 |
+
- requests-oauthlib==1.3.0
|
| 57 |
+
- rsa==4.0
|
| 58 |
+
- scikit-image==0.17.2
|
| 59 |
+
- scipy==1.4.1
|
| 60 |
+
- six==1.15.0
|
| 61 |
+
- tensorboard==2.2.1
|
| 62 |
+
- tensorboard-plugin-wit==1.6.0.post3
|
| 63 |
+
- tensorboardx==1.9
|
| 64 |
+
- tifffile==2020.5.25
|
| 65 |
+
- torch==1.6.0
|
| 66 |
+
- torchvision==0.7.1
|
| 67 |
+
- tqdm==4.46.0
|
| 68 |
+
- urllib3==1.25.9
|
| 69 |
+
- werkzeug==1.0.1
|
| 70 |
+
- zipp==3.1.0
|
| 71 |
+
- pyaml
|
| 72 |
+
prefix: ~/anaconda3/envs/e4e_env
|
| 73 |
+
|
encoder4editing/infer.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
|
| 12 |
+
sys.path.append(".")
|
| 13 |
+
sys.path.append("..")
|
| 14 |
+
|
| 15 |
+
from utils.common import tensor2im
|
| 16 |
+
from models.psp import pSp # we use the pSp framework to load the e4e encoder.
|
| 17 |
+
experiment_type = 'ffhq_encode'
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument('--input_image', type=str, default="", help='input image path')
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
opts = vars(args)
|
| 23 |
+
print(opts)
|
| 24 |
+
image_path = opts["input_image"]
|
| 25 |
+
|
| 26 |
+
def get_download_model_command(file_id, file_name):
|
| 27 |
+
""" Get wget download command for downloading the desired model and save to directory pretrained_models. """
|
| 28 |
+
current_directory = os.getcwd()
|
| 29 |
+
save_path = "encoder4editing/saves"
|
| 30 |
+
if not os.path.exists(save_path):
|
| 31 |
+
os.makedirs(save_path)
|
| 32 |
+
url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
|
| 33 |
+
return url
|
| 34 |
+
|
| 35 |
+
MODEL_PATHS = {
|
| 36 |
+
"ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"},
|
| 37 |
+
"cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"},
|
| 38 |
+
"horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"},
|
| 39 |
+
"church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
path = MODEL_PATHS[experiment_type]
|
| 43 |
+
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
|
| 44 |
+
|
| 45 |
+
EXPERIMENT_DATA_ARGS = {
|
| 46 |
+
"ffhq_encode": {
|
| 47 |
+
"model_path": "encoder4editing/e4e_ffhq_encode.pt",
|
| 48 |
+
"image_path": "notebooks/images/input_img.jpg"
|
| 49 |
+
},
|
| 50 |
+
"cars_encode": {
|
| 51 |
+
"model_path": "pretrained_models/e4e_cars_encode.pt",
|
| 52 |
+
"image_path": "notebooks/images/car_img.jpg"
|
| 53 |
+
},
|
| 54 |
+
"horse_encode": {
|
| 55 |
+
"model_path": "pretrained_models/e4e_horse_encode.pt",
|
| 56 |
+
"image_path": "notebooks/images/horse_img.jpg"
|
| 57 |
+
},
|
| 58 |
+
"church_encode": {
|
| 59 |
+
"model_path": "pretrained_models/e4e_church_encode.pt",
|
| 60 |
+
"image_path": "notebooks/images/church_img.jpg"
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
}
|
| 64 |
+
# Setup required image transformations
|
| 65 |
+
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
|
| 66 |
+
if experiment_type == 'cars_encode':
|
| 67 |
+
EXPERIMENT_ARGS['transform'] = transforms.Compose([
|
| 68 |
+
transforms.Resize((192, 256)),
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
| 71 |
+
resize_dims = (256, 192)
|
| 72 |
+
else:
|
| 73 |
+
EXPERIMENT_ARGS['transform'] = transforms.Compose([
|
| 74 |
+
transforms.Resize((256, 256)),
|
| 75 |
+
transforms.ToTensor(),
|
| 76 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
| 77 |
+
resize_dims = (256, 256)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
model_path = EXPERIMENT_ARGS['model_path']
|
| 81 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 82 |
+
opts = ckpt['opts']
|
| 83 |
+
|
| 84 |
+
# update the training options
|
| 85 |
+
opts['checkpoint_path'] = model_path
|
| 86 |
+
opts= Namespace(**opts)
|
| 87 |
+
net = pSp(opts)
|
| 88 |
+
net.eval()
|
| 89 |
+
net.cuda()
|
| 90 |
+
print('Model successfully loaded!')
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
original_image = Image.open(image_path)
|
| 94 |
+
original_image = original_image.convert("RGB")
|
| 95 |
+
|
| 96 |
+
def run_alignment(image_path):
|
| 97 |
+
import dlib
|
| 98 |
+
from utils.alignment import align_face
|
| 99 |
+
predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat")
|
| 100 |
+
aligned_image = align_face(filepath=image_path, predictor=predictor)
|
| 101 |
+
print("Aligned image has shape: {}".format(aligned_image.size))
|
| 102 |
+
return aligned_image
|
| 103 |
+
|
| 104 |
+
if experiment_type == "ffhq_encode":
|
| 105 |
+
input_image = run_alignment(image_path)
|
| 106 |
+
else:
|
| 107 |
+
input_image = original_image
|
| 108 |
+
|
| 109 |
+
input_image.resize(resize_dims)
|
| 110 |
+
|
| 111 |
+
img_transforms = EXPERIMENT_ARGS['transform']
|
| 112 |
+
transformed_image = img_transforms(input_image)
|
| 113 |
+
|
| 114 |
+
def display_alongside_source_image(result_image, source_image):
|
| 115 |
+
res = np.concatenate([np.array(source_image.resize(resize_dims)),
|
| 116 |
+
np.array(result_image.resize(resize_dims))], axis=1)
|
| 117 |
+
return Image.fromarray(res)
|
| 118 |
+
|
| 119 |
+
def run_on_batch(inputs, net):
|
| 120 |
+
images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True)
|
| 121 |
+
if experiment_type == 'cars_encode':
|
| 122 |
+
images = images[:, :, 32:224, :]
|
| 123 |
+
return images, latents
|
| 124 |
+
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
tic = time.time()
|
| 127 |
+
images, latents = run_on_batch(transformed_image.unsqueeze(0), net)
|
| 128 |
+
result_image, latent = images[0], latents[0]
|
| 129 |
+
toc = time.time()
|
| 130 |
+
print('Inference took {:.4f} seconds.'.format(toc - tic))
|
| 131 |
+
|
| 132 |
+
# Display inversion:
|
| 133 |
+
display_alongside_source_image(tensor2im(result_image), input_image)
|
| 134 |
+
np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy())
|
encoder4editing/metrics/LEC.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
sys.path.append(".")
|
| 8 |
+
sys.path.append("..")
|
| 9 |
+
|
| 10 |
+
from configs import data_configs
|
| 11 |
+
from datasets.images_dataset import ImagesDataset
|
| 12 |
+
from utils.model_utils import setup_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LEC:
|
| 16 |
+
def __init__(self, net, is_cars=False):
|
| 17 |
+
"""
|
| 18 |
+
Latent Editing Consistency metric as proposed in the main paper.
|
| 19 |
+
:param net: e4e model loaded over the pSp framework.
|
| 20 |
+
:param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
|
| 21 |
+
"""
|
| 22 |
+
self.net = net
|
| 23 |
+
self.is_cars = is_cars
|
| 24 |
+
|
| 25 |
+
def _encode(self, images):
|
| 26 |
+
"""
|
| 27 |
+
Encodes the given images into StyleGAN's latent space.
|
| 28 |
+
:param images: Tensor of shape NxCxHxW representing the images to be encoded.
|
| 29 |
+
:return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
|
| 30 |
+
"""
|
| 31 |
+
codes = self.net.encoder(images)
|
| 32 |
+
assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
|
| 33 |
+
# normalize with respect to the center of an average face
|
| 34 |
+
if self.net.opts.start_from_latent_avg:
|
| 35 |
+
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
| 36 |
+
return codes
|
| 37 |
+
|
| 38 |
+
def _generate(self, codes):
|
| 39 |
+
"""
|
| 40 |
+
Generate the StyleGAN2 images of the given codes
|
| 41 |
+
:param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
|
| 42 |
+
:return: Tensor of shape NxCxHxW representing the generated images.
|
| 43 |
+
"""
|
| 44 |
+
images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
| 45 |
+
images = self.net.face_pool(images)
|
| 46 |
+
if self.is_cars:
|
| 47 |
+
images = images[:, :, 32:224, :]
|
| 48 |
+
return images
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def _filter_outliers(arr):
|
| 52 |
+
arr = np.array(arr)
|
| 53 |
+
|
| 54 |
+
lo = np.percentile(arr, 1, interpolation="lower")
|
| 55 |
+
hi = np.percentile(arr, 99, interpolation="higher")
|
| 56 |
+
return np.extract(
|
| 57 |
+
np.logical_and(lo <= arr, arr <= hi), arr
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
|
| 61 |
+
"""
|
| 62 |
+
Calculate the LEC metric score.
|
| 63 |
+
:param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
|
| 64 |
+
:param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
|
| 65 |
+
latent space.
|
| 66 |
+
:param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
|
| 67 |
+
`edit_function` parameter.
|
| 68 |
+
:return: The LEC metric score.
|
| 69 |
+
"""
|
| 70 |
+
distances = []
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
for batch in data_loader:
|
| 73 |
+
x, _ = batch
|
| 74 |
+
inputs = x.to(device).float()
|
| 75 |
+
|
| 76 |
+
codes = self._encode(inputs)
|
| 77 |
+
edited_codes = edit_function(codes)
|
| 78 |
+
edited_image = self._generate(edited_codes)
|
| 79 |
+
edited_image_inversion_codes = self._encode(edited_image)
|
| 80 |
+
inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
|
| 81 |
+
|
| 82 |
+
dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
|
| 83 |
+
distances.append(dist.to("cpu").numpy())
|
| 84 |
+
|
| 85 |
+
distances = self._filter_outliers(distances)
|
| 86 |
+
return distances.mean()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
device = "cuda"
|
| 91 |
+
|
| 92 |
+
parser = argparse.ArgumentParser(description="LEC metric calculator")
|
| 93 |
+
|
| 94 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
|
| 95 |
+
parser.add_argument("--images_dir", type=str, default=None,
|
| 96 |
+
help="Path to the images directory on which we calculate the LEC score")
|
| 97 |
+
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
|
| 98 |
+
|
| 99 |
+
args = parser.parse_args()
|
| 100 |
+
print(args)
|
| 101 |
+
|
| 102 |
+
net, opts = setup_model(args.ckpt, device)
|
| 103 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
| 104 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
| 105 |
+
|
| 106 |
+
images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
|
| 107 |
+
test_dataset = ImagesDataset(source_root=images_directory,
|
| 108 |
+
target_root=images_directory,
|
| 109 |
+
source_transform=transforms_dict['transform_source'],
|
| 110 |
+
target_transform=transforms_dict['transform_test'],
|
| 111 |
+
opts=opts)
|
| 112 |
+
|
| 113 |
+
data_loader = DataLoader(test_dataset,
|
| 114 |
+
batch_size=args.batch,
|
| 115 |
+
shuffle=False,
|
| 116 |
+
num_workers=2,
|
| 117 |
+
drop_last=True)
|
| 118 |
+
|
| 119 |
+
print(f'dataset length: {len(test_dataset)}')
|
| 120 |
+
|
| 121 |
+
# In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
|
| 122 |
+
# Change the provided example according to your domain and needs.
|
| 123 |
+
direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
|
| 124 |
+
|
| 125 |
+
def edit_func_example(codes):
|
| 126 |
+
return codes + 3 * direction
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def inverse_edit_func_example(codes):
|
| 130 |
+
return codes - 3 * direction
|
| 131 |
+
|
| 132 |
+
lec = LEC(net, is_cars='car' in opts.dataset_type)
|
| 133 |
+
result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
|
| 134 |
+
print(f"LEC: {result}")
|
encoder4editing/models/__init__.py
ADDED
|
File without changes
|
encoder4editing/models/discriminator.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LatentCodesDiscriminator(nn.Module):
|
| 5 |
+
def __init__(self, style_dim, n_mlp):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
self.style_dim = style_dim
|
| 9 |
+
|
| 10 |
+
layers = []
|
| 11 |
+
for i in range(n_mlp-1):
|
| 12 |
+
layers.append(
|
| 13 |
+
nn.Linear(style_dim, style_dim)
|
| 14 |
+
)
|
| 15 |
+
layers.append(nn.LeakyReLU(0.2))
|
| 16 |
+
layers.append(nn.Linear(512, 1))
|
| 17 |
+
self.mlp = nn.Sequential(*layers)
|
| 18 |
+
|
| 19 |
+
def forward(self, w):
|
| 20 |
+
return self.mlp(w)
|
encoder4editing/models/encoders/__init__.py
ADDED
|
File without changes
|
encoder4editing/models/encoders/helpers.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Flatten(Module):
|
| 12 |
+
def forward(self, input):
|
| 13 |
+
return input.view(input.size(0), -1)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def l2_norm(input, axis=1):
|
| 17 |
+
norm = torch.norm(input, 2, axis, True)
|
| 18 |
+
output = torch.div(input, norm)
|
| 19 |
+
return output
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
| 23 |
+
""" A named tuple describing a ResNet block. """
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
| 27 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_blocks(num_layers):
|
| 31 |
+
if num_layers == 50:
|
| 32 |
+
blocks = [
|
| 33 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 34 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
| 35 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
| 36 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 37 |
+
]
|
| 38 |
+
elif num_layers == 100:
|
| 39 |
+
blocks = [
|
| 40 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 41 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
| 42 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
| 43 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 44 |
+
]
|
| 45 |
+
elif num_layers == 152:
|
| 46 |
+
blocks = [
|
| 47 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 48 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
| 49 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
| 50 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 51 |
+
]
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
| 54 |
+
return blocks
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SEModule(Module):
|
| 58 |
+
def __init__(self, channels, reduction):
|
| 59 |
+
super(SEModule, self).__init__()
|
| 60 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
| 61 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
| 62 |
+
self.relu = ReLU(inplace=True)
|
| 63 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
| 64 |
+
self.sigmoid = Sigmoid()
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
module_input = x
|
| 68 |
+
x = self.avg_pool(x)
|
| 69 |
+
x = self.fc1(x)
|
| 70 |
+
x = self.relu(x)
|
| 71 |
+
x = self.fc2(x)
|
| 72 |
+
x = self.sigmoid(x)
|
| 73 |
+
return module_input * x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class bottleneck_IR(Module):
|
| 77 |
+
def __init__(self, in_channel, depth, stride):
|
| 78 |
+
super(bottleneck_IR, self).__init__()
|
| 79 |
+
if in_channel == depth:
|
| 80 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 81 |
+
else:
|
| 82 |
+
self.shortcut_layer = Sequential(
|
| 83 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
| 84 |
+
BatchNorm2d(depth)
|
| 85 |
+
)
|
| 86 |
+
self.res_layer = Sequential(
|
| 87 |
+
BatchNorm2d(in_channel),
|
| 88 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
| 89 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
shortcut = self.shortcut_layer(x)
|
| 94 |
+
res = self.res_layer(x)
|
| 95 |
+
return res + shortcut
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class bottleneck_IR_SE(Module):
|
| 99 |
+
def __init__(self, in_channel, depth, stride):
|
| 100 |
+
super(bottleneck_IR_SE, self).__init__()
|
| 101 |
+
if in_channel == depth:
|
| 102 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 103 |
+
else:
|
| 104 |
+
self.shortcut_layer = Sequential(
|
| 105 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
| 106 |
+
BatchNorm2d(depth)
|
| 107 |
+
)
|
| 108 |
+
self.res_layer = Sequential(
|
| 109 |
+
BatchNorm2d(in_channel),
|
| 110 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
| 111 |
+
PReLU(depth),
|
| 112 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
| 113 |
+
BatchNorm2d(depth),
|
| 114 |
+
SEModule(depth, 16)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
shortcut = self.shortcut_layer(x)
|
| 119 |
+
res = self.res_layer(x)
|
| 120 |
+
return res + shortcut
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _upsample_add(x, y):
|
| 124 |
+
"""Upsample and add two feature maps.
|
| 125 |
+
Args:
|
| 126 |
+
x: (Variable) top feature map to be upsampled.
|
| 127 |
+
y: (Variable) lateral feature map.
|
| 128 |
+
Returns:
|
| 129 |
+
(Variable) added feature map.
|
| 130 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
| 131 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
| 132 |
+
maybe not equal to the lateral feature map size.
|
| 133 |
+
e.g.
|
| 134 |
+
original input size: [N,_,15,15] ->
|
| 135 |
+
conv2d feature map size: [N,_,8,8] ->
|
| 136 |
+
upsampled feature map size: [N,_,16,16]
|
| 137 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
| 138 |
+
"""
|
| 139 |
+
_, _, H, W = y.size()
|
| 140 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
encoder4editing/models/encoders/model_irse.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
| 2 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Backbone(Module):
|
| 10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
| 11 |
+
super(Backbone, self).__init__()
|
| 12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
| 13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
| 14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
| 15 |
+
blocks = get_blocks(num_layers)
|
| 16 |
+
if mode == 'ir':
|
| 17 |
+
unit_module = bottleneck_IR
|
| 18 |
+
elif mode == 'ir_se':
|
| 19 |
+
unit_module = bottleneck_IR_SE
|
| 20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
| 21 |
+
BatchNorm2d(64),
|
| 22 |
+
PReLU(64))
|
| 23 |
+
if input_size == 112:
|
| 24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
| 25 |
+
Dropout(drop_ratio),
|
| 26 |
+
Flatten(),
|
| 27 |
+
Linear(512 * 7 * 7, 512),
|
| 28 |
+
BatchNorm1d(512, affine=affine))
|
| 29 |
+
else:
|
| 30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
| 31 |
+
Dropout(drop_ratio),
|
| 32 |
+
Flatten(),
|
| 33 |
+
Linear(512 * 14 * 14, 512),
|
| 34 |
+
BatchNorm1d(512, affine=affine))
|
| 35 |
+
|
| 36 |
+
modules = []
|
| 37 |
+
for block in blocks:
|
| 38 |
+
for bottleneck in block:
|
| 39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
| 40 |
+
bottleneck.depth,
|
| 41 |
+
bottleneck.stride))
|
| 42 |
+
self.body = Sequential(*modules)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.input_layer(x)
|
| 46 |
+
x = self.body(x)
|
| 47 |
+
x = self.output_layer(x)
|
| 48 |
+
return l2_norm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def IR_50(input_size):
|
| 52 |
+
"""Constructs a ir-50 model."""
|
| 53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def IR_101(input_size):
|
| 58 |
+
"""Constructs a ir-101 model."""
|
| 59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def IR_152(input_size):
|
| 64 |
+
"""Constructs a ir-152 model."""
|
| 65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def IR_SE_50(input_size):
|
| 70 |
+
"""Constructs a ir_se-50 model."""
|
| 71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def IR_SE_101(input_size):
|
| 76 |
+
"""Constructs a ir_se-101 model."""
|
| 77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 78 |
+
return model
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def IR_SE_152(input_size):
|
| 82 |
+
"""Constructs a ir_se-152 model."""
|
| 83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 84 |
+
return model
|
encoder4editing/models/encoders/psp_encoders.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
| 7 |
+
|
| 8 |
+
from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
|
| 9 |
+
from models.stylegan2.model import EqualLinear
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ProgressiveStage(Enum):
|
| 13 |
+
WTraining = 0
|
| 14 |
+
Delta1Training = 1
|
| 15 |
+
Delta2Training = 2
|
| 16 |
+
Delta3Training = 3
|
| 17 |
+
Delta4Training = 4
|
| 18 |
+
Delta5Training = 5
|
| 19 |
+
Delta6Training = 6
|
| 20 |
+
Delta7Training = 7
|
| 21 |
+
Delta8Training = 8
|
| 22 |
+
Delta9Training = 9
|
| 23 |
+
Delta10Training = 10
|
| 24 |
+
Delta11Training = 11
|
| 25 |
+
Delta12Training = 12
|
| 26 |
+
Delta13Training = 13
|
| 27 |
+
Delta14Training = 14
|
| 28 |
+
Delta15Training = 15
|
| 29 |
+
Delta16Training = 16
|
| 30 |
+
Delta17Training = 17
|
| 31 |
+
Inference = 18
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class GradualStyleBlock(Module):
|
| 35 |
+
def __init__(self, in_c, out_c, spatial):
|
| 36 |
+
super(GradualStyleBlock, self).__init__()
|
| 37 |
+
self.out_c = out_c
|
| 38 |
+
self.spatial = spatial
|
| 39 |
+
num_pools = int(np.log2(spatial))
|
| 40 |
+
modules = []
|
| 41 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
| 42 |
+
nn.LeakyReLU()]
|
| 43 |
+
for i in range(num_pools - 1):
|
| 44 |
+
modules += [
|
| 45 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
| 46 |
+
nn.LeakyReLU()
|
| 47 |
+
]
|
| 48 |
+
self.convs = nn.Sequential(*modules)
|
| 49 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = self.convs(x)
|
| 53 |
+
x = x.view(-1, self.out_c)
|
| 54 |
+
x = self.linear(x)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class GradualStyleEncoder(Module):
|
| 59 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
| 60 |
+
super(GradualStyleEncoder, self).__init__()
|
| 61 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
| 62 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
| 63 |
+
blocks = get_blocks(num_layers)
|
| 64 |
+
if mode == 'ir':
|
| 65 |
+
unit_module = bottleneck_IR
|
| 66 |
+
elif mode == 'ir_se':
|
| 67 |
+
unit_module = bottleneck_IR_SE
|
| 68 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
| 69 |
+
BatchNorm2d(64),
|
| 70 |
+
PReLU(64))
|
| 71 |
+
modules = []
|
| 72 |
+
for block in blocks:
|
| 73 |
+
for bottleneck in block:
|
| 74 |
+
modules.append(unit_module(bottleneck.in_channel,
|
| 75 |
+
bottleneck.depth,
|
| 76 |
+
bottleneck.stride))
|
| 77 |
+
self.body = Sequential(*modules)
|
| 78 |
+
|
| 79 |
+
self.styles = nn.ModuleList()
|
| 80 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
| 81 |
+
self.style_count = 2 * log_size - 2
|
| 82 |
+
self.coarse_ind = 3
|
| 83 |
+
self.middle_ind = 7
|
| 84 |
+
for i in range(self.style_count):
|
| 85 |
+
if i < self.coarse_ind:
|
| 86 |
+
style = GradualStyleBlock(512, 512, 16)
|
| 87 |
+
elif i < self.middle_ind:
|
| 88 |
+
style = GradualStyleBlock(512, 512, 32)
|
| 89 |
+
else:
|
| 90 |
+
style = GradualStyleBlock(512, 512, 64)
|
| 91 |
+
self.styles.append(style)
|
| 92 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
| 93 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = self.input_layer(x)
|
| 97 |
+
|
| 98 |
+
latents = []
|
| 99 |
+
modulelist = list(self.body._modules.values())
|
| 100 |
+
for i, l in enumerate(modulelist):
|
| 101 |
+
x = l(x)
|
| 102 |
+
if i == 6:
|
| 103 |
+
c1 = x
|
| 104 |
+
elif i == 20:
|
| 105 |
+
c2 = x
|
| 106 |
+
elif i == 23:
|
| 107 |
+
c3 = x
|
| 108 |
+
|
| 109 |
+
for j in range(self.coarse_ind):
|
| 110 |
+
latents.append(self.styles[j](c3))
|
| 111 |
+
|
| 112 |
+
p2 = _upsample_add(c3, self.latlayer1(c2))
|
| 113 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
| 114 |
+
latents.append(self.styles[j](p2))
|
| 115 |
+
|
| 116 |
+
p1 = _upsample_add(p2, self.latlayer2(c1))
|
| 117 |
+
for j in range(self.middle_ind, self.style_count):
|
| 118 |
+
latents.append(self.styles[j](p1))
|
| 119 |
+
|
| 120 |
+
out = torch.stack(latents, dim=1)
|
| 121 |
+
return out
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Encoder4Editing(Module):
|
| 125 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
| 126 |
+
super(Encoder4Editing, self).__init__()
|
| 127 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
| 128 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
| 129 |
+
blocks = get_blocks(num_layers)
|
| 130 |
+
if mode == 'ir':
|
| 131 |
+
unit_module = bottleneck_IR
|
| 132 |
+
elif mode == 'ir_se':
|
| 133 |
+
unit_module = bottleneck_IR_SE
|
| 134 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
| 135 |
+
BatchNorm2d(64),
|
| 136 |
+
PReLU(64))
|
| 137 |
+
modules = []
|
| 138 |
+
for block in blocks:
|
| 139 |
+
for bottleneck in block:
|
| 140 |
+
modules.append(unit_module(bottleneck.in_channel,
|
| 141 |
+
bottleneck.depth,
|
| 142 |
+
bottleneck.stride))
|
| 143 |
+
self.body = Sequential(*modules)
|
| 144 |
+
|
| 145 |
+
self.styles = nn.ModuleList()
|
| 146 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
| 147 |
+
self.style_count = 2 * log_size - 2
|
| 148 |
+
self.coarse_ind = 3
|
| 149 |
+
self.middle_ind = 7
|
| 150 |
+
|
| 151 |
+
for i in range(self.style_count):
|
| 152 |
+
if i < self.coarse_ind:
|
| 153 |
+
style = GradualStyleBlock(512, 512, 16)
|
| 154 |
+
elif i < self.middle_ind:
|
| 155 |
+
style = GradualStyleBlock(512, 512, 32)
|
| 156 |
+
else:
|
| 157 |
+
style = GradualStyleBlock(512, 512, 64)
|
| 158 |
+
self.styles.append(style)
|
| 159 |
+
|
| 160 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
| 161 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
| 162 |
+
|
| 163 |
+
self.progressive_stage = ProgressiveStage.Inference
|
| 164 |
+
|
| 165 |
+
def get_deltas_starting_dimensions(self):
|
| 166 |
+
''' Get a list of the initial dimension of every delta from which it is applied '''
|
| 167 |
+
return list(range(self.style_count)) # Each dimension has a delta applied to it
|
| 168 |
+
|
| 169 |
+
def set_progressive_stage(self, new_stage: ProgressiveStage):
|
| 170 |
+
self.progressive_stage = new_stage
|
| 171 |
+
print('Changed progressive stage to: ', new_stage)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
x = self.input_layer(x)
|
| 175 |
+
|
| 176 |
+
modulelist = list(self.body._modules.values())
|
| 177 |
+
for i, l in enumerate(modulelist):
|
| 178 |
+
x = l(x)
|
| 179 |
+
if i == 6:
|
| 180 |
+
c1 = x
|
| 181 |
+
elif i == 20:
|
| 182 |
+
c2 = x
|
| 183 |
+
elif i == 23:
|
| 184 |
+
c3 = x
|
| 185 |
+
|
| 186 |
+
# Infer main W and duplicate it
|
| 187 |
+
w0 = self.styles[0](c3)
|
| 188 |
+
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
| 189 |
+
stage = self.progressive_stage.value
|
| 190 |
+
features = c3
|
| 191 |
+
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
|
| 192 |
+
if i == self.coarse_ind:
|
| 193 |
+
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
|
| 194 |
+
features = p2
|
| 195 |
+
elif i == self.middle_ind:
|
| 196 |
+
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
|
| 197 |
+
features = p1
|
| 198 |
+
delta_i = self.styles[i](features)
|
| 199 |
+
w[:, i] += delta_i
|
| 200 |
+
return w
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class BackboneEncoderUsingLastLayerIntoW(Module):
|
| 204 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
| 205 |
+
super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
|
| 206 |
+
print('Using BackboneEncoderUsingLastLayerIntoW')
|
| 207 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
| 208 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
| 209 |
+
blocks = get_blocks(num_layers)
|
| 210 |
+
if mode == 'ir':
|
| 211 |
+
unit_module = bottleneck_IR
|
| 212 |
+
elif mode == 'ir_se':
|
| 213 |
+
unit_module = bottleneck_IR_SE
|
| 214 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
| 215 |
+
BatchNorm2d(64),
|
| 216 |
+
PReLU(64))
|
| 217 |
+
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
| 218 |
+
self.linear = EqualLinear(512, 512, lr_mul=1)
|
| 219 |
+
modules = []
|
| 220 |
+
for block in blocks:
|
| 221 |
+
for bottleneck in block:
|
| 222 |
+
modules.append(unit_module(bottleneck.in_channel,
|
| 223 |
+
bottleneck.depth,
|
| 224 |
+
bottleneck.stride))
|
| 225 |
+
self.body = Sequential(*modules)
|
| 226 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
| 227 |
+
self.style_count = 2 * log_size - 2
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
x = self.input_layer(x)
|
| 231 |
+
x = self.body(x)
|
| 232 |
+
x = self.output_pool(x)
|
| 233 |
+
x = x.view(-1, 512)
|
| 234 |
+
x = self.linear(x)
|
| 235 |
+
return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
encoder4editing/models/latent_codes_pool.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LatentCodesPool:
|
| 6 |
+
"""This class implements latent codes buffer that stores previously generated w latent codes.
|
| 7 |
+
This buffer enables us to update discriminators using a history of generated w's
|
| 8 |
+
rather than the ones produced by the latest encoder.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, pool_size):
|
| 12 |
+
"""Initialize the ImagePool class
|
| 13 |
+
Parameters:
|
| 14 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
| 15 |
+
"""
|
| 16 |
+
self.pool_size = pool_size
|
| 17 |
+
if self.pool_size > 0: # create an empty pool
|
| 18 |
+
self.num_ws = 0
|
| 19 |
+
self.ws = []
|
| 20 |
+
|
| 21 |
+
def query(self, ws):
|
| 22 |
+
"""Return w's from the pool.
|
| 23 |
+
Parameters:
|
| 24 |
+
ws: the latest generated w's from the generator
|
| 25 |
+
Returns w's from the buffer.
|
| 26 |
+
By 50/100, the buffer will return input w's.
|
| 27 |
+
By 50/100, the buffer will return w's previously stored in the buffer,
|
| 28 |
+
and insert the current w's to the buffer.
|
| 29 |
+
"""
|
| 30 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
| 31 |
+
return ws
|
| 32 |
+
return_ws = []
|
| 33 |
+
for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
|
| 34 |
+
# w = torch.unsqueeze(image.data, 0)
|
| 35 |
+
if w.ndim == 2:
|
| 36 |
+
i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
|
| 37 |
+
w = w[i]
|
| 38 |
+
self.handle_w(w, return_ws)
|
| 39 |
+
return_ws = torch.stack(return_ws, 0) # collect all the images and return
|
| 40 |
+
return return_ws
|
| 41 |
+
|
| 42 |
+
def handle_w(self, w, return_ws):
|
| 43 |
+
if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
|
| 44 |
+
self.num_ws = self.num_ws + 1
|
| 45 |
+
self.ws.append(w)
|
| 46 |
+
return_ws.append(w)
|
| 47 |
+
else:
|
| 48 |
+
p = random.uniform(0, 1)
|
| 49 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
|
| 50 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
| 51 |
+
tmp = self.ws[random_id].clone()
|
| 52 |
+
self.ws[random_id] = w
|
| 53 |
+
return_ws.append(tmp)
|
| 54 |
+
else: # by another 50% chance, the buffer will return the current image
|
| 55 |
+
return_ws.append(w)
|
encoder4editing/models/psp.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
|
| 3 |
+
matplotlib.use('Agg')
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from models.encoders import psp_encoders
|
| 7 |
+
from models.stylegan2.model import Generator
|
| 8 |
+
from configs.paths_config import model_paths
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_keys(d, name):
|
| 12 |
+
if 'state_dict' in d:
|
| 13 |
+
d = d['state_dict']
|
| 14 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
| 15 |
+
return d_filt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class pSp(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, opts):
|
| 21 |
+
super(pSp, self).__init__()
|
| 22 |
+
self.opts = opts
|
| 23 |
+
# Define architecture
|
| 24 |
+
self.encoder = self.set_encoder()
|
| 25 |
+
self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
|
| 26 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
| 27 |
+
# Load weights if needed
|
| 28 |
+
self.load_weights()
|
| 29 |
+
|
| 30 |
+
def set_encoder(self):
|
| 31 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
| 32 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
| 33 |
+
elif self.opts.encoder_type == 'Encoder4Editing':
|
| 34 |
+
encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
|
| 35 |
+
elif self.opts.encoder_type == 'SingleStyleCodeEncoder':
|
| 36 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
|
| 37 |
+
else:
|
| 38 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
| 39 |
+
return encoder
|
| 40 |
+
|
| 41 |
+
def load_weights(self):
|
| 42 |
+
if self.opts.checkpoint_path is not None:
|
| 43 |
+
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
| 44 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
| 45 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
| 46 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
| 47 |
+
self.__load_latent_avg(ckpt)
|
| 48 |
+
else:
|
| 49 |
+
print('Loading encoders weights from irse50!')
|
| 50 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
| 51 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
| 52 |
+
print('Loading decoder weights from pretrained!')
|
| 53 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
| 54 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
| 55 |
+
self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
|
| 56 |
+
|
| 57 |
+
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
| 58 |
+
inject_latent=None, return_latents=False, alpha=None):
|
| 59 |
+
if input_code:
|
| 60 |
+
codes = x
|
| 61 |
+
else:
|
| 62 |
+
codes = self.encoder(x)
|
| 63 |
+
# normalize with respect to the center of an average face
|
| 64 |
+
if self.opts.start_from_latent_avg:
|
| 65 |
+
if codes.ndim == 2:
|
| 66 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
| 67 |
+
else:
|
| 68 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
| 69 |
+
|
| 70 |
+
if latent_mask is not None:
|
| 71 |
+
for i in latent_mask:
|
| 72 |
+
if inject_latent is not None:
|
| 73 |
+
if alpha is not None:
|
| 74 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
| 75 |
+
else:
|
| 76 |
+
codes[:, i] = inject_latent[:, i]
|
| 77 |
+
else:
|
| 78 |
+
codes[:, i] = 0
|
| 79 |
+
|
| 80 |
+
input_is_latent = not input_code
|
| 81 |
+
images, result_latent = self.decoder([codes],
|
| 82 |
+
input_is_latent=input_is_latent,
|
| 83 |
+
randomize_noise=randomize_noise,
|
| 84 |
+
return_latents=return_latents)
|
| 85 |
+
|
| 86 |
+
if resize:
|
| 87 |
+
images = self.face_pool(images)
|
| 88 |
+
|
| 89 |
+
if return_latents:
|
| 90 |
+
return images, result_latent
|
| 91 |
+
else:
|
| 92 |
+
return images
|
| 93 |
+
|
| 94 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
| 95 |
+
if 'latent_avg' in ckpt:
|
| 96 |
+
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
|
| 97 |
+
if repeat is not None:
|
| 98 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
| 99 |
+
else:
|
| 100 |
+
self.latent_avg = None
|
encoder4editing/models/stylegan2/__init__.py
ADDED
|
File without changes
|
encoder4editing/models/stylegan2/model.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PixelNorm(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
def forward(self, input):
|
| 15 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def make_kernel(k):
|
| 19 |
+
k = torch.tensor(k, dtype=torch.float32)
|
| 20 |
+
|
| 21 |
+
if k.ndim == 1:
|
| 22 |
+
k = k[None, :] * k[:, None]
|
| 23 |
+
|
| 24 |
+
k /= k.sum()
|
| 25 |
+
|
| 26 |
+
return k
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Upsample(nn.Module):
|
| 30 |
+
def __init__(self, kernel, factor=2):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.factor = factor
|
| 34 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
| 35 |
+
self.register_buffer('kernel', kernel)
|
| 36 |
+
|
| 37 |
+
p = kernel.shape[0] - factor
|
| 38 |
+
|
| 39 |
+
pad0 = (p + 1) // 2 + factor - 1
|
| 40 |
+
pad1 = p // 2
|
| 41 |
+
|
| 42 |
+
self.pad = (pad0, pad1)
|
| 43 |
+
|
| 44 |
+
def forward(self, input):
|
| 45 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
| 46 |
+
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Downsample(nn.Module):
|
| 51 |
+
def __init__(self, kernel, factor=2):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.factor = factor
|
| 55 |
+
kernel = make_kernel(kernel)
|
| 56 |
+
self.register_buffer('kernel', kernel)
|
| 57 |
+
|
| 58 |
+
p = kernel.shape[0] - factor
|
| 59 |
+
|
| 60 |
+
pad0 = (p + 1) // 2
|
| 61 |
+
pad1 = p // 2
|
| 62 |
+
|
| 63 |
+
self.pad = (pad0, pad1)
|
| 64 |
+
|
| 65 |
+
def forward(self, input):
|
| 66 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
| 67 |
+
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Blur(nn.Module):
|
| 72 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
kernel = make_kernel(kernel)
|
| 76 |
+
|
| 77 |
+
if upsample_factor > 1:
|
| 78 |
+
kernel = kernel * (upsample_factor ** 2)
|
| 79 |
+
|
| 80 |
+
self.register_buffer('kernel', kernel)
|
| 81 |
+
|
| 82 |
+
self.pad = pad
|
| 83 |
+
|
| 84 |
+
def forward(self, input):
|
| 85 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
| 86 |
+
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class EqualConv2d(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
self.weight = nn.Parameter(
|
| 97 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
| 98 |
+
)
|
| 99 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
| 100 |
+
|
| 101 |
+
self.stride = stride
|
| 102 |
+
self.padding = padding
|
| 103 |
+
|
| 104 |
+
if bias:
|
| 105 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
| 106 |
+
|
| 107 |
+
else:
|
| 108 |
+
self.bias = None
|
| 109 |
+
|
| 110 |
+
def forward(self, input):
|
| 111 |
+
out = F.conv2d(
|
| 112 |
+
input,
|
| 113 |
+
self.weight * self.scale,
|
| 114 |
+
bias=self.bias,
|
| 115 |
+
stride=self.stride,
|
| 116 |
+
padding=self.padding,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
def __repr__(self):
|
| 122 |
+
return (
|
| 123 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
| 124 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class EqualLinear(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
| 135 |
+
|
| 136 |
+
if bias:
|
| 137 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
self.bias = None
|
| 141 |
+
|
| 142 |
+
self.activation = activation
|
| 143 |
+
|
| 144 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
| 145 |
+
self.lr_mul = lr_mul
|
| 146 |
+
|
| 147 |
+
def forward(self, input):
|
| 148 |
+
if self.activation:
|
| 149 |
+
out = F.linear(input, self.weight * self.scale)
|
| 150 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
out = F.linear(
|
| 154 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
def __repr__(self):
|
| 160 |
+
return (
|
| 161 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ScaledLeakyReLU(nn.Module):
|
| 166 |
+
def __init__(self, negative_slope=0.2):
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
self.negative_slope = negative_slope
|
| 170 |
+
|
| 171 |
+
def forward(self, input):
|
| 172 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
| 173 |
+
|
| 174 |
+
return out * math.sqrt(2)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ModulatedConv2d(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
in_channel,
|
| 181 |
+
out_channel,
|
| 182 |
+
kernel_size,
|
| 183 |
+
style_dim,
|
| 184 |
+
demodulate=True,
|
| 185 |
+
upsample=False,
|
| 186 |
+
downsample=False,
|
| 187 |
+
blur_kernel=[1, 3, 3, 1],
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
|
| 191 |
+
self.eps = 1e-8
|
| 192 |
+
self.kernel_size = kernel_size
|
| 193 |
+
self.in_channel = in_channel
|
| 194 |
+
self.out_channel = out_channel
|
| 195 |
+
self.upsample = upsample
|
| 196 |
+
self.downsample = downsample
|
| 197 |
+
|
| 198 |
+
if upsample:
|
| 199 |
+
factor = 2
|
| 200 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
| 201 |
+
pad0 = (p + 1) // 2 + factor - 1
|
| 202 |
+
pad1 = p // 2 + 1
|
| 203 |
+
|
| 204 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
| 205 |
+
|
| 206 |
+
if downsample:
|
| 207 |
+
factor = 2
|
| 208 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 209 |
+
pad0 = (p + 1) // 2
|
| 210 |
+
pad1 = p // 2
|
| 211 |
+
|
| 212 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
| 213 |
+
|
| 214 |
+
fan_in = in_channel * kernel_size ** 2
|
| 215 |
+
self.scale = 1 / math.sqrt(fan_in)
|
| 216 |
+
self.padding = kernel_size // 2
|
| 217 |
+
|
| 218 |
+
self.weight = nn.Parameter(
|
| 219 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
| 223 |
+
|
| 224 |
+
self.demodulate = demodulate
|
| 225 |
+
|
| 226 |
+
def __repr__(self):
|
| 227 |
+
return (
|
| 228 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
| 229 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def forward(self, input, style):
|
| 233 |
+
batch, in_channel, height, width = input.shape
|
| 234 |
+
|
| 235 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
| 236 |
+
weight = self.scale * self.weight * style
|
| 237 |
+
|
| 238 |
+
if self.demodulate:
|
| 239 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
| 240 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
| 241 |
+
|
| 242 |
+
weight = weight.view(
|
| 243 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if self.upsample:
|
| 247 |
+
input = input.view(1, batch * in_channel, height, width)
|
| 248 |
+
weight = weight.view(
|
| 249 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
| 250 |
+
)
|
| 251 |
+
weight = weight.transpose(1, 2).reshape(
|
| 252 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
| 253 |
+
)
|
| 254 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
| 255 |
+
_, _, height, width = out.shape
|
| 256 |
+
out = out.view(batch, self.out_channel, height, width)
|
| 257 |
+
out = self.blur(out)
|
| 258 |
+
|
| 259 |
+
elif self.downsample:
|
| 260 |
+
input = self.blur(input)
|
| 261 |
+
_, _, height, width = input.shape
|
| 262 |
+
input = input.view(1, batch * in_channel, height, width)
|
| 263 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
| 264 |
+
_, _, height, width = out.shape
|
| 265 |
+
out = out.view(batch, self.out_channel, height, width)
|
| 266 |
+
|
| 267 |
+
else:
|
| 268 |
+
input = input.view(1, batch * in_channel, height, width)
|
| 269 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
| 270 |
+
_, _, height, width = out.shape
|
| 271 |
+
out = out.view(batch, self.out_channel, height, width)
|
| 272 |
+
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class NoiseInjection(nn.Module):
|
| 277 |
+
def __init__(self):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
| 281 |
+
|
| 282 |
+
def forward(self, image, noise=None):
|
| 283 |
+
if noise is None:
|
| 284 |
+
batch, _, height, width = image.shape
|
| 285 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
| 286 |
+
|
| 287 |
+
return image + self.weight * noise
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ConstantInput(nn.Module):
|
| 291 |
+
def __init__(self, channel, size=4):
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
| 295 |
+
|
| 296 |
+
def forward(self, input):
|
| 297 |
+
batch = input.shape[0]
|
| 298 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
| 299 |
+
|
| 300 |
+
return out
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class StyledConv(nn.Module):
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
in_channel,
|
| 307 |
+
out_channel,
|
| 308 |
+
kernel_size,
|
| 309 |
+
style_dim,
|
| 310 |
+
upsample=False,
|
| 311 |
+
blur_kernel=[1, 3, 3, 1],
|
| 312 |
+
demodulate=True,
|
| 313 |
+
):
|
| 314 |
+
super().__init__()
|
| 315 |
+
|
| 316 |
+
self.conv = ModulatedConv2d(
|
| 317 |
+
in_channel,
|
| 318 |
+
out_channel,
|
| 319 |
+
kernel_size,
|
| 320 |
+
style_dim,
|
| 321 |
+
upsample=upsample,
|
| 322 |
+
blur_kernel=blur_kernel,
|
| 323 |
+
demodulate=demodulate,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
self.noise = NoiseInjection()
|
| 327 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
| 328 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
| 329 |
+
self.activate = FusedLeakyReLU(out_channel)
|
| 330 |
+
|
| 331 |
+
def forward(self, input, style, noise=None):
|
| 332 |
+
out = self.conv(input, style)
|
| 333 |
+
out = self.noise(out, noise=noise)
|
| 334 |
+
# out = out + self.bias
|
| 335 |
+
out = self.activate(out)
|
| 336 |
+
|
| 337 |
+
return out
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class ToRGB(nn.Module):
|
| 341 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
if upsample:
|
| 345 |
+
self.upsample = Upsample(blur_kernel)
|
| 346 |
+
|
| 347 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
| 348 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
| 349 |
+
|
| 350 |
+
def forward(self, input, style, skip=None):
|
| 351 |
+
out = self.conv(input, style)
|
| 352 |
+
out = out + self.bias
|
| 353 |
+
|
| 354 |
+
if skip is not None:
|
| 355 |
+
skip = self.upsample(skip)
|
| 356 |
+
|
| 357 |
+
out = out + skip
|
| 358 |
+
|
| 359 |
+
return out
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class Generator(nn.Module):
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
size,
|
| 366 |
+
style_dim,
|
| 367 |
+
n_mlp,
|
| 368 |
+
channel_multiplier=2,
|
| 369 |
+
blur_kernel=[1, 3, 3, 1],
|
| 370 |
+
lr_mlp=0.01,
|
| 371 |
+
):
|
| 372 |
+
super().__init__()
|
| 373 |
+
|
| 374 |
+
self.size = size
|
| 375 |
+
|
| 376 |
+
self.style_dim = style_dim
|
| 377 |
+
|
| 378 |
+
layers = [PixelNorm()]
|
| 379 |
+
|
| 380 |
+
for i in range(n_mlp):
|
| 381 |
+
layers.append(
|
| 382 |
+
EqualLinear(
|
| 383 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.style = nn.Sequential(*layers)
|
| 388 |
+
|
| 389 |
+
self.channels = {
|
| 390 |
+
4: 512,
|
| 391 |
+
8: 512,
|
| 392 |
+
16: 512,
|
| 393 |
+
32: 512,
|
| 394 |
+
64: 256 * channel_multiplier,
|
| 395 |
+
128: 128 * channel_multiplier,
|
| 396 |
+
256: 64 * channel_multiplier,
|
| 397 |
+
512: 32 * channel_multiplier,
|
| 398 |
+
1024: 16 * channel_multiplier,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
self.input = ConstantInput(self.channels[4])
|
| 402 |
+
self.conv1 = StyledConv(
|
| 403 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
| 404 |
+
)
|
| 405 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
| 406 |
+
|
| 407 |
+
self.log_size = int(math.log(size, 2))
|
| 408 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
| 409 |
+
|
| 410 |
+
self.convs = nn.ModuleList()
|
| 411 |
+
self.upsamples = nn.ModuleList()
|
| 412 |
+
self.to_rgbs = nn.ModuleList()
|
| 413 |
+
self.noises = nn.Module()
|
| 414 |
+
|
| 415 |
+
in_channel = self.channels[4]
|
| 416 |
+
|
| 417 |
+
for layer_idx in range(self.num_layers):
|
| 418 |
+
res = (layer_idx + 5) // 2
|
| 419 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
| 420 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
| 421 |
+
|
| 422 |
+
for i in range(3, self.log_size + 1):
|
| 423 |
+
out_channel = self.channels[2 ** i]
|
| 424 |
+
|
| 425 |
+
self.convs.append(
|
| 426 |
+
StyledConv(
|
| 427 |
+
in_channel,
|
| 428 |
+
out_channel,
|
| 429 |
+
3,
|
| 430 |
+
style_dim,
|
| 431 |
+
upsample=True,
|
| 432 |
+
blur_kernel=blur_kernel,
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
self.convs.append(
|
| 437 |
+
StyledConv(
|
| 438 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
| 443 |
+
|
| 444 |
+
in_channel = out_channel
|
| 445 |
+
|
| 446 |
+
self.n_latent = self.log_size * 2 - 2
|
| 447 |
+
|
| 448 |
+
def make_noise(self):
|
| 449 |
+
device = self.input.input.device
|
| 450 |
+
|
| 451 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
| 452 |
+
|
| 453 |
+
for i in range(3, self.log_size + 1):
|
| 454 |
+
for _ in range(2):
|
| 455 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
| 456 |
+
|
| 457 |
+
return noises
|
| 458 |
+
|
| 459 |
+
def mean_latent(self, n_latent):
|
| 460 |
+
latent_in = torch.randn(
|
| 461 |
+
n_latent, self.style_dim, device=self.input.input.device
|
| 462 |
+
)
|
| 463 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
| 464 |
+
|
| 465 |
+
return latent
|
| 466 |
+
|
| 467 |
+
def get_latent(self, input):
|
| 468 |
+
return self.style(input)
|
| 469 |
+
|
| 470 |
+
def forward(
|
| 471 |
+
self,
|
| 472 |
+
styles,
|
| 473 |
+
return_latents=False,
|
| 474 |
+
return_features=False,
|
| 475 |
+
inject_index=None,
|
| 476 |
+
truncation=1,
|
| 477 |
+
truncation_latent=None,
|
| 478 |
+
input_is_latent=False,
|
| 479 |
+
noise=None,
|
| 480 |
+
randomize_noise=True,
|
| 481 |
+
):
|
| 482 |
+
if not input_is_latent:
|
| 483 |
+
styles = [self.style(s) for s in styles]
|
| 484 |
+
|
| 485 |
+
if noise is None:
|
| 486 |
+
if randomize_noise:
|
| 487 |
+
noise = [None] * self.num_layers
|
| 488 |
+
else:
|
| 489 |
+
noise = [
|
| 490 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
| 491 |
+
]
|
| 492 |
+
|
| 493 |
+
if truncation < 1:
|
| 494 |
+
style_t = []
|
| 495 |
+
|
| 496 |
+
for style in styles:
|
| 497 |
+
style_t.append(
|
| 498 |
+
truncation_latent + truncation * (style - truncation_latent)
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
styles = style_t
|
| 502 |
+
|
| 503 |
+
if len(styles) < 2:
|
| 504 |
+
inject_index = self.n_latent
|
| 505 |
+
|
| 506 |
+
if styles[0].ndim < 3:
|
| 507 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 508 |
+
else:
|
| 509 |
+
latent = styles[0]
|
| 510 |
+
|
| 511 |
+
else:
|
| 512 |
+
if inject_index is None:
|
| 513 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
| 514 |
+
|
| 515 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
| 516 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
| 517 |
+
|
| 518 |
+
latent = torch.cat([latent, latent2], 1)
|
| 519 |
+
|
| 520 |
+
out = self.input(latent)
|
| 521 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
| 522 |
+
|
| 523 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
| 524 |
+
|
| 525 |
+
i = 1
|
| 526 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
| 527 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
| 528 |
+
):
|
| 529 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
| 530 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
| 531 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
| 532 |
+
|
| 533 |
+
i += 2
|
| 534 |
+
|
| 535 |
+
image = skip
|
| 536 |
+
|
| 537 |
+
if return_latents:
|
| 538 |
+
return image, latent
|
| 539 |
+
elif return_features:
|
| 540 |
+
return image, out
|
| 541 |
+
else:
|
| 542 |
+
return image, None
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class ConvLayer(nn.Sequential):
|
| 546 |
+
def __init__(
|
| 547 |
+
self,
|
| 548 |
+
in_channel,
|
| 549 |
+
out_channel,
|
| 550 |
+
kernel_size,
|
| 551 |
+
downsample=False,
|
| 552 |
+
blur_kernel=[1, 3, 3, 1],
|
| 553 |
+
bias=True,
|
| 554 |
+
activate=True,
|
| 555 |
+
):
|
| 556 |
+
layers = []
|
| 557 |
+
|
| 558 |
+
if downsample:
|
| 559 |
+
factor = 2
|
| 560 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 561 |
+
pad0 = (p + 1) // 2
|
| 562 |
+
pad1 = p // 2
|
| 563 |
+
|
| 564 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
| 565 |
+
|
| 566 |
+
stride = 2
|
| 567 |
+
self.padding = 0
|
| 568 |
+
|
| 569 |
+
else:
|
| 570 |
+
stride = 1
|
| 571 |
+
self.padding = kernel_size // 2
|
| 572 |
+
|
| 573 |
+
layers.append(
|
| 574 |
+
EqualConv2d(
|
| 575 |
+
in_channel,
|
| 576 |
+
out_channel,
|
| 577 |
+
kernel_size,
|
| 578 |
+
padding=self.padding,
|
| 579 |
+
stride=stride,
|
| 580 |
+
bias=bias and not activate,
|
| 581 |
+
)
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
if activate:
|
| 585 |
+
if bias:
|
| 586 |
+
layers.append(FusedLeakyReLU(out_channel))
|
| 587 |
+
|
| 588 |
+
else:
|
| 589 |
+
layers.append(ScaledLeakyReLU(0.2))
|
| 590 |
+
|
| 591 |
+
super().__init__(*layers)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class ResBlock(nn.Module):
|
| 595 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
| 596 |
+
super().__init__()
|
| 597 |
+
|
| 598 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
| 599 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
| 600 |
+
|
| 601 |
+
self.skip = ConvLayer(
|
| 602 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
def forward(self, input):
|
| 606 |
+
out = self.conv1(input)
|
| 607 |
+
out = self.conv2(out)
|
| 608 |
+
|
| 609 |
+
skip = self.skip(input)
|
| 610 |
+
out = (out + skip) / math.sqrt(2)
|
| 611 |
+
|
| 612 |
+
return out
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class Discriminator(nn.Module):
|
| 616 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
| 617 |
+
super().__init__()
|
| 618 |
+
|
| 619 |
+
channels = {
|
| 620 |
+
4: 512,
|
| 621 |
+
8: 512,
|
| 622 |
+
16: 512,
|
| 623 |
+
32: 512,
|
| 624 |
+
64: 256 * channel_multiplier,
|
| 625 |
+
128: 128 * channel_multiplier,
|
| 626 |
+
256: 64 * channel_multiplier,
|
| 627 |
+
512: 32 * channel_multiplier,
|
| 628 |
+
1024: 16 * channel_multiplier,
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
| 632 |
+
|
| 633 |
+
log_size = int(math.log(size, 2))
|
| 634 |
+
|
| 635 |
+
in_channel = channels[size]
|
| 636 |
+
|
| 637 |
+
for i in range(log_size, 2, -1):
|
| 638 |
+
out_channel = channels[2 ** (i - 1)]
|
| 639 |
+
|
| 640 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
| 641 |
+
|
| 642 |
+
in_channel = out_channel
|
| 643 |
+
|
| 644 |
+
self.convs = nn.Sequential(*convs)
|
| 645 |
+
|
| 646 |
+
self.stddev_group = 4
|
| 647 |
+
self.stddev_feat = 1
|
| 648 |
+
|
| 649 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
| 650 |
+
self.final_linear = nn.Sequential(
|
| 651 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
| 652 |
+
EqualLinear(channels[4], 1),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, input):
|
| 656 |
+
out = self.convs(input)
|
| 657 |
+
|
| 658 |
+
batch, channel, height, width = out.shape
|
| 659 |
+
group = min(batch, self.stddev_group)
|
| 660 |
+
stddev = out.view(
|
| 661 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
| 662 |
+
)
|
| 663 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
| 664 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
| 665 |
+
stddev = stddev.repeat(group, 1, height, width)
|
| 666 |
+
out = torch.cat([out, stddev], 1)
|
| 667 |
+
|
| 668 |
+
out = self.final_conv(out)
|
| 669 |
+
|
| 670 |
+
out = out.view(batch, -1)
|
| 671 |
+
out = self.final_linear(out)
|
| 672 |
+
|
| 673 |
+
return out
|
encoder4editing/models/stylegan2/op/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
| 2 |
+
from .upfirdn2d import upfirdn2d
|
encoder4editing/models/stylegan2/op/fused_act.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.autograd import Function
|
| 6 |
+
from torch.utils.cpp_extension import load
|
| 7 |
+
|
| 8 |
+
module_path = os.path.dirname(__file__)
|
| 9 |
+
fused = load(
|
| 10 |
+
'fused',
|
| 11 |
+
sources=[
|
| 12 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
| 13 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
| 14 |
+
],
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
| 21 |
+
ctx.save_for_backward(out)
|
| 22 |
+
ctx.negative_slope = negative_slope
|
| 23 |
+
ctx.scale = scale
|
| 24 |
+
|
| 25 |
+
empty = grad_output.new_empty(0)
|
| 26 |
+
|
| 27 |
+
grad_input = fused.fused_bias_act(
|
| 28 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
dim = [0]
|
| 32 |
+
|
| 33 |
+
if grad_input.ndim > 2:
|
| 34 |
+
dim += list(range(2, grad_input.ndim))
|
| 35 |
+
|
| 36 |
+
grad_bias = grad_input.sum(dim).detach()
|
| 37 |
+
|
| 38 |
+
return grad_input, grad_bias
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
| 42 |
+
out, = ctx.saved_tensors
|
| 43 |
+
gradgrad_out = fused.fused_bias_act(
|
| 44 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return gradgrad_out, None, None, None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class FusedLeakyReLUFunction(Function):
|
| 51 |
+
@staticmethod
|
| 52 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
| 53 |
+
empty = input.new_empty(0)
|
| 54 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
| 55 |
+
ctx.save_for_backward(out)
|
| 56 |
+
ctx.negative_slope = negative_slope
|
| 57 |
+
ctx.scale = scale
|
| 58 |
+
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def backward(ctx, grad_output):
|
| 63 |
+
out, = ctx.saved_tensors
|
| 64 |
+
|
| 65 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
| 66 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return grad_input, grad_bias, None, None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FusedLeakyReLU(nn.Module):
|
| 73 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
| 77 |
+
self.negative_slope = negative_slope
|
| 78 |
+
self.scale = scale
|
| 79 |
+
|
| 80 |
+
def forward(self, input):
|
| 81 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
| 85 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
encoder4editing/models/stylegan2/op/fused_bias_act.cpp
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
| 5 |
+
int act, int grad, float alpha, float scale);
|
| 6 |
+
|
| 7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 10 |
+
|
| 11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
| 12 |
+
int act, int grad, float alpha, float scale) {
|
| 13 |
+
CHECK_CUDA(input);
|
| 14 |
+
CHECK_CUDA(bias);
|
| 15 |
+
|
| 16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
| 21 |
+
}
|
encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, visit
|
| 5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
| 6 |
+
|
| 7 |
+
#include <torch/types.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/ATen.h>
|
| 10 |
+
#include <ATen/AccumulateType.h>
|
| 11 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 13 |
+
|
| 14 |
+
#include <cuda.h>
|
| 15 |
+
#include <cuda_runtime.h>
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
template <typename scalar_t>
|
| 19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
| 20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
| 21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
| 22 |
+
|
| 23 |
+
scalar_t zero = 0.0;
|
| 24 |
+
|
| 25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
| 26 |
+
scalar_t x = p_x[xi];
|
| 27 |
+
|
| 28 |
+
if (use_bias) {
|
| 29 |
+
x += p_b[(xi / step_b) % size_b];
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
| 33 |
+
|
| 34 |
+
scalar_t y;
|
| 35 |
+
|
| 36 |
+
switch (act * 10 + grad) {
|
| 37 |
+
default:
|
| 38 |
+
case 10: y = x; break;
|
| 39 |
+
case 11: y = x; break;
|
| 40 |
+
case 12: y = 0.0; break;
|
| 41 |
+
|
| 42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
| 43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
| 44 |
+
case 32: y = 0.0; break;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
out[xi] = y * scale;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
| 53 |
+
int act, int grad, float alpha, float scale) {
|
| 54 |
+
int curDevice = -1;
|
| 55 |
+
cudaGetDevice(&curDevice);
|
| 56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 57 |
+
|
| 58 |
+
auto x = input.contiguous();
|
| 59 |
+
auto b = bias.contiguous();
|
| 60 |
+
auto ref = refer.contiguous();
|
| 61 |
+
|
| 62 |
+
int use_bias = b.numel() ? 1 : 0;
|
| 63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
| 64 |
+
|
| 65 |
+
int size_x = x.numel();
|
| 66 |
+
int size_b = b.numel();
|
| 67 |
+
int step_b = 1;
|
| 68 |
+
|
| 69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
| 70 |
+
step_b *= x.size(i);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
int loop_x = 4;
|
| 74 |
+
int block_size = 4 * 32;
|
| 75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
| 76 |
+
|
| 77 |
+
auto y = torch::empty_like(x);
|
| 78 |
+
|
| 79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
| 80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
| 81 |
+
y.data_ptr<scalar_t>(),
|
| 82 |
+
x.data_ptr<scalar_t>(),
|
| 83 |
+
b.data_ptr<scalar_t>(),
|
| 84 |
+
ref.data_ptr<scalar_t>(),
|
| 85 |
+
act,
|
| 86 |
+
grad,
|
| 87 |
+
alpha,
|
| 88 |
+
scale,
|
| 89 |
+
loop_x,
|
| 90 |
+
size_x,
|
| 91 |
+
step_b,
|
| 92 |
+
size_b,
|
| 93 |
+
use_bias,
|
| 94 |
+
use_ref
|
| 95 |
+
);
|
| 96 |
+
});
|
| 97 |
+
|
| 98 |
+
return y;
|
| 99 |
+
}
|
encoder4editing/models/stylegan2/op/upfirdn2d.cpp
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
| 5 |
+
int up_x, int up_y, int down_x, int down_y,
|
| 6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
| 7 |
+
|
| 8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 11 |
+
|
| 12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
| 13 |
+
int up_x, int up_y, int down_x, int down_y,
|
| 14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
| 15 |
+
CHECK_CUDA(input);
|
| 16 |
+
CHECK_CUDA(kernel);
|
| 17 |
+
|
| 18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
| 23 |
+
}
|