jmanhype
commited on
Commit
·
06e9d12
0
Parent(s):
Initial commit without binary files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +65 -0
- .gitmodules +11 -0
- CHANGES +5 -0
- Dockerfile +18 -0
- LICENSE +175 -0
- MMCM +1 -0
- README-zh.md +465 -0
- README.md +37 -0
- configs/model/T2I_all_model.py +15 -0
- configs/model/ip_adapter.py +66 -0
- configs/model/lcm_model.py +17 -0
- configs/model/motion_model.py +22 -0
- configs/model/negative_prompt.py +32 -0
- configs/model/referencenet.py +14 -0
- configs/tasks/example.yaml +210 -0
- controlnet_aux +1 -0
- data/models/musev_structure.png +0 -0
- data/models/parallel_denoise.png +0 -0
- diffusers +1 -0
- environment.yml +312 -0
- musev/__init__.py +9 -0
- musev/auto_prompt/__init__.py +0 -0
- musev/auto_prompt/attributes/__init__.py +8 -0
- musev/auto_prompt/attributes/attr2template.py +127 -0
- musev/auto_prompt/attributes/attributes.py +227 -0
- musev/auto_prompt/attributes/human.py +424 -0
- musev/auto_prompt/attributes/render.py +33 -0
- musev/auto_prompt/attributes/style.py +12 -0
- musev/auto_prompt/human.py +40 -0
- musev/auto_prompt/load_template.py +37 -0
- musev/auto_prompt/util.py +25 -0
- musev/data/__init__.py +0 -0
- musev/data/data_util.py +681 -0
- musev/logging.conf +32 -0
- musev/models/__init__.py +3 -0
- musev/models/attention.py +431 -0
- musev/models/attention_processor.py +750 -0
- musev/models/controlnet.py +399 -0
- musev/models/embeddings.py +87 -0
- musev/models/facein_loader.py +120 -0
- musev/models/ip_adapter_face_loader.py +179 -0
- musev/models/ip_adapter_loader.py +340 -0
- musev/models/referencenet.py +1216 -0
- musev/models/referencenet_loader.py +124 -0
- musev/models/resnet.py +135 -0
- musev/models/super_model.py +253 -0
- musev/models/temporal_transformer.py +308 -0
- musev/models/text_model.py +40 -0
- musev/models/transformer_2d.py +445 -0
- musev/models/unet_2d_blocks.py +1537 -0
.gitignore
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Checkpoints
|
24 |
+
checkpoints/
|
25 |
+
|
26 |
+
# Logs
|
27 |
+
*.log
|
28 |
+
logs/
|
29 |
+
tensorboard/
|
30 |
+
|
31 |
+
# Environment
|
32 |
+
.env
|
33 |
+
.venv
|
34 |
+
env/
|
35 |
+
venv/
|
36 |
+
ENV/
|
37 |
+
|
38 |
+
# IDE
|
39 |
+
.idea/
|
40 |
+
.vscode/
|
41 |
+
*.swp
|
42 |
+
*.swo
|
43 |
+
|
44 |
+
# OS
|
45 |
+
.DS_Store
|
46 |
+
Thumbs.db
|
47 |
+
|
48 |
+
# Large files
|
49 |
+
data/result_video/
|
50 |
+
*.mp4
|
51 |
+
*.png
|
52 |
+
*.jpg
|
53 |
+
*.jpeg
|
54 |
+
*.gif
|
55 |
+
*.webp
|
56 |
+
*.avi
|
57 |
+
*.mov
|
58 |
+
*.mkv
|
59 |
+
*.flv
|
60 |
+
*.wmv
|
61 |
+
|
62 |
+
# Demo and source files
|
63 |
+
data/demo/
|
64 |
+
data/source_video/
|
65 |
+
data/images/
|
.gitmodules
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "MMCM"]
|
2 |
+
path = MMCM
|
3 |
+
url = https://github.com/TMElyralab/MMCM.git
|
4 |
+
[submodule "controlnet_aux"]
|
5 |
+
path = controlnet_aux
|
6 |
+
url = https://github.com/TMElyralab/controlnet_aux.git
|
7 |
+
branch = tme
|
8 |
+
[submodule "diffusers"]
|
9 |
+
path = diffusers
|
10 |
+
url = https://github.com/TMElyralab/diffusers.git
|
11 |
+
branch = tme
|
CHANGES
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Version 1.0.0 (2024.03.27)
|
2 |
+
|
3 |
+
* init musev, support video generation with text and image
|
4 |
+
* controlnet_aux: enrich interface and function of dwpose.
|
5 |
+
* diffusers: controlnet support latent instead of images only.
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM anchorxia/musev:1.0.0
|
2 |
+
|
3 |
+
#MAINTAINER 维护者信息
|
4 |
+
LABEL MAINTAINER="anchorxia"
|
5 |
+
LABEL Email="[email protected]"
|
6 |
+
LABEL Description="musev gpu runtime image, base docker is pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel"
|
7 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
8 |
+
|
9 |
+
USER root
|
10 |
+
|
11 |
+
SHELL ["/bin/bash", "--login", "-c"]
|
12 |
+
|
13 |
+
RUN . /opt/conda/etc/profile.d/conda.sh \
|
14 |
+
&& echo "source activate musev" >> ~/.bashrc \
|
15 |
+
&& conda activate musev \
|
16 |
+
&& conda env list \
|
17 |
+
&& pip --no-cache-dir install cuid gradio==4.12 spaces
|
18 |
+
USER root
|
LICENSE
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
MIT License
|
3 |
+
|
4 |
+
Copyright (c) 2024 Tencent Music Entertainment Group
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
|
24 |
+
|
25 |
+
Other dependencies and licenses:
|
26 |
+
|
27 |
+
|
28 |
+
Open Source Software Licensed under the MIT License:
|
29 |
+
--------------------------------------------------------------------
|
30 |
+
1. BriVL-BUA-applications
|
31 |
+
Files:https://github.com/chuhaojin/BriVL-BUA-applications
|
32 |
+
License:MIT License
|
33 |
+
Copyright (c) 2021 chuhaojin
|
34 |
+
For details:https://github.com/chuhaojin/BriVL-BUA-applications/blob/master/LICENSE
|
35 |
+
|
36 |
+
2.deep-person-reid
|
37 |
+
Files:https://github.com/KaiyangZhou/deep-person-reid
|
38 |
+
License:MIT License
|
39 |
+
Copyright (c) 2018 Kaiyang Zhou
|
40 |
+
For details:https://github.com/KaiyangZhou/deep-person-reid/blob/master/LICENSE
|
41 |
+
|
42 |
+
|
43 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
44 |
+
--------------------------------------------------------------------
|
45 |
+
1. diffusers
|
46 |
+
Files:https://github.com/huggingface/diffusers
|
47 |
+
License:Apache License 2.0
|
48 |
+
Copyright 2024 The HuggingFace Team. All rights reserved.
|
49 |
+
For details:https://github.com/huggingface/diffusers/blob/main/LICENSE
|
50 |
+
https://github.com/huggingface/diffusers/blob/main/setup.py
|
51 |
+
|
52 |
+
|
53 |
+
2. controlnet_aux
|
54 |
+
Files: https://github.com/huggingface/controlnet_aux
|
55 |
+
License: Apache License 2.0
|
56 |
+
Copyright 2023 The HuggingFace Team. All rights reserved.
|
57 |
+
For details: https://github.com/huggingface/controlnet_aux/blob/master/LICENSE.txt
|
58 |
+
https://github.com/huggingface/controlnet_aux/blob/master/setup.py
|
59 |
+
|
60 |
+
3. decord
|
61 |
+
Files:https://github.com/dmlc/decord
|
62 |
+
License:Apache License 2.0
|
63 |
+
For details:https://github.com/dmlc/decord/blob/master/LICENSE
|
64 |
+
|
65 |
+
|
66 |
+
Terms of the Apache License Version 2.0:
|
67 |
+
--------------------------------------------------------------------
|
68 |
+
Apache License
|
69 |
+
|
70 |
+
Version 2.0, January 2004
|
71 |
+
|
72 |
+
http://www.apache.org/licenses/
|
73 |
+
|
74 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
75 |
+
1. Definitions.
|
76 |
+
|
77 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
78 |
+
|
79 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
80 |
+
|
81 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
82 |
+
|
83 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
84 |
+
|
85 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
86 |
+
|
87 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
88 |
+
|
89 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
90 |
+
|
91 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
92 |
+
|
93 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
94 |
+
|
95 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
96 |
+
|
97 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
98 |
+
|
99 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
100 |
+
|
101 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
102 |
+
|
103 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
104 |
+
|
105 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
106 |
+
|
107 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
108 |
+
|
109 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
110 |
+
|
111 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
112 |
+
|
113 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
114 |
+
|
115 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
116 |
+
|
117 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
118 |
+
|
119 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
120 |
+
|
121 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
122 |
+
|
123 |
+
END OF TERMS AND CONDITIONS
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
128 |
+
--------------------------------------------------------------------
|
129 |
+
1. pynvml
|
130 |
+
Files:https://github.com/gpuopenanalytics/pynvml/tree/master
|
131 |
+
License:BSD 3-Clause
|
132 |
+
Copyright (c) 2011-2021, NVIDIA Corporation.
|
133 |
+
All rights reserved.
|
134 |
+
For details:https://github.com/gpuopenanalytics/pynvml/blob/master/LICENSE.txt
|
135 |
+
|
136 |
+
|
137 |
+
Terms of the BSD 3-Clause License:
|
138 |
+
--------------------------------------------------------------------
|
139 |
+
Redistribution and use in source and binary forms, with or without
|
140 |
+
modification, are permitted provided that the following conditions are met:
|
141 |
+
|
142 |
+
* Redistributions of source code must retain the above copyright notice, this
|
143 |
+
list of conditions and the following disclaimer.
|
144 |
+
|
145 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
146 |
+
this list of conditions and the following disclaimer in the documentation
|
147 |
+
and/or other materials provided with the distribution.
|
148 |
+
|
149 |
+
* Neither the name of the copyright holder nor the names of its
|
150 |
+
contributors may be used to endorse or promote products derived from
|
151 |
+
this software without specific prior written permission.
|
152 |
+
|
153 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
154 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
155 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
156 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
157 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
158 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
159 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
160 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
161 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
162 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
163 |
+
|
164 |
+
|
165 |
+
Other Open Source Software:
|
166 |
+
--------------------------------------------------------------------
|
167 |
+
1.SceneSeg
|
168 |
+
Files:https://github.com/AnyiRao/SceneSeg/tree/master
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
MMCM
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 1e2b6e6a848f0f116e8acaf0621c2ee64d3642ce
|
README-zh.md
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MuseV [English](README.md) [中文](README-zh.md)
|
2 |
+
|
3 |
+
<font size=5>MuseV:基于视觉条件并行去噪的无限长度和高保真虚拟人视频生成。
|
4 |
+
</br>
|
5 |
+
Zhiqiang Xia <sup>\*</sup>,
|
6 |
+
Zhaokang Chen<sup>\*</sup>,
|
7 |
+
Bin Wu<sup>†</sup>,
|
8 |
+
Chao Li,
|
9 |
+
Kwok-Wai Hung,
|
10 |
+
Chao Zhan,
|
11 |
+
Yingjie He,
|
12 |
+
Wenjiang Zhou
|
13 |
+
(<sup>*</sup>co-first author, <sup>†</sup>Corresponding Author, [email protected])
|
14 |
+
</font>
|
15 |
+
|
16 |
+
**[github](https://github.com/TMElyralab/MuseV)** **[huggingface](https://huggingface.co/TMElyralab/MuseV)** **[HuggingfaceSpace](https://huggingface.co/spaces/AnchorFake/MuseVDemo)** **[project](https://tmelyralab.github.io/MuseV_Page/)** **Technical report (comming soon)**
|
17 |
+
|
18 |
+
|
19 |
+
我们在2023年3月相信扩散模型可以模拟世界,也开始基于扩散模型研发世界视觉模拟器。`MuseV`是在 2023 年 7 月左右实现的一个里程碑。受到 Sora 进展的启发,我们决定开源 MuseV。MuseV 站在开源的肩膀上成长,也希望能够借此反馈社区。接下来,我们将转向有前景的扩散+变换器方案。
|
20 |
+
|
21 |
+
我们已经发布 <a href="https://github.com/TMElyralab/MuseTalk" style="font-size:24px; color:red;">MuseTalk</a>. `MuseTalk`是一个实时高质量的唇同步模型,可与 `MuseV` 一起构建完整的`虚拟人生成解决方案`。请保持关注!
|
22 |
+
|
23 |
+
:new: 我们新发布了<a href="https://github.com/TMElyralab/MusePose" style="font-size:24px; color:red;">MusePose</a>。 MusePose是一个用于虚拟人物的图像到视频生成框架,它可以根据控制信号(姿态)生成视频。结合 MuseV 和 MuseTalk,我们希望社区能够加入我们,一起迈向一个愿景:能够端到端生成具有全身运动和交互能力的虚拟人物。
|
24 |
+
|
25 |
+
# 概述
|
26 |
+
|
27 |
+
`MuseV` 是基于扩散模型的虚拟人视频生成框架,具有以下特点:
|
28 |
+
|
29 |
+
1. 支持使用新颖的视觉条件并行去噪方案进行无限长度生成,不会再有误差累计的问题,尤其适用于固定相机位的场景。
|
30 |
+
1. 提供了基于人物类型数据集训练的虚拟人视频生成预训练模型。
|
31 |
+
1. 支持图像到视频、文本到图像到视频、视频到视频的生成。
|
32 |
+
1. 兼容 `Stable Diffusion` 文图生成生态系统,包括 `base_model`、`lora`、`controlnet` 等。
|
33 |
+
1. 支持多参考图像技术,包括 `IPAdapter`、`ReferenceOnly`、`ReferenceNet`、`IPAdapterFaceID`。
|
34 |
+
1. 我们后面也会推出训练代码。
|
35 |
+
|
36 |
+
# 重要更新
|
37 |
+
1. `musev_referencenet_pose`: `unet`, `ip_adapter` 的模型名字指定错误,请使用 `musev_referencenet_pose`而不是`musev_referencenet`,请使用最新的main分支。
|
38 |
+
|
39 |
+
# 进展
|
40 |
+
- [2024年3月27日] 发布 `MuseV` 项目和训练好的模型 `musev`、`muse_referencenet`、`muse_referencenet_pose`。
|
41 |
+
- [03/30/2024] 在 huggingface space 上新增 [gui](https://huggingface.co/spaces/AnchorFake/MuseVDemo) 交互方式来生成视频.
|
42 |
+
|
43 |
+
## 模型
|
44 |
+
### 模型结构示意图
|
45 |
+

|
46 |
+
### 并行去噪算法示意图
|
47 |
+

|
48 |
+
|
49 |
+
## 测试用例
|
50 |
+
生成结果的所有帧直接由`MuseV`生成,没有时序超分辨、空间超分辨等任何后处理。
|
51 |
+
更多测试结果请看[MuseVPage]()
|
52 |
+
|
53 |
+
<!-- # TODO: // use youtu video link? -->
|
54 |
+
以下所有测试用例都维护在 `configs/tasks/example.yaml`,可以直接运行复现。
|
55 |
+
**[project](https://tmelyralab.github.io/)** 有更多测试用例,包括直接生成的、一两分钟的长视频。
|
56 |
+
|
57 |
+
### 输入文本、图像的视频生成
|
58 |
+
#### 人类
|
59 |
+
<table class="center">
|
60 |
+
<tr style="font-weight: bolder;text-align:center;">
|
61 |
+
<td width="50%">image</td>
|
62 |
+
<td width="45%">video </td>
|
63 |
+
<td width="5%">prompt</td>
|
64 |
+
</tr>
|
65 |
+
|
66 |
+
<tr>
|
67 |
+
<td>
|
68 |
+
<img src=./data/images/yongen.jpeg width="400">
|
69 |
+
</td>
|
70 |
+
<td >
|
71 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/732cf1fd-25e7-494e-b462-969c9425d277" width="100" controls preload></video>
|
72 |
+
</td>
|
73 |
+
<td>(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
|
74 |
+
</td>
|
75 |
+
</tr>
|
76 |
+
|
77 |
+
<tr>
|
78 |
+
<td>
|
79 |
+
<img src=./data/images/seaside4.jpeg width="400">
|
80 |
+
</td>
|
81 |
+
<td>
|
82 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/9b75a46c-f4e6-45ef-ad02-05729f091c8f" width="100" controls preload></video>
|
83 |
+
</td>
|
84 |
+
<td>
|
85 |
+
(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
86 |
+
</td>
|
87 |
+
</tr>
|
88 |
+
<tr>
|
89 |
+
<td>
|
90 |
+
<img src=./data/images/seaside_girl.jpeg width="400">
|
91 |
+
</td>
|
92 |
+
<td>
|
93 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/d0f3b401-09bf-4018-81c3-569ec24a4de9" width="100" controls preload></video>
|
94 |
+
</td>
|
95 |
+
<td>
|
96 |
+
(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
97 |
+
</td>
|
98 |
+
</tr>
|
99 |
+
<!-- guitar -->
|
100 |
+
<tr>
|
101 |
+
<td>
|
102 |
+
<img src=./data/images/boy_play_guitar.jpeg width="400">
|
103 |
+
</td>
|
104 |
+
<td>
|
105 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/61bf955e-7161-44c8-a498-8811c4f4eb4f" width="100" controls preload></video>
|
106 |
+
</td>
|
107 |
+
<td>
|
108 |
+
(masterpiece, best quality, highres:1), playing guitar
|
109 |
+
</td>
|
110 |
+
</tr>
|
111 |
+
<tr>
|
112 |
+
<td>
|
113 |
+
<img src=./data/images/girl_play_guitar2.jpeg width="400">
|
114 |
+
</td>
|
115 |
+
<td>
|
116 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/40982aa7-9f6a-4e44-8ef6-3f185d284e6a" width="100" controls preload></video>
|
117 |
+
</td>
|
118 |
+
<td>
|
119 |
+
(masterpiece, best quality, highres:1), playing guitar
|
120 |
+
</td>
|
121 |
+
</tr>
|
122 |
+
<!-- famous people -->
|
123 |
+
<tr>
|
124 |
+
<td>
|
125 |
+
<img src=./data/images/dufu.jpeg width="400">
|
126 |
+
</td>
|
127 |
+
<td>
|
128 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/28294baa-b996-420f-b1fb-046542adf87d" width="100" controls preload></video>
|
129 |
+
</td>
|
130 |
+
<td>
|
131 |
+
(masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
|
132 |
+
</td>
|
133 |
+
</tr>
|
134 |
+
|
135 |
+
<tr>
|
136 |
+
<td>
|
137 |
+
<img src=./data/images/Mona_Lisa.jpg width="400">
|
138 |
+
</td>
|
139 |
+
<td>
|
140 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/1ce11da6-14c6-4dcd-b7f9-7a5f060d71fb" width="100" controls preload></video>
|
141 |
+
</td>
|
142 |
+
<td>
|
143 |
+
(masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
|
144 |
+
soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
|
145 |
+
</td>
|
146 |
+
</tr>
|
147 |
+
</table >
|
148 |
+
|
149 |
+
#### 场景
|
150 |
+
<table class="center">
|
151 |
+
<tr style="font-weight: bolder;text-align:center;">
|
152 |
+
<td width="35%">image</td>
|
153 |
+
<td width="50%">video</td>
|
154 |
+
<td width="15%">prompt</td>
|
155 |
+
</tr>
|
156 |
+
|
157 |
+
<tr>
|
158 |
+
<td>
|
159 |
+
<img src=./data/images/waterfall4.jpeg width="400">
|
160 |
+
</td>
|
161 |
+
<td>
|
162 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/852daeb6-6b58-4931-81f9-0dddfa1b4ea5" width="100" controls preload></video>
|
163 |
+
</td>
|
164 |
+
<td>
|
165 |
+
(masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
|
166 |
+
endless waterfall
|
167 |
+
</td>
|
168 |
+
</tr>
|
169 |
+
|
170 |
+
<tr>
|
171 |
+
<td>
|
172 |
+
<img src=./data/images/seaside2.jpeg width="400">
|
173 |
+
</td>
|
174 |
+
<td>
|
175 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/4a4d527a-6203-411f-afe9-31c992d26816" width="100" controls preload></video>
|
176 |
+
</td>
|
177 |
+
<td>(masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
178 |
+
</td>
|
179 |
+
</tr>
|
180 |
+
</table >
|
181 |
+
|
182 |
+
### 输入视频条件的视频生成
|
183 |
+
当前生成模式下,需要参考视频的首帧条件和参考图像的首帧条件对齐,不然会破坏首帧的信息,效果会更差。所以一般生成流程是
|
184 |
+
1. 确定参考视频;
|
185 |
+
2. 用参考视频的首帧走图生图、controlnet流程,可以使用`MJ`等各种平台;
|
186 |
+
3. 拿2中的生成图、参考视频用MuseV生成视频;
|
187 |
+
4.
|
188 |
+
**pose2video**
|
189 |
+
|
190 |
+
`duffy` 的测试用例中,视觉条件帧的姿势与控制视频的第一帧不对齐。需要`posealign` 将解决这个问题。
|
191 |
+
|
192 |
+
<table class="center">
|
193 |
+
<tr style="font-weight: bolder;text-align:center;">
|
194 |
+
<td width="25%">image</td>
|
195 |
+
<td width="65%">video</td>
|
196 |
+
<td width="10%">prompt</td>
|
197 |
+
</tr>
|
198 |
+
<tr>
|
199 |
+
<td>
|
200 |
+
<img src=./data/images/spark_girl.png width="200">
|
201 |
+
<img src=./data/images/cyber_girl.png width="200">
|
202 |
+
</td>
|
203 |
+
<td>
|
204 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/484cc69d-c316-4464-a55b-3df929780a8e" width="400" controls preload></video>
|
205 |
+
</td>
|
206 |
+
<td>
|
207 |
+
(masterpiece, best quality, highres:1) , a girl is dancing, animation
|
208 |
+
</td>
|
209 |
+
</tr>
|
210 |
+
<tr>
|
211 |
+
<td>
|
212 |
+
<img src=./data/images/duffy.png width="400">
|
213 |
+
</td>
|
214 |
+
<td>
|
215 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/c44682e6-aafc-4730-8fc1-72825c1bacf2" width="400" controls preload></video>
|
216 |
+
</td>
|
217 |
+
<td>
|
218 |
+
(masterpiece, best quality, highres:1), is dancing, animation
|
219 |
+
</td>
|
220 |
+
</tr>
|
221 |
+
</table >
|
222 |
+
|
223 |
+
### MuseTalk
|
224 |
+
|
225 |
+
`talk`的角色`孙昕荧`著名的网络大V,可以在 [抖音](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8) 关注。
|
226 |
+
|
227 |
+
<table class="center">
|
228 |
+
<tr style="font-weight: bolder;">
|
229 |
+
<td width="35%">name</td>
|
230 |
+
<td width="50%">video</td>
|
231 |
+
</tr>
|
232 |
+
|
233 |
+
<tr>
|
234 |
+
<td>
|
235 |
+
talk
|
236 |
+
</td>
|
237 |
+
<td>
|
238 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/951188d1-4731-4e7f-bf40-03cacba17f2f" width="100" controls preload></video>
|
239 |
+
</td>
|
240 |
+
<tr>
|
241 |
+
<td>
|
242 |
+
sing
|
243 |
+
</td>
|
244 |
+
<td>
|
245 |
+
<video src="https://github.com/TMElyralab/MuseV/assets/163980830/50b8ffab-9307-4836-99e5-947e6ce7d112" width="100" controls preload></video>
|
246 |
+
</td>
|
247 |
+
</tr>
|
248 |
+
</table >
|
249 |
+
|
250 |
+
|
251 |
+
# 待办事项:
|
252 |
+
- [ ] 技术报告(即将推出)。
|
253 |
+
- [ ] 训练代码。
|
254 |
+
- [ ] 扩散变换生成框架。
|
255 |
+
- [ ] `posealign` 模块。
|
256 |
+
|
257 |
+
# 快速入门
|
258 |
+
准备 Python 环境并安装额外的包,如 `diffusers`、`controlnet_aux`、`mmcm`。
|
259 |
+
|
260 |
+
## 第三方整合版
|
261 |
+
一些第三方的整合,方便大家安装、使用,感谢第三方的工作。
|
262 |
+
同时也希望注意,我们没有对第���方的支持做验证、维护和后续更新,具体效果请以本项目为准。
|
263 |
+
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseV)
|
264 |
+
### [windows整合包](https://www.bilibili.com/video/BV1ux4y1v7pF/?vd_source=fe03b064abab17b79e22a692551405c3)
|
265 |
+
netdisk:https://www.123pan.com/s/Pf5Yjv-Bb9W3.html
|
266 |
+
code: glut
|
267 |
+
|
268 |
+
## 准备环境
|
269 |
+
建议您优先使用 `docker` 来准备 Python 环境。
|
270 |
+
|
271 |
+
### 准备 Python 环境
|
272 |
+
**注意**:我们只测试了 Docker,使用 conda 或其他环境可能会遇到问题。我们将尽力解决。但依然请优先使用 `docker`。
|
273 |
+
|
274 |
+
#### 方法 1:使用 Docker
|
275 |
+
1. 拉取 Docker 镜像
|
276 |
+
```bash
|
277 |
+
docker pull anchorxia/musev:latest
|
278 |
+
```
|
279 |
+
2. 运行 Docker 容器
|
280 |
+
```bash
|
281 |
+
docker run --gpus all -it --entrypoint /bin/bash anchorxia/musev:latest
|
282 |
+
```
|
283 |
+
docker启动后默认的 conda 环境是 `musev`。
|
284 |
+
|
285 |
+
#### 方法 2:使用 conda
|
286 |
+
从 environment.yaml 创建 conda 环境
|
287 |
+
```
|
288 |
+
conda env create --name musev --file ./environment.yml
|
289 |
+
```
|
290 |
+
#### 方法 3:使用 pip requirements
|
291 |
+
```
|
292 |
+
pip install -r requirements.txt
|
293 |
+
```
|
294 |
+
#### 准备 [openmmlab](https://openmmlab.com/) 包
|
295 |
+
如果不使用 Docker方式,还需要额外安装 mmlab 包。
|
296 |
+
```bash
|
297 |
+
pip install --no-cache-dir -U openmim
|
298 |
+
mim install mmengine
|
299 |
+
mim install "mmcv>=2.0.1"
|
300 |
+
mim install "mmdet>=3.1.0"
|
301 |
+
mim install "mmpose>=1.1.0"
|
302 |
+
```
|
303 |
+
|
304 |
+
### 准备我们开发的包
|
305 |
+
#### 下载
|
306 |
+
```bash
|
307 |
+
git clone --recursive https://github.com/TMElyralab/MuseV.git
|
308 |
+
```
|
309 |
+
#### 准备 PYTHONPATH
|
310 |
+
```bash
|
311 |
+
current_dir=$(pwd)
|
312 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV
|
313 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/MMCM
|
314 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/diffusers/src
|
315 |
+
export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/controlnet_aux/src
|
316 |
+
cd MuseV
|
317 |
+
```
|
318 |
+
|
319 |
+
1. `MMCM`:多媒体、跨模态处理包。
|
320 |
+
1. `diffusers`:基于 [diffusers](https://github.com/huggingface/diffusers) 修改的 diffusers 包。
|
321 |
+
1. `controlnet_aux`:基于 [controlnet_aux](https://github.com/TMElyralab/controlnet_aux) 修改的包。
|
322 |
+
|
323 |
+
|
324 |
+
## 下载模型
|
325 |
+
```bash
|
326 |
+
git clone https://huggingface.co/TMElyralab/MuseV ./checkpoints
|
327 |
+
```
|
328 |
+
- `motion`:多个版本的视频生成模型。使用小数据集 `ucf101` 和小 `webvid` 数据子集进行训练,约 60K 个视频文本对。GPU 内存消耗测试在 `resolution` $=512*512,`time_size=12`。
|
329 |
+
- `musev/unet`:这个版本 仅训练 `unet` 运动模块。推断 `GPU 内存消耗` $\approx 8G$。
|
330 |
+
- `musev_referencenet`:这个版本训练 `unet` 运动模块、`referencenet`、`IPAdapter`。推断 `GPU 内存消耗` $\approx 12G$。
|
331 |
+
- `unet`:`motion` 模块,具有 `Attention` 层中的 `to_k`、`to_v`,参考 `IPAdapter`。
|
332 |
+
- `referencenet`:类似于 `AnimateAnyone`。
|
333 |
+
- `ip_adapter_image_proj.bin`:图像特征变换层,参考 `IPAdapter`。
|
334 |
+
- `musev_referencenet_pose`:这个版本基于 `musev_referencenet`,固定 `referencenet` 和 `controlnet_pose`,训练 `unet motion` 和 `IPAdapter`。推断 `GPU 内存消耗` $\approx 12G$。
|
335 |
+
- `t2i/sd1.5`:text2image 模型,在训练运动模块时参数被冻结。
|
336 |
+
- `majicmixRealv6Fp16`:示例,可以替换为其他 t2i 基础。从 [majicmixRealv6Fp16](https://civitai.com/models/43331/majicmix-realistic) 下载。
|
337 |
+
- `fantasticmix_v10`: 可在 [fantasticmix_v10](https://civitai.com/models/22402?modelVersionId=26744) 下载。
|
338 |
+
- `IP-Adapter/models`:从 [IPAdapter](https://huggingface.co/h94/IP-Adapter/tree/main) 下载。
|
339 |
+
- `image_encoder`:视觉特征抽取模型。
|
340 |
+
- `ip-adapter_sd15.bin`:原始 IPAdapter 模型预训练权重。
|
341 |
+
- `ip-adapter-faceid_sd15.bin`:原始 IPAdapter 模型预训练权重。
|
342 |
+
|
343 |
+
## 推理
|
344 |
+
|
345 |
+
### 准备模型路径
|
346 |
+
当使用示例推断命令运行示例任务时,可以跳过此步骤。
|
347 |
+
该模块主要是在配置文件中设置模型路径和缩写,以在推断脚本中使用简单缩写而不是完整路径。
|
348 |
+
- T2I SD:参考 `musev/configs/model/T2I_all_model.py`
|
349 |
+
- 运动 Unet:参考 `musev/configs/model/motion_model.py`
|
350 |
+
- 任务:参考 `musev/configs/tasks/example.yaml`
|
351 |
+
|
352 |
+
### musev_referencenet
|
353 |
+
#### 输入文本、图像的视频生成
|
354 |
+
```bash
|
355 |
+
python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --time_size 12 --fps 12
|
356 |
+
```
|
357 |
+
**通用参数**:
|
358 |
+
- `test_data_path`:测试用例任务路径
|
359 |
+
- `target_datas`:如果 `test_data_path` 中的 `name` 在 `target_datas` 中,则只运行这些子任务。`sep` 是 `,`;
|
360 |
+
- `sd_model_cfg_path`:T2I sd 模型路径,模型配置路径或模型路径。
|
361 |
+
- `sd_model_name`:sd 模型名称,用于在 `sd_model_cfg_path` 中选择完整模型��径。使用 `,` 分隔的多个模型名称,或 `all`。
|
362 |
+
- `unet_model_cfg_path`:运动 unet 模型配置路径或模型路径。
|
363 |
+
- `unet_model_name`:unet 模型名称,用于获取 `unet_model_cfg_path` 中的模型路径,并在 `musev/models/unet_loader.py` 中初始化 unet 类实例。使用 `,` 分隔的多个模型名称,或 `all`。如果 `unet_model_cfg_path` 是模型路径,则 `unet_name` 必须在 `musev/models/unet_loader.py` 中支持。
|
364 |
+
- `time_size`:扩散模型每次生成一个片段,这里是一个片段的帧数。默认为 `12`。
|
365 |
+
- `n_batch`:首尾相连方式生成总片段数,$total\_frames=n\_batch * time\_size + n\_viscond$,默认为 `1`。
|
366 |
+
- `context_frames`: 并行去噪子窗口一次生成的帧数。如果 `time_size` > `context_frame`,则会启动并行去噪逻辑, `time_size` 窗口会分成多个子窗口进行并行去噪。默认为 `12`。
|
367 |
+
|
368 |
+
生成**长视频**,有两种方法,可以共同使用:
|
369 |
+
1. `视觉条件并行去噪`:设置 `n_batch=1`,`time_size` = 想要的所有帧。
|
370 |
+
2. `传统的首尾相连方式`:设置 `time_size` = `context_frames` = 一次片段的帧数 (`12`),`context_overlap` = 0。会首尾相连方式生成`n_batch`片段数,首尾相连存在误差累计,当`n_batch`越大,最后的结果越差。
|
371 |
+
|
372 |
+
|
373 |
+
**模型参数**:
|
374 |
+
支持 `referencenet`、`IPAdapter`、`IPAdapterFaceID`、`Facein`。
|
375 |
+
- `referencenet_model_name`:`referencenet` 模型名称。
|
376 |
+
- `ImageClipVisionFeatureExtractor`:`ImageEmbExtractor` 名称,在 `IPAdapter` 中提取视觉特征。
|
377 |
+
- `vision_clip_model_path`:`ImageClipVisionFeatureExtractor` 模型路径。
|
378 |
+
- `ip_adapter_model_name`:来自 `IPAdapter` 的,它是 `ImagePromptEmbProj`,与 `ImageEmbExtractor` 一起使用。
|
379 |
+
- `ip_adapter_face_model_name`:`IPAdapterFaceID`,来自 `IPAdapter`,应该设置 `face_image_path`。
|
380 |
+
|
381 |
+
**一些影响运动范围和生成结果的参数**:
|
382 |
+
- `video_guidance_scale`:类似于 text2image,控制 cond 和 uncond 之间的影响,影响较大,默认为 `3.5`。
|
383 |
+
- `use_condition_image`:是否使用给定的第一帧进行视频生成, 默认 `True`。
|
384 |
+
- `redraw_condition_image`:是否重新绘制给定的第一帧图像。
|
385 |
+
- `video_negative_prompt`:配置文件中全 `negative_prompt` 的缩写。默认为 `V2`。
|
386 |
+
|
387 |
+
|
388 |
+
#### 输入视频的视频生成
|
389 |
+
```bash
|
390 |
+
python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
391 |
+
```
|
392 |
+
**一些重要参数**
|
393 |
+
|
394 |
+
大多数参数与 `musev_text2video` 相同。`video2video` 的特殊参数有:
|
395 |
+
1. 需要在 `test_data` 中设置 `video_path`。现在支持 `rgb video` 和 `controlnet_middle_video`。
|
396 |
+
- `which2video`: 参考视频类型。 如果是 `video_middle`,则只使用类似`pose`、`depth`的 `video_middle`;如果是 `video`, 视频本身也会参与视频噪声初始化,类似于`img2imge`。
|
397 |
+
- `controlnet_name`:是否使用 `controlnet condition`,例如 `dwpose,depth`, pose的话 优先建议使用`dwpose_body_hand`。
|
398 |
+
- `video_is_middle`:`video_path` 是 `rgb video` 还是 `controlnet_middle_video`。可以为 `test_data_path` 中的每个 `test_data` 设置。
|
399 |
+
- `video_has_condition`:condtion_images 是否与 video_path 的第一帧对齐。如果不是,则首先生成 `condition_images`,然后与参考视频拼接对齐。 目前仅支持参考视频是`video_is_middle=True`,可`test_data` 设置。
|
400 |
+
|
401 |
+
所有 `controlnet_names` 维护在 [mmcm](https://github.com/TMElyralab/MMCM/blob/main/mmcm/vision/feature_extractor/controlnet.py#L513)
|
402 |
+
```python
|
403 |
+
['pose', 'pose_body', 'pose_hand', 'pose_face', 'pose_hand_body', 'pose_hand_face', 'dwpose', 'dwpose_face', 'dwpose_hand', 'dwpose_body', 'dwpose_body_hand', 'canny', 'tile', 'hed', 'hed_scribble', 'depth', 'pidi', 'normal_bae', 'lineart', 'lineart_anime', 'zoe', 'sam', 'mobile_sam', 'leres', 'content', 'face_detector']
|
404 |
+
```
|
405 |
+
|
406 |
+
### musev_referencenet_pose
|
407 |
+
仅用于 `pose2video`
|
408 |
+
基于 `musev_referencenet` 训练,固定 `referencenet`、`pose-controlnet` 和 `T2I`,训练 `motion` 模块和 `IPAdapter`。
|
409 |
+
```bash
|
410 |
+
python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet_pose --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet_pose -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
411 |
+
```
|
412 |
+
|
413 |
+
### musev
|
414 |
+
仅有动作模块,没有 referencenet,需要更少的 GPU 内存。
|
415 |
+
#### 文本到视频
|
416 |
+
```bash
|
417 |
+
python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --time_size 12 --fps 12
|
418 |
+
```
|
419 |
+
#### 视频到视频
|
420 |
+
```bash
|
421 |
+
python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
|
422 |
+
```
|
423 |
+
|
424 |
+
### Gradio 演示
|
425 |
+
MuseV 提供 gradio 脚本,可在本地机器上生成 GUI,方便生成视频。
|
426 |
+
|
427 |
+
```bash
|
428 |
+
cd scripts/gradio
|
429 |
+
python app.py
|
430 |
+
```
|
431 |
+
|
432 |
+
# 致谢
|
433 |
+
1. MuseV 开发过程中参考学习了很多开源工作 [TuneAVideo](https://github.com/showlab/Tune-A-Video)、[diffusers](https://github.com/huggingface/diffusers)、[Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone/tree/master/src/pipelines)、[animatediff](https://github.com/guoyww/AnimateDiff)、[IP-Adapter](https://github.com/tencent-ailab/IP-Adapter)、[AnimateAnyone](https://arxiv.org/abs/2311.17117)、[VideoFusion](https://arxiv.org/abs/2303.08320) 和 [insightface](https://github.com/deepinsight/insightface)。
|
434 |
+
2. MuseV 基于 `ucf101` 和 `webvid` 数据集构建。
|
435 |
+
|
436 |
+
感谢开源社区的贡献!
|
437 |
+
|
438 |
+
# 限制
|
439 |
+
|
440 |
+
`MuseV` 仍然存在很多待优化项,包括:
|
441 |
+
|
442 |
+
1. 缺乏泛化能力。对视觉条件帧敏感,有些视觉条件图像表现良好,有些表现不佳。有些预训练的 t2i 模型表现良好,有些表现不佳。
|
443 |
+
1. 有限的视频生成类型和有限的动作范围,部分原因是训练数据类型有限。发布的 `MuseV` 已经在大约 6 万对分辨率为 `512*320` 的人类文本视频对上进行了训练。`MuseV` 在较低分辨率下具有更大的动作范围,但视频质量较低。`MuseV` 在高分辨率下画质很好、但动作范围较小。在更大、更高分辨率、更高质量的文本视频数据集上进行训练可能会使 `MuseV` 更好。
|
444 |
+
1. 因为使用 `webvid` 训练会有水印问题。使用没有水印的、更干净的数据集可能会解决这个问题。
|
445 |
+
1. 有限类型的长视频生成。视觉条件并行去噪可以解决视频生成的累积误差,但当前的方法只适用于相对固定的摄像机场景。
|
446 |
+
1. referencenet 和 IP-Adapter 训练不足,因为时间有限和资源有限。
|
447 |
+
1. 代码结构不够完善。`MuseV` 支持丰富而动态的功能,但代码复杂且未经过重构。熟悉需要时间。
|
448 |
+
|
449 |
+
|
450 |
+
<!-- # Contribution 暂时不需要组织开源共建 -->
|
451 |
+
# 引用
|
452 |
+
```bib
|
453 |
+
@article{musev,
|
454 |
+
title={MuseV: 基于视觉条件的并行去噪的无限长度和高保真虚拟人视频生成},
|
455 |
+
author={Xia, Zhiqiang and Chen, Zhaokang and Wu, Bin and Li, Chao and Hung, Kwok-Wai and Zhan, Chao and He, Yingjie and Zhou, Wenjiang},
|
456 |
+
journal={arxiv},
|
457 |
+
year={2024}
|
458 |
+
}
|
459 |
+
```
|
460 |
+
# 免责声明/许可
|
461 |
+
1. `代码`:`MuseV` 的代码采用 `MIT` 许可证发布,学术用途和商业用途都可以。
|
462 |
+
1. `模型`:训练好的模型仅供非商业研究目的使用。
|
463 |
+
1. `其他开源模型`:使用的其他开源模型必须遵守他们的许可证,如 `insightface`、`IP-Adapter`、`ft-mse-vae` 等。
|
464 |
+
1. 测试数据收集自互联网,仅供非商业研究目的使用。
|
465 |
+
1. `AIGC`:本项目旨在积极影响基于人工智能的视频生成领域。用户被授予使用此工具创建视频的自由,但他们应该遵守当地法律,并负责任地使用。开发人员不对用户可能的不当使用承担任何责任。
|
README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MuseV Demo
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.50.2
|
8 |
+
app_file: scripts/gradio/app_gradio_space.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# MuseV Demo
|
13 |
+
|
14 |
+
This is a Hugging Face Space for MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising.
|
15 |
+
|
16 |
+
## Features
|
17 |
+
|
18 |
+
- Text-to-Video generation
|
19 |
+
- Visual condition support
|
20 |
+
- High-quality video generation
|
21 |
+
|
22 |
+
For more details, visit the [GitHub repository](https://github.com/TMElyralab/MuseV).
|
23 |
+
|
24 |
+
## Usage
|
25 |
+
|
26 |
+
1. Enter your prompt describing the video you want to generate
|
27 |
+
2. Upload a reference image
|
28 |
+
3. Adjust parameters like seed, FPS, dimensions, etc.
|
29 |
+
4. Click generate and wait for the results
|
30 |
+
|
31 |
+
## Model Details
|
32 |
+
|
33 |
+
The model will be automatically downloaded when you first run the demo.
|
34 |
+
|
35 |
+
## Credits
|
36 |
+
|
37 |
+
Created by Lyra Lab, Tencent Music Entertainment
|
configs/model/T2I_all_model.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
T2IDir = os.path.join(
|
5 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "t2i"
|
6 |
+
)
|
7 |
+
|
8 |
+
MODEL_CFG = {
|
9 |
+
"majicmixRealv6Fp16": {
|
10 |
+
"sd": os.path.join(T2IDir, "sd1.5/majicmixRealv6Fp16"),
|
11 |
+
},
|
12 |
+
"fantasticmix_v10": {
|
13 |
+
"sd": os.path.join(T2IDir, "sd1.5/fantasticmix_v10"),
|
14 |
+
},
|
15 |
+
}
|
configs/model/ip_adapter.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
IPAdapterModelDir = os.path.join(
|
4 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "IP-Adapter"
|
5 |
+
)
|
6 |
+
|
7 |
+
|
8 |
+
MotionDir = os.path.join(
|
9 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
MODEL_CFG = {
|
14 |
+
"IPAdapter": {
|
15 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "models/image_encoder"),
|
16 |
+
"ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter_sd15.bin"),
|
17 |
+
"ip_scale": 1.0,
|
18 |
+
"clip_extra_context_tokens": 4,
|
19 |
+
"clip_embeddings_dim": 1024,
|
20 |
+
"desp": "",
|
21 |
+
},
|
22 |
+
"IPAdapterPlus": {
|
23 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
|
24 |
+
"ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus_sd15.bin"),
|
25 |
+
"ip_scale": 1.0,
|
26 |
+
"clip_extra_context_tokens": 16,
|
27 |
+
"clip_embeddings_dim": 1024,
|
28 |
+
"desp": "",
|
29 |
+
},
|
30 |
+
"IPAdapterPlus-face": {
|
31 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
|
32 |
+
"ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus-face_sd15.bin"),
|
33 |
+
"ip_scale": 1.0,
|
34 |
+
"clip_extra_context_tokens": 16,
|
35 |
+
"clip_embeddings_dim": 1024,
|
36 |
+
"desp": "",
|
37 |
+
},
|
38 |
+
"IPAdapterFaceID": {
|
39 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
|
40 |
+
"ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-faceid_sd15.bin"),
|
41 |
+
"ip_scale": 1.0,
|
42 |
+
"clip_extra_context_tokens": 4,
|
43 |
+
"clip_embeddings_dim": 512,
|
44 |
+
"desp": "",
|
45 |
+
},
|
46 |
+
"musev_referencenet": {
|
47 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
|
48 |
+
"ip_ckpt": os.path.join(
|
49 |
+
MotionDir, "musev_referencenet/ip_adapter_image_proj.bin"
|
50 |
+
),
|
51 |
+
"ip_scale": 1.0,
|
52 |
+
"clip_extra_context_tokens": 4,
|
53 |
+
"clip_embeddings_dim": 1024,
|
54 |
+
"desp": "",
|
55 |
+
},
|
56 |
+
"musev_referencenet_pose": {
|
57 |
+
"ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
|
58 |
+
"ip_ckpt": os.path.join(
|
59 |
+
MotionDir, "musev_referencenet_pose/ip_adapter_image_proj.bin"
|
60 |
+
),
|
61 |
+
"ip_scale": 1.0,
|
62 |
+
"clip_extra_context_tokens": 4,
|
63 |
+
"clip_embeddings_dim": 1024,
|
64 |
+
"desp": "",
|
65 |
+
},
|
66 |
+
}
|
configs/model/lcm_model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
LCMDir = os.path.join(
|
5 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "lcm"
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
MODEL_CFG = {
|
10 |
+
"lcm": {
|
11 |
+
os.path.join(LCMDir, "lcm-lora-sdv1-5/pytorch_lora_weights.safetensors"): {
|
12 |
+
"strength": 1.0,
|
13 |
+
"lora_block_weight": "ALL",
|
14 |
+
"strength_offset": 0,
|
15 |
+
},
|
16 |
+
},
|
17 |
+
}
|
configs/model/motion_model.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
MotionDIr = os.path.join(
|
5 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
MODEL_CFG = {
|
10 |
+
"musev": {
|
11 |
+
"unet": os.path.join(MotionDIr, "musev"),
|
12 |
+
"desp": "only train unet motion module, fix t2i",
|
13 |
+
},
|
14 |
+
"musev_referencenet": {
|
15 |
+
"unet": os.path.join(MotionDIr, "musev_referencenet"),
|
16 |
+
"desp": "train referencenet, IPAdapter and unet motion module, fix t2i",
|
17 |
+
},
|
18 |
+
"musev_referencenet_pose": {
|
19 |
+
"unet": os.path.join(MotionDIr, "musev_referencenet_pose"),
|
20 |
+
"desp": "train unet motion module and IPAdapter, fix t2i and referencenet",
|
21 |
+
},
|
22 |
+
}
|
configs/model/negative_prompt.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Negative_Prompt_CFG = {
|
2 |
+
"Empty": {
|
3 |
+
"base_model": "",
|
4 |
+
"prompt": "",
|
5 |
+
"refer": "",
|
6 |
+
},
|
7 |
+
"V1": {
|
8 |
+
"base_model": "",
|
9 |
+
"prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, tail, watermarks",
|
10 |
+
"refer": "",
|
11 |
+
},
|
12 |
+
"V2": {
|
13 |
+
"base_model": "",
|
14 |
+
"prompt": "badhandv4, ng_deepnegative_v1_75t, (((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)",
|
15 |
+
"refer": "Weiban",
|
16 |
+
},
|
17 |
+
"V3": {
|
18 |
+
"base_model": "",
|
19 |
+
"prompt": "badhandv4, ng_deepnegative_v1_75t, bad quality",
|
20 |
+
"refer": "",
|
21 |
+
},
|
22 |
+
"V4": {
|
23 |
+
"base_model": "",
|
24 |
+
"prompt": "badhandv4,ng_deepnegative_v1_75t,EasyNegativeV2,bad_prompt_version2-neg,bad quality",
|
25 |
+
"refer": "",
|
26 |
+
},
|
27 |
+
"V5": {
|
28 |
+
"base_model": "",
|
29 |
+
"prompt": "(((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)",
|
30 |
+
"refer": "Weiban",
|
31 |
+
},
|
32 |
+
}
|
configs/model/referencenet.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
MotionDIr = os.path.join(
|
5 |
+
os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
MODEL_CFG = {
|
10 |
+
"musev_referencenet": {
|
11 |
+
"net": os.path.join(MotionDIr, "musev_referencenet"),
|
12 |
+
"desp": "",
|
13 |
+
},
|
14 |
+
}
|
configs/tasks/example.yaml
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# - name: task_name
|
2 |
+
# condition_images: vision condition images path
|
3 |
+
# video_path: str, default null, used for video2video
|
4 |
+
# prompt: text to guide image generation
|
5 |
+
# ipadapter_image: image_path for IP-Apdater
|
6 |
+
# refer_image: image_path for referencenet, generally speaking, same as ipadapter_image
|
7 |
+
# height: int # The shorter the image size, the larger the motion amplitude, and the lower video quality.
|
8 |
+
# width: int # The longer the W&H, the smaller the motion amplitude, and the higher video quality.
|
9 |
+
# img_length_ratio: float, generation video size is (height, width) * img_length_ratio
|
10 |
+
|
11 |
+
# text/image2video
|
12 |
+
- condition_images: ./data/images/yongen.jpeg
|
13 |
+
eye_blinks_factor: 1.8
|
14 |
+
height: 1308
|
15 |
+
img_length_ratio: 0.957
|
16 |
+
ipadapter_image: ${.condition_images}
|
17 |
+
name: yongen
|
18 |
+
prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
|
19 |
+
refer_image: ${.condition_images}
|
20 |
+
video_path: null
|
21 |
+
width: 736
|
22 |
+
- condition_images: ./data/images/jinkesi2.jpeg
|
23 |
+
eye_blinks_factor: 1.8
|
24 |
+
height: 714
|
25 |
+
img_length_ratio: 1.25
|
26 |
+
ipadapter_image: ${.condition_images}
|
27 |
+
name: jinkesi2
|
28 |
+
prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
|
29 |
+
soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
|
30 |
+
refer_image: ${.condition_images}
|
31 |
+
video_path: null
|
32 |
+
width: 563
|
33 |
+
- condition_images: ./data/images/seaside4.jpeg
|
34 |
+
eye_blinks_factor: 1.8
|
35 |
+
height: 317
|
36 |
+
img_length_ratio: 2.221
|
37 |
+
ipadapter_image: ${.condition_images}
|
38 |
+
name: seaside4
|
39 |
+
prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
40 |
+
refer_image: ${.condition_images}
|
41 |
+
video_path: null
|
42 |
+
width: 564
|
43 |
+
- condition_images: ./data/images/seaside_girl.jpeg
|
44 |
+
eye_blinks_factor: 1.8
|
45 |
+
height: 736
|
46 |
+
img_length_ratio: 0.957
|
47 |
+
ipadapter_image: ${.condition_images}
|
48 |
+
name: seaside_girl
|
49 |
+
prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
50 |
+
refer_image: ${.condition_images}
|
51 |
+
video_path: null
|
52 |
+
width: 736
|
53 |
+
- condition_images: ./data/images/boy_play_guitar.jpeg
|
54 |
+
eye_blinks_factor: 1.8
|
55 |
+
height: 846
|
56 |
+
img_length_ratio: 1.248
|
57 |
+
ipadapter_image: ${.condition_images}
|
58 |
+
name: boy_play_guitar
|
59 |
+
prompt: (masterpiece, best quality, highres:1), playing guitar
|
60 |
+
refer_image: ${.condition_images}
|
61 |
+
video_path: null
|
62 |
+
width: 564
|
63 |
+
- condition_images: ./data/images/girl_play_guitar2.jpeg
|
64 |
+
eye_blinks_factor: 1.8
|
65 |
+
height: 1002
|
66 |
+
img_length_ratio: 1.248
|
67 |
+
ipadapter_image: ${.condition_images}
|
68 |
+
name: girl_play_guitar2
|
69 |
+
prompt: (masterpiece, best quality, highres:1), playing guitar
|
70 |
+
refer_image: ${.condition_images}
|
71 |
+
video_path: null
|
72 |
+
width: 564
|
73 |
+
- condition_images: ./data/images/boy_play_guitar2.jpeg
|
74 |
+
eye_blinks_factor: 1.8
|
75 |
+
height: 630
|
76 |
+
img_length_ratio: 1.676
|
77 |
+
ipadapter_image: ${.condition_images}
|
78 |
+
name: boy_play_guitar2
|
79 |
+
prompt: (masterpiece, best quality, highres:1), playing guitar
|
80 |
+
refer_image: ${.condition_images}
|
81 |
+
video_path: null
|
82 |
+
width: 420
|
83 |
+
- condition_images: ./data/images/girl_play_guitar4.jpeg
|
84 |
+
eye_blinks_factor: 1.8
|
85 |
+
height: 846
|
86 |
+
img_length_ratio: 1.248
|
87 |
+
ipadapter_image: ${.condition_images}
|
88 |
+
name: girl_play_guitar4
|
89 |
+
prompt: (masterpiece, best quality, highres:1), playing guitar
|
90 |
+
refer_image: ${.condition_images}
|
91 |
+
video_path: null
|
92 |
+
width: 564
|
93 |
+
- condition_images: ./data/images/dufu.jpeg
|
94 |
+
eye_blinks_factor: 1.8
|
95 |
+
height: 500
|
96 |
+
img_length_ratio: 1.495
|
97 |
+
ipadapter_image: ${.condition_images}
|
98 |
+
name: dufu
|
99 |
+
prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
|
100 |
+
refer_image: ${.condition_images}
|
101 |
+
video_path: null
|
102 |
+
width: 471
|
103 |
+
- condition_images: ./data/images/Mona_Lisa..jpg
|
104 |
+
eye_blinks_factor: 1.8
|
105 |
+
height: 894
|
106 |
+
img_length_ratio: 1.173
|
107 |
+
ipadapter_image: ${.condition_images}
|
108 |
+
name: Mona_Lisa.
|
109 |
+
prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
|
110 |
+
soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
|
111 |
+
refer_image: ${.condition_images}
|
112 |
+
video_path: null
|
113 |
+
width: 600
|
114 |
+
- condition_images: ./data/images/Portrait-of-Dr.-Gachet.jpg
|
115 |
+
eye_blinks_factor: 1.8
|
116 |
+
height: 985
|
117 |
+
img_length_ratio: 0.88
|
118 |
+
ipadapter_image: ${.condition_images}
|
119 |
+
name: Portrait-of-Dr.-Gachet
|
120 |
+
prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)
|
121 |
+
refer_image: ${.condition_images}
|
122 |
+
video_path: null
|
123 |
+
width: 800
|
124 |
+
- condition_images: ./data/images/Self-Portrait-with-Cropped-Hair.jpg
|
125 |
+
eye_blinks_factor: 1.8
|
126 |
+
height: 565
|
127 |
+
img_length_ratio: 1.246
|
128 |
+
ipadapter_image: ${.condition_images}
|
129 |
+
name: Self-Portrait-with-Cropped-Hair
|
130 |
+
prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3), animate
|
131 |
+
refer_image: ${.condition_images}
|
132 |
+
video_path: null
|
133 |
+
width: 848
|
134 |
+
- condition_images: ./data/images/The-Laughing-Cavalier.jpg
|
135 |
+
eye_blinks_factor: 1.8
|
136 |
+
height: 1462
|
137 |
+
img_length_ratio: 0.587
|
138 |
+
ipadapter_image: ${.condition_images}
|
139 |
+
name: The-Laughing-Cavalier
|
140 |
+
prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)
|
141 |
+
refer_image: ${.condition_images}
|
142 |
+
video_path: null
|
143 |
+
width: 1200
|
144 |
+
|
145 |
+
# scene
|
146 |
+
- condition_images: ./data/images/waterfall4.jpeg
|
147 |
+
eye_blinks_factor: 1.8
|
148 |
+
height: 846
|
149 |
+
img_length_ratio: 1.248
|
150 |
+
ipadapter_image: ${.condition_images}
|
151 |
+
name: waterfall4
|
152 |
+
prompt: (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
|
153 |
+
endless waterfall
|
154 |
+
refer_image: ${.condition_images}
|
155 |
+
video_path: null
|
156 |
+
width: 564
|
157 |
+
- condition_images: ./data/images/river.jpeg
|
158 |
+
eye_blinks_factor: 1.8
|
159 |
+
height: 736
|
160 |
+
img_length_ratio: 0.957
|
161 |
+
ipadapter_image: ${.condition_images}
|
162 |
+
name: river
|
163 |
+
prompt: (masterpiece, best quality, highres:1), peaceful beautiful river
|
164 |
+
refer_image: ${.condition_images}
|
165 |
+
video_path: null
|
166 |
+
width: 736
|
167 |
+
- condition_images: ./data/images/seaside2.jpeg
|
168 |
+
eye_blinks_factor: 1.8
|
169 |
+
height: 1313
|
170 |
+
img_length_ratio: 0.957
|
171 |
+
ipadapter_image: ${.condition_images}
|
172 |
+
name: seaside2
|
173 |
+
prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
|
174 |
+
refer_image: ${.condition_images}
|
175 |
+
video_path: null
|
176 |
+
width: 736
|
177 |
+
|
178 |
+
# video2video
|
179 |
+
- name: "dance1"
|
180 |
+
prompt: "(masterpiece, best quality, highres:1) , a girl is dancing, wearing a dress made of stars, animation"
|
181 |
+
video_path: ./data/source_video/video1_girl_poseseq.mp4
|
182 |
+
condition_images: ./data/images/spark_girl.png
|
183 |
+
refer_image: ${.condition_images}
|
184 |
+
ipadapter_image: ${.condition_images}
|
185 |
+
height: 960
|
186 |
+
width: 512
|
187 |
+
img_length_ratio: 1.0
|
188 |
+
video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
|
189 |
+
|
190 |
+
- name: "dance2"
|
191 |
+
prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper"
|
192 |
+
video_path: ./data/source_video/video1_girl_poseseq.mp4
|
193 |
+
condition_images: ./data/images/cyber_girl.png
|
194 |
+
refer_image: ${.condition_images}
|
195 |
+
ipadapter_image: ${.condition_images}
|
196 |
+
height: 960
|
197 |
+
width: 512
|
198 |
+
img_length_ratio: 1.0
|
199 |
+
video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
|
200 |
+
|
201 |
+
- name: "duffy"
|
202 |
+
prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper"
|
203 |
+
video_path: ./data/source_video/pose-for-Duffy-4.mp4
|
204 |
+
condition_images: ./data/images/duffy.png
|
205 |
+
refer_image: ${.condition_images}
|
206 |
+
ipadapter_image: ${.condition_images}
|
207 |
+
height: 1280
|
208 |
+
width: 704
|
209 |
+
img_length_ratio: 1.0
|
210 |
+
video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
|
controlnet_aux
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 54c6c49baf68bff290679f5bb896715f25932133
|
data/models/musev_structure.png
ADDED
![]() |
data/models/parallel_denoise.png
ADDED
![]() |
diffusers
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit abf2f8bf698a895cecc30a73c6ff4abb92fdce1c
|
environment.yml
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: musev
|
2 |
+
channels:
|
3 |
+
- https://repo.anaconda.com/pkgs/main
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- bzip2=1.0.8=h7b6447c_0
|
9 |
+
- ca-certificates=2023.12.12=h06a4308_0
|
10 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
11 |
+
- libffi=3.3=he6710b0_2
|
12 |
+
- libgcc-ng=11.2.0=h1234567_1
|
13 |
+
- libgomp=11.2.0=h1234567_1
|
14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
15 |
+
- libuuid=1.41.5=h5eee18b_0
|
16 |
+
- ncurses=6.4=h6a678d5_0
|
17 |
+
- openssl=1.1.1w=h7f8727e_0
|
18 |
+
- python=3.10.6=haa1d7c7_1
|
19 |
+
- readline=8.2=h5eee18b_0
|
20 |
+
- sqlite=3.41.2=h5eee18b_0
|
21 |
+
- tk=8.6.12=h1ccaba5_0
|
22 |
+
- xz=5.4.5=h5eee18b_0
|
23 |
+
- zlib=1.2.13=h5eee18b_0
|
24 |
+
- pip:
|
25 |
+
- absl-py==2.1.0
|
26 |
+
- accelerate==0.22.0
|
27 |
+
- addict==2.4.0
|
28 |
+
- aiofiles==23.2.1
|
29 |
+
- aiohttp==3.9.1
|
30 |
+
- aiosignal==1.3.1
|
31 |
+
- albumentations==1.3.1
|
32 |
+
- aliyun-python-sdk-core==2.14.0
|
33 |
+
- aliyun-python-sdk-kms==2.16.2
|
34 |
+
- altair==5.2.0
|
35 |
+
- antlr4-python3-runtime==4.9.3
|
36 |
+
- anyio==4.2.0
|
37 |
+
- appdirs==1.4.4
|
38 |
+
- argparse==1.4.0
|
39 |
+
- asttokens==2.4.1
|
40 |
+
- astunparse==1.6.3
|
41 |
+
- async-timeout==4.0.3
|
42 |
+
- attrs==23.2.0
|
43 |
+
- audioread==3.0.1
|
44 |
+
- basicsr==1.4.2
|
45 |
+
- beautifulsoup4==4.12.2
|
46 |
+
- bitsandbytes==0.41.1
|
47 |
+
- black==23.12.1
|
48 |
+
- blinker==1.7.0
|
49 |
+
- braceexpand==0.1.7
|
50 |
+
- cachetools==5.3.2
|
51 |
+
- certifi==2023.11.17
|
52 |
+
- cffi==1.16.0
|
53 |
+
- charset-normalizer==3.3.2
|
54 |
+
- chumpy==0.70
|
55 |
+
- click==8.1.7
|
56 |
+
- cmake==3.28.1
|
57 |
+
- colorama==0.4.6
|
58 |
+
- coloredlogs==15.0.1
|
59 |
+
- comm==0.2.1
|
60 |
+
- contourpy==1.2.0
|
61 |
+
- cos-python-sdk-v5==1.9.22
|
62 |
+
- coscmd==1.8.6.30
|
63 |
+
- crcmod==1.7
|
64 |
+
- cryptography==41.0.7
|
65 |
+
- cycler==0.12.1
|
66 |
+
- cython==3.0.2
|
67 |
+
- datetime==5.4
|
68 |
+
- debugpy==1.8.0
|
69 |
+
- decorator==4.4.2
|
70 |
+
- decord==0.6.0
|
71 |
+
- dill==0.3.7
|
72 |
+
- docker-pycreds==0.4.0
|
73 |
+
- dulwich==0.21.7
|
74 |
+
- easydict==1.11
|
75 |
+
- einops==0.7.0
|
76 |
+
- exceptiongroup==1.2.0
|
77 |
+
- executing==2.0.1
|
78 |
+
- fastapi==0.109.0
|
79 |
+
- ffmpeg==1.4
|
80 |
+
- ffmpeg-python==0.2.0
|
81 |
+
- ffmpy==0.3.1
|
82 |
+
- filelock==3.13.1
|
83 |
+
- flatbuffers==23.5.26
|
84 |
+
- fonttools==4.47.2
|
85 |
+
- frozenlist==1.4.1
|
86 |
+
- fsspec==2023.12.2
|
87 |
+
- ftfy==6.1.1
|
88 |
+
- future==0.18.3
|
89 |
+
- fuzzywuzzy==0.18.0
|
90 |
+
- fvcore==0.1.5.post20221221
|
91 |
+
- gast==0.4.0
|
92 |
+
- gdown==4.5.3
|
93 |
+
- gitdb==4.0.11
|
94 |
+
- gitpython==3.1.41
|
95 |
+
- google-auth==2.26.2
|
96 |
+
- google-auth-oauthlib==0.4.6
|
97 |
+
- google-pasta==0.2.0
|
98 |
+
- gradio==3.43.2
|
99 |
+
- gradio-client==0.5.0
|
100 |
+
- grpcio==1.60.0
|
101 |
+
- h11==0.14.0
|
102 |
+
- h5py==3.10.0
|
103 |
+
- httpcore==1.0.2
|
104 |
+
- httpx==0.26.0
|
105 |
+
- huggingface-hub==0.20.2
|
106 |
+
- humanfriendly==10.0
|
107 |
+
- idna==3.6
|
108 |
+
- imageio==2.31.1
|
109 |
+
- imageio-ffmpeg==0.4.8
|
110 |
+
- importlib-metadata==7.0.1
|
111 |
+
- importlib-resources==6.1.1
|
112 |
+
- iniconfig==2.0.0
|
113 |
+
- insightface==0.7.3
|
114 |
+
- invisible-watermark==0.1.5
|
115 |
+
- iopath==0.1.10
|
116 |
+
- ip-adapter==0.1.0
|
117 |
+
- iprogress==0.4
|
118 |
+
- ipykernel==6.29.0
|
119 |
+
- ipython==8.20.0
|
120 |
+
- ipywidgets==8.0.3
|
121 |
+
- jax==0.4.23
|
122 |
+
- jedi==0.19.1
|
123 |
+
- jinja2==3.1.3
|
124 |
+
- jmespath==0.10.0
|
125 |
+
- joblib==1.3.2
|
126 |
+
- json-tricks==3.17.3
|
127 |
+
- jsonschema==4.21.0
|
128 |
+
- jsonschema-specifications==2023.12.1
|
129 |
+
- jupyter-client==8.6.0
|
130 |
+
- jupyter-core==5.7.1
|
131 |
+
- jupyterlab-widgets==3.0.9
|
132 |
+
- keras==2.12.0
|
133 |
+
- kiwisolver==1.4.5
|
134 |
+
- kornia==0.7.0
|
135 |
+
- lazy-loader==0.3
|
136 |
+
- libclang==16.0.6
|
137 |
+
- librosa==0.10.1
|
138 |
+
- lightning-utilities==0.10.0
|
139 |
+
- lit==17.0.6
|
140 |
+
- llvmlite==0.41.1
|
141 |
+
- lmdb==1.4.1
|
142 |
+
- loguru==0.6.0
|
143 |
+
- markdown==3.5.2
|
144 |
+
- markdown-it-py==3.0.0
|
145 |
+
- markupsafe==2.0.1
|
146 |
+
- matplotlib==3.6.2
|
147 |
+
- matplotlib-inline==0.1.6
|
148 |
+
- mdurl==0.1.2
|
149 |
+
- mediapipe==0.10.3
|
150 |
+
- ml-dtypes==0.3.2
|
151 |
+
- model-index==0.1.11
|
152 |
+
- modelcards==0.1.6
|
153 |
+
- moviepy==1.0.3
|
154 |
+
- mpmath==1.3.0
|
155 |
+
- msgpack==1.0.7
|
156 |
+
- multidict==6.0.4
|
157 |
+
- munkres==1.1.4
|
158 |
+
- mypy-extensions==1.0.0
|
159 |
+
- nest-asyncio==1.5.9
|
160 |
+
- networkx==3.2.1
|
161 |
+
- ninja==1.11.1
|
162 |
+
- numba==0.58.1
|
163 |
+
- numpy==1.23.5
|
164 |
+
- oauthlib==3.2.2
|
165 |
+
- omegaconf==2.3.0
|
166 |
+
- onnx==1.14.1
|
167 |
+
- onnxruntime==1.15.1
|
168 |
+
- onnxsim==0.4.33
|
169 |
+
- open-clip-torch==2.20.0
|
170 |
+
- opencv-contrib-python==4.8.0.76
|
171 |
+
- opencv-python==4.9.0.80
|
172 |
+
- opencv-python-headless==4.9.0.80
|
173 |
+
- opendatalab==0.0.10
|
174 |
+
- openmim==0.3.9
|
175 |
+
- openxlab==0.0.34
|
176 |
+
- opt-einsum==3.3.0
|
177 |
+
- ordered-set==4.1.0
|
178 |
+
- orjson==3.9.10
|
179 |
+
- oss2==2.17.0
|
180 |
+
- packaging==23.2
|
181 |
+
- pandas==2.1.4
|
182 |
+
- parso==0.8.3
|
183 |
+
- pathspec==0.12.1
|
184 |
+
- pathtools==0.1.2
|
185 |
+
- pexpect==4.9.0
|
186 |
+
- pillow==10.2.0
|
187 |
+
- pip==23.3.1
|
188 |
+
- platformdirs==4.1.0
|
189 |
+
- pluggy==1.3.0
|
190 |
+
- pooch==1.8.0
|
191 |
+
- portalocker==2.8.2
|
192 |
+
- prettytable==3.9.0
|
193 |
+
- proglog==0.1.10
|
194 |
+
- prompt-toolkit==3.0.43
|
195 |
+
- protobuf==3.20.3
|
196 |
+
- psutil==5.9.7
|
197 |
+
- ptyprocess==0.7.0
|
198 |
+
- pure-eval==0.2.2
|
199 |
+
- pyarrow==14.0.2
|
200 |
+
- pyasn1==0.5.1
|
201 |
+
- pyasn1-modules==0.3.0
|
202 |
+
- pycocotools==2.0.7
|
203 |
+
- pycparser==2.21
|
204 |
+
- pycryptodome==3.20.0
|
205 |
+
- pydantic==1.10.2
|
206 |
+
- pydeck==0.8.1b0
|
207 |
+
- pydub==0.25.1
|
208 |
+
- pygments==2.17.2
|
209 |
+
- pynvml==11.5.0
|
210 |
+
- pyparsing==3.1.1
|
211 |
+
- pysocks==1.7.1
|
212 |
+
- pytest==7.4.4
|
213 |
+
- python-dateutil==2.8.2
|
214 |
+
- python-dotenv==1.0.0
|
215 |
+
- python-multipart==0.0.6
|
216 |
+
- pytorch-lightning==2.0.8
|
217 |
+
- pytube==15.0.0
|
218 |
+
- pytz==2023.3.post1
|
219 |
+
- pywavelets==1.5.0
|
220 |
+
- pyyaml==6.0.1
|
221 |
+
- pyzmq==25.1.2
|
222 |
+
- qudida==0.0.4
|
223 |
+
- redis==4.5.1
|
224 |
+
- referencing==0.32.1
|
225 |
+
- regex==2023.12.25
|
226 |
+
- requests==2.28.2
|
227 |
+
- requests-oauthlib==1.3.1
|
228 |
+
- rich==13.4.2
|
229 |
+
- rpds-py==0.17.1
|
230 |
+
- rsa==4.9
|
231 |
+
- safetensors==0.3.3
|
232 |
+
- scikit-image==0.22.0
|
233 |
+
- scikit-learn==1.3.2
|
234 |
+
- scipy==1.11.4
|
235 |
+
- semantic-version==2.10.0
|
236 |
+
- sentencepiece==0.1.99
|
237 |
+
- sentry-sdk==1.39.2
|
238 |
+
- setproctitle==1.3.3
|
239 |
+
- setuptools==60.2.0
|
240 |
+
- shapely==2.0.2
|
241 |
+
- six==1.16.0
|
242 |
+
- smmap==5.0.1
|
243 |
+
- sniffio==1.3.0
|
244 |
+
- sounddevice==0.4.6
|
245 |
+
- soundfile==0.12.1
|
246 |
+
- soupsieve==2.5
|
247 |
+
- soxr==0.3.7
|
248 |
+
- stack-data==0.6.3
|
249 |
+
- starlette==0.35.1
|
250 |
+
- streamlit==1.30.0
|
251 |
+
- streamlit-drawable-canvas==0.9.3
|
252 |
+
- sympy==1.12
|
253 |
+
- tabulate==0.9.0
|
254 |
+
- tb-nightly==2.11.0a20220906
|
255 |
+
- tenacity==8.2.3
|
256 |
+
- tensorboard==2.12.0
|
257 |
+
- tensorboard-data-server==0.6.1
|
258 |
+
- tensorboard-plugin-wit==1.8.1
|
259 |
+
- tensorflow==2.12.0
|
260 |
+
- tensorflow-estimator==2.12.0
|
261 |
+
- tensorflow-io-gcs-filesystem==0.35.0
|
262 |
+
- termcolor==2.4.0
|
263 |
+
- terminaltables==3.1.10
|
264 |
+
- test-tube==0.7.5
|
265 |
+
- threadpoolctl==3.2.0
|
266 |
+
- tifffile==2023.12.9
|
267 |
+
- timm==0.9.12
|
268 |
+
- tokenizers==0.13.3
|
269 |
+
- toml==0.10.2
|
270 |
+
- tomli==2.0.1
|
271 |
+
- toolz==0.12.0
|
272 |
+
- torch==2.0.1+cu118
|
273 |
+
- torch-tb-profiler==0.4.1
|
274 |
+
- torchmetrics==1.1.1
|
275 |
+
- torchvision==0.15.2+cu118
|
276 |
+
- tornado==6.4
|
277 |
+
- tqdm==4.65.2
|
278 |
+
- traitlets==5.14.1
|
279 |
+
- transformers==4.33.1
|
280 |
+
- triton==2.0.0
|
281 |
+
- typing-extensions==4.9.0
|
282 |
+
- tzdata==2023.4
|
283 |
+
- tzlocal==5.2
|
284 |
+
- urllib3==1.26.18
|
285 |
+
- urwid==2.4.2
|
286 |
+
- uvicorn==0.26.0
|
287 |
+
- validators==0.22.0
|
288 |
+
- wandb==0.15.10
|
289 |
+
- watchdog==3.0.0
|
290 |
+
- wcwidth==0.2.13
|
291 |
+
- webdataset==0.2.86
|
292 |
+
- webp==0.3.0
|
293 |
+
- websockets==11.0.3
|
294 |
+
- werkzeug==3.0.1
|
295 |
+
- wget==3.2
|
296 |
+
- wheel==0.41.2
|
297 |
+
- widgetsnbextension==4.0.9
|
298 |
+
- wrapt==1.14.1
|
299 |
+
- xformers==0.0.21
|
300 |
+
- xmltodict==0.13.0
|
301 |
+
- xtcocotools==1.14.3
|
302 |
+
- yacs==0.1.8
|
303 |
+
- yapf==0.40.2
|
304 |
+
- yarl==1.9.4
|
305 |
+
- zipp==3.17.0
|
306 |
+
- zope-interface==6.1
|
307 |
+
- fire==0.6.0
|
308 |
+
- cuid
|
309 |
+
- git+https://github.com/tencent-ailab/IP-Adapter.git@main
|
310 |
+
- git+https://github.com/openai/CLIP.git@main
|
311 |
+
prefix: /data/miniconda3/envs/musev
|
312 |
+
|
musev/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import logging.config
|
4 |
+
|
5 |
+
# 读取日志配置文件内容
|
6 |
+
logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf"))
|
7 |
+
|
8 |
+
# 创建一个日志器logger
|
9 |
+
logger = logging.getLogger("musev")
|
musev/auto_prompt/__init__.py
ADDED
File without changes
|
musev/auto_prompt/attributes/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ...utils.register import Register
|
2 |
+
|
3 |
+
AttrRegister = Register(registry_name="attributes")
|
4 |
+
|
5 |
+
# must import like bellow to ensure that each class is registered with AttrRegister:
|
6 |
+
from .human import *
|
7 |
+
from .render import *
|
8 |
+
from .style import *
|
musev/auto_prompt/attributes/attr2template.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
中文
|
3 |
+
该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。
|
4 |
+
提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。
|
5 |
+
该模块主要有三种类,分别是:
|
6 |
+
1. `BaseAttribute2Text`: 单属性文本转换类
|
7 |
+
2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。
|
8 |
+
3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。
|
9 |
+
1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`;
|
10 |
+
2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来;
|
11 |
+
3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序;
|
12 |
+
|
13 |
+
English
|
14 |
+
This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency.
|
15 |
+
|
16 |
+
Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming.
|
17 |
+
|
18 |
+
This module mainly consists of three classes:
|
19 |
+
|
20 |
+
BaseAttribute2Text: A class for converting single attribute text.
|
21 |
+
MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate.
|
22 |
+
MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input.
|
23 |
+
If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network.
|
24 |
+
If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table.
|
25 |
+
If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate.
|
26 |
+
"""
|
27 |
+
|
28 |
+
from typing import List, Tuple, Union
|
29 |
+
|
30 |
+
from mmcm.utils.str_util import (
|
31 |
+
has_key_brace,
|
32 |
+
merge_near_same_char,
|
33 |
+
get_word_from_key_brace_string,
|
34 |
+
)
|
35 |
+
|
36 |
+
from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText
|
37 |
+
from . import AttrRegister
|
38 |
+
|
39 |
+
|
40 |
+
class MultiAttr2PromptTemplate(object):
|
41 |
+
"""
|
42 |
+
将多属性转化为模型输入文本的实际类
|
43 |
+
The actual class that converts multiple attributes into model input text is
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
template: str,
|
49 |
+
attr2text: MultiAttr2Text,
|
50 |
+
name: str,
|
51 |
+
) -> None:
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
template (str): 提词模板, prompt template.
|
55 |
+
如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key
|
56 |
+
如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list.
|
57 |
+
attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt.
|
58 |
+
name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name
|
59 |
+
"""
|
60 |
+
self.attr2text = attr2text
|
61 |
+
self.name = name
|
62 |
+
if template == "":
|
63 |
+
template = "{}"
|
64 |
+
self.template = template
|
65 |
+
self.template_has_key_brace = has_key_brace(template)
|
66 |
+
|
67 |
+
def __call__(self, attributes: dict) -> Union[str, List[str]]:
|
68 |
+
texts = self.attr2text(attributes)
|
69 |
+
if not isinstance(texts, list):
|
70 |
+
texts = [texts]
|
71 |
+
prompts = [merge_multi_attrtext(text, self.template) for text in texts]
|
72 |
+
prompts = [merge_near_same_char(prompt) for prompt in prompts]
|
73 |
+
if len(prompts) == 1:
|
74 |
+
prompts = prompts[0]
|
75 |
+
return prompts
|
76 |
+
|
77 |
+
|
78 |
+
class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate):
|
79 |
+
def __init__(self, template: str, name: str = "keywords") -> None:
|
80 |
+
"""关键词模板属性2文本转化类
|
81 |
+
1. 获取关键词模板字符串中的关键词属性;
|
82 |
+
2. 从import * 存储在locals()中变量中获取对应的类;
|
83 |
+
3. 将集成了多属性转换类的`MultiAttr2Text`
|
84 |
+
Args:
|
85 |
+
template (str): 含有{key}的模板字符串
|
86 |
+
name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords".
|
87 |
+
|
88 |
+
class for converting keyword template attributes to text
|
89 |
+
1. Get the keyword attributes in the keyword template string;
|
90 |
+
2. Get the corresponding class from the variables stored in locals() by import *;
|
91 |
+
3. The `MultiAttr2Text` integrated with multiple attribute conversion classes
|
92 |
+
Args:
|
93 |
+
template (str): template string containing {key}
|
94 |
+
name (str, optional): the name of the template string, no actual use. Defaults to "keywords".
|
95 |
+
"""
|
96 |
+
assert has_key_brace(
|
97 |
+
template
|
98 |
+
), "template should have key brace, but given {}".format(template)
|
99 |
+
keywords = get_word_from_key_brace_string(template)
|
100 |
+
funcs = []
|
101 |
+
for word in keywords:
|
102 |
+
if word in AttrRegister:
|
103 |
+
func = AttrRegister[word](name=word)
|
104 |
+
else:
|
105 |
+
func = AttriributeIsText(name=word)
|
106 |
+
funcs.append(func)
|
107 |
+
attr2text = MultiAttr2Text(funcs, name=name)
|
108 |
+
super().__init__(template, attr2text, name)
|
109 |
+
|
110 |
+
|
111 |
+
class OnlySpacePromptTemplate(MultiAttr2PromptTemplate):
|
112 |
+
def __init__(self, template: str, name: str = "space_prompt") -> None:
|
113 |
+
"""纯空模板,无论输入啥,都只返回空格字符串作为prompt。
|
114 |
+
Args:
|
115 |
+
template (str): 符合只输出空格字符串的模板,
|
116 |
+
name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt".
|
117 |
+
|
118 |
+
Pure empty template, no matter what the input is, it will only return a space string as the prompt.
|
119 |
+
Args:
|
120 |
+
template (str): template that only outputs a space string,
|
121 |
+
name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt".
|
122 |
+
"""
|
123 |
+
attr2text = None
|
124 |
+
super().__init__(template, attr2text, name)
|
125 |
+
|
126 |
+
def __call__(self, attributes: dict) -> Union[str, List[str]]:
|
127 |
+
return ""
|
musev/auto_prompt/attributes/attributes.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import List, Tuple, Dict
|
3 |
+
|
4 |
+
from mmcm.utils.str_util import has_key_brace
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAttribute2Text(object):
|
8 |
+
"""
|
9 |
+
属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。
|
10 |
+
Base class for converting attributes to text which converts attributes to prompt text.
|
11 |
+
"""
|
12 |
+
|
13 |
+
name = "base_attribute"
|
14 |
+
|
15 |
+
def __init__(self, name: str = None) -> None:
|
16 |
+
"""这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。
|
17 |
+
Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
name (str, optional): _description_. Defaults to None.
|
21 |
+
"""
|
22 |
+
if name is not None:
|
23 |
+
self.name = name
|
24 |
+
|
25 |
+
def __call__(self, attributes) -> str:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
|
29 |
+
class AttributeIsTextAndName(BaseAttribute2Text):
|
30 |
+
"""
|
31 |
+
属性文本转换功能类,将key和value拼接在一起作为文本.
|
32 |
+
class for converting attributes to text which concatenates the key and value together as text.
|
33 |
+
"""
|
34 |
+
|
35 |
+
name = "attribute_is_text_name"
|
36 |
+
|
37 |
+
def __call__(self, attributes) -> str:
|
38 |
+
if attributes == "" or attributes is None:
|
39 |
+
return ""
|
40 |
+
attributes = attributes.split(",")
|
41 |
+
text = ", ".join(
|
42 |
+
[
|
43 |
+
"{} {}".format(attr, self.name) if attr != "" else ""
|
44 |
+
for attr in attributes
|
45 |
+
]
|
46 |
+
)
|
47 |
+
return text
|
48 |
+
|
49 |
+
|
50 |
+
class AttriributeIsText(BaseAttribute2Text):
|
51 |
+
"""
|
52 |
+
属性文本转换功能类,将value作为文本.
|
53 |
+
class for converting attributes to text which only uses the value as text.
|
54 |
+
"""
|
55 |
+
|
56 |
+
name = "attribute_is_text"
|
57 |
+
|
58 |
+
def __call__(self, attributes: str) -> str:
|
59 |
+
if attributes == "" or attributes is None:
|
60 |
+
return ""
|
61 |
+
attributes = str(attributes)
|
62 |
+
attributes = attributes.split(",")
|
63 |
+
text = ", ".join(["{}".format(attr) for attr in attributes])
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
class MultiAttr2Text(object):
|
68 |
+
"""将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号
|
69 |
+
class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
object (_type_): _description_
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, funcs: list, name) -> None:
|
76 |
+
"""
|
77 |
+
Args:
|
78 |
+
funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class.
|
79 |
+
name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about.
|
80 |
+
"""
|
81 |
+
if not isinstance(funcs, list):
|
82 |
+
funcs = [funcs]
|
83 |
+
self.funcs = funcs
|
84 |
+
self.name = name
|
85 |
+
|
86 |
+
def __call__(
|
87 |
+
self, dct: dict, ignored_blank_str: bool = False
|
88 |
+
) -> List[Tuple[str, str]]:
|
89 |
+
"""
|
90 |
+
有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。
|
91 |
+
sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product.
|
92 |
+
Args:
|
93 |
+
dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本.
|
94 |
+
Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text.
|
95 |
+
ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`.
|
96 |
+
If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`.
|
97 |
+
Returns:
|
98 |
+
Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries.
|
99 |
+
"""
|
100 |
+
attrs_lst = [[]]
|
101 |
+
for func in self.funcs:
|
102 |
+
if func.name in dct:
|
103 |
+
attrs = func(dct[func.name])
|
104 |
+
if isinstance(attrs, str):
|
105 |
+
for i in range(len(attrs_lst)):
|
106 |
+
attrs_lst[i].append((func.name, attrs))
|
107 |
+
else:
|
108 |
+
# 一个属性可能会返回多个文本
|
109 |
+
n_attrs = len(attrs)
|
110 |
+
new_attrs_lst = []
|
111 |
+
for n in range(n_attrs):
|
112 |
+
attrs_lst_cp = deepcopy(attrs_lst)
|
113 |
+
for i in range(len(attrs_lst_cp)):
|
114 |
+
attrs_lst_cp[i].append((func.name, attrs[n]))
|
115 |
+
new_attrs_lst.extend(attrs_lst_cp)
|
116 |
+
attrs_lst = new_attrs_lst
|
117 |
+
|
118 |
+
texts = [
|
119 |
+
[
|
120 |
+
(attr, text)
|
121 |
+
for (attr, text) in attrs
|
122 |
+
if not (text == "" and ignored_blank_str)
|
123 |
+
]
|
124 |
+
for attrs in attrs_lst
|
125 |
+
]
|
126 |
+
return texts
|
127 |
+
|
128 |
+
|
129 |
+
def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str:
|
130 |
+
"""使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本
|
131 |
+
concatenate multiple attribute text tuples using a template containing "{}" to form a new text
|
132 |
+
Args:
|
133 |
+
template (str):
|
134 |
+
texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
str: 拼接后的新文本, merged new text
|
138 |
+
"""
|
139 |
+
merged_text = ", ".join([text[1] for text in texts if text[1] != ""])
|
140 |
+
merged_text = template.format(merged_text)
|
141 |
+
return merged_text
|
142 |
+
|
143 |
+
|
144 |
+
def format_dct_texts(template: str, texts: Dict[str, str]) -> str:
|
145 |
+
"""使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本
|
146 |
+
concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text
|
147 |
+
Args:
|
148 |
+
template (str):
|
149 |
+
texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
str: 拼接后的新文本, merged new text
|
153 |
+
"""
|
154 |
+
merged_text = template.format(**texts)
|
155 |
+
return merged_text
|
156 |
+
|
157 |
+
|
158 |
+
def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str:
|
159 |
+
"""对多属性文本元组进行拼接,形成新文本。
|
160 |
+
如果`template`含有{key},则根据key来取值;
|
161 |
+
如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。
|
162 |
+
|
163 |
+
concatenate multiple attribute text tuples to form a new text.
|
164 |
+
if `template` contains {key}, the value is taken according to the key;
|
165 |
+
if `template` contains only one {}, the values in texts are concatenated in order.
|
166 |
+
Args:
|
167 |
+
texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本.
|
168 |
+
Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion.
|
169 |
+
template (str, optional): template . Defaults to None.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: 拼接后的新文本, merged new text
|
173 |
+
"""
|
174 |
+
if not isinstance(texts, List):
|
175 |
+
texts = [texts]
|
176 |
+
if template is None or template == "":
|
177 |
+
template = "{}"
|
178 |
+
if has_key_brace(template):
|
179 |
+
texts = {k: v for k, v in texts}
|
180 |
+
merged_text = format_dct_texts(template, texts)
|
181 |
+
else:
|
182 |
+
merged_text = format_tuple_texts(template, texts)
|
183 |
+
return merged_text
|
184 |
+
|
185 |
+
|
186 |
+
class PresetMultiAttr2Text(MultiAttr2Text):
|
187 |
+
"""预置了多种关注属性转换的类,方便维护
|
188 |
+
class for multiple attribute conversion with multiple attention attributes preset for easy maintenance
|
189 |
+
|
190 |
+
"""
|
191 |
+
|
192 |
+
preset_attributes = []
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self, funcs: List = None, use_preset: bool = True, name: str = "preset"
|
196 |
+
) -> None:
|
197 |
+
"""虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。
|
198 |
+
注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。
|
199 |
+
|
200 |
+
Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance.
|
201 |
+
Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
funcs (List, optional): list of funcs . Defaults to None.
|
205 |
+
use_preset (bool, optional): _description_. Defaults to True.
|
206 |
+
name (str, optional): _description_. Defaults to "preset".
|
207 |
+
"""
|
208 |
+
if use_preset:
|
209 |
+
preset_funcs = self.preset()
|
210 |
+
else:
|
211 |
+
preset_funcs = []
|
212 |
+
if funcs is None:
|
213 |
+
funcs = []
|
214 |
+
if not isinstance(funcs, list):
|
215 |
+
funcs = [funcs]
|
216 |
+
funcs_names = [func.name for func in funcs]
|
217 |
+
preset_funcs = [
|
218 |
+
preset_func
|
219 |
+
for preset_func in preset_funcs
|
220 |
+
if preset_func.name not in funcs_names
|
221 |
+
]
|
222 |
+
funcs = funcs + preset_funcs
|
223 |
+
super().__init__(funcs, name)
|
224 |
+
|
225 |
+
def preset(self):
|
226 |
+
funcs = [cls() for cls in self.preset_attributes]
|
227 |
+
return funcs
|
musev/auto_prompt/attributes/human.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import json
|
5 |
+
|
6 |
+
from .attributes import (
|
7 |
+
MultiAttr2Text,
|
8 |
+
AttriributeIsText,
|
9 |
+
AttributeIsTextAndName,
|
10 |
+
PresetMultiAttr2Text,
|
11 |
+
)
|
12 |
+
from .style import Style
|
13 |
+
from .render import Render
|
14 |
+
from . import AttrRegister
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"Age",
|
19 |
+
"Sex",
|
20 |
+
"Singing",
|
21 |
+
"Country",
|
22 |
+
"Lighting",
|
23 |
+
"Headwear",
|
24 |
+
"Eyes",
|
25 |
+
"Irises",
|
26 |
+
"Hair",
|
27 |
+
"Skin",
|
28 |
+
"Face",
|
29 |
+
"Smile",
|
30 |
+
"Expression",
|
31 |
+
"Clothes",
|
32 |
+
"Nose",
|
33 |
+
"Mouth",
|
34 |
+
"Beard",
|
35 |
+
"Necklace",
|
36 |
+
"KeyWords",
|
37 |
+
"InsightFace",
|
38 |
+
"Caption",
|
39 |
+
"Env",
|
40 |
+
"Decoration",
|
41 |
+
"Festival",
|
42 |
+
"SpringHeadwear",
|
43 |
+
"SpringClothes",
|
44 |
+
"Animal",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
@AttrRegister.register
|
49 |
+
class Sex(AttriributeIsText):
|
50 |
+
name = "sex"
|
51 |
+
|
52 |
+
def __init__(self, name: str = None) -> None:
|
53 |
+
super().__init__(name)
|
54 |
+
|
55 |
+
|
56 |
+
@AttrRegister.register
|
57 |
+
class Headwear(AttriributeIsText):
|
58 |
+
name = "headwear"
|
59 |
+
|
60 |
+
def __init__(self, name: str = None) -> None:
|
61 |
+
super().__init__(name)
|
62 |
+
|
63 |
+
|
64 |
+
@AttrRegister.register
|
65 |
+
class Expression(AttriributeIsText):
|
66 |
+
name = "expression"
|
67 |
+
|
68 |
+
def __init__(self, name: str = None) -> None:
|
69 |
+
super().__init__(name)
|
70 |
+
|
71 |
+
|
72 |
+
@AttrRegister.register
|
73 |
+
class KeyWords(AttriributeIsText):
|
74 |
+
name = "keywords"
|
75 |
+
|
76 |
+
def __init__(self, name: str = None) -> None:
|
77 |
+
super().__init__(name)
|
78 |
+
|
79 |
+
|
80 |
+
@AttrRegister.register
|
81 |
+
class Singing(AttriributeIsText):
|
82 |
+
def __init__(self, name: str = "singing") -> None:
|
83 |
+
super().__init__(name)
|
84 |
+
|
85 |
+
|
86 |
+
@AttrRegister.register
|
87 |
+
class Country(AttriributeIsText):
|
88 |
+
name = "country"
|
89 |
+
|
90 |
+
def __init__(self, name: str = None) -> None:
|
91 |
+
super().__init__(name)
|
92 |
+
|
93 |
+
|
94 |
+
@AttrRegister.register
|
95 |
+
class Clothes(AttriributeIsText):
|
96 |
+
name = "clothes"
|
97 |
+
|
98 |
+
def __init__(self, name: str = None) -> None:
|
99 |
+
super().__init__(name)
|
100 |
+
|
101 |
+
|
102 |
+
@AttrRegister.register
|
103 |
+
class Age(AttributeIsTextAndName):
|
104 |
+
name = "age"
|
105 |
+
|
106 |
+
def __init__(self, name: str = None) -> None:
|
107 |
+
super().__init__(name)
|
108 |
+
|
109 |
+
def __call__(self, attributes: str) -> str:
|
110 |
+
if not isinstance(attributes, str):
|
111 |
+
attributes = str(attributes)
|
112 |
+
attributes = attributes.split(",")
|
113 |
+
text = ", ".join(
|
114 |
+
["{}-year-old".format(attr) if attr != "" else "" for attr in attributes]
|
115 |
+
)
|
116 |
+
return text
|
117 |
+
|
118 |
+
|
119 |
+
@AttrRegister.register
|
120 |
+
class Eyes(AttributeIsTextAndName):
|
121 |
+
name = "eyes"
|
122 |
+
|
123 |
+
def __init__(self, name: str = None) -> None:
|
124 |
+
super().__init__(name)
|
125 |
+
|
126 |
+
|
127 |
+
@AttrRegister.register
|
128 |
+
class Hair(AttributeIsTextAndName):
|
129 |
+
name = "hair"
|
130 |
+
|
131 |
+
def __init__(self, name: str = None) -> None:
|
132 |
+
super().__init__(name)
|
133 |
+
|
134 |
+
|
135 |
+
@AttrRegister.register
|
136 |
+
class Background(AttributeIsTextAndName):
|
137 |
+
name = "background"
|
138 |
+
|
139 |
+
def __init__(self, name: str = None) -> None:
|
140 |
+
super().__init__(name)
|
141 |
+
|
142 |
+
|
143 |
+
@AttrRegister.register
|
144 |
+
class Skin(AttributeIsTextAndName):
|
145 |
+
name = "skin"
|
146 |
+
|
147 |
+
def __init__(self, name: str = None) -> None:
|
148 |
+
super().__init__(name)
|
149 |
+
|
150 |
+
|
151 |
+
@AttrRegister.register
|
152 |
+
class Face(AttributeIsTextAndName):
|
153 |
+
name = "face"
|
154 |
+
|
155 |
+
def __init__(self, name: str = None) -> None:
|
156 |
+
super().__init__(name)
|
157 |
+
|
158 |
+
|
159 |
+
@AttrRegister.register
|
160 |
+
class Smile(AttributeIsTextAndName):
|
161 |
+
name = "smile"
|
162 |
+
|
163 |
+
def __init__(self, name: str = None) -> None:
|
164 |
+
super().__init__(name)
|
165 |
+
|
166 |
+
|
167 |
+
@AttrRegister.register
|
168 |
+
class Nose(AttributeIsTextAndName):
|
169 |
+
name = "nose"
|
170 |
+
|
171 |
+
def __init__(self, name: str = None) -> None:
|
172 |
+
super().__init__(name)
|
173 |
+
|
174 |
+
|
175 |
+
@AttrRegister.register
|
176 |
+
class Mouth(AttributeIsTextAndName):
|
177 |
+
name = "mouth"
|
178 |
+
|
179 |
+
def __init__(self, name: str = None) -> None:
|
180 |
+
super().__init__(name)
|
181 |
+
|
182 |
+
|
183 |
+
@AttrRegister.register
|
184 |
+
class Beard(AttriributeIsText):
|
185 |
+
name = "beard"
|
186 |
+
|
187 |
+
def __init__(self, name: str = None) -> None:
|
188 |
+
super().__init__(name)
|
189 |
+
|
190 |
+
|
191 |
+
@AttrRegister.register
|
192 |
+
class Necklace(AttributeIsTextAndName):
|
193 |
+
name = "necklace"
|
194 |
+
|
195 |
+
def __init__(self, name: str = None) -> None:
|
196 |
+
super().__init__(name)
|
197 |
+
|
198 |
+
|
199 |
+
@AttrRegister.register
|
200 |
+
class Irises(AttributeIsTextAndName):
|
201 |
+
name = "irises"
|
202 |
+
|
203 |
+
def __init__(self, name: str = None) -> None:
|
204 |
+
super().__init__(name)
|
205 |
+
|
206 |
+
|
207 |
+
@AttrRegister.register
|
208 |
+
class Lighting(AttributeIsTextAndName):
|
209 |
+
name = "lighting"
|
210 |
+
|
211 |
+
def __init__(self, name: str = None) -> None:
|
212 |
+
super().__init__(name)
|
213 |
+
|
214 |
+
|
215 |
+
PresetPortraitAttributes = [
|
216 |
+
Age,
|
217 |
+
Sex,
|
218 |
+
Singing,
|
219 |
+
Country,
|
220 |
+
Lighting,
|
221 |
+
Headwear,
|
222 |
+
Eyes,
|
223 |
+
Irises,
|
224 |
+
Hair,
|
225 |
+
Skin,
|
226 |
+
Face,
|
227 |
+
Smile,
|
228 |
+
Expression,
|
229 |
+
Clothes,
|
230 |
+
Nose,
|
231 |
+
Mouth,
|
232 |
+
Beard,
|
233 |
+
Necklace,
|
234 |
+
Style,
|
235 |
+
KeyWords,
|
236 |
+
Render,
|
237 |
+
]
|
238 |
+
|
239 |
+
|
240 |
+
class PortraitMultiAttr2Text(PresetMultiAttr2Text):
|
241 |
+
preset_attributes = PresetPortraitAttributes
|
242 |
+
|
243 |
+
def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None:
|
244 |
+
super().__init__(funcs, use_preset, name)
|
245 |
+
|
246 |
+
|
247 |
+
@AttrRegister.register
|
248 |
+
class InsightFace(AttriributeIsText):
|
249 |
+
name = "insight_face"
|
250 |
+
face_render_dict = {
|
251 |
+
"boy": "handsome,elegant",
|
252 |
+
"girl": "gorgeous,kawaii,colorful",
|
253 |
+
}
|
254 |
+
key_words = "delicate face,beautiful eyes"
|
255 |
+
|
256 |
+
def __call__(self, attributes: str) -> str:
|
257 |
+
"""将insight faces 检测的结果转化成prompt
|
258 |
+
convert the results of insight faces detection to prompt
|
259 |
+
Args:
|
260 |
+
face_list (_type_): _description_
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
_type_: _description_
|
264 |
+
"""
|
265 |
+
attributes = json.loads(attributes)
|
266 |
+
face_list = attributes["info"]
|
267 |
+
if len(face_list) == 0:
|
268 |
+
return ""
|
269 |
+
|
270 |
+
if attributes["image_type"] == "body":
|
271 |
+
for face in face_list:
|
272 |
+
if "black" in face and face["black"]:
|
273 |
+
return "african,dark skin"
|
274 |
+
return ""
|
275 |
+
|
276 |
+
gender_dict = {"girl": 0, "boy": 0}
|
277 |
+
face_render_list = []
|
278 |
+
black = False
|
279 |
+
|
280 |
+
for face in face_list:
|
281 |
+
if face["ratio"] < 0.02:
|
282 |
+
continue
|
283 |
+
|
284 |
+
if face["gender"] == 0:
|
285 |
+
gender_dict["girl"] += 1
|
286 |
+
face_render_list.append(self.face_render_dict["girl"])
|
287 |
+
else:
|
288 |
+
gender_dict["boy"] += 1
|
289 |
+
face_render_list.append(self.face_render_dict["boy"])
|
290 |
+
|
291 |
+
if "black" in face and face["black"]:
|
292 |
+
black = True
|
293 |
+
|
294 |
+
if len(face_render_list) == 0:
|
295 |
+
return ""
|
296 |
+
elif len(face_render_list) == 1:
|
297 |
+
solo = True
|
298 |
+
else:
|
299 |
+
solo = False
|
300 |
+
|
301 |
+
gender = ""
|
302 |
+
for g, num in gender_dict.items():
|
303 |
+
if num > 0:
|
304 |
+
if gender:
|
305 |
+
gender += ", "
|
306 |
+
gender += "{}{}".format(num, g)
|
307 |
+
if num > 1:
|
308 |
+
gender += "s"
|
309 |
+
|
310 |
+
face_render_list = ",".join(face_render_list)
|
311 |
+
face_render_list = face_render_list.split(",")
|
312 |
+
face_render = list(set(face_render_list))
|
313 |
+
face_render.sort(key=face_render_list.index)
|
314 |
+
face_render = ",".join(face_render)
|
315 |
+
if gender_dict["girl"] == 0:
|
316 |
+
face_render = "male focus," + face_render
|
317 |
+
|
318 |
+
insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words)
|
319 |
+
|
320 |
+
if solo:
|
321 |
+
insightface_prompt += ",solo"
|
322 |
+
if black:
|
323 |
+
insightface_prompt = "african,dark skin," + insightface_prompt
|
324 |
+
|
325 |
+
return insightface_prompt
|
326 |
+
|
327 |
+
|
328 |
+
@AttrRegister.register
|
329 |
+
class Caption(AttriributeIsText):
|
330 |
+
name = "caption"
|
331 |
+
|
332 |
+
|
333 |
+
@AttrRegister.register
|
334 |
+
class Env(AttriributeIsText):
|
335 |
+
name = "env"
|
336 |
+
envs_list = [
|
337 |
+
"east asian architecture",
|
338 |
+
"fireworks",
|
339 |
+
"snow, snowflakes",
|
340 |
+
"snowing, snowflakes",
|
341 |
+
]
|
342 |
+
|
343 |
+
def __call__(self, attributes: str = None) -> str:
|
344 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
345 |
+
return attributes
|
346 |
+
else:
|
347 |
+
return random.choice(self.envs_list)
|
348 |
+
|
349 |
+
|
350 |
+
@AttrRegister.register
|
351 |
+
class Decoration(AttriributeIsText):
|
352 |
+
name = "decoration"
|
353 |
+
|
354 |
+
def __init__(self, name: str = None) -> None:
|
355 |
+
self.decoration_list = [
|
356 |
+
"chinese knot",
|
357 |
+
"flowers",
|
358 |
+
"food",
|
359 |
+
"lanterns",
|
360 |
+
"red envelop",
|
361 |
+
]
|
362 |
+
super().__init__(name)
|
363 |
+
|
364 |
+
def __call__(self, attributes: str = None) -> str:
|
365 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
366 |
+
return attributes
|
367 |
+
else:
|
368 |
+
return random.choice(self.decoration_list)
|
369 |
+
|
370 |
+
|
371 |
+
@AttrRegister.register
|
372 |
+
class Festival(AttriributeIsText):
|
373 |
+
name = "festival"
|
374 |
+
festival_list = ["new year"]
|
375 |
+
|
376 |
+
def __init__(self, name: str = None) -> None:
|
377 |
+
super().__init__(name)
|
378 |
+
|
379 |
+
def __call__(self, attributes: str = None) -> str:
|
380 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
381 |
+
return attributes
|
382 |
+
else:
|
383 |
+
return random.choice(self.festival_list)
|
384 |
+
|
385 |
+
|
386 |
+
@AttrRegister.register
|
387 |
+
class SpringHeadwear(AttriributeIsText):
|
388 |
+
name = "spring_headwear"
|
389 |
+
headwear_list = ["rabbit ears", "rabbit ears, fur hat"]
|
390 |
+
|
391 |
+
def __call__(self, attributes: str = None) -> str:
|
392 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
393 |
+
return attributes
|
394 |
+
else:
|
395 |
+
return random.choice(self.headwear_list)
|
396 |
+
|
397 |
+
|
398 |
+
@AttrRegister.register
|
399 |
+
class SpringClothes(AttriributeIsText):
|
400 |
+
name = "spring_clothes"
|
401 |
+
clothes_list = [
|
402 |
+
"mittens,chinese clothes",
|
403 |
+
"mittens,fur trim",
|
404 |
+
"mittens,red scarf",
|
405 |
+
"mittens,winter clothes",
|
406 |
+
]
|
407 |
+
|
408 |
+
def __call__(self, attributes: str = None) -> str:
|
409 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
410 |
+
return attributes
|
411 |
+
else:
|
412 |
+
return random.choice(self.clothes_list)
|
413 |
+
|
414 |
+
|
415 |
+
@AttrRegister.register
|
416 |
+
class Animal(AttriributeIsText):
|
417 |
+
name = "animal"
|
418 |
+
animal_list = ["rabbit", "holding rabbits"]
|
419 |
+
|
420 |
+
def __call__(self, attributes: str = None) -> str:
|
421 |
+
if attributes != "" and attributes != " " and attributes is not None:
|
422 |
+
return attributes
|
423 |
+
else:
|
424 |
+
return random.choice(self.animal_list)
|
musev/auto_prompt/attributes/render.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcm.utils.util import flatten
|
2 |
+
|
3 |
+
from .attributes import BaseAttribute2Text
|
4 |
+
from . import AttrRegister
|
5 |
+
|
6 |
+
__all__ = ["Render"]
|
7 |
+
|
8 |
+
RenderMap = {
|
9 |
+
"Epic": "artstation, epic environment, highly detailed, 8k, HD",
|
10 |
+
"HD": "8k, highly detailed",
|
11 |
+
"EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k",
|
12 |
+
"Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation",
|
13 |
+
"Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k",
|
14 |
+
"Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k",
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
@AttrRegister.register
|
19 |
+
class Render(BaseAttribute2Text):
|
20 |
+
name = "render"
|
21 |
+
|
22 |
+
def __init__(self, name: str = None) -> None:
|
23 |
+
super().__init__(name)
|
24 |
+
|
25 |
+
def __call__(self, attributes: str) -> str:
|
26 |
+
if attributes == "" or attributes is None:
|
27 |
+
return ""
|
28 |
+
attributes = attributes.split(",")
|
29 |
+
render = [RenderMap[attr] for attr in attributes if attr in RenderMap]
|
30 |
+
render = flatten(render, ignored_iterable_types=[str])
|
31 |
+
if len(render) == 1:
|
32 |
+
render = render[0]
|
33 |
+
return render
|
musev/auto_prompt/attributes/style.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attributes import AttriributeIsText
|
2 |
+
from . import AttrRegister
|
3 |
+
|
4 |
+
__all__ = ["Style"]
|
5 |
+
|
6 |
+
|
7 |
+
@AttrRegister.register
|
8 |
+
class Style(AttriributeIsText):
|
9 |
+
name = "style"
|
10 |
+
|
11 |
+
def __init__(self, name: str = None) -> None:
|
12 |
+
super().__init__(name)
|
musev/auto_prompt/human.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""负责按照人相关的属性转化成提词
|
2 |
+
"""
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from .attributes.human import PortraitMultiAttr2Text
|
6 |
+
from .attributes.attributes import BaseAttribute2Text
|
7 |
+
from .attributes.attr2template import MultiAttr2PromptTemplate
|
8 |
+
|
9 |
+
|
10 |
+
class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate):
|
11 |
+
"""可以将任务字典转化为形象提词模板类
|
12 |
+
template class for converting task dictionaries into image prompt templates
|
13 |
+
Args:
|
14 |
+
MultiAttr2PromptTemplate (_type_): _description_
|
15 |
+
"""
|
16 |
+
|
17 |
+
templates = "a portrait of {}"
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self, templates: str = None, attr2text: List = None, name: str = "portrait"
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
|
24 |
+
Args:
|
25 |
+
templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None.
|
26 |
+
portrait prompt template, if None, the default class attribute is used.
|
27 |
+
attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None.
|
28 |
+
the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used.
|
29 |
+
name (str, optional): 该形象类的名字. Defaults to "portrait".
|
30 |
+
class name of this class instance
|
31 |
+
"""
|
32 |
+
if (
|
33 |
+
attr2text is None
|
34 |
+
or isinstance(attr2text, list)
|
35 |
+
or isinstance(attr2text, BaseAttribute2Text)
|
36 |
+
):
|
37 |
+
attr2text = PortraitMultiAttr2Text(funcs=attr2text)
|
38 |
+
if templates is None:
|
39 |
+
templates = self.templates
|
40 |
+
super().__init__(templates, attr2text, name=name)
|
musev/auto_prompt/load_template.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcm.utils.str_util import has_key_brace
|
2 |
+
|
3 |
+
from .human import PortraitAttr2PromptTemplate
|
4 |
+
from .attributes.attr2template import (
|
5 |
+
KeywordMultiAttr2PromptTemplate,
|
6 |
+
OnlySpacePromptTemplate,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def get_template_by_name(template: str, name: str = None):
|
11 |
+
"""根据 template_name 确定 prompt 生成器类
|
12 |
+
choose prompt generator class according to template_name
|
13 |
+
Args:
|
14 |
+
name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference
|
15 |
+
|
16 |
+
Raises:
|
17 |
+
ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function
|
21 |
+
|
22 |
+
"""
|
23 |
+
if template == "" or template is None:
|
24 |
+
template = OnlySpacePromptTemplate(template=template)
|
25 |
+
elif has_key_brace(template):
|
26 |
+
# if has_key_brace(template):
|
27 |
+
template = KeywordMultiAttr2PromptTemplate(template=template)
|
28 |
+
else:
|
29 |
+
if name == "portrait":
|
30 |
+
template = PortraitAttr2PromptTemplate(templates=template)
|
31 |
+
else:
|
32 |
+
raise ValueError(
|
33 |
+
"PresetAttr2PromptTemplate only support one of [portrait], but given {}".format(
|
34 |
+
name
|
35 |
+
)
|
36 |
+
)
|
37 |
+
return template
|
musev/auto_prompt/util.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
+
from .load_template import get_template_by_name
|
5 |
+
|
6 |
+
|
7 |
+
def generate_prompts(tasks: List[Dict]) -> List[Dict]:
|
8 |
+
new_tasks = []
|
9 |
+
for task in tasks:
|
10 |
+
task["origin_prompt"] = deepcopy(task["prompt"])
|
11 |
+
# 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值
|
12 |
+
if "{" not in task["prompt"] and len(task["prompt"]) != 0:
|
13 |
+
new_tasks.append(task)
|
14 |
+
else:
|
15 |
+
template = get_template_by_name(
|
16 |
+
template=task["prompt"], name=task.get("template_name", None)
|
17 |
+
)
|
18 |
+
prompts = template(task)
|
19 |
+
if not isinstance(prompts, list) and isinstance(prompts, str):
|
20 |
+
prompts = [prompts]
|
21 |
+
for prompt in prompts:
|
22 |
+
task_cp = deepcopy(task)
|
23 |
+
task_cp["prompt"] = prompt
|
24 |
+
new_tasks.append(task_cp)
|
25 |
+
return new_tasks
|
musev/data/__init__.py
ADDED
File without changes
|
musev/data/data_util.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Literal, Union, Tuple
|
2 |
+
import os
|
3 |
+
import string
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def generate_tasks_of_dir(
|
14 |
+
path: str,
|
15 |
+
output_dir: str,
|
16 |
+
exts: Tuple[str],
|
17 |
+
same_dir_name: bool = False,
|
18 |
+
**kwargs,
|
19 |
+
) -> List[Dict]:
|
20 |
+
"""covert video directory into tasks
|
21 |
+
|
22 |
+
Args:
|
23 |
+
path (str): _description_
|
24 |
+
output_dir (str): _description_
|
25 |
+
exts (Tuple[str]): _description_
|
26 |
+
same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False.
|
27 |
+
whether keep the same parent dir name as the source video
|
28 |
+
Returns:
|
29 |
+
List[Dict]: _description_
|
30 |
+
"""
|
31 |
+
tasks = []
|
32 |
+
for rootdir, dirs, files in os.walk(path):
|
33 |
+
for basename in files:
|
34 |
+
if basename.lower().endswith(exts):
|
35 |
+
video_path = os.path.join(rootdir, basename)
|
36 |
+
filename, ext = basename.split(".")
|
37 |
+
rootdir_name = os.path.basename(rootdir)
|
38 |
+
if same_dir_name:
|
39 |
+
save_path = os.path.join(
|
40 |
+
output_dir, rootdir_name, f"{filename}.h5py"
|
41 |
+
)
|
42 |
+
save_dir = os.path.join(output_dir, rootdir_name)
|
43 |
+
else:
|
44 |
+
save_path = os.path.join(output_dir, f"{filename}.h5py")
|
45 |
+
save_dir = output_dir
|
46 |
+
task = {
|
47 |
+
"video_path": video_path,
|
48 |
+
"output_path": save_path,
|
49 |
+
"output_dir": save_dir,
|
50 |
+
"filename": filename,
|
51 |
+
"ext": ext,
|
52 |
+
}
|
53 |
+
task.update(kwargs)
|
54 |
+
tasks.append(task)
|
55 |
+
return tasks
|
56 |
+
|
57 |
+
|
58 |
+
def sample_by_idx(
|
59 |
+
T: int,
|
60 |
+
n_sample: int,
|
61 |
+
sample_rate: int,
|
62 |
+
sample_start_idx: int = None,
|
63 |
+
change_sample_rate: bool = False,
|
64 |
+
seed: int = None,
|
65 |
+
whether_random: bool = True,
|
66 |
+
n_independent: int = 0,
|
67 |
+
) -> List[int]:
|
68 |
+
"""given a int to represent candidate list, sample n_sample with sample_rate from the candidate list
|
69 |
+
|
70 |
+
Args:
|
71 |
+
T (int): _description_
|
72 |
+
n_sample (int): 目标采样数目. sample number
|
73 |
+
sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number
|
74 |
+
sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0.
|
75 |
+
change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False.
|
76 |
+
whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False.
|
77 |
+
|
78 |
+
Raises:
|
79 |
+
ValueError: T / sample_rate should be larger than n_sample
|
80 |
+
Returns:
|
81 |
+
List[int]: 采样的索引位置. sampled index position
|
82 |
+
"""
|
83 |
+
if T < n_sample:
|
84 |
+
raise ValueError(f"T({T}) < n_sample({n_sample})")
|
85 |
+
else:
|
86 |
+
if T / sample_rate < n_sample:
|
87 |
+
if not change_sample_rate:
|
88 |
+
raise ValueError(
|
89 |
+
f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})"
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
while T / sample_rate < n_sample:
|
93 |
+
sample_rate -= 1
|
94 |
+
logger.error(
|
95 |
+
f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}"
|
96 |
+
)
|
97 |
+
if sample_rate == 0:
|
98 |
+
raise ValueError("T / sample_rate < n_sample")
|
99 |
+
|
100 |
+
if sample_start_idx is None:
|
101 |
+
if whether_random:
|
102 |
+
sample_start_idx_candidates = np.arange(T - n_sample * sample_rate)
|
103 |
+
if seed is not None:
|
104 |
+
np.random.seed(seed)
|
105 |
+
sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0]
|
106 |
+
|
107 |
+
else:
|
108 |
+
sample_start_idx = 0
|
109 |
+
sample_end_idx = sample_start_idx + sample_rate * n_sample
|
110 |
+
sample = list(range(sample_start_idx, sample_end_idx, sample_rate))
|
111 |
+
if n_independent == 0:
|
112 |
+
n_independent_sample = None
|
113 |
+
else:
|
114 |
+
left_candidate = np.array(
|
115 |
+
list(range(0, sample_start_idx)) + list(range(sample_end_idx, T))
|
116 |
+
)
|
117 |
+
if len(left_candidate) >= n_independent:
|
118 |
+
# 使用两端的剩余空间采样, use the left space to sample
|
119 |
+
n_independent_sample = np.random.choice(left_candidate, n_independent)
|
120 |
+
else:
|
121 |
+
# 当两端没有剩余采样空间时,使用任意不是sample中的帧
|
122 |
+
# if no enough space to sample, use any frame not in sample
|
123 |
+
left_candidate = np.array(list(set(range(T) - set(sample))))
|
124 |
+
n_independent_sample = np.random.choice(left_candidate, n_independent)
|
125 |
+
|
126 |
+
return sample, sample_rate, n_independent_sample
|
127 |
+
|
128 |
+
|
129 |
+
def sample_tensor_by_idx(
|
130 |
+
tensor: Union[torch.Tensor, np.ndarray],
|
131 |
+
n_sample: int,
|
132 |
+
sample_rate: int,
|
133 |
+
sample_start_idx: int = 0,
|
134 |
+
change_sample_rate: bool = False,
|
135 |
+
seed: int = None,
|
136 |
+
dim: int = 0,
|
137 |
+
return_type: Literal["numpy", "torch"] = "torch",
|
138 |
+
whether_random: bool = True,
|
139 |
+
n_independent: int = 0,
|
140 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]:
|
141 |
+
"""sample sub_tensor
|
142 |
+
|
143 |
+
Args:
|
144 |
+
tensor (Union[torch.Tensor, np.ndarray]): _description_
|
145 |
+
n_sample (int): _description_
|
146 |
+
sample_rate (int): _description_
|
147 |
+
sample_start_idx (int, optional): _description_. Defaults to 0.
|
148 |
+
change_sample_rate (bool, optional): _description_. Defaults to False.
|
149 |
+
seed (int, optional): _description_. Defaults to None.
|
150 |
+
dim (int, optional): _description_. Defaults to 0.
|
151 |
+
return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch".
|
152 |
+
whether_random (bool, optional): _description_. Defaults to True.
|
153 |
+
n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0.
|
154 |
+
n_independent sample number that is independent of n_sample
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor
|
158 |
+
"""
|
159 |
+
if isinstance(tensor, np.ndarray):
|
160 |
+
tensor = torch.from_numpy(tensor)
|
161 |
+
T = tensor.shape[dim]
|
162 |
+
sample_idx, sample_rate, independent_sample_idx = sample_by_idx(
|
163 |
+
T,
|
164 |
+
n_sample,
|
165 |
+
sample_rate,
|
166 |
+
sample_start_idx,
|
167 |
+
change_sample_rate,
|
168 |
+
seed,
|
169 |
+
whether_random=whether_random,
|
170 |
+
n_independent=n_independent,
|
171 |
+
)
|
172 |
+
sample_idx = torch.LongTensor(sample_idx)
|
173 |
+
sample = torch.index_select(tensor, dim, sample_idx)
|
174 |
+
if independent_sample_idx is not None:
|
175 |
+
independent_sample_idx = torch.LongTensor(independent_sample_idx)
|
176 |
+
independent_sample = torch.index_select(tensor, dim, independent_sample_idx)
|
177 |
+
else:
|
178 |
+
independent_sample = None
|
179 |
+
independent_sample_idx = None
|
180 |
+
if return_type == "numpy":
|
181 |
+
sample = sample.cpu().numpy()
|
182 |
+
return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx
|
183 |
+
|
184 |
+
|
185 |
+
def concat_two_tensor(
|
186 |
+
data1: torch.Tensor,
|
187 |
+
data2: torch.Tensor,
|
188 |
+
dim: int,
|
189 |
+
method: Literal[
|
190 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index"
|
191 |
+
] = "first_in_first_out",
|
192 |
+
data1_index: torch.long = None,
|
193 |
+
data2_index: torch.long = None,
|
194 |
+
return_index: bool = False,
|
195 |
+
):
|
196 |
+
"""concat two tensor along dim with given method
|
197 |
+
|
198 |
+
Args:
|
199 |
+
data1 (torch.Tensor): first in data
|
200 |
+
data2 (torch.Tensor): last in data
|
201 |
+
dim (int): _description_
|
202 |
+
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out".
|
203 |
+
|
204 |
+
Raises:
|
205 |
+
NotImplementedError: unsupported method
|
206 |
+
ValueError: unsupported method
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
_type_: _description_
|
210 |
+
"""
|
211 |
+
len_data1 = data1.shape[dim]
|
212 |
+
len_data2 = data2.shape[dim]
|
213 |
+
|
214 |
+
if method == "first_in_first_out":
|
215 |
+
res = torch.concat([data1, data2], dim=dim)
|
216 |
+
data1_index = range(len_data1)
|
217 |
+
data2_index = [len_data1 + x for x in range(len_data2)]
|
218 |
+
elif method == "first_in_last_out":
|
219 |
+
res = torch.concat([data2, data1], dim=dim)
|
220 |
+
data2_index = range(len_data2)
|
221 |
+
data1_index = [len_data2 + x for x in range(len_data1)]
|
222 |
+
elif method == "intertwine":
|
223 |
+
raise NotImplementedError("intertwine")
|
224 |
+
elif method == "index":
|
225 |
+
res = concat_two_tensor_with_index(
|
226 |
+
data1=data1,
|
227 |
+
data1_index=data1_index,
|
228 |
+
data2=data2,
|
229 |
+
data2_index=data2_index,
|
230 |
+
dim=dim,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
raise ValueError(
|
234 |
+
"only support first_in_first_out, first_in_last_out, intertwine, index"
|
235 |
+
)
|
236 |
+
if return_index:
|
237 |
+
return res, data1_index, data2_index
|
238 |
+
else:
|
239 |
+
return res
|
240 |
+
|
241 |
+
|
242 |
+
def concat_two_tensor_with_index(
|
243 |
+
data1: torch.Tensor,
|
244 |
+
data1_index: torch.LongTensor,
|
245 |
+
data2: torch.Tensor,
|
246 |
+
data2_index: torch.LongTensor,
|
247 |
+
dim: int,
|
248 |
+
) -> torch.Tensor:
|
249 |
+
"""_summary_
|
250 |
+
|
251 |
+
Args:
|
252 |
+
data1 (torch.Tensor): b1*c1*h1*w1*...
|
253 |
+
data1_index (torch.LongTensor): N, if dim=1, N=c1
|
254 |
+
data2 (torch.Tensor): b2*c2*h2*w2*...
|
255 |
+
data2_index (torch.LongTensor): M, if dim=1, M=c2
|
256 |
+
dim (int): int
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,...
|
260 |
+
"""
|
261 |
+
shape1 = list(data1.shape)
|
262 |
+
shape2 = list(data2.shape)
|
263 |
+
target_shape = list(shape1)
|
264 |
+
target_shape[dim] = shape1[dim] + shape2[dim]
|
265 |
+
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
|
266 |
+
target = batch_index_copy(target, dim=dim, index=data1_index, source=data1)
|
267 |
+
target = batch_index_copy(target, dim=dim, index=data2_index, source=data2)
|
268 |
+
return target
|
269 |
+
|
270 |
+
|
271 |
+
def repeat_index_to_target_size(
|
272 |
+
index: torch.LongTensor, target_size: int
|
273 |
+
) -> torch.LongTensor:
|
274 |
+
if len(index.shape) == 1:
|
275 |
+
index = repeat(index, "n -> b n", b=target_size)
|
276 |
+
if len(index.shape) == 2:
|
277 |
+
remainder = target_size % index.shape[0]
|
278 |
+
assert (
|
279 |
+
remainder == 0
|
280 |
+
), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}"
|
281 |
+
index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0]))
|
282 |
+
return index
|
283 |
+
|
284 |
+
|
285 |
+
def batch_concat_two_tensor_with_index(
|
286 |
+
data1: torch.Tensor,
|
287 |
+
data1_index: torch.LongTensor,
|
288 |
+
data2: torch.Tensor,
|
289 |
+
data2_index: torch.LongTensor,
|
290 |
+
dim: int,
|
291 |
+
) -> torch.Tensor:
|
292 |
+
return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim)
|
293 |
+
|
294 |
+
|
295 |
+
def interwine_two_tensor(
|
296 |
+
data1: torch.Tensor,
|
297 |
+
data2: torch.Tensor,
|
298 |
+
dim: int,
|
299 |
+
return_index: bool = False,
|
300 |
+
) -> torch.Tensor:
|
301 |
+
shape1 = list(data1.shape)
|
302 |
+
shape2 = list(data2.shape)
|
303 |
+
target_shape = list(shape1)
|
304 |
+
target_shape[dim] = shape1[dim] + shape2[dim]
|
305 |
+
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
|
306 |
+
data1_reshape = torch.swapaxes(data1, 0, dim)
|
307 |
+
data2_reshape = torch.swapaxes(data2, 0, dim)
|
308 |
+
target = torch.swapaxes(target, 0, dim)
|
309 |
+
total_index = set(range(target_shape[dim]))
|
310 |
+
data1_index = range(0, 2 * shape1[dim], 2)
|
311 |
+
data2_index = sorted(list(set(total_index) - set(data1_index)))
|
312 |
+
data1_index = torch.LongTensor(data1_index)
|
313 |
+
data2_index = torch.LongTensor(data2_index)
|
314 |
+
target[data1_index, ...] = data1_reshape
|
315 |
+
target[data2_index, ...] = data2_reshape
|
316 |
+
target = torch.swapaxes(target, 0, dim)
|
317 |
+
if return_index:
|
318 |
+
return target, data1_index, data2_index
|
319 |
+
else:
|
320 |
+
return target
|
321 |
+
|
322 |
+
|
323 |
+
def split_index(
|
324 |
+
indexs: torch.Tensor,
|
325 |
+
n_first: int = None,
|
326 |
+
n_last: int = None,
|
327 |
+
method: Literal[
|
328 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
|
329 |
+
] = "first_in_first_out",
|
330 |
+
):
|
331 |
+
"""_summary_
|
332 |
+
|
333 |
+
Args:
|
334 |
+
indexs (List): _description_
|
335 |
+
n_first (int): _description_
|
336 |
+
n_last (int): _description_
|
337 |
+
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out".
|
338 |
+
|
339 |
+
Raises:
|
340 |
+
NotImplementedError: _description_
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
first_index: _description_
|
344 |
+
last_index:
|
345 |
+
"""
|
346 |
+
# assert (
|
347 |
+
# n_first is None and n_last is None
|
348 |
+
# ), "must assign one value for n_first or n_last"
|
349 |
+
n_total = len(indexs)
|
350 |
+
if n_first is None:
|
351 |
+
n_first = n_total - n_last
|
352 |
+
if n_last is None:
|
353 |
+
n_last = n_total - n_first
|
354 |
+
assert len(indexs) == n_first + n_last
|
355 |
+
if method == "first_in_first_out":
|
356 |
+
first_index = indexs[:n_first]
|
357 |
+
last_index = indexs[n_first:]
|
358 |
+
elif method == "first_in_last_out":
|
359 |
+
first_index = indexs[n_last:]
|
360 |
+
last_index = indexs[:n_last]
|
361 |
+
elif method == "intertwine":
|
362 |
+
raise NotImplementedError
|
363 |
+
elif method == "random":
|
364 |
+
idx_ = torch.randperm(len(indexs))
|
365 |
+
first_index = indexs[idx_[:n_first]]
|
366 |
+
last_index = indexs[idx_[n_first:]]
|
367 |
+
return first_index, last_index
|
368 |
+
|
369 |
+
|
370 |
+
def split_tensor(
|
371 |
+
tensor: torch.Tensor,
|
372 |
+
dim: int,
|
373 |
+
n_first=None,
|
374 |
+
n_last=None,
|
375 |
+
method: Literal[
|
376 |
+
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
|
377 |
+
] = "first_in_first_out",
|
378 |
+
need_return_index: bool = False,
|
379 |
+
):
|
380 |
+
device = tensor.device
|
381 |
+
total = tensor.shape[dim]
|
382 |
+
if n_first is None:
|
383 |
+
n_first = total - n_last
|
384 |
+
if n_last is None:
|
385 |
+
n_last = total - n_first
|
386 |
+
indexs = torch.arange(
|
387 |
+
total,
|
388 |
+
dtype=torch.long,
|
389 |
+
device=device,
|
390 |
+
)
|
391 |
+
(
|
392 |
+
first_index,
|
393 |
+
last_index,
|
394 |
+
) = split_index(
|
395 |
+
indexs=indexs,
|
396 |
+
n_first=n_first,
|
397 |
+
method=method,
|
398 |
+
)
|
399 |
+
first_tensor = torch.index_select(tensor, dim=dim, index=first_index)
|
400 |
+
last_tensor = torch.index_select(tensor, dim=dim, index=last_index)
|
401 |
+
if need_return_index:
|
402 |
+
return (
|
403 |
+
first_tensor,
|
404 |
+
last_tensor,
|
405 |
+
first_index,
|
406 |
+
last_index,
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
return (first_tensor, last_tensor)
|
410 |
+
|
411 |
+
|
412 |
+
# TODO: 待确定batch_index_select的优化
|
413 |
+
def batch_index_select(
|
414 |
+
tensor: torch.Tensor, index: torch.LongTensor, dim: int
|
415 |
+
) -> torch.Tensor:
|
416 |
+
"""_summary_
|
417 |
+
|
418 |
+
Args:
|
419 |
+
tensor (torch.Tensor): D1*D2*D3*D4...
|
420 |
+
index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim]
|
421 |
+
dim (int): dim to select
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
torch.Tensor: D1*...*N*...
|
425 |
+
"""
|
426 |
+
# TODO: now only support N same for every d1
|
427 |
+
if len(index.shape) == 1:
|
428 |
+
return torch.index_select(tensor, dim=dim, index=index)
|
429 |
+
else:
|
430 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
431 |
+
out = []
|
432 |
+
for i in torch.arange(tensor.shape[0]):
|
433 |
+
sub_tensor = tensor[i]
|
434 |
+
sub_index = index[i]
|
435 |
+
d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index)
|
436 |
+
out.append(d)
|
437 |
+
return torch.stack(out).to(dtype=tensor.dtype)
|
438 |
+
|
439 |
+
|
440 |
+
def batch_index_copy(
|
441 |
+
tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor
|
442 |
+
) -> torch.Tensor:
|
443 |
+
"""_summary_
|
444 |
+
|
445 |
+
Args:
|
446 |
+
tensor (torch.Tensor): b*c*h
|
447 |
+
dim (int):
|
448 |
+
index (torch.LongTensor): b*d,
|
449 |
+
source (torch.Tensor):
|
450 |
+
b*d*h*..., if dim=1
|
451 |
+
b*c*d*..., if dim=2
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
torch.Tensor: b*c*d*...
|
455 |
+
"""
|
456 |
+
if len(index.shape) == 1:
|
457 |
+
tensor.index_copy_(dim=dim, index=index, source=source)
|
458 |
+
else:
|
459 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
460 |
+
|
461 |
+
batch_size = tensor.shape[0]
|
462 |
+
for b in torch.arange(batch_size):
|
463 |
+
sub_index = index[b]
|
464 |
+
sub_source = source[b]
|
465 |
+
sub_tensor = tensor[b]
|
466 |
+
sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source)
|
467 |
+
tensor[b] = sub_tensor
|
468 |
+
return tensor
|
469 |
+
|
470 |
+
|
471 |
+
def batch_index_fill(
|
472 |
+
tensor: torch.Tensor,
|
473 |
+
dim: int,
|
474 |
+
index: torch.LongTensor,
|
475 |
+
value: Literal[torch.Tensor, torch.float],
|
476 |
+
) -> torch.Tensor:
|
477 |
+
"""_summary_
|
478 |
+
|
479 |
+
Args:
|
480 |
+
tensor (torch.Tensor): b*c*h
|
481 |
+
dim (int):
|
482 |
+
index (torch.LongTensor): b*d,
|
483 |
+
value (torch.Tensor): b
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
torch.Tensor: b*c*d*...
|
487 |
+
"""
|
488 |
+
index = repeat_index_to_target_size(index, tensor.shape[0])
|
489 |
+
batch_size = tensor.shape[0]
|
490 |
+
for b in torch.arange(batch_size):
|
491 |
+
sub_index = index[b]
|
492 |
+
sub_value = value[b] if isinstance(value, torch.Tensor) else value
|
493 |
+
sub_tensor = tensor[b]
|
494 |
+
sub_tensor.index_fill_(dim - 1, sub_index, sub_value)
|
495 |
+
tensor[b] = sub_tensor
|
496 |
+
return tensor
|
497 |
+
|
498 |
+
|
499 |
+
def adaptive_instance_normalization(
|
500 |
+
src: torch.Tensor,
|
501 |
+
dst: torch.Tensor,
|
502 |
+
eps: float = 1e-6,
|
503 |
+
):
|
504 |
+
"""
|
505 |
+
Args:
|
506 |
+
src (torch.Tensor): b c t h w
|
507 |
+
dst (torch.Tensor): b c t h w
|
508 |
+
"""
|
509 |
+
ndim = src.ndim
|
510 |
+
if ndim == 5:
|
511 |
+
dim = (2, 3, 4)
|
512 |
+
elif ndim == 4:
|
513 |
+
dim = (2, 3)
|
514 |
+
elif ndim == 3:
|
515 |
+
dim = 2
|
516 |
+
else:
|
517 |
+
raise ValueError("only support ndim in [3,4,5], but given {ndim}")
|
518 |
+
var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0)
|
519 |
+
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
520 |
+
dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0)
|
521 |
+
mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0)
|
522 |
+
# mean_acc = sum(mean_acc) / float(len(mean_acc))
|
523 |
+
# var_acc = sum(var_acc) / float(len(var_acc))
|
524 |
+
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
525 |
+
src = (((src - mean) / std) * std_acc) + mean_acc
|
526 |
+
return src
|
527 |
+
|
528 |
+
|
529 |
+
def adaptive_instance_normalization_with_ref(
|
530 |
+
src: torch.LongTensor,
|
531 |
+
dst: torch.LongTensor,
|
532 |
+
style_fidelity: float = 0.5,
|
533 |
+
do_classifier_free_guidance: bool = True,
|
534 |
+
):
|
535 |
+
# logger.debug(
|
536 |
+
# f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n"
|
537 |
+
# f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}"
|
538 |
+
# )
|
539 |
+
batch_size = src.shape[0] // 2
|
540 |
+
uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool()
|
541 |
+
src_uc = adaptive_instance_normalization(src, dst)
|
542 |
+
src_c = src_uc.clone()
|
543 |
+
# TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True
|
544 |
+
if do_classifier_free_guidance and style_fidelity > 0:
|
545 |
+
src_c[uc_mask] = src[uc_mask]
|
546 |
+
src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc
|
547 |
+
return src
|
548 |
+
|
549 |
+
|
550 |
+
def batch_adain_conditioned_tensor(
|
551 |
+
tensor: torch.Tensor,
|
552 |
+
src_index: torch.LongTensor,
|
553 |
+
dst_index: torch.LongTensor,
|
554 |
+
keep_dim: bool = True,
|
555 |
+
num_frames: int = None,
|
556 |
+
dim: int = 2,
|
557 |
+
style_fidelity: float = 0.5,
|
558 |
+
do_classifier_free_guidance: bool = True,
|
559 |
+
need_style_fidelity: bool = False,
|
560 |
+
):
|
561 |
+
"""_summary_
|
562 |
+
|
563 |
+
Args:
|
564 |
+
tensor (torch.Tensor): b c t h w
|
565 |
+
src_index (torch.LongTensor): _description_
|
566 |
+
dst_index (torch.LongTensor): _description_
|
567 |
+
keep_dim (bool, optional): _description_. Defaults to True.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
_type_: _description_
|
571 |
+
"""
|
572 |
+
ndim = tensor.ndim
|
573 |
+
dtype = tensor.dtype
|
574 |
+
if ndim == 4 and num_frames is not None:
|
575 |
+
tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames)
|
576 |
+
src = batch_index_select(tensor, dim=dim, index=src_index).contiguous()
|
577 |
+
dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous()
|
578 |
+
if need_style_fidelity:
|
579 |
+
src = adaptive_instance_normalization_with_ref(
|
580 |
+
src=src,
|
581 |
+
dst=dst,
|
582 |
+
style_fidelity=style_fidelity,
|
583 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
584 |
+
need_style_fidelity=need_style_fidelity,
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
src = adaptive_instance_normalization(
|
588 |
+
src=src,
|
589 |
+
dst=dst,
|
590 |
+
)
|
591 |
+
if keep_dim:
|
592 |
+
src = batch_concat_two_tensor_with_index(
|
593 |
+
src.to(dtype=dtype),
|
594 |
+
src_index,
|
595 |
+
dst.to(dtype=dtype),
|
596 |
+
dst_index,
|
597 |
+
dim=dim,
|
598 |
+
)
|
599 |
+
|
600 |
+
if ndim == 4 and num_frames is not None:
|
601 |
+
src = rearrange(tensor, "b c t h w ->(b t) c h w")
|
602 |
+
return src
|
603 |
+
|
604 |
+
|
605 |
+
def align_repeat_tensor_single_dim(
|
606 |
+
src: torch.Tensor,
|
607 |
+
target_length: int,
|
608 |
+
dim: int = 0,
|
609 |
+
n_src_base_length: int = 1,
|
610 |
+
src_base_index: List[int] = None,
|
611 |
+
) -> torch.Tensor:
|
612 |
+
"""沿着 dim 纬度, 补齐 src 的长度到目标 target_length。
|
613 |
+
当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length
|
614 |
+
|
615 |
+
align length of src to target_length along dim
|
616 |
+
when src length is less than target_length, take the first n_src_base_length and repeat to target_length
|
617 |
+
|
618 |
+
Args:
|
619 |
+
src (torch.Tensor): 输入 tensor, input tensor
|
620 |
+
target_length (int): 目标长度, target_length
|
621 |
+
dim (int, optional): 处理纬度, target dim . Defaults to 0.
|
622 |
+
n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1.
|
623 |
+
|
624 |
+
Returns:
|
625 |
+
torch.Tensor: _description_
|
626 |
+
"""
|
627 |
+
src_dim_length = src.shape[dim]
|
628 |
+
if target_length > src_dim_length:
|
629 |
+
if target_length % src_dim_length == 0:
|
630 |
+
new = src.repeat_interleave(
|
631 |
+
repeats=target_length // src_dim_length, dim=dim
|
632 |
+
)
|
633 |
+
else:
|
634 |
+
if src_base_index is None and n_src_base_length is not None:
|
635 |
+
src_base_index = torch.arange(n_src_base_length)
|
636 |
+
|
637 |
+
new = src.index_select(
|
638 |
+
dim=dim,
|
639 |
+
index=torch.LongTensor(src_base_index).to(device=src.device),
|
640 |
+
)
|
641 |
+
new = new.repeat_interleave(
|
642 |
+
repeats=target_length // len(src_base_index),
|
643 |
+
dim=dim,
|
644 |
+
)
|
645 |
+
elif target_length < src_dim_length:
|
646 |
+
new = src.index_select(
|
647 |
+
dim=dim,
|
648 |
+
index=torch.LongTensor(torch.arange(target_length)).to(device=src.device),
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
new = src
|
652 |
+
return new
|
653 |
+
|
654 |
+
|
655 |
+
def fuse_part_tensor(
|
656 |
+
src: torch.Tensor,
|
657 |
+
dst: torch.Tensor,
|
658 |
+
overlap: int,
|
659 |
+
weight: float = 0.5,
|
660 |
+
skip_step: int = 0,
|
661 |
+
) -> torch.Tensor:
|
662 |
+
"""fuse overstep tensor with weight of src into dst
|
663 |
+
out = src_fused_part * weight + dst * (1-weight) for overlap
|
664 |
+
|
665 |
+
Args:
|
666 |
+
src (torch.Tensor): b c t h w
|
667 |
+
dst (torch.Tensor): b c t h w
|
668 |
+
overlap (int): 1
|
669 |
+
weight (float, optional): weight of src tensor part. Defaults to 0.5.
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
torch.Tensor: fused tensor
|
673 |
+
"""
|
674 |
+
if overlap == 0:
|
675 |
+
return dst
|
676 |
+
else:
|
677 |
+
dst[:, :, skip_step : skip_step + overlap] = (
|
678 |
+
weight * src[:, :, -overlap:]
|
679 |
+
+ (1 - weight) * dst[:, :, skip_step : skip_step + overlap]
|
680 |
+
)
|
681 |
+
return dst
|
musev/logging.conf
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root,musev
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=musevFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
# logger level 尽量设置低一点
|
15 |
+
[logger_musev]
|
16 |
+
level=DEBUG
|
17 |
+
handlers=consoleHandler
|
18 |
+
qualname=musev
|
19 |
+
propagate=0
|
20 |
+
|
21 |
+
# handler level 设置比 logger level高
|
22 |
+
[handler_consoleHandler]
|
23 |
+
class=StreamHandler
|
24 |
+
level=DEBUG
|
25 |
+
# level=INFO
|
26 |
+
|
27 |
+
formatter=musevFormatter
|
28 |
+
args=(sys.stdout,)
|
29 |
+
|
30 |
+
[formatter_musevFormatter]
|
31 |
+
format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s
|
32 |
+
datefmt=
|
musev/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils.register import Register
|
2 |
+
|
3 |
+
Model_Register = Register(registry_name="torch_model")
|
musev/models/attention.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/64bf5d33b7ef1b1deac256bed7bd99b55020c4e0/src/diffusers/models/attention.py
|
16 |
+
from __future__ import annotations
|
17 |
+
from copy import deepcopy
|
18 |
+
|
19 |
+
from typing import Any, Dict, List, Literal, Optional, Callable, Tuple
|
20 |
+
import logging
|
21 |
+
from einops import rearrange
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
28 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
29 |
+
from diffusers.models.attention_processor import Attention as DiffusersAttention
|
30 |
+
from diffusers.models.attention import (
|
31 |
+
BasicTransformerBlock as DiffusersBasicTransformerBlock,
|
32 |
+
AdaLayerNormZero,
|
33 |
+
AdaLayerNorm,
|
34 |
+
FeedForward,
|
35 |
+
)
|
36 |
+
from diffusers.models.attention_processor import AttnProcessor
|
37 |
+
|
38 |
+
from .attention_processor import IPAttention, BaseIPAttnProcessor
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
def not_use_xformers_anyway(
|
45 |
+
use_memory_efficient_attention_xformers: bool,
|
46 |
+
attention_op: Optional[Callable] = None,
|
47 |
+
):
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
@maybe_allow_in_graph
|
52 |
+
class BasicTransformerBlock(DiffusersBasicTransformerBlock):
|
53 |
+
print_idx = 0
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
dim: int,
|
58 |
+
num_attention_heads: int,
|
59 |
+
attention_head_dim: int,
|
60 |
+
dropout=0,
|
61 |
+
cross_attention_dim: int | None = None,
|
62 |
+
activation_fn: str = "geglu",
|
63 |
+
num_embeds_ada_norm: int | None = None,
|
64 |
+
attention_bias: bool = False,
|
65 |
+
only_cross_attention: bool = False,
|
66 |
+
double_self_attention: bool = False,
|
67 |
+
upcast_attention: bool = False,
|
68 |
+
norm_elementwise_affine: bool = True,
|
69 |
+
norm_type: str = "layer_norm",
|
70 |
+
final_dropout: bool = False,
|
71 |
+
attention_type: str = "default",
|
72 |
+
allow_xformers: bool = True,
|
73 |
+
cross_attn_temporal_cond: bool = False,
|
74 |
+
image_scale: float = 1.0,
|
75 |
+
processor: AttnProcessor | None = None,
|
76 |
+
ip_adapter_cross_attn: bool = False,
|
77 |
+
need_t2i_facein: bool = False,
|
78 |
+
need_t2i_ip_adapter_face: bool = False,
|
79 |
+
):
|
80 |
+
if not only_cross_attention and double_self_attention:
|
81 |
+
cross_attention_dim = None
|
82 |
+
super().__init__(
|
83 |
+
dim,
|
84 |
+
num_attention_heads,
|
85 |
+
attention_head_dim,
|
86 |
+
dropout,
|
87 |
+
cross_attention_dim,
|
88 |
+
activation_fn,
|
89 |
+
num_embeds_ada_norm,
|
90 |
+
attention_bias,
|
91 |
+
only_cross_attention,
|
92 |
+
double_self_attention,
|
93 |
+
upcast_attention,
|
94 |
+
norm_elementwise_affine,
|
95 |
+
norm_type,
|
96 |
+
final_dropout,
|
97 |
+
attention_type,
|
98 |
+
)
|
99 |
+
|
100 |
+
self.attn1 = IPAttention(
|
101 |
+
query_dim=dim,
|
102 |
+
heads=num_attention_heads,
|
103 |
+
dim_head=attention_head_dim,
|
104 |
+
dropout=dropout,
|
105 |
+
bias=attention_bias,
|
106 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
107 |
+
upcast_attention=upcast_attention,
|
108 |
+
cross_attn_temporal_cond=cross_attn_temporal_cond,
|
109 |
+
image_scale=image_scale,
|
110 |
+
ip_adapter_dim=cross_attention_dim
|
111 |
+
if only_cross_attention
|
112 |
+
else attention_head_dim,
|
113 |
+
facein_dim=cross_attention_dim
|
114 |
+
if only_cross_attention
|
115 |
+
else attention_head_dim,
|
116 |
+
processor=processor,
|
117 |
+
)
|
118 |
+
# 2. Cross-Attn
|
119 |
+
if cross_attention_dim is not None or double_self_attention:
|
120 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
121 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
122 |
+
# the second cross attention block.
|
123 |
+
self.norm2 = (
|
124 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
125 |
+
if self.use_ada_layer_norm
|
126 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
127 |
+
)
|
128 |
+
|
129 |
+
self.attn2 = IPAttention(
|
130 |
+
query_dim=dim,
|
131 |
+
cross_attention_dim=cross_attention_dim
|
132 |
+
if not double_self_attention
|
133 |
+
else None,
|
134 |
+
heads=num_attention_heads,
|
135 |
+
dim_head=attention_head_dim,
|
136 |
+
dropout=dropout,
|
137 |
+
bias=attention_bias,
|
138 |
+
upcast_attention=upcast_attention,
|
139 |
+
cross_attn_temporal_cond=ip_adapter_cross_attn,
|
140 |
+
need_t2i_facein=need_t2i_facein,
|
141 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
142 |
+
image_scale=image_scale,
|
143 |
+
ip_adapter_dim=cross_attention_dim
|
144 |
+
if not double_self_attention
|
145 |
+
else attention_head_dim,
|
146 |
+
facein_dim=cross_attention_dim
|
147 |
+
if not double_self_attention
|
148 |
+
else attention_head_dim,
|
149 |
+
ip_adapter_face_dim=cross_attention_dim
|
150 |
+
if not double_self_attention
|
151 |
+
else attention_head_dim,
|
152 |
+
processor=processor,
|
153 |
+
) # is self-attn if encoder_hidden_states is none
|
154 |
+
else:
|
155 |
+
self.norm2 = None
|
156 |
+
self.attn2 = None
|
157 |
+
if self.attn1 is not None:
|
158 |
+
if not allow_xformers:
|
159 |
+
self.attn1.set_use_memory_efficient_attention_xformers = (
|
160 |
+
not_use_xformers_anyway
|
161 |
+
)
|
162 |
+
if self.attn2 is not None:
|
163 |
+
if not allow_xformers:
|
164 |
+
self.attn2.set_use_memory_efficient_attention_xformers = (
|
165 |
+
not_use_xformers_anyway
|
166 |
+
)
|
167 |
+
self.double_self_attention = double_self_attention
|
168 |
+
self.only_cross_attention = only_cross_attention
|
169 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
170 |
+
self.image_scale = image_scale
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
hidden_states: torch.FloatTensor,
|
175 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
176 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
177 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
178 |
+
timestep: Optional[torch.LongTensor] = None,
|
179 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
180 |
+
class_labels: Optional[torch.LongTensor] = None,
|
181 |
+
self_attn_block_embs: Optional[Tuple[List[torch.Tensor], List[None]]] = None,
|
182 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
183 |
+
) -> torch.FloatTensor:
|
184 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
185 |
+
# 0. Self-Attention
|
186 |
+
if self.use_ada_layer_norm:
|
187 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
188 |
+
elif self.use_ada_layer_norm_zero:
|
189 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
190 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
norm_hidden_states = self.norm1(hidden_states)
|
194 |
+
|
195 |
+
# 1. Retrieve lora scale.
|
196 |
+
lora_scale = (
|
197 |
+
cross_attention_kwargs.get("scale", 1.0)
|
198 |
+
if cross_attention_kwargs is not None
|
199 |
+
else 1.0
|
200 |
+
)
|
201 |
+
|
202 |
+
if cross_attention_kwargs is None:
|
203 |
+
cross_attention_kwargs = {}
|
204 |
+
# 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
|
205 |
+
# special AttnProcessor needs input parameters in cross_attention_kwargs
|
206 |
+
original_cross_attention_kwargs = {
|
207 |
+
k: v
|
208 |
+
for k, v in cross_attention_kwargs.items()
|
209 |
+
if k
|
210 |
+
not in [
|
211 |
+
"num_frames",
|
212 |
+
"sample_index",
|
213 |
+
"vision_conditon_frames_sample_index",
|
214 |
+
"vision_cond",
|
215 |
+
"vision_clip_emb",
|
216 |
+
"ip_adapter_scale",
|
217 |
+
"face_emb",
|
218 |
+
"facein_scale",
|
219 |
+
"ip_adapter_face_emb",
|
220 |
+
"ip_adapter_face_scale",
|
221 |
+
"do_classifier_free_guidance",
|
222 |
+
]
|
223 |
+
}
|
224 |
+
|
225 |
+
if "do_classifier_free_guidance" in cross_attention_kwargs:
|
226 |
+
do_classifier_free_guidance = cross_attention_kwargs[
|
227 |
+
"do_classifier_free_guidance"
|
228 |
+
]
|
229 |
+
else:
|
230 |
+
do_classifier_free_guidance = False
|
231 |
+
|
232 |
+
# 2. Prepare GLIGEN inputs
|
233 |
+
original_cross_attention_kwargs = (
|
234 |
+
original_cross_attention_kwargs.copy()
|
235 |
+
if original_cross_attention_kwargs is not None
|
236 |
+
else {}
|
237 |
+
)
|
238 |
+
gligen_kwargs = original_cross_attention_kwargs.pop("gligen", None)
|
239 |
+
|
240 |
+
# 返回self_attn的结果,适用于referencenet的输出给其他Unet来使用
|
241 |
+
# return the result of self_attn, which is suitable for the output of referencenet to be used by other Unet
|
242 |
+
if (
|
243 |
+
self_attn_block_embs is not None
|
244 |
+
and self_attn_block_embs_mode.lower() == "write"
|
245 |
+
):
|
246 |
+
# self_attn_block_emb = self.attn1.head_to_batch_dim(attn_output, out_dim=4)
|
247 |
+
self_attn_block_emb = norm_hidden_states
|
248 |
+
if not hasattr(self, "spatial_self_attn_idx"):
|
249 |
+
raise ValueError(
|
250 |
+
"must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
|
251 |
+
)
|
252 |
+
basick_transformer_idx = self.spatial_self_attn_idx
|
253 |
+
if self.print_idx == 0:
|
254 |
+
logger.debug(
|
255 |
+
f"self_attn_block_embs, self_attn_block_embs_mode={self_attn_block_embs_mode}, "
|
256 |
+
f"basick_transformer_idx={basick_transformer_idx}, length={len(self_attn_block_embs)}, shape={self_attn_block_emb.shape}, "
|
257 |
+
# f"attn1 processor, {type(self.attn1.processor)}"
|
258 |
+
)
|
259 |
+
self_attn_block_embs[basick_transformer_idx] = self_attn_block_emb
|
260 |
+
|
261 |
+
# read and put referencenet emb into cross_attention_kwargs, which would be fused into attn_processor
|
262 |
+
if (
|
263 |
+
self_attn_block_embs is not None
|
264 |
+
and self_attn_block_embs_mode.lower() == "read"
|
265 |
+
):
|
266 |
+
basick_transformer_idx = self.spatial_self_attn_idx
|
267 |
+
if not hasattr(self, "spatial_self_attn_idx"):
|
268 |
+
raise ValueError(
|
269 |
+
"must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
|
270 |
+
)
|
271 |
+
if self.print_idx == 0:
|
272 |
+
logger.debug(
|
273 |
+
f"refer_self_attn_emb: , self_attn_block_embs_mode={self_attn_block_embs_mode}, "
|
274 |
+
f"length={len(self_attn_block_embs)}, idx={basick_transformer_idx}, "
|
275 |
+
# f"attn1 processor, {type(self.attn1.processor)}, "
|
276 |
+
)
|
277 |
+
ref_emb = self_attn_block_embs[basick_transformer_idx]
|
278 |
+
cross_attention_kwargs["refer_emb"] = ref_emb
|
279 |
+
if self.print_idx == 0:
|
280 |
+
logger.debug(
|
281 |
+
f"unet attention read, {self.spatial_self_attn_idx}",
|
282 |
+
)
|
283 |
+
# ------------------------------warning-----------------------
|
284 |
+
# 这两行由于使用了ref_emb会导致和checkpoint_train相关的训练错误,具体未知,留在这里作为警示
|
285 |
+
# bellow annoated code will cause training error, keep it here as a warning
|
286 |
+
# logger.debug(f"ref_emb shape,{ref_emb.shape}, {ref_emb.mean()}")
|
287 |
+
# logger.debug(
|
288 |
+
# f"norm_hidden_states shape, {norm_hidden_states.shape}, {norm_hidden_states.mean()}",
|
289 |
+
# )
|
290 |
+
if self.attn1 is None:
|
291 |
+
self.print_idx += 1
|
292 |
+
return norm_hidden_states
|
293 |
+
attn_output = self.attn1(
|
294 |
+
norm_hidden_states,
|
295 |
+
encoder_hidden_states=encoder_hidden_states
|
296 |
+
if self.only_cross_attention
|
297 |
+
else None,
|
298 |
+
attention_mask=attention_mask,
|
299 |
+
**(
|
300 |
+
cross_attention_kwargs
|
301 |
+
if isinstance(self.attn1.processor, BaseIPAttnProcessor)
|
302 |
+
else original_cross_attention_kwargs
|
303 |
+
),
|
304 |
+
)
|
305 |
+
|
306 |
+
if self.use_ada_layer_norm_zero:
|
307 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
308 |
+
hidden_states = attn_output + hidden_states
|
309 |
+
|
310 |
+
# 推断的时候,对于uncondition_部分独立生成,排除掉 refer_emb,
|
311 |
+
# 首帧等的影响,避免生成参考了refer_emb、首帧等,又在uncond上去除了
|
312 |
+
# in inference stage, eliminate influence of refer_emb, vis_cond on unconditionpart
|
313 |
+
# to avoid use that, and then eliminate in pipeline
|
314 |
+
# refer to moore-animate anyone
|
315 |
+
|
316 |
+
# do_classifier_free_guidance = False
|
317 |
+
if self.print_idx == 0:
|
318 |
+
logger.debug(f"do_classifier_free_guidance={do_classifier_free_guidance},")
|
319 |
+
if do_classifier_free_guidance:
|
320 |
+
hidden_states_c = attn_output.clone()
|
321 |
+
_uc_mask = (
|
322 |
+
torch.Tensor(
|
323 |
+
[1] * (norm_hidden_states.shape[0] // 2)
|
324 |
+
+ [0] * (norm_hidden_states.shape[0] // 2)
|
325 |
+
)
|
326 |
+
.to(norm_hidden_states.device)
|
327 |
+
.bool()
|
328 |
+
)
|
329 |
+
hidden_states_c[_uc_mask] = self.attn1(
|
330 |
+
norm_hidden_states[_uc_mask],
|
331 |
+
encoder_hidden_states=norm_hidden_states[_uc_mask],
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
)
|
334 |
+
attn_output = hidden_states_c.clone()
|
335 |
+
|
336 |
+
if "refer_emb" in cross_attention_kwargs:
|
337 |
+
del cross_attention_kwargs["refer_emb"]
|
338 |
+
|
339 |
+
# 2.5 GLIGEN Control
|
340 |
+
if gligen_kwargs is not None:
|
341 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
342 |
+
# 2.5 ends
|
343 |
+
|
344 |
+
# 3. Cross-Attention
|
345 |
+
if self.attn2 is not None:
|
346 |
+
norm_hidden_states = (
|
347 |
+
self.norm2(hidden_states, timestep)
|
348 |
+
if self.use_ada_layer_norm
|
349 |
+
else self.norm2(hidden_states)
|
350 |
+
)
|
351 |
+
|
352 |
+
# 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
|
353 |
+
# special AttnProcessor needs input parameters in cross_attention_kwargs
|
354 |
+
attn_output = self.attn2(
|
355 |
+
norm_hidden_states,
|
356 |
+
encoder_hidden_states=encoder_hidden_states
|
357 |
+
if not self.double_self_attention
|
358 |
+
else None,
|
359 |
+
attention_mask=encoder_attention_mask,
|
360 |
+
**(
|
361 |
+
original_cross_attention_kwargs
|
362 |
+
if not isinstance(self.attn2.processor, BaseIPAttnProcessor)
|
363 |
+
else cross_attention_kwargs
|
364 |
+
),
|
365 |
+
)
|
366 |
+
if self.print_idx == 0:
|
367 |
+
logger.debug(
|
368 |
+
f"encoder_hidden_states, type={type(encoder_hidden_states)}"
|
369 |
+
)
|
370 |
+
if encoder_hidden_states is not None:
|
371 |
+
logger.debug(
|
372 |
+
f"encoder_hidden_states, ={encoder_hidden_states.shape}"
|
373 |
+
)
|
374 |
+
|
375 |
+
# encoder_hidden_states_tmp = (
|
376 |
+
# encoder_hidden_states
|
377 |
+
# if not self.double_self_attention
|
378 |
+
# else norm_hidden_states
|
379 |
+
# )
|
380 |
+
# if do_classifier_free_guidance:
|
381 |
+
# hidden_states_c = attn_output.clone()
|
382 |
+
# _uc_mask = (
|
383 |
+
# torch.Tensor(
|
384 |
+
# [1] * (norm_hidden_states.shape[0] // 2)
|
385 |
+
# + [0] * (norm_hidden_states.shape[0] // 2)
|
386 |
+
# )
|
387 |
+
# .to(norm_hidden_states.device)
|
388 |
+
# .bool()
|
389 |
+
# )
|
390 |
+
# hidden_states_c[_uc_mask] = self.attn2(
|
391 |
+
# norm_hidden_states[_uc_mask],
|
392 |
+
# encoder_hidden_states=encoder_hidden_states_tmp[_uc_mask],
|
393 |
+
# attention_mask=attention_mask,
|
394 |
+
# )
|
395 |
+
# attn_output = hidden_states_c.clone()
|
396 |
+
hidden_states = attn_output + hidden_states
|
397 |
+
# 4. Feed-forward
|
398 |
+
if self.norm3 is not None and self.ff is not None:
|
399 |
+
norm_hidden_states = self.norm3(hidden_states)
|
400 |
+
if self.use_ada_layer_norm_zero:
|
401 |
+
norm_hidden_states = (
|
402 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
403 |
+
)
|
404 |
+
if self._chunk_size is not None:
|
405 |
+
# "feed_forward_chunk_size" can be used to save memory
|
406 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
407 |
+
raise ValueError(
|
408 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
409 |
+
)
|
410 |
+
|
411 |
+
num_chunks = (
|
412 |
+
norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
413 |
+
)
|
414 |
+
ff_output = torch.cat(
|
415 |
+
[
|
416 |
+
self.ff(hid_slice, scale=lora_scale)
|
417 |
+
for hid_slice in norm_hidden_states.chunk(
|
418 |
+
num_chunks, dim=self._chunk_dim
|
419 |
+
)
|
420 |
+
],
|
421 |
+
dim=self._chunk_dim,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
425 |
+
|
426 |
+
if self.use_ada_layer_norm_zero:
|
427 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
428 |
+
|
429 |
+
hidden_states = ff_output + hidden_states
|
430 |
+
self.print_idx += 1
|
431 |
+
return hidden_states
|
musev/models/attention_processor.py
ADDED
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。
|
16 |
+
相对而言,开源代码经常会重新定义Attention 类,
|
17 |
+
|
18 |
+
This module implements special AttnProcessor function with custom attn_processor class.
|
19 |
+
While other open source code always modify Attention class.
|
20 |
+
"""
|
21 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
22 |
+
from __future__ import annotations
|
23 |
+
|
24 |
+
import time
|
25 |
+
from typing import Any, Callable, Optional
|
26 |
+
import logging
|
27 |
+
|
28 |
+
from einops import rearrange, repeat
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
import xformers
|
33 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
34 |
+
|
35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
+
from diffusers.models.attention_processor import (
|
37 |
+
Attention as DiffusersAttention,
|
38 |
+
AttnProcessor,
|
39 |
+
AttnProcessor2_0,
|
40 |
+
)
|
41 |
+
from ..data.data_util import (
|
42 |
+
batch_concat_two_tensor_with_index,
|
43 |
+
batch_index_select,
|
44 |
+
align_repeat_tensor_single_dim,
|
45 |
+
batch_adain_conditioned_tensor,
|
46 |
+
)
|
47 |
+
|
48 |
+
from . import Model_Register
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
|
53 |
+
@maybe_allow_in_graph
|
54 |
+
class IPAttention(DiffusersAttention):
|
55 |
+
r"""
|
56 |
+
Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v,
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
query_dim: int,
|
62 |
+
cross_attention_dim: int | None = None,
|
63 |
+
heads: int = 8,
|
64 |
+
dim_head: int = 64,
|
65 |
+
dropout: float = 0,
|
66 |
+
bias=False,
|
67 |
+
upcast_attention: bool = False,
|
68 |
+
upcast_softmax: bool = False,
|
69 |
+
cross_attention_norm: str | None = None,
|
70 |
+
cross_attention_norm_num_groups: int = 32,
|
71 |
+
added_kv_proj_dim: int | None = None,
|
72 |
+
norm_num_groups: int | None = None,
|
73 |
+
spatial_norm_dim: int | None = None,
|
74 |
+
out_bias: bool = True,
|
75 |
+
scale_qk: bool = True,
|
76 |
+
only_cross_attention: bool = False,
|
77 |
+
eps: float = 0.00001,
|
78 |
+
rescale_output_factor: float = 1,
|
79 |
+
residual_connection: bool = False,
|
80 |
+
_from_deprecated_attn_block=False,
|
81 |
+
processor: AttnProcessor | None = None,
|
82 |
+
cross_attn_temporal_cond: bool = False,
|
83 |
+
image_scale: float = 1.0,
|
84 |
+
ip_adapter_dim: int = None,
|
85 |
+
need_t2i_facein: bool = False,
|
86 |
+
facein_dim: int = None,
|
87 |
+
need_t2i_ip_adapter_face: bool = False,
|
88 |
+
ip_adapter_face_dim: int = None,
|
89 |
+
):
|
90 |
+
super().__init__(
|
91 |
+
query_dim,
|
92 |
+
cross_attention_dim,
|
93 |
+
heads,
|
94 |
+
dim_head,
|
95 |
+
dropout,
|
96 |
+
bias,
|
97 |
+
upcast_attention,
|
98 |
+
upcast_softmax,
|
99 |
+
cross_attention_norm,
|
100 |
+
cross_attention_norm_num_groups,
|
101 |
+
added_kv_proj_dim,
|
102 |
+
norm_num_groups,
|
103 |
+
spatial_norm_dim,
|
104 |
+
out_bias,
|
105 |
+
scale_qk,
|
106 |
+
only_cross_attention,
|
107 |
+
eps,
|
108 |
+
rescale_output_factor,
|
109 |
+
residual_connection,
|
110 |
+
_from_deprecated_attn_block,
|
111 |
+
processor,
|
112 |
+
)
|
113 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
114 |
+
self.image_scale = image_scale
|
115 |
+
# 面向首帧的 ip_adapter
|
116 |
+
# ip_apdater
|
117 |
+
if cross_attn_temporal_cond:
|
118 |
+
self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
|
119 |
+
self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
|
120 |
+
# facein
|
121 |
+
self.need_t2i_facein = need_t2i_facein
|
122 |
+
self.facein_dim = facein_dim
|
123 |
+
if need_t2i_facein:
|
124 |
+
raise NotImplementedError("facein")
|
125 |
+
|
126 |
+
# ip_adapter_face
|
127 |
+
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
|
128 |
+
self.ip_adapter_face_dim = ip_adapter_face_dim
|
129 |
+
if need_t2i_ip_adapter_face:
|
130 |
+
self.ip_adapter_face_to_k_ip = LoRACompatibleLinear(
|
131 |
+
ip_adapter_face_dim, query_dim, bias=False
|
132 |
+
)
|
133 |
+
self.ip_adapter_face_to_v_ip = LoRACompatibleLinear(
|
134 |
+
ip_adapter_face_dim, query_dim, bias=False
|
135 |
+
)
|
136 |
+
|
137 |
+
def set_use_memory_efficient_attention_xformers(
|
138 |
+
self,
|
139 |
+
use_memory_efficient_attention_xformers: bool,
|
140 |
+
attention_op: Callable[..., Any] | None = None,
|
141 |
+
):
|
142 |
+
if (
|
143 |
+
"XFormers" in self.processor.__class__.__name__
|
144 |
+
or "IP" in self.processor.__class__.__name__
|
145 |
+
):
|
146 |
+
pass
|
147 |
+
else:
|
148 |
+
return super().set_use_memory_efficient_attention_xformers(
|
149 |
+
use_memory_efficient_attention_xformers, attention_op
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
@Model_Register.register
|
154 |
+
class BaseIPAttnProcessor(nn.Module):
|
155 |
+
print_idx = 0
|
156 |
+
|
157 |
+
def __init__(self, *args, **kwargs) -> None:
|
158 |
+
super().__init__(*args, **kwargs)
|
159 |
+
|
160 |
+
|
161 |
+
@Model_Register.register
|
162 |
+
class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor):
|
163 |
+
r"""
|
164 |
+
面向 ref_image的 self_attn的 IPAdapter
|
165 |
+
"""
|
166 |
+
print_idx = 0
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
attention_op: Optional[Callable] = None,
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.attention_op = attention_op
|
175 |
+
|
176 |
+
def __call__(
|
177 |
+
self,
|
178 |
+
attn: IPAttention,
|
179 |
+
hidden_states: torch.FloatTensor,
|
180 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
181 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
182 |
+
temb: Optional[torch.FloatTensor] = None,
|
183 |
+
scale: float = 1.0,
|
184 |
+
num_frames: int = None,
|
185 |
+
sample_index: torch.LongTensor = None,
|
186 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
187 |
+
refer_emb: torch.Tensor = None,
|
188 |
+
vision_clip_emb: torch.Tensor = None,
|
189 |
+
ip_adapter_scale: float = 1.0,
|
190 |
+
face_emb: torch.Tensor = None,
|
191 |
+
facein_scale: float = 1.0,
|
192 |
+
ip_adapter_face_emb: torch.Tensor = None,
|
193 |
+
ip_adapter_face_scale: float = 1.0,
|
194 |
+
do_classifier_free_guidance: bool = False,
|
195 |
+
):
|
196 |
+
residual = hidden_states
|
197 |
+
|
198 |
+
if attn.spatial_norm is not None:
|
199 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
200 |
+
|
201 |
+
input_ndim = hidden_states.ndim
|
202 |
+
|
203 |
+
if input_ndim == 4:
|
204 |
+
batch_size, channel, height, width = hidden_states.shape
|
205 |
+
hidden_states = hidden_states.view(
|
206 |
+
batch_size, channel, height * width
|
207 |
+
).transpose(1, 2)
|
208 |
+
|
209 |
+
batch_size, key_tokens, _ = (
|
210 |
+
hidden_states.shape
|
211 |
+
if encoder_hidden_states is None
|
212 |
+
else encoder_hidden_states.shape
|
213 |
+
)
|
214 |
+
|
215 |
+
attention_mask = attn.prepare_attention_mask(
|
216 |
+
attention_mask, key_tokens, batch_size
|
217 |
+
)
|
218 |
+
if attention_mask is not None:
|
219 |
+
# expand our mask's singleton query_tokens dimension:
|
220 |
+
# [batch*heads, 1, key_tokens] ->
|
221 |
+
# [batch*heads, query_tokens, key_tokens]
|
222 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
223 |
+
# [batch*heads, query_tokens, key_tokens]
|
224 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
225 |
+
_, query_tokens, _ = hidden_states.shape
|
226 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
227 |
+
|
228 |
+
if attn.group_norm is not None:
|
229 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
230 |
+
1, 2
|
231 |
+
)
|
232 |
+
|
233 |
+
query = attn.to_q(hidden_states, scale=scale)
|
234 |
+
|
235 |
+
if encoder_hidden_states is None:
|
236 |
+
encoder_hidden_states = hidden_states
|
237 |
+
elif attn.norm_cross:
|
238 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
239 |
+
encoder_hidden_states
|
240 |
+
)
|
241 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
242 |
+
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
243 |
+
)
|
244 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
245 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
246 |
+
|
247 |
+
# for facein
|
248 |
+
if self.print_idx == 0:
|
249 |
+
logger.debug(
|
250 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}"
|
251 |
+
)
|
252 |
+
if facein_scale > 0 and face_emb is not None:
|
253 |
+
raise NotImplementedError("facein")
|
254 |
+
|
255 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
256 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
257 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
258 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
259 |
+
query,
|
260 |
+
key,
|
261 |
+
value,
|
262 |
+
attn_bias=attention_mask,
|
263 |
+
op=self.attention_op,
|
264 |
+
scale=attn.scale,
|
265 |
+
)
|
266 |
+
|
267 |
+
# ip-adapter start
|
268 |
+
if self.print_idx == 0:
|
269 |
+
logger.debug(
|
270 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}"
|
271 |
+
)
|
272 |
+
if ip_adapter_scale > 0 and vision_clip_emb is not None:
|
273 |
+
if self.print_idx == 0:
|
274 |
+
logger.debug(
|
275 |
+
f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
|
276 |
+
)
|
277 |
+
ip_key = attn.to_k_ip(vision_clip_emb)
|
278 |
+
ip_value = attn.to_v_ip(vision_clip_emb)
|
279 |
+
ip_key = align_repeat_tensor_single_dim(
|
280 |
+
ip_key, target_length=batch_size, dim=0
|
281 |
+
)
|
282 |
+
ip_value = align_repeat_tensor_single_dim(
|
283 |
+
ip_value, target_length=batch_size, dim=0
|
284 |
+
)
|
285 |
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
286 |
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
287 |
+
if self.print_idx == 0:
|
288 |
+
logger.debug(
|
289 |
+
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
|
290 |
+
)
|
291 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
292 |
+
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
|
293 |
+
query,
|
294 |
+
ip_key,
|
295 |
+
ip_value,
|
296 |
+
attn_bias=attention_mask,
|
297 |
+
op=self.attention_op,
|
298 |
+
scale=attn.scale,
|
299 |
+
)
|
300 |
+
hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip
|
301 |
+
# ip-adapter end
|
302 |
+
|
303 |
+
# ip-adapter face start
|
304 |
+
if self.print_idx == 0:
|
305 |
+
logger.debug(
|
306 |
+
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}"
|
307 |
+
)
|
308 |
+
if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None:
|
309 |
+
if self.print_idx == 0:
|
310 |
+
logger.debug(
|
311 |
+
f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
|
312 |
+
)
|
313 |
+
ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb)
|
314 |
+
ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb)
|
315 |
+
ip_key = align_repeat_tensor_single_dim(
|
316 |
+
ip_key, target_length=batch_size, dim=0
|
317 |
+
)
|
318 |
+
ip_value = align_repeat_tensor_single_dim(
|
319 |
+
ip_value, target_length=batch_size, dim=0
|
320 |
+
)
|
321 |
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
322 |
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
323 |
+
if self.print_idx == 0:
|
324 |
+
logger.debug(
|
325 |
+
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
|
326 |
+
)
|
327 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
328 |
+
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
|
329 |
+
query,
|
330 |
+
ip_key,
|
331 |
+
ip_value,
|
332 |
+
attn_bias=attention_mask,
|
333 |
+
op=self.attention_op,
|
334 |
+
scale=attn.scale,
|
335 |
+
)
|
336 |
+
hidden_states = (
|
337 |
+
hidden_states + ip_adapter_face_scale * hidden_states_from_ip
|
338 |
+
)
|
339 |
+
# ip-adapter face end
|
340 |
+
|
341 |
+
hidden_states = hidden_states.to(query.dtype)
|
342 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
343 |
+
|
344 |
+
# linear proj
|
345 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
346 |
+
# dropout
|
347 |
+
hidden_states = attn.to_out[1](hidden_states)
|
348 |
+
|
349 |
+
if input_ndim == 4:
|
350 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
351 |
+
batch_size, channel, height, width
|
352 |
+
)
|
353 |
+
|
354 |
+
if attn.residual_connection:
|
355 |
+
hidden_states = hidden_states + residual
|
356 |
+
|
357 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
358 |
+
self.print_idx += 1
|
359 |
+
return hidden_states
|
360 |
+
|
361 |
+
|
362 |
+
@Model_Register.register
|
363 |
+
class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor):
|
364 |
+
r"""
|
365 |
+
面向首帧的 referenceonly attn,适用于 T2I的 self_attn
|
366 |
+
referenceonly with vis_cond as key, value, in t2i self_attn.
|
367 |
+
"""
|
368 |
+
print_idx = 0
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
attention_op: Optional[Callable] = None,
|
373 |
+
):
|
374 |
+
super().__init__()
|
375 |
+
|
376 |
+
self.attention_op = attention_op
|
377 |
+
|
378 |
+
def __call__(
|
379 |
+
self,
|
380 |
+
attn: IPAttention,
|
381 |
+
hidden_states: torch.FloatTensor,
|
382 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
383 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
384 |
+
temb: Optional[torch.FloatTensor] = None,
|
385 |
+
scale: float = 1.0,
|
386 |
+
num_frames: int = None,
|
387 |
+
sample_index: torch.LongTensor = None,
|
388 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
389 |
+
refer_emb: torch.Tensor = None,
|
390 |
+
face_emb: torch.Tensor = None,
|
391 |
+
vision_clip_emb: torch.Tensor = None,
|
392 |
+
ip_adapter_scale: float = 1.0,
|
393 |
+
facein_scale: float = 1.0,
|
394 |
+
ip_adapter_face_emb: torch.Tensor = None,
|
395 |
+
ip_adapter_face_scale: float = 1.0,
|
396 |
+
do_classifier_free_guidance: bool = False,
|
397 |
+
):
|
398 |
+
residual = hidden_states
|
399 |
+
|
400 |
+
if attn.spatial_norm is not None:
|
401 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
402 |
+
|
403 |
+
input_ndim = hidden_states.ndim
|
404 |
+
|
405 |
+
if input_ndim == 4:
|
406 |
+
batch_size, channel, height, width = hidden_states.shape
|
407 |
+
hidden_states = hidden_states.view(
|
408 |
+
batch_size, channel, height * width
|
409 |
+
).transpose(1, 2)
|
410 |
+
|
411 |
+
batch_size, key_tokens, _ = (
|
412 |
+
hidden_states.shape
|
413 |
+
if encoder_hidden_states is None
|
414 |
+
else encoder_hidden_states.shape
|
415 |
+
)
|
416 |
+
|
417 |
+
attention_mask = attn.prepare_attention_mask(
|
418 |
+
attention_mask, key_tokens, batch_size
|
419 |
+
)
|
420 |
+
if attention_mask is not None:
|
421 |
+
# expand our mask's singleton query_tokens dimension:
|
422 |
+
# [batch*heads, 1, key_tokens] ->
|
423 |
+
# [batch*heads, query_tokens, key_tokens]
|
424 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
425 |
+
# [batch*heads, query_tokens, key_tokens]
|
426 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
427 |
+
_, query_tokens, _ = hidden_states.shape
|
428 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
429 |
+
|
430 |
+
# vision_cond in same unet attn start
|
431 |
+
if (
|
432 |
+
vision_conditon_frames_sample_index is not None and num_frames > 1
|
433 |
+
) or refer_emb is not None:
|
434 |
+
batchsize_timesize = hidden_states.shape[0]
|
435 |
+
if self.print_idx == 0:
|
436 |
+
logger.debug(
|
437 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}"
|
438 |
+
)
|
439 |
+
encoder_hidden_states = rearrange(
|
440 |
+
hidden_states, "(b t) hw c -> b t hw c", t=num_frames
|
441 |
+
)
|
442 |
+
# if False:
|
443 |
+
if vision_conditon_frames_sample_index is not None and num_frames > 1:
|
444 |
+
ip_hidden_states = batch_index_select(
|
445 |
+
encoder_hidden_states,
|
446 |
+
dim=1,
|
447 |
+
index=vision_conditon_frames_sample_index,
|
448 |
+
).contiguous()
|
449 |
+
if self.print_idx == 0:
|
450 |
+
logger.debug(
|
451 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
452 |
+
)
|
453 |
+
#
|
454 |
+
ip_hidden_states = rearrange(
|
455 |
+
ip_hidden_states, "b t hw c -> b 1 (t hw) c"
|
456 |
+
)
|
457 |
+
ip_hidden_states = align_repeat_tensor_single_dim(
|
458 |
+
ip_hidden_states,
|
459 |
+
dim=1,
|
460 |
+
target_length=num_frames,
|
461 |
+
)
|
462 |
+
# b t hw c -> b t hw + hw c
|
463 |
+
if self.print_idx == 0:
|
464 |
+
logger.debug(
|
465 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
466 |
+
)
|
467 |
+
encoder_hidden_states = torch.concat(
|
468 |
+
[encoder_hidden_states, ip_hidden_states], dim=2
|
469 |
+
)
|
470 |
+
if self.print_idx == 0:
|
471 |
+
logger.debug(
|
472 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
|
473 |
+
)
|
474 |
+
# if False:
|
475 |
+
if refer_emb is not None: # and num_frames > 1:
|
476 |
+
refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c")
|
477 |
+
refer_emb = align_repeat_tensor_single_dim(
|
478 |
+
refer_emb, target_length=num_frames, dim=1
|
479 |
+
)
|
480 |
+
if self.print_idx == 0:
|
481 |
+
logger.debug(
|
482 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
|
483 |
+
)
|
484 |
+
encoder_hidden_states = torch.concat(
|
485 |
+
[encoder_hidden_states, refer_emb], dim=2
|
486 |
+
)
|
487 |
+
if self.print_idx == 0:
|
488 |
+
logger.debug(
|
489 |
+
f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
|
490 |
+
)
|
491 |
+
encoder_hidden_states = rearrange(
|
492 |
+
encoder_hidden_states, "b t hw c -> (b t) hw c"
|
493 |
+
)
|
494 |
+
# vision_cond in same unet attn end
|
495 |
+
|
496 |
+
if attn.group_norm is not None:
|
497 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
498 |
+
1, 2
|
499 |
+
)
|
500 |
+
|
501 |
+
query = attn.to_q(hidden_states, scale=scale)
|
502 |
+
|
503 |
+
if encoder_hidden_states is None:
|
504 |
+
encoder_hidden_states = hidden_states
|
505 |
+
elif attn.norm_cross:
|
506 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
507 |
+
encoder_hidden_states
|
508 |
+
)
|
509 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
510 |
+
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
511 |
+
)
|
512 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
513 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
514 |
+
|
515 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
516 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
517 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
518 |
+
|
519 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
520 |
+
query,
|
521 |
+
key,
|
522 |
+
value,
|
523 |
+
attn_bias=attention_mask,
|
524 |
+
op=self.attention_op,
|
525 |
+
scale=attn.scale,
|
526 |
+
)
|
527 |
+
hidden_states = hidden_states.to(query.dtype)
|
528 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
529 |
+
|
530 |
+
# linear proj
|
531 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
532 |
+
# dropout
|
533 |
+
hidden_states = attn.to_out[1](hidden_states)
|
534 |
+
|
535 |
+
if input_ndim == 4:
|
536 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
537 |
+
batch_size, channel, height, width
|
538 |
+
)
|
539 |
+
|
540 |
+
if attn.residual_connection:
|
541 |
+
hidden_states = hidden_states + residual
|
542 |
+
|
543 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
544 |
+
self.print_idx += 1
|
545 |
+
|
546 |
+
return hidden_states
|
547 |
+
|
548 |
+
|
549 |
+
@Model_Register.register
|
550 |
+
class NonParamReferenceIPXFormersAttnProcessor(
|
551 |
+
NonParamT2ISelfReferenceXFormersAttnProcessor
|
552 |
+
):
|
553 |
+
def __init__(self, attention_op: Callable[..., Any] | None = None):
|
554 |
+
super().__init__(attention_op)
|
555 |
+
|
556 |
+
|
557 |
+
@maybe_allow_in_graph
|
558 |
+
class ReferEmbFuseAttention(IPAttention):
|
559 |
+
"""使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中
|
560 |
+
# TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合
|
561 |
+
residual_connection: bool = True, 默认, 从不产生影响开始学习
|
562 |
+
|
563 |
+
use attention to fuse referencenet emb into unet latents
|
564 |
+
# TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c
|
565 |
+
residual_connection: bool = True, default, start from no effect
|
566 |
+
|
567 |
+
Args:
|
568 |
+
IPAttention (_type_): _description_
|
569 |
+
"""
|
570 |
+
|
571 |
+
print_idx = 0
|
572 |
+
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
query_dim: int,
|
576 |
+
cross_attention_dim: int | None = None,
|
577 |
+
heads: int = 8,
|
578 |
+
dim_head: int = 64,
|
579 |
+
dropout: float = 0,
|
580 |
+
bias=False,
|
581 |
+
upcast_attention: bool = False,
|
582 |
+
upcast_softmax: bool = False,
|
583 |
+
cross_attention_norm: str | None = None,
|
584 |
+
cross_attention_norm_num_groups: int = 32,
|
585 |
+
added_kv_proj_dim: int | None = None,
|
586 |
+
norm_num_groups: int | None = None,
|
587 |
+
spatial_norm_dim: int | None = None,
|
588 |
+
out_bias: bool = True,
|
589 |
+
scale_qk: bool = True,
|
590 |
+
only_cross_attention: bool = False,
|
591 |
+
eps: float = 0.00001,
|
592 |
+
rescale_output_factor: float = 1,
|
593 |
+
residual_connection: bool = True,
|
594 |
+
_from_deprecated_attn_block=False,
|
595 |
+
processor: AttnProcessor | None = None,
|
596 |
+
cross_attn_temporal_cond: bool = False,
|
597 |
+
image_scale: float = 1,
|
598 |
+
):
|
599 |
+
super().__init__(
|
600 |
+
query_dim,
|
601 |
+
cross_attention_dim,
|
602 |
+
heads,
|
603 |
+
dim_head,
|
604 |
+
dropout,
|
605 |
+
bias,
|
606 |
+
upcast_attention,
|
607 |
+
upcast_softmax,
|
608 |
+
cross_attention_norm,
|
609 |
+
cross_attention_norm_num_groups,
|
610 |
+
added_kv_proj_dim,
|
611 |
+
norm_num_groups,
|
612 |
+
spatial_norm_dim,
|
613 |
+
out_bias,
|
614 |
+
scale_qk,
|
615 |
+
only_cross_attention,
|
616 |
+
eps,
|
617 |
+
rescale_output_factor,
|
618 |
+
residual_connection,
|
619 |
+
_from_deprecated_attn_block,
|
620 |
+
processor,
|
621 |
+
cross_attn_temporal_cond,
|
622 |
+
image_scale,
|
623 |
+
)
|
624 |
+
self.processor = None
|
625 |
+
# 配合residual,使一开始不影响之前结果
|
626 |
+
nn.init.zeros_(self.to_out[0].weight)
|
627 |
+
nn.init.zeros_(self.to_out[0].bias)
|
628 |
+
|
629 |
+
def forward(
|
630 |
+
self,
|
631 |
+
hidden_states: torch.FloatTensor,
|
632 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
633 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
634 |
+
temb: Optional[torch.FloatTensor] = None,
|
635 |
+
scale: float = 1.0,
|
636 |
+
num_frames: int = None,
|
637 |
+
) -> torch.Tensor:
|
638 |
+
"""fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn
|
639 |
+
refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor
|
640 |
+
|
641 |
+
Args:
|
642 |
+
hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1
|
643 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None.
|
644 |
+
attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
|
645 |
+
temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
|
646 |
+
scale (float, optional): _description_. Defaults to 1.0.
|
647 |
+
num_frames (int, optional): _description_. Defaults to None.
|
648 |
+
|
649 |
+
Returns:
|
650 |
+
torch.Tensor: _description_
|
651 |
+
"""
|
652 |
+
residual = hidden_states
|
653 |
+
# start
|
654 |
+
hidden_states = rearrange(
|
655 |
+
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
|
656 |
+
)
|
657 |
+
batch_size, channel, t1, height, width = hidden_states.shape
|
658 |
+
if self.print_idx == 0:
|
659 |
+
logger.debug(
|
660 |
+
f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}"
|
661 |
+
)
|
662 |
+
# concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c
|
663 |
+
encoder_hidden_states = rearrange(
|
664 |
+
encoder_hidden_states, " b c t2 h w-> b (t2 h w) c"
|
665 |
+
)
|
666 |
+
encoder_hidden_states = repeat(
|
667 |
+
encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1
|
668 |
+
)
|
669 |
+
hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c")
|
670 |
+
# bt (t2+1)hw d
|
671 |
+
encoder_hidden_states = torch.concat(
|
672 |
+
[encoder_hidden_states, hidden_states], dim=1
|
673 |
+
)
|
674 |
+
# encoder_hidden_states = align_repeat_tensor_single_dim(
|
675 |
+
# encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
|
676 |
+
# )
|
677 |
+
# end
|
678 |
+
|
679 |
+
if self.spatial_norm is not None:
|
680 |
+
hidden_states = self.spatial_norm(hidden_states, temb)
|
681 |
+
|
682 |
+
_, key_tokens, _ = (
|
683 |
+
hidden_states.shape
|
684 |
+
if encoder_hidden_states is None
|
685 |
+
else encoder_hidden_states.shape
|
686 |
+
)
|
687 |
+
|
688 |
+
attention_mask = self.prepare_attention_mask(
|
689 |
+
attention_mask, key_tokens, batch_size
|
690 |
+
)
|
691 |
+
if attention_mask is not None:
|
692 |
+
# expand our mask's singleton query_tokens dimension:
|
693 |
+
# [batch*heads, 1, key_tokens] ->
|
694 |
+
# [batch*heads, query_tokens, key_tokens]
|
695 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
696 |
+
# [batch*heads, query_tokens, key_tokens]
|
697 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
698 |
+
_, query_tokens, _ = hidden_states.shape
|
699 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
700 |
+
|
701 |
+
if self.group_norm is not None:
|
702 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
|
703 |
+
1, 2
|
704 |
+
)
|
705 |
+
|
706 |
+
query = self.to_q(hidden_states, scale=scale)
|
707 |
+
|
708 |
+
if encoder_hidden_states is None:
|
709 |
+
encoder_hidden_states = hidden_states
|
710 |
+
elif self.norm_cross:
|
711 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(
|
712 |
+
encoder_hidden_states
|
713 |
+
)
|
714 |
+
|
715 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
716 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
717 |
+
|
718 |
+
query = self.head_to_batch_dim(query).contiguous()
|
719 |
+
key = self.head_to_batch_dim(key).contiguous()
|
720 |
+
value = self.head_to_batch_dim(value).contiguous()
|
721 |
+
|
722 |
+
# query: b t hw d
|
723 |
+
# key/value: bt (t1+1)hw d
|
724 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
725 |
+
query,
|
726 |
+
key,
|
727 |
+
value,
|
728 |
+
attn_bias=attention_mask,
|
729 |
+
scale=self.scale,
|
730 |
+
)
|
731 |
+
hidden_states = hidden_states.to(query.dtype)
|
732 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
733 |
+
|
734 |
+
# linear proj
|
735 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
736 |
+
# dropout
|
737 |
+
hidden_states = self.to_out[1](hidden_states)
|
738 |
+
|
739 |
+
hidden_states = rearrange(
|
740 |
+
hidden_states,
|
741 |
+
"bt (h w) c-> bt c h w",
|
742 |
+
h=height,
|
743 |
+
w=width,
|
744 |
+
)
|
745 |
+
if self.residual_connection:
|
746 |
+
hidden_states = hidden_states + residual
|
747 |
+
|
748 |
+
hidden_states = hidden_states / self.rescale_output_factor
|
749 |
+
self.print_idx += 1
|
750 |
+
return hidden_states
|
musev/models/controlnet.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
2 |
+
import warnings
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
import PIL
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.init as init
|
14 |
+
from diffusers.models.controlnet import ControlNetModel
|
15 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
16 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
|
17 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
18 |
+
|
19 |
+
|
20 |
+
class ControlnetPredictor(object):
|
21 |
+
def __init__(self, controlnet_model_path: str, *args, **kwargs):
|
22 |
+
"""Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取
|
23 |
+
Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training
|
24 |
+
Args:
|
25 |
+
controlnet_model_path (str): controlnet 模型路径. controlnet model path.
|
26 |
+
"""
|
27 |
+
super(ControlnetPredictor, self).__init__(*args, **kwargs)
|
28 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
29 |
+
controlnet_model_path,
|
30 |
+
)
|
31 |
+
|
32 |
+
def prepare_image(
|
33 |
+
self,
|
34 |
+
image, # b c t h w
|
35 |
+
width,
|
36 |
+
height,
|
37 |
+
batch_size,
|
38 |
+
num_images_per_prompt,
|
39 |
+
device,
|
40 |
+
dtype,
|
41 |
+
do_classifier_free_guidance=False,
|
42 |
+
guess_mode=False,
|
43 |
+
):
|
44 |
+
if height is None:
|
45 |
+
height = image.shape[-2]
|
46 |
+
if width is None:
|
47 |
+
width = image.shape[-1]
|
48 |
+
width, height = (
|
49 |
+
x - x % self.control_image_processor.vae_scale_factor
|
50 |
+
for x in (width, height)
|
51 |
+
)
|
52 |
+
image = rearrange(image, "b c t h w-> (b t) c h w")
|
53 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
54 |
+
|
55 |
+
image = (
|
56 |
+
torch.nn.functional.interpolate(
|
57 |
+
image,
|
58 |
+
size=(height, width),
|
59 |
+
mode="bilinear",
|
60 |
+
),
|
61 |
+
)
|
62 |
+
|
63 |
+
do_normalize = self.control_image_processor.config.do_normalize
|
64 |
+
if image.min() < 0:
|
65 |
+
warnings.warn(
|
66 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
67 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
68 |
+
FutureWarning,
|
69 |
+
)
|
70 |
+
do_normalize = False
|
71 |
+
|
72 |
+
if do_normalize:
|
73 |
+
image = self.control_image_processor.normalize(image)
|
74 |
+
|
75 |
+
image_batch_size = image.shape[0]
|
76 |
+
|
77 |
+
if image_batch_size == 1:
|
78 |
+
repeat_by = batch_size
|
79 |
+
else:
|
80 |
+
# image batch size is the same as prompt batch size
|
81 |
+
repeat_by = num_images_per_prompt
|
82 |
+
|
83 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
84 |
+
|
85 |
+
image = image.to(device=device, dtype=dtype)
|
86 |
+
|
87 |
+
if do_classifier_free_guidance and not guess_mode:
|
88 |
+
image = torch.cat([image] * 2)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def __call__(
|
94 |
+
self,
|
95 |
+
batch_size: int,
|
96 |
+
device: str,
|
97 |
+
dtype: torch.dtype,
|
98 |
+
timesteps: List[float],
|
99 |
+
i: int,
|
100 |
+
scheduler: KarrasDiffusionSchedulers,
|
101 |
+
prompt_embeds: torch.Tensor,
|
102 |
+
do_classifier_free_guidance: bool = False,
|
103 |
+
# 2b co t ho wo
|
104 |
+
latent_model_input: torch.Tensor = None,
|
105 |
+
# b co t ho wo
|
106 |
+
latents: torch.Tensor = None,
|
107 |
+
# b c t h w
|
108 |
+
image: Union[
|
109 |
+
torch.FloatTensor,
|
110 |
+
PIL.Image.Image,
|
111 |
+
np.ndarray,
|
112 |
+
List[torch.FloatTensor],
|
113 |
+
List[PIL.Image.Image],
|
114 |
+
List[np.ndarray],
|
115 |
+
] = None,
|
116 |
+
# b c t(1) hi wi
|
117 |
+
controlnet_condition_frames: Optional[torch.FloatTensor] = None,
|
118 |
+
# b c t ho wo
|
119 |
+
controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
|
120 |
+
# b c t(1) ho wo
|
121 |
+
controlnet_condition_latents: Optional[torch.FloatTensor] = None,
|
122 |
+
height: Optional[int] = None,
|
123 |
+
width: Optional[int] = None,
|
124 |
+
num_videos_per_prompt: Optional[int] = 1,
|
125 |
+
return_dict: bool = True,
|
126 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
127 |
+
guess_mode: bool = False,
|
128 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
129 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
130 |
+
latent_index: torch.LongTensor = None,
|
131 |
+
vision_condition_latent_index: torch.LongTensor = None,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
assert (
|
135 |
+
image is None and controlnet_latents is None
|
136 |
+
), "should set one of image and controlnet_latents"
|
137 |
+
|
138 |
+
controlnet = (
|
139 |
+
self.controlnet._orig_mod
|
140 |
+
if is_compiled_module(self.controlnet)
|
141 |
+
else self.controlnet
|
142 |
+
)
|
143 |
+
|
144 |
+
# align format for control guidance
|
145 |
+
if not isinstance(control_guidance_start, list) and isinstance(
|
146 |
+
control_guidance_end, list
|
147 |
+
):
|
148 |
+
control_guidance_start = len(control_guidance_end) * [
|
149 |
+
control_guidance_start
|
150 |
+
]
|
151 |
+
elif not isinstance(control_guidance_end, list) and isinstance(
|
152 |
+
control_guidance_start, list
|
153 |
+
):
|
154 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
155 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(
|
156 |
+
control_guidance_end, list
|
157 |
+
):
|
158 |
+
mult = (
|
159 |
+
len(controlnet.nets)
|
160 |
+
if isinstance(controlnet, MultiControlNetModel)
|
161 |
+
else 1
|
162 |
+
)
|
163 |
+
control_guidance_start, control_guidance_end = mult * [
|
164 |
+
control_guidance_start
|
165 |
+
], mult * [control_guidance_end]
|
166 |
+
|
167 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
168 |
+
controlnet_conditioning_scale, float
|
169 |
+
):
|
170 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
|
171 |
+
controlnet.nets
|
172 |
+
)
|
173 |
+
|
174 |
+
global_pool_conditions = (
|
175 |
+
controlnet.config.global_pool_conditions
|
176 |
+
if isinstance(controlnet, ControlNetModel)
|
177 |
+
else controlnet.nets[0].config.global_pool_conditions
|
178 |
+
)
|
179 |
+
guess_mode = guess_mode or global_pool_conditions
|
180 |
+
|
181 |
+
# 4. Prepare image
|
182 |
+
if isinstance(controlnet, ControlNetModel):
|
183 |
+
if (
|
184 |
+
controlnet_latents is not None
|
185 |
+
and controlnet_condition_latents is not None
|
186 |
+
):
|
187 |
+
if isinstance(controlnet_latents, np.ndarray):
|
188 |
+
controlnet_latents = torch.from_numpy(controlnet_latents)
|
189 |
+
if isinstance(controlnet_condition_latents, np.ndarray):
|
190 |
+
controlnet_condition_latents = torch.from_numpy(
|
191 |
+
controlnet_condition_latents
|
192 |
+
)
|
193 |
+
# TODO:使用index进行concat
|
194 |
+
controlnet_latents = torch.concat(
|
195 |
+
[controlnet_condition_latents, controlnet_latents], dim=2
|
196 |
+
)
|
197 |
+
if not guess_mode and do_classifier_free_guidance:
|
198 |
+
controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0)
|
199 |
+
controlnet_latents = rearrange(
|
200 |
+
controlnet_latents, "b c t h w->(b t) c h w"
|
201 |
+
)
|
202 |
+
controlnet_latents = controlnet_latents.to(device=device, dtype=dtype)
|
203 |
+
else:
|
204 |
+
# TODO:使用index进行concat
|
205 |
+
# TODO: concat with index
|
206 |
+
if controlnet_condition_frames is not None:
|
207 |
+
if isinstance(controlnet_condition_frames, np.ndarray):
|
208 |
+
image = np.concatenate(
|
209 |
+
[controlnet_condition_frames, image], axis=2
|
210 |
+
)
|
211 |
+
image = self.prepare_image(
|
212 |
+
image=image,
|
213 |
+
width=width,
|
214 |
+
height=height,
|
215 |
+
batch_size=batch_size * num_videos_per_prompt,
|
216 |
+
num_images_per_prompt=num_videos_per_prompt,
|
217 |
+
device=device,
|
218 |
+
dtype=controlnet.dtype,
|
219 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
220 |
+
guess_mode=guess_mode,
|
221 |
+
)
|
222 |
+
height, width = image.shape[-2:]
|
223 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
224 |
+
images = []
|
225 |
+
# TODO: 支持直接使用controlnet_latent而不是frames
|
226 |
+
# TODO: support using controlnet_latent directly instead of frames
|
227 |
+
if controlnet_latents is not None:
|
228 |
+
raise NotImplementedError
|
229 |
+
else:
|
230 |
+
for i, image_ in enumerate(image):
|
231 |
+
if controlnet_condition_frames is not None and isinstance(
|
232 |
+
controlnet_condition_frames, list
|
233 |
+
):
|
234 |
+
if isinstance(controlnet_condition_frames[i], np.ndarray):
|
235 |
+
image_ = np.concatenate(
|
236 |
+
[controlnet_condition_frames[i], image_], axis=2
|
237 |
+
)
|
238 |
+
image_ = self.prepare_image(
|
239 |
+
image=image_,
|
240 |
+
width=width,
|
241 |
+
height=height,
|
242 |
+
batch_size=batch_size * num_videos_per_prompt,
|
243 |
+
num_images_per_prompt=num_videos_per_prompt,
|
244 |
+
device=device,
|
245 |
+
dtype=controlnet.dtype,
|
246 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
247 |
+
guess_mode=guess_mode,
|
248 |
+
)
|
249 |
+
|
250 |
+
images.append(image_)
|
251 |
+
|
252 |
+
image = images
|
253 |
+
height, width = image[0].shape[-2:]
|
254 |
+
else:
|
255 |
+
assert False
|
256 |
+
|
257 |
+
# 7.1 Create tensor stating which controlnets to keep
|
258 |
+
controlnet_keep = []
|
259 |
+
for i in range(len(timesteps)):
|
260 |
+
keeps = [
|
261 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
262 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
263 |
+
]
|
264 |
+
controlnet_keep.append(
|
265 |
+
keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
|
266 |
+
)
|
267 |
+
|
268 |
+
t = timesteps[i]
|
269 |
+
|
270 |
+
# controlnet(s) inference
|
271 |
+
if guess_mode and do_classifier_free_guidance:
|
272 |
+
# Infer ControlNet only for the conditional batch.
|
273 |
+
control_model_input = latents
|
274 |
+
control_model_input = scheduler.scale_model_input(control_model_input, t)
|
275 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
276 |
+
else:
|
277 |
+
control_model_input = latent_model_input
|
278 |
+
controlnet_prompt_embeds = prompt_embeds
|
279 |
+
if isinstance(controlnet_keep[i], list):
|
280 |
+
cond_scale = [
|
281 |
+
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
|
282 |
+
]
|
283 |
+
else:
|
284 |
+
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
|
285 |
+
control_model_input_reshape = rearrange(
|
286 |
+
control_model_input, "b c t h w -> (b t) c h w"
|
287 |
+
)
|
288 |
+
encoder_hidden_states_repeat = repeat(
|
289 |
+
controlnet_prompt_embeds,
|
290 |
+
"b n q->(b t) n q",
|
291 |
+
t=control_model_input.shape[2],
|
292 |
+
)
|
293 |
+
|
294 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
295 |
+
control_model_input_reshape,
|
296 |
+
t,
|
297 |
+
encoder_hidden_states_repeat,
|
298 |
+
controlnet_cond=image,
|
299 |
+
controlnet_cond_latents=controlnet_latents,
|
300 |
+
conditioning_scale=cond_scale,
|
301 |
+
guess_mode=guess_mode,
|
302 |
+
return_dict=False,
|
303 |
+
)
|
304 |
+
|
305 |
+
return down_block_res_samples, mid_block_res_sample
|
306 |
+
|
307 |
+
|
308 |
+
class InflatedConv3d(nn.Conv2d):
|
309 |
+
def forward(self, x):
|
310 |
+
video_length = x.shape[2]
|
311 |
+
|
312 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
313 |
+
x = super().forward(x)
|
314 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
315 |
+
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
def zero_module(module):
|
320 |
+
# Zero out the parameters of a module and return it.
|
321 |
+
for p in module.parameters():
|
322 |
+
p.detach().zero_()
|
323 |
+
return module
|
324 |
+
|
325 |
+
|
326 |
+
class PoseGuider(ModelMixin):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
conditioning_embedding_channels: int,
|
330 |
+
conditioning_channels: int = 3,
|
331 |
+
block_out_channels: Tuple[int] = (16, 32, 64, 128),
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
self.conv_in = InflatedConv3d(
|
335 |
+
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
|
336 |
+
)
|
337 |
+
|
338 |
+
self.blocks = nn.ModuleList([])
|
339 |
+
|
340 |
+
for i in range(len(block_out_channels) - 1):
|
341 |
+
channel_in = block_out_channels[i]
|
342 |
+
channel_out = block_out_channels[i + 1]
|
343 |
+
self.blocks.append(
|
344 |
+
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
|
345 |
+
)
|
346 |
+
self.blocks.append(
|
347 |
+
InflatedConv3d(
|
348 |
+
channel_in, channel_out, kernel_size=3, padding=1, stride=2
|
349 |
+
)
|
350 |
+
)
|
351 |
+
|
352 |
+
self.conv_out = zero_module(
|
353 |
+
InflatedConv3d(
|
354 |
+
block_out_channels[-1],
|
355 |
+
conditioning_embedding_channels,
|
356 |
+
kernel_size=3,
|
357 |
+
padding=1,
|
358 |
+
)
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, conditioning):
|
362 |
+
embedding = self.conv_in(conditioning)
|
363 |
+
embedding = F.silu(embedding)
|
364 |
+
|
365 |
+
for block in self.blocks:
|
366 |
+
embedding = block(embedding)
|
367 |
+
embedding = F.silu(embedding)
|
368 |
+
|
369 |
+
embedding = self.conv_out(embedding)
|
370 |
+
|
371 |
+
return embedding
|
372 |
+
|
373 |
+
@classmethod
|
374 |
+
def from_pretrained(
|
375 |
+
cls,
|
376 |
+
pretrained_model_path,
|
377 |
+
conditioning_embedding_channels: int,
|
378 |
+
conditioning_channels: int = 3,
|
379 |
+
block_out_channels: Tuple[int] = (16, 32, 64, 128),
|
380 |
+
):
|
381 |
+
if not os.path.exists(pretrained_model_path):
|
382 |
+
print(f"There is no model file in {pretrained_model_path}")
|
383 |
+
print(
|
384 |
+
f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..."
|
385 |
+
)
|
386 |
+
|
387 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
388 |
+
model = PoseGuider(
|
389 |
+
conditioning_embedding_channels=conditioning_embedding_channels,
|
390 |
+
conditioning_channels=conditioning_channels,
|
391 |
+
block_out_channels=block_out_channels,
|
392 |
+
)
|
393 |
+
|
394 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
395 |
+
# print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
396 |
+
params = [p.numel() for n, p in model.named_parameters()]
|
397 |
+
print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
|
398 |
+
|
399 |
+
return model
|
musev/models/embeddings.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from einops import rearrange
|
16 |
+
import torch
|
17 |
+
from torch.nn import functional as F
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid
|
21 |
+
|
22 |
+
|
23 |
+
# ref diffusers.models.embeddings.get_2d_sincos_pos_embed
|
24 |
+
def get_2d_sincos_pos_embed(
|
25 |
+
embed_dim,
|
26 |
+
grid_size_w,
|
27 |
+
grid_size_h,
|
28 |
+
cls_token=False,
|
29 |
+
extra_tokens=0,
|
30 |
+
norm_length: bool = False,
|
31 |
+
max_length: float = 2048,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
35 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
36 |
+
"""
|
37 |
+
if norm_length and grid_size_h <= max_length and grid_size_w <= max_length:
|
38 |
+
grid_h = np.linspace(0, max_length, grid_size_h)
|
39 |
+
grid_w = np.linspace(0, max_length, grid_size_w)
|
40 |
+
else:
|
41 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
42 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
43 |
+
grid = np.meshgrid(grid_h, grid_w) # here h goes first
|
44 |
+
grid = np.stack(grid, axis=0)
|
45 |
+
|
46 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
47 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
48 |
+
if cls_token and extra_tokens > 0:
|
49 |
+
pos_embed = np.concatenate(
|
50 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
51 |
+
)
|
52 |
+
return pos_embed
|
53 |
+
|
54 |
+
|
55 |
+
def resize_spatial_position_emb(
|
56 |
+
emb: torch.Tensor,
|
57 |
+
height: int,
|
58 |
+
width: int,
|
59 |
+
scale: float = None,
|
60 |
+
target_height: int = None,
|
61 |
+
target_width: int = None,
|
62 |
+
) -> torch.Tensor:
|
63 |
+
"""_summary_
|
64 |
+
|
65 |
+
Args:
|
66 |
+
emb (torch.Tensor): b ( h w) d
|
67 |
+
height (int): _description_
|
68 |
+
width (int): _description_
|
69 |
+
scale (float, optional): _description_. Defaults to None.
|
70 |
+
target_height (int, optional): _description_. Defaults to None.
|
71 |
+
target_width (int, optional): _description_. Defaults to None.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: b (target_height target_width) d
|
75 |
+
"""
|
76 |
+
if scale is not None:
|
77 |
+
target_height = int(height * scale)
|
78 |
+
target_width = int(width * scale)
|
79 |
+
emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1)
|
80 |
+
emb = F.interpolate(
|
81 |
+
emb,
|
82 |
+
size=(target_height, target_width),
|
83 |
+
mode="bicubic",
|
84 |
+
align_corners=False,
|
85 |
+
)
|
86 |
+
emb = rearrange(emb, "b d h w-> (h w) (b d)")
|
87 |
+
return emb
|
musev/models/facein_loader.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
37 |
+
ImageClipVisionFeatureExtractor,
|
38 |
+
ImageClipVisionFeatureExtractorV2,
|
39 |
+
)
|
40 |
+
from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor
|
41 |
+
|
42 |
+
from ip_adapter.resampler import Resampler
|
43 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
44 |
+
|
45 |
+
from .unet_loader import update_unet_with_sd
|
46 |
+
from .unet_3d_condition import UNet3DConditionModel
|
47 |
+
from .ip_adapter_loader import ip_adapter_keys_list
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
|
52 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
53 |
+
unet_keys_list = [
|
54 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
55 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
56 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
57 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
58 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
59 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
60 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
61 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
62 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
63 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
64 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
65 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
66 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
67 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
68 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
69 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
70 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
71 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
72 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
73 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
74 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
75 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
76 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
77 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
78 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
79 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
80 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
81 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
82 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
83 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
84 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
|
85 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
|
86 |
+
]
|
87 |
+
|
88 |
+
|
89 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
90 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def load_facein_extractor_and_proj_by_name(
|
95 |
+
model_name: str,
|
96 |
+
ip_ckpt: Tuple[str, nn.Module],
|
97 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
98 |
+
cross_attention_dim: int = 768,
|
99 |
+
clip_embeddings_dim: int = 512,
|
100 |
+
clip_extra_context_tokens: int = 1,
|
101 |
+
ip_scale: float = 0.0,
|
102 |
+
dtype: torch.dtype = torch.float16,
|
103 |
+
device: str = "cuda",
|
104 |
+
unet: nn.Module = None,
|
105 |
+
) -> nn.Module:
|
106 |
+
pass
|
107 |
+
|
108 |
+
|
109 |
+
def update_unet_facein_cross_attn_param(
|
110 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
111 |
+
) -> None:
|
112 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
113 |
+
ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典
|
114 |
+
|
115 |
+
|
116 |
+
Args:
|
117 |
+
unet (UNet3DConditionModel): _description_
|
118 |
+
ip_adapter_state_dict (Dict): _description_
|
119 |
+
"""
|
120 |
+
pass
|
musev/models/ip_adapter_face_loader.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from ip_adapter.resampler import Resampler
|
37 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
38 |
+
from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel
|
39 |
+
|
40 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
41 |
+
ImageClipVisionFeatureExtractor,
|
42 |
+
ImageClipVisionFeatureExtractorV2,
|
43 |
+
)
|
44 |
+
from mmcm.vision.feature_extractor.insight_face_extractor import (
|
45 |
+
InsightFaceExtractorNormEmb,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
from .unet_loader import update_unet_with_sd
|
50 |
+
from .unet_3d_condition import UNet3DConditionModel
|
51 |
+
from .ip_adapter_loader import ip_adapter_keys_list
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
+
|
55 |
+
|
56 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
57 |
+
unet_keys_list = [
|
58 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
59 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
60 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
61 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
62 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
63 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
64 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
65 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
66 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
67 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
68 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
69 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
70 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
71 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
72 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
73 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
74 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
75 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
76 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
77 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
78 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
79 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
80 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
81 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
82 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
83 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
84 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
85 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
86 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
87 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
88 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
|
89 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
|
90 |
+
]
|
91 |
+
|
92 |
+
|
93 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
94 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
def load_ip_adapter_face_extractor_and_proj_by_name(
|
99 |
+
model_name: str,
|
100 |
+
ip_ckpt: Tuple[str, nn.Module],
|
101 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
102 |
+
cross_attention_dim: int = 768,
|
103 |
+
clip_embeddings_dim: int = 1024,
|
104 |
+
clip_extra_context_tokens: int = 4,
|
105 |
+
ip_scale: float = 0.0,
|
106 |
+
dtype: torch.dtype = torch.float16,
|
107 |
+
device: str = "cuda",
|
108 |
+
unet: nn.Module = None,
|
109 |
+
) -> nn.Module:
|
110 |
+
if model_name == "IPAdapterFaceID":
|
111 |
+
if ip_image_encoder is not None:
|
112 |
+
ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb(
|
113 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
114 |
+
dtype=dtype,
|
115 |
+
device=device,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
ip_adapter_face_emb_extractor = None
|
119 |
+
ip_adapter_image_proj = MLPProjModel(
|
120 |
+
cross_attention_dim=cross_attention_dim,
|
121 |
+
id_embeddings_dim=clip_embeddings_dim,
|
122 |
+
num_tokens=clip_extra_context_tokens,
|
123 |
+
).to(device, dtype=dtype)
|
124 |
+
else:
|
125 |
+
raise ValueError(
|
126 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID"
|
127 |
+
)
|
128 |
+
ip_adapter_state_dict = torch.load(
|
129 |
+
ip_ckpt,
|
130 |
+
map_location="cpu",
|
131 |
+
)
|
132 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
133 |
+
if unet is not None and "ip_adapter" in ip_adapter_state_dict:
|
134 |
+
update_unet_ip_adapter_cross_attn_param(
|
135 |
+
unet,
|
136 |
+
ip_adapter_state_dict["ip_adapter"],
|
137 |
+
)
|
138 |
+
logger.info(
|
139 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
140 |
+
)
|
141 |
+
return (
|
142 |
+
ip_adapter_face_emb_extractor,
|
143 |
+
ip_adapter_image_proj,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def update_unet_ip_adapter_cross_attn_param(
|
148 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
149 |
+
) -> None:
|
150 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
151 |
+
ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
|
152 |
+
|
153 |
+
|
154 |
+
Args:
|
155 |
+
unet (UNet3DConditionModel): _description_
|
156 |
+
ip_adapter_state_dict (Dict): _description_
|
157 |
+
"""
|
158 |
+
unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
|
159 |
+
unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
|
160 |
+
for i, (unet_key_more, ip_adapter_key) in enumerate(
|
161 |
+
UNET2IPAadapter_Keys_MAPIING.items()
|
162 |
+
):
|
163 |
+
ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
|
164 |
+
unet_key_more_spit = unet_key_more.split(".")
|
165 |
+
unet_key = ".".join(unet_key_more_spit[:-3])
|
166 |
+
suffix = ".".join(unet_key_more_spit[-3:])
|
167 |
+
logger.debug(
|
168 |
+
f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
|
169 |
+
)
|
170 |
+
if ".ip_adapter_face_to_k" in suffix:
|
171 |
+
with torch.no_grad():
|
172 |
+
unet_spatial_cross_atnns_dct[
|
173 |
+
unet_key
|
174 |
+
].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data)
|
175 |
+
else:
|
176 |
+
with torch.no_grad():
|
177 |
+
unet_spatial_cross_atnns_dct[
|
178 |
+
unet_key
|
179 |
+
].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data)
|
musev/models/ip_adapter_loader.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from mmcm.vision.feature_extractor import clip_vision_extractor
|
37 |
+
from mmcm.vision.feature_extractor.clip_vision_extractor import (
|
38 |
+
ImageClipVisionFeatureExtractor,
|
39 |
+
ImageClipVisionFeatureExtractorV2,
|
40 |
+
VerstailSDLastHiddenState2ImageEmb,
|
41 |
+
)
|
42 |
+
|
43 |
+
from ip_adapter.resampler import Resampler
|
44 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
45 |
+
|
46 |
+
from .unet_loader import update_unet_with_sd
|
47 |
+
from .unet_3d_condition import UNet3DConditionModel
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
|
52 |
+
def load_vision_clip_encoder_by_name(
|
53 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
54 |
+
dtype: torch.dtype = torch.float16,
|
55 |
+
device: str = "cuda",
|
56 |
+
vision_clip_extractor_class_name: str = None,
|
57 |
+
) -> nn.Module:
|
58 |
+
if vision_clip_extractor_class_name is not None:
|
59 |
+
vision_clip_extractor = getattr(
|
60 |
+
clip_vision_extractor, vision_clip_extractor_class_name
|
61 |
+
)(
|
62 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
63 |
+
dtype=dtype,
|
64 |
+
device=device,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
vision_clip_extractor = None
|
68 |
+
return vision_clip_extractor
|
69 |
+
|
70 |
+
|
71 |
+
def load_ip_adapter_image_proj_by_name(
|
72 |
+
model_name: str,
|
73 |
+
ip_ckpt: Tuple[str, nn.Module] = None,
|
74 |
+
cross_attention_dim: int = 768,
|
75 |
+
clip_embeddings_dim: int = 1024,
|
76 |
+
clip_extra_context_tokens: int = 4,
|
77 |
+
ip_scale: float = 0.0,
|
78 |
+
dtype: torch.dtype = torch.float16,
|
79 |
+
device: str = "cuda",
|
80 |
+
unet: nn.Module = None,
|
81 |
+
vision_clip_extractor_class_name: str = None,
|
82 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
83 |
+
) -> nn.Module:
|
84 |
+
if model_name in [
|
85 |
+
"IPAdapter",
|
86 |
+
"musev_referencenet",
|
87 |
+
"musev_referencenet_pose",
|
88 |
+
]:
|
89 |
+
ip_adapter_image_proj = ImageProjModel(
|
90 |
+
cross_attention_dim=cross_attention_dim,
|
91 |
+
clip_embeddings_dim=clip_embeddings_dim,
|
92 |
+
clip_extra_context_tokens=clip_extra_context_tokens,
|
93 |
+
)
|
94 |
+
|
95 |
+
elif model_name == "IPAdapterPlus":
|
96 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
|
97 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
98 |
+
dtype=dtype,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
ip_adapter_image_proj = Resampler(
|
102 |
+
dim=cross_attention_dim,
|
103 |
+
depth=4,
|
104 |
+
dim_head=64,
|
105 |
+
heads=12,
|
106 |
+
num_queries=clip_extra_context_tokens,
|
107 |
+
embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
|
108 |
+
output_dim=cross_attention_dim,
|
109 |
+
ff_mult=4,
|
110 |
+
)
|
111 |
+
elif model_name in [
|
112 |
+
"VerstailSDLastHiddenState2ImageEmb",
|
113 |
+
"OriginLastHiddenState2ImageEmbd",
|
114 |
+
"OriginLastHiddenState2Poolout",
|
115 |
+
]:
|
116 |
+
ip_adapter_image_proj = getattr(
|
117 |
+
clip_vision_extractor, model_name
|
118 |
+
).from_pretrained(ip_image_encoder)
|
119 |
+
else:
|
120 |
+
raise ValueError(
|
121 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb"
|
122 |
+
)
|
123 |
+
if ip_ckpt is not None:
|
124 |
+
ip_adapter_state_dict = torch.load(
|
125 |
+
ip_ckpt,
|
126 |
+
map_location="cpu",
|
127 |
+
)
|
128 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
129 |
+
if (
|
130 |
+
unet is not None
|
131 |
+
and unet.ip_adapter_cross_attn
|
132 |
+
and "ip_adapter" in ip_adapter_state_dict
|
133 |
+
):
|
134 |
+
update_unet_ip_adapter_cross_attn_param(
|
135 |
+
unet, ip_adapter_state_dict["ip_adapter"]
|
136 |
+
)
|
137 |
+
logger.info(
|
138 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
139 |
+
)
|
140 |
+
return ip_adapter_image_proj
|
141 |
+
|
142 |
+
|
143 |
+
def load_ip_adapter_vision_clip_encoder_by_name(
|
144 |
+
model_name: str,
|
145 |
+
ip_ckpt: Tuple[str, nn.Module],
|
146 |
+
ip_image_encoder: Tuple[str, nn.Module] = None,
|
147 |
+
cross_attention_dim: int = 768,
|
148 |
+
clip_embeddings_dim: int = 1024,
|
149 |
+
clip_extra_context_tokens: int = 4,
|
150 |
+
ip_scale: float = 0.0,
|
151 |
+
dtype: torch.dtype = torch.float16,
|
152 |
+
device: str = "cuda",
|
153 |
+
unet: nn.Module = None,
|
154 |
+
vision_clip_extractor_class_name: str = None,
|
155 |
+
) -> nn.Module:
|
156 |
+
if vision_clip_extractor_class_name is not None:
|
157 |
+
vision_clip_extractor = getattr(
|
158 |
+
clip_vision_extractor, vision_clip_extractor_class_name
|
159 |
+
)(
|
160 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
161 |
+
dtype=dtype,
|
162 |
+
device=device,
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
vision_clip_extractor = None
|
166 |
+
if model_name in [
|
167 |
+
"IPAdapter",
|
168 |
+
"musev_referencenet",
|
169 |
+
]:
|
170 |
+
if ip_image_encoder is not None:
|
171 |
+
if vision_clip_extractor_class_name is None:
|
172 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractor(
|
173 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
174 |
+
dtype=dtype,
|
175 |
+
device=device,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
vision_clip_extractor = None
|
179 |
+
ip_adapter_image_proj = ImageProjModel(
|
180 |
+
cross_attention_dim=cross_attention_dim,
|
181 |
+
clip_embeddings_dim=clip_embeddings_dim,
|
182 |
+
clip_extra_context_tokens=clip_extra_context_tokens,
|
183 |
+
)
|
184 |
+
|
185 |
+
elif model_name == "IPAdapterPlus":
|
186 |
+
if ip_image_encoder is not None:
|
187 |
+
if vision_clip_extractor_class_name is None:
|
188 |
+
vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
|
189 |
+
pretrained_model_name_or_path=ip_image_encoder,
|
190 |
+
dtype=dtype,
|
191 |
+
device=device,
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
vision_clip_extractor = None
|
195 |
+
ip_adapter_image_proj = Resampler(
|
196 |
+
dim=cross_attention_dim,
|
197 |
+
depth=4,
|
198 |
+
dim_head=64,
|
199 |
+
heads=12,
|
200 |
+
num_queries=clip_extra_context_tokens,
|
201 |
+
embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
|
202 |
+
output_dim=cross_attention_dim,
|
203 |
+
ff_mult=4,
|
204 |
+
).to(dtype=torch.float16)
|
205 |
+
else:
|
206 |
+
raise ValueError(
|
207 |
+
f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus"
|
208 |
+
)
|
209 |
+
ip_adapter_state_dict = torch.load(
|
210 |
+
ip_ckpt,
|
211 |
+
map_location="cpu",
|
212 |
+
)
|
213 |
+
ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
|
214 |
+
if (
|
215 |
+
unet is not None
|
216 |
+
and unet.ip_adapter_cross_attn
|
217 |
+
and "ip_adapter" in ip_adapter_state_dict
|
218 |
+
):
|
219 |
+
update_unet_ip_adapter_cross_attn_param(
|
220 |
+
unet, ip_adapter_state_dict["ip_adapter"]
|
221 |
+
)
|
222 |
+
logger.info(
|
223 |
+
f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
|
224 |
+
)
|
225 |
+
return (
|
226 |
+
vision_clip_extractor,
|
227 |
+
ip_adapter_image_proj,
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
|
232 |
+
unet_keys_list = [
|
233 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
234 |
+
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
235 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
236 |
+
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
237 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
238 |
+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
239 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
240 |
+
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
241 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
242 |
+
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
243 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
244 |
+
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
245 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
246 |
+
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
247 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
248 |
+
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
249 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
250 |
+
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
251 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
252 |
+
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
253 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
254 |
+
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
255 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
256 |
+
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
257 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
258 |
+
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
259 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
260 |
+
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
261 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
262 |
+
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
263 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
|
264 |
+
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
|
265 |
+
]
|
266 |
+
|
267 |
+
|
268 |
+
ip_adapter_keys_list = [
|
269 |
+
"1.to_k_ip.weight",
|
270 |
+
"1.to_v_ip.weight",
|
271 |
+
"3.to_k_ip.weight",
|
272 |
+
"3.to_v_ip.weight",
|
273 |
+
"5.to_k_ip.weight",
|
274 |
+
"5.to_v_ip.weight",
|
275 |
+
"7.to_k_ip.weight",
|
276 |
+
"7.to_v_ip.weight",
|
277 |
+
"9.to_k_ip.weight",
|
278 |
+
"9.to_v_ip.weight",
|
279 |
+
"11.to_k_ip.weight",
|
280 |
+
"11.to_v_ip.weight",
|
281 |
+
"13.to_k_ip.weight",
|
282 |
+
"13.to_v_ip.weight",
|
283 |
+
"15.to_k_ip.weight",
|
284 |
+
"15.to_v_ip.weight",
|
285 |
+
"17.to_k_ip.weight",
|
286 |
+
"17.to_v_ip.weight",
|
287 |
+
"19.to_k_ip.weight",
|
288 |
+
"19.to_v_ip.weight",
|
289 |
+
"21.to_k_ip.weight",
|
290 |
+
"21.to_v_ip.weight",
|
291 |
+
"23.to_k_ip.weight",
|
292 |
+
"23.to_v_ip.weight",
|
293 |
+
"25.to_k_ip.weight",
|
294 |
+
"25.to_v_ip.weight",
|
295 |
+
"27.to_k_ip.weight",
|
296 |
+
"27.to_v_ip.weight",
|
297 |
+
"29.to_k_ip.weight",
|
298 |
+
"29.to_v_ip.weight",
|
299 |
+
"31.to_k_ip.weight",
|
300 |
+
"31.to_v_ip.weight",
|
301 |
+
]
|
302 |
+
|
303 |
+
UNET2IPAadapter_Keys_MAPIING = {
|
304 |
+
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
|
305 |
+
}
|
306 |
+
|
307 |
+
|
308 |
+
def update_unet_ip_adapter_cross_attn_param(
|
309 |
+
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
|
310 |
+
) -> None:
|
311 |
+
"""use independent ip_adapter attn 中的 to_k, to_v in unet
|
312 |
+
ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
|
313 |
+
|
314 |
+
|
315 |
+
Args:
|
316 |
+
unet (UNet3DConditionModel): _description_
|
317 |
+
ip_adapter_state_dict (Dict): _description_
|
318 |
+
"""
|
319 |
+
unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
|
320 |
+
unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
|
321 |
+
for i, (unet_key_more, ip_adapter_key) in enumerate(
|
322 |
+
UNET2IPAadapter_Keys_MAPIING.items()
|
323 |
+
):
|
324 |
+
ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
|
325 |
+
unet_key_more_spit = unet_key_more.split(".")
|
326 |
+
unet_key = ".".join(unet_key_more_spit[:-3])
|
327 |
+
suffix = ".".join(unet_key_more_spit[-3:])
|
328 |
+
logger.debug(
|
329 |
+
f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
|
330 |
+
)
|
331 |
+
if "to_k" in suffix:
|
332 |
+
with torch.no_grad():
|
333 |
+
unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_(
|
334 |
+
ip_adapter_value.data
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
with torch.no_grad():
|
338 |
+
unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_(
|
339 |
+
ip_adapter_value.data
|
340 |
+
)
|
musev/models/referencenet.py
ADDED
@@ -0,0 +1,1216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
import logging
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
22 |
+
from einops import rearrange, repeat
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import xformers
|
26 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
27 |
+
from diffusers.models.unet_2d_condition import (
|
28 |
+
UNet2DConditionModel,
|
29 |
+
UNet2DConditionOutput,
|
30 |
+
)
|
31 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
32 |
+
from diffusers.utils.constants import USE_PEFT_BACKEND
|
33 |
+
from diffusers.utils.deprecation_utils import deprecate
|
34 |
+
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
|
35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
37 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
38 |
+
from diffusers.utils import (
|
39 |
+
USE_PEFT_BACKEND,
|
40 |
+
BaseOutput,
|
41 |
+
deprecate,
|
42 |
+
scale_lora_layers,
|
43 |
+
unscale_lora_layers,
|
44 |
+
)
|
45 |
+
from diffusers.models.activations import get_activation
|
46 |
+
from diffusers.models.attention_processor import (
|
47 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
48 |
+
CROSS_ATTENTION_PROCESSORS,
|
49 |
+
AttentionProcessor,
|
50 |
+
AttnAddedKVProcessor,
|
51 |
+
AttnProcessor,
|
52 |
+
)
|
53 |
+
from diffusers.models.embeddings import (
|
54 |
+
GaussianFourierProjection,
|
55 |
+
ImageHintTimeEmbedding,
|
56 |
+
ImageProjection,
|
57 |
+
ImageTimeEmbedding,
|
58 |
+
PositionNet,
|
59 |
+
TextImageProjection,
|
60 |
+
TextImageTimeEmbedding,
|
61 |
+
TextTimeEmbedding,
|
62 |
+
TimestepEmbedding,
|
63 |
+
Timesteps,
|
64 |
+
)
|
65 |
+
from diffusers.models.modeling_utils import ModelMixin
|
66 |
+
|
67 |
+
|
68 |
+
from ..data.data_util import align_repeat_tensor_single_dim
|
69 |
+
from .unet_3d_condition import UNet3DConditionModel
|
70 |
+
from .attention import BasicTransformerBlock, IPAttention
|
71 |
+
from .unet_2d_blocks import (
|
72 |
+
UNetMidBlock2D,
|
73 |
+
UNetMidBlock2DCrossAttn,
|
74 |
+
UNetMidBlock2DSimpleCrossAttn,
|
75 |
+
get_down_block,
|
76 |
+
get_up_block,
|
77 |
+
)
|
78 |
+
|
79 |
+
from . import Model_Register
|
80 |
+
|
81 |
+
|
82 |
+
logger = logging.getLogger(__name__)
|
83 |
+
|
84 |
+
|
85 |
+
@Model_Register.register
|
86 |
+
class ReferenceNet2D(UNet2DConditionModel, nn.Module):
|
87 |
+
"""继承 UNet2DConditionModel. 新增功能,类似controlnet 返回模型中间特征,用于后续作用
|
88 |
+
Inherit Unet2DConditionModel. Add new functions, similar to controlnet, return the intermediate features of the model for subsequent effects
|
89 |
+
Args:
|
90 |
+
UNet2DConditionModel (_type_): _description_
|
91 |
+
"""
|
92 |
+
|
93 |
+
_supports_gradient_checkpointing = True
|
94 |
+
print_idx = 0
|
95 |
+
|
96 |
+
@register_to_config
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
sample_size: int | None = None,
|
100 |
+
in_channels: int = 4,
|
101 |
+
out_channels: int = 4,
|
102 |
+
center_input_sample: bool = False,
|
103 |
+
flip_sin_to_cos: bool = True,
|
104 |
+
freq_shift: int = 0,
|
105 |
+
down_block_types: Tuple[str] = (
|
106 |
+
"CrossAttnDownBlock2D",
|
107 |
+
"CrossAttnDownBlock2D",
|
108 |
+
"CrossAttnDownBlock2D",
|
109 |
+
"DownBlock2D",
|
110 |
+
),
|
111 |
+
mid_block_type: str | None = "UNetMidBlock2DCrossAttn",
|
112 |
+
up_block_types: Tuple[str] = (
|
113 |
+
"UpBlock2D",
|
114 |
+
"CrossAttnUpBlock2D",
|
115 |
+
"CrossAttnUpBlock2D",
|
116 |
+
"CrossAttnUpBlock2D",
|
117 |
+
),
|
118 |
+
only_cross_attention: bool | Tuple[bool] = False,
|
119 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
120 |
+
layers_per_block: int | Tuple[int] = 2,
|
121 |
+
downsample_padding: int = 1,
|
122 |
+
mid_block_scale_factor: float = 1,
|
123 |
+
dropout: float = 0,
|
124 |
+
act_fn: str = "silu",
|
125 |
+
norm_num_groups: int | None = 32,
|
126 |
+
norm_eps: float = 0.00001,
|
127 |
+
cross_attention_dim: int | Tuple[int] = 1280,
|
128 |
+
transformer_layers_per_block: int | Tuple[int] | Tuple[Tuple] = 1,
|
129 |
+
reverse_transformer_layers_per_block: Tuple[Tuple[int]] | None = None,
|
130 |
+
encoder_hid_dim: int | None = None,
|
131 |
+
encoder_hid_dim_type: str | None = None,
|
132 |
+
attention_head_dim: int | Tuple[int] = 8,
|
133 |
+
num_attention_heads: int | Tuple[int] | None = None,
|
134 |
+
dual_cross_attention: bool = False,
|
135 |
+
use_linear_projection: bool = False,
|
136 |
+
class_embed_type: str | None = None,
|
137 |
+
addition_embed_type: str | None = None,
|
138 |
+
addition_time_embed_dim: int | None = None,
|
139 |
+
num_class_embeds: int | None = None,
|
140 |
+
upcast_attention: bool = False,
|
141 |
+
resnet_time_scale_shift: str = "default",
|
142 |
+
resnet_skip_time_act: bool = False,
|
143 |
+
resnet_out_scale_factor: int = 1,
|
144 |
+
time_embedding_type: str = "positional",
|
145 |
+
time_embedding_dim: int | None = None,
|
146 |
+
time_embedding_act_fn: str | None = None,
|
147 |
+
timestep_post_act: str | None = None,
|
148 |
+
time_cond_proj_dim: int | None = None,
|
149 |
+
conv_in_kernel: int = 3,
|
150 |
+
conv_out_kernel: int = 3,
|
151 |
+
projection_class_embeddings_input_dim: int | None = None,
|
152 |
+
attention_type: str = "default",
|
153 |
+
class_embeddings_concat: bool = False,
|
154 |
+
mid_block_only_cross_attention: bool | None = None,
|
155 |
+
cross_attention_norm: str | None = None,
|
156 |
+
addition_embed_type_num_heads=64,
|
157 |
+
need_self_attn_block_embs: bool = False,
|
158 |
+
need_block_embs: bool = False,
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
self.sample_size = sample_size
|
163 |
+
|
164 |
+
if num_attention_heads is not None:
|
165 |
+
raise ValueError(
|
166 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
167 |
+
)
|
168 |
+
|
169 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
170 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
171 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
172 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
173 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
174 |
+
# which is why we correct for the naming here.
|
175 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
176 |
+
|
177 |
+
# Check inputs
|
178 |
+
if len(down_block_types) != len(up_block_types):
|
179 |
+
raise ValueError(
|
180 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
181 |
+
)
|
182 |
+
|
183 |
+
if len(block_out_channels) != len(down_block_types):
|
184 |
+
raise ValueError(
|
185 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
186 |
+
)
|
187 |
+
|
188 |
+
if not isinstance(only_cross_attention, bool) and len(
|
189 |
+
only_cross_attention
|
190 |
+
) != len(down_block_types):
|
191 |
+
raise ValueError(
|
192 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
193 |
+
)
|
194 |
+
|
195 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
196 |
+
down_block_types
|
197 |
+
):
|
198 |
+
raise ValueError(
|
199 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
200 |
+
)
|
201 |
+
|
202 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
203 |
+
down_block_types
|
204 |
+
):
|
205 |
+
raise ValueError(
|
206 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
207 |
+
)
|
208 |
+
|
209 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
|
210 |
+
down_block_types
|
211 |
+
):
|
212 |
+
raise ValueError(
|
213 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
214 |
+
)
|
215 |
+
|
216 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
|
217 |
+
down_block_types
|
218 |
+
):
|
219 |
+
raise ValueError(
|
220 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
221 |
+
)
|
222 |
+
if (
|
223 |
+
isinstance(transformer_layers_per_block, list)
|
224 |
+
and reverse_transformer_layers_per_block is None
|
225 |
+
):
|
226 |
+
for layer_number_per_block in transformer_layers_per_block:
|
227 |
+
if isinstance(layer_number_per_block, list):
|
228 |
+
raise ValueError(
|
229 |
+
"Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
|
230 |
+
)
|
231 |
+
|
232 |
+
# input
|
233 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
234 |
+
self.conv_in = nn.Conv2d(
|
235 |
+
in_channels,
|
236 |
+
block_out_channels[0],
|
237 |
+
kernel_size=conv_in_kernel,
|
238 |
+
padding=conv_in_padding,
|
239 |
+
)
|
240 |
+
|
241 |
+
# time
|
242 |
+
if time_embedding_type == "fourier":
|
243 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
244 |
+
if time_embed_dim % 2 != 0:
|
245 |
+
raise ValueError(
|
246 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
247 |
+
)
|
248 |
+
self.time_proj = GaussianFourierProjection(
|
249 |
+
time_embed_dim // 2,
|
250 |
+
set_W_to_weight=False,
|
251 |
+
log=False,
|
252 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
253 |
+
)
|
254 |
+
timestep_input_dim = time_embed_dim
|
255 |
+
elif time_embedding_type == "positional":
|
256 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
257 |
+
|
258 |
+
self.time_proj = Timesteps(
|
259 |
+
block_out_channels[0], flip_sin_to_cos, freq_shift
|
260 |
+
)
|
261 |
+
timestep_input_dim = block_out_channels[0]
|
262 |
+
else:
|
263 |
+
raise ValueError(
|
264 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
265 |
+
)
|
266 |
+
|
267 |
+
self.time_embedding = TimestepEmbedding(
|
268 |
+
timestep_input_dim,
|
269 |
+
time_embed_dim,
|
270 |
+
act_fn=act_fn,
|
271 |
+
post_act_fn=timestep_post_act,
|
272 |
+
cond_proj_dim=time_cond_proj_dim,
|
273 |
+
)
|
274 |
+
|
275 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
276 |
+
encoder_hid_dim_type = "text_proj"
|
277 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
278 |
+
logger.info(
|
279 |
+
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
280 |
+
)
|
281 |
+
|
282 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
283 |
+
raise ValueError(
|
284 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
285 |
+
)
|
286 |
+
|
287 |
+
if encoder_hid_dim_type == "text_proj":
|
288 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
289 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
290 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
291 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
292 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
293 |
+
self.encoder_hid_proj = TextImageProjection(
|
294 |
+
text_embed_dim=encoder_hid_dim,
|
295 |
+
image_embed_dim=cross_attention_dim,
|
296 |
+
cross_attention_dim=cross_attention_dim,
|
297 |
+
)
|
298 |
+
elif encoder_hid_dim_type == "image_proj":
|
299 |
+
# Kandinsky 2.2
|
300 |
+
self.encoder_hid_proj = ImageProjection(
|
301 |
+
image_embed_dim=encoder_hid_dim,
|
302 |
+
cross_attention_dim=cross_attention_dim,
|
303 |
+
)
|
304 |
+
elif encoder_hid_dim_type is not None:
|
305 |
+
raise ValueError(
|
306 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
self.encoder_hid_proj = None
|
310 |
+
|
311 |
+
# class embedding
|
312 |
+
if class_embed_type is None and num_class_embeds is not None:
|
313 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
314 |
+
elif class_embed_type == "timestep":
|
315 |
+
self.class_embedding = TimestepEmbedding(
|
316 |
+
timestep_input_dim, time_embed_dim, act_fn=act_fn
|
317 |
+
)
|
318 |
+
elif class_embed_type == "identity":
|
319 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
320 |
+
elif class_embed_type == "projection":
|
321 |
+
if projection_class_embeddings_input_dim is None:
|
322 |
+
raise ValueError(
|
323 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
324 |
+
)
|
325 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
326 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
327 |
+
# 2. it projects from an arbitrary input dimension.
|
328 |
+
#
|
329 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
330 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
331 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
332 |
+
self.class_embedding = TimestepEmbedding(
|
333 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
334 |
+
)
|
335 |
+
elif class_embed_type == "simple_projection":
|
336 |
+
if projection_class_embeddings_input_dim is None:
|
337 |
+
raise ValueError(
|
338 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
339 |
+
)
|
340 |
+
self.class_embedding = nn.Linear(
|
341 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
self.class_embedding = None
|
345 |
+
|
346 |
+
if addition_embed_type == "text":
|
347 |
+
if encoder_hid_dim is not None:
|
348 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
349 |
+
else:
|
350 |
+
text_time_embedding_from_dim = cross_attention_dim
|
351 |
+
|
352 |
+
self.add_embedding = TextTimeEmbedding(
|
353 |
+
text_time_embedding_from_dim,
|
354 |
+
time_embed_dim,
|
355 |
+
num_heads=addition_embed_type_num_heads,
|
356 |
+
)
|
357 |
+
elif addition_embed_type == "text_image":
|
358 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
359 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
360 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
361 |
+
self.add_embedding = TextImageTimeEmbedding(
|
362 |
+
text_embed_dim=cross_attention_dim,
|
363 |
+
image_embed_dim=cross_attention_dim,
|
364 |
+
time_embed_dim=time_embed_dim,
|
365 |
+
)
|
366 |
+
elif addition_embed_type == "text_time":
|
367 |
+
self.add_time_proj = Timesteps(
|
368 |
+
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
369 |
+
)
|
370 |
+
self.add_embedding = TimestepEmbedding(
|
371 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
372 |
+
)
|
373 |
+
elif addition_embed_type == "image":
|
374 |
+
# Kandinsky 2.2
|
375 |
+
self.add_embedding = ImageTimeEmbedding(
|
376 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
377 |
+
)
|
378 |
+
elif addition_embed_type == "image_hint":
|
379 |
+
# Kandinsky 2.2 ControlNet
|
380 |
+
self.add_embedding = ImageHintTimeEmbedding(
|
381 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
382 |
+
)
|
383 |
+
elif addition_embed_type is not None:
|
384 |
+
raise ValueError(
|
385 |
+
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
386 |
+
)
|
387 |
+
|
388 |
+
if time_embedding_act_fn is None:
|
389 |
+
self.time_embed_act = None
|
390 |
+
else:
|
391 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
392 |
+
|
393 |
+
self.down_blocks = nn.ModuleList([])
|
394 |
+
self.up_blocks = nn.ModuleList([])
|
395 |
+
|
396 |
+
if isinstance(only_cross_attention, bool):
|
397 |
+
if mid_block_only_cross_attention is None:
|
398 |
+
mid_block_only_cross_attention = only_cross_attention
|
399 |
+
|
400 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
401 |
+
|
402 |
+
if mid_block_only_cross_attention is None:
|
403 |
+
mid_block_only_cross_attention = False
|
404 |
+
|
405 |
+
if isinstance(num_attention_heads, int):
|
406 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
407 |
+
|
408 |
+
if isinstance(attention_head_dim, int):
|
409 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
410 |
+
|
411 |
+
if isinstance(cross_attention_dim, int):
|
412 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
413 |
+
|
414 |
+
if isinstance(layers_per_block, int):
|
415 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
416 |
+
|
417 |
+
if isinstance(transformer_layers_per_block, int):
|
418 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
419 |
+
down_block_types
|
420 |
+
)
|
421 |
+
|
422 |
+
if class_embeddings_concat:
|
423 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
424 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
425 |
+
# regular time embeddings
|
426 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
427 |
+
else:
|
428 |
+
blocks_time_embed_dim = time_embed_dim
|
429 |
+
|
430 |
+
# down
|
431 |
+
output_channel = block_out_channels[0]
|
432 |
+
for i, down_block_type in enumerate(down_block_types):
|
433 |
+
input_channel = output_channel
|
434 |
+
output_channel = block_out_channels[i]
|
435 |
+
is_final_block = i == len(block_out_channels) - 1
|
436 |
+
|
437 |
+
down_block = get_down_block(
|
438 |
+
down_block_type,
|
439 |
+
num_layers=layers_per_block[i],
|
440 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
441 |
+
in_channels=input_channel,
|
442 |
+
out_channels=output_channel,
|
443 |
+
temb_channels=blocks_time_embed_dim,
|
444 |
+
add_downsample=not is_final_block,
|
445 |
+
resnet_eps=norm_eps,
|
446 |
+
resnet_act_fn=act_fn,
|
447 |
+
resnet_groups=norm_num_groups,
|
448 |
+
cross_attention_dim=cross_attention_dim[i],
|
449 |
+
num_attention_heads=num_attention_heads[i],
|
450 |
+
downsample_padding=downsample_padding,
|
451 |
+
dual_cross_attention=dual_cross_attention,
|
452 |
+
use_linear_projection=use_linear_projection,
|
453 |
+
only_cross_attention=only_cross_attention[i],
|
454 |
+
upcast_attention=upcast_attention,
|
455 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
456 |
+
attention_type=attention_type,
|
457 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
458 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
459 |
+
cross_attention_norm=cross_attention_norm,
|
460 |
+
attention_head_dim=attention_head_dim[i]
|
461 |
+
if attention_head_dim[i] is not None
|
462 |
+
else output_channel,
|
463 |
+
dropout=dropout,
|
464 |
+
)
|
465 |
+
self.down_blocks.append(down_block)
|
466 |
+
|
467 |
+
# mid
|
468 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
469 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
470 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
471 |
+
in_channels=block_out_channels[-1],
|
472 |
+
temb_channels=blocks_time_embed_dim,
|
473 |
+
dropout=dropout,
|
474 |
+
resnet_eps=norm_eps,
|
475 |
+
resnet_act_fn=act_fn,
|
476 |
+
output_scale_factor=mid_block_scale_factor,
|
477 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
478 |
+
cross_attention_dim=cross_attention_dim[-1],
|
479 |
+
num_attention_heads=num_attention_heads[-1],
|
480 |
+
resnet_groups=norm_num_groups,
|
481 |
+
dual_cross_attention=dual_cross_attention,
|
482 |
+
use_linear_projection=use_linear_projection,
|
483 |
+
upcast_attention=upcast_attention,
|
484 |
+
attention_type=attention_type,
|
485 |
+
)
|
486 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
487 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
488 |
+
in_channels=block_out_channels[-1],
|
489 |
+
temb_channels=blocks_time_embed_dim,
|
490 |
+
dropout=dropout,
|
491 |
+
resnet_eps=norm_eps,
|
492 |
+
resnet_act_fn=act_fn,
|
493 |
+
output_scale_factor=mid_block_scale_factor,
|
494 |
+
cross_attention_dim=cross_attention_dim[-1],
|
495 |
+
attention_head_dim=attention_head_dim[-1],
|
496 |
+
resnet_groups=norm_num_groups,
|
497 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
498 |
+
skip_time_act=resnet_skip_time_act,
|
499 |
+
only_cross_attention=mid_block_only_cross_attention,
|
500 |
+
cross_attention_norm=cross_attention_norm,
|
501 |
+
)
|
502 |
+
elif mid_block_type == "UNetMidBlock2D":
|
503 |
+
self.mid_block = UNetMidBlock2D(
|
504 |
+
in_channels=block_out_channels[-1],
|
505 |
+
temb_channels=blocks_time_embed_dim,
|
506 |
+
dropout=dropout,
|
507 |
+
num_layers=0,
|
508 |
+
resnet_eps=norm_eps,
|
509 |
+
resnet_act_fn=act_fn,
|
510 |
+
output_scale_factor=mid_block_scale_factor,
|
511 |
+
resnet_groups=norm_num_groups,
|
512 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
513 |
+
add_attention=False,
|
514 |
+
)
|
515 |
+
elif mid_block_type is None:
|
516 |
+
self.mid_block = None
|
517 |
+
else:
|
518 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
519 |
+
|
520 |
+
# count how many layers upsample the images
|
521 |
+
self.num_upsamplers = 0
|
522 |
+
|
523 |
+
# up
|
524 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
525 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
526 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
527 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
528 |
+
reversed_transformer_layers_per_block = (
|
529 |
+
list(reversed(transformer_layers_per_block))
|
530 |
+
if reverse_transformer_layers_per_block is None
|
531 |
+
else reverse_transformer_layers_per_block
|
532 |
+
)
|
533 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
534 |
+
|
535 |
+
output_channel = reversed_block_out_channels[0]
|
536 |
+
for i, up_block_type in enumerate(up_block_types):
|
537 |
+
is_final_block = i == len(block_out_channels) - 1
|
538 |
+
|
539 |
+
prev_output_channel = output_channel
|
540 |
+
output_channel = reversed_block_out_channels[i]
|
541 |
+
input_channel = reversed_block_out_channels[
|
542 |
+
min(i + 1, len(block_out_channels) - 1)
|
543 |
+
]
|
544 |
+
|
545 |
+
# add upsample block for all BUT final layer
|
546 |
+
if not is_final_block:
|
547 |
+
add_upsample = True
|
548 |
+
self.num_upsamplers += 1
|
549 |
+
else:
|
550 |
+
add_upsample = False
|
551 |
+
|
552 |
+
up_block = get_up_block(
|
553 |
+
up_block_type,
|
554 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
555 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
556 |
+
in_channels=input_channel,
|
557 |
+
out_channels=output_channel,
|
558 |
+
prev_output_channel=prev_output_channel,
|
559 |
+
temb_channels=blocks_time_embed_dim,
|
560 |
+
add_upsample=add_upsample,
|
561 |
+
resnet_eps=norm_eps,
|
562 |
+
resnet_act_fn=act_fn,
|
563 |
+
resolution_idx=i,
|
564 |
+
resnet_groups=norm_num_groups,
|
565 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
566 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
567 |
+
dual_cross_attention=dual_cross_attention,
|
568 |
+
use_linear_projection=use_linear_projection,
|
569 |
+
only_cross_attention=only_cross_attention[i],
|
570 |
+
upcast_attention=upcast_attention,
|
571 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
572 |
+
attention_type=attention_type,
|
573 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
574 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
575 |
+
cross_attention_norm=cross_attention_norm,
|
576 |
+
attention_head_dim=attention_head_dim[i]
|
577 |
+
if attention_head_dim[i] is not None
|
578 |
+
else output_channel,
|
579 |
+
dropout=dropout,
|
580 |
+
)
|
581 |
+
self.up_blocks.append(up_block)
|
582 |
+
prev_output_channel = output_channel
|
583 |
+
|
584 |
+
# out
|
585 |
+
if norm_num_groups is not None:
|
586 |
+
self.conv_norm_out = nn.GroupNorm(
|
587 |
+
num_channels=block_out_channels[0],
|
588 |
+
num_groups=norm_num_groups,
|
589 |
+
eps=norm_eps,
|
590 |
+
)
|
591 |
+
|
592 |
+
self.conv_act = get_activation(act_fn)
|
593 |
+
|
594 |
+
else:
|
595 |
+
self.conv_norm_out = None
|
596 |
+
self.conv_act = None
|
597 |
+
|
598 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
599 |
+
self.conv_out = nn.Conv2d(
|
600 |
+
block_out_channels[0],
|
601 |
+
out_channels,
|
602 |
+
kernel_size=conv_out_kernel,
|
603 |
+
padding=conv_out_padding,
|
604 |
+
)
|
605 |
+
|
606 |
+
if attention_type in ["gated", "gated-text-image"]:
|
607 |
+
positive_len = 768
|
608 |
+
if isinstance(cross_attention_dim, int):
|
609 |
+
positive_len = cross_attention_dim
|
610 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(
|
611 |
+
cross_attention_dim, list
|
612 |
+
):
|
613 |
+
positive_len = cross_attention_dim[0]
|
614 |
+
|
615 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
616 |
+
self.position_net = PositionNet(
|
617 |
+
positive_len=positive_len,
|
618 |
+
out_dim=cross_attention_dim,
|
619 |
+
feature_type=feature_type,
|
620 |
+
)
|
621 |
+
self.need_block_embs = need_block_embs
|
622 |
+
self.need_self_attn_block_embs = need_self_attn_block_embs
|
623 |
+
|
624 |
+
# only use referencenet soma layers, other layers set None
|
625 |
+
self.conv_norm_out = None
|
626 |
+
self.conv_act = None
|
627 |
+
self.conv_out = None
|
628 |
+
|
629 |
+
self.up_blocks[-1].attentions[-1].proj_out = None
|
630 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn1 = None
|
631 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn2 = None
|
632 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm2 = None
|
633 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].ff = None
|
634 |
+
self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm3 = None
|
635 |
+
if not self.need_self_attn_block_embs:
|
636 |
+
self.up_blocks = None
|
637 |
+
|
638 |
+
self.insert_spatial_self_attn_idx()
|
639 |
+
|
640 |
+
def forward(
|
641 |
+
self,
|
642 |
+
sample: torch.FloatTensor,
|
643 |
+
timestep: Union[torch.Tensor, float, int],
|
644 |
+
encoder_hidden_states: torch.Tensor,
|
645 |
+
class_labels: Optional[torch.Tensor] = None,
|
646 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
647 |
+
attention_mask: Optional[torch.Tensor] = None,
|
648 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
649 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
650 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
651 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
652 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
653 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
654 |
+
return_dict: bool = True,
|
655 |
+
# update new paramestes start
|
656 |
+
num_frames: int = None,
|
657 |
+
return_ndim: int = 5,
|
658 |
+
# update new paramestes end
|
659 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
660 |
+
r"""
|
661 |
+
The [`UNet2DConditionModel`] forward method.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
sample (`torch.FloatTensor`):
|
665 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
666 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
667 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
668 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
669 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
670 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
671 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
672 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
673 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
674 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
675 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
676 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
677 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
678 |
+
cross_attention_kwargs (`dict`, *optional*):
|
679 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
680 |
+
`self.processor` in
|
681 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
682 |
+
added_cond_kwargs: (`dict`, *optional*):
|
683 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
684 |
+
are passed along to the UNet blocks.
|
685 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
686 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
687 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
688 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
689 |
+
encoder_attention_mask (`torch.Tensor`):
|
690 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
691 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
692 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
693 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
694 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
695 |
+
tuple.
|
696 |
+
cross_attention_kwargs (`dict`, *optional*):
|
697 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
698 |
+
added_cond_kwargs: (`dict`, *optional*):
|
699 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
700 |
+
are passed along to the UNet blocks.
|
701 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
702 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
703 |
+
example from ControlNet side model(s)
|
704 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
705 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
706 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
707 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
708 |
+
|
709 |
+
Returns:
|
710 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
711 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
712 |
+
a `tuple` is returned where the first element is the sample tensor.
|
713 |
+
"""
|
714 |
+
|
715 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
716 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
717 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
718 |
+
# on the fly if necessary.
|
719 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
720 |
+
|
721 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
722 |
+
forward_upsample_size = False
|
723 |
+
upsample_size = None
|
724 |
+
|
725 |
+
for dim in sample.shape[-2:]:
|
726 |
+
if dim % default_overall_up_factor != 0:
|
727 |
+
# Forward upsample size to force interpolation output size.
|
728 |
+
forward_upsample_size = True
|
729 |
+
break
|
730 |
+
|
731 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
732 |
+
# expects mask of shape:
|
733 |
+
# [batch, key_tokens]
|
734 |
+
# adds singleton query_tokens dimension:
|
735 |
+
# [batch, 1, key_tokens]
|
736 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
737 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
738 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
739 |
+
if attention_mask is not None:
|
740 |
+
# assume that mask is expressed as:
|
741 |
+
# (1 = keep, 0 = discard)
|
742 |
+
# convert mask into a bias that can be added to attention scores:
|
743 |
+
# (keep = +0, discard = -10000.0)
|
744 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
745 |
+
attention_mask = attention_mask.unsqueeze(1)
|
746 |
+
|
747 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
748 |
+
if encoder_attention_mask is not None:
|
749 |
+
encoder_attention_mask = (
|
750 |
+
1 - encoder_attention_mask.to(sample.dtype)
|
751 |
+
) * -10000.0
|
752 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
753 |
+
|
754 |
+
# 0. center input if necessary
|
755 |
+
if self.config.center_input_sample:
|
756 |
+
sample = 2 * sample - 1.0
|
757 |
+
|
758 |
+
# 1. time
|
759 |
+
timesteps = timestep
|
760 |
+
if not torch.is_tensor(timesteps):
|
761 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
762 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
763 |
+
is_mps = sample.device.type == "mps"
|
764 |
+
if isinstance(timestep, float):
|
765 |
+
dtype = torch.float32 if is_mps else torch.float64
|
766 |
+
else:
|
767 |
+
dtype = torch.int32 if is_mps else torch.int64
|
768 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
769 |
+
elif len(timesteps.shape) == 0:
|
770 |
+
timesteps = timesteps[None].to(sample.device)
|
771 |
+
|
772 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
773 |
+
timesteps = timesteps.expand(sample.shape[0])
|
774 |
+
|
775 |
+
t_emb = self.time_proj(timesteps)
|
776 |
+
|
777 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
778 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
779 |
+
# there might be better ways to encapsulate this.
|
780 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
781 |
+
|
782 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
783 |
+
aug_emb = None
|
784 |
+
|
785 |
+
if self.class_embedding is not None:
|
786 |
+
if class_labels is None:
|
787 |
+
raise ValueError(
|
788 |
+
"class_labels should be provided when num_class_embeds > 0"
|
789 |
+
)
|
790 |
+
|
791 |
+
if self.config.class_embed_type == "timestep":
|
792 |
+
class_labels = self.time_proj(class_labels)
|
793 |
+
|
794 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
795 |
+
# there might be better ways to encapsulate this.
|
796 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
797 |
+
|
798 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
799 |
+
|
800 |
+
if self.config.class_embeddings_concat:
|
801 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
802 |
+
else:
|
803 |
+
emb = emb + class_emb
|
804 |
+
|
805 |
+
if self.config.addition_embed_type == "text":
|
806 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
807 |
+
elif self.config.addition_embed_type == "text_image":
|
808 |
+
# Kandinsky 2.1 - style
|
809 |
+
if "image_embeds" not in added_cond_kwargs:
|
810 |
+
raise ValueError(
|
811 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
812 |
+
)
|
813 |
+
|
814 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
815 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
816 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
817 |
+
elif self.config.addition_embed_type == "text_time":
|
818 |
+
# SDXL - style
|
819 |
+
if "text_embeds" not in added_cond_kwargs:
|
820 |
+
raise ValueError(
|
821 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
822 |
+
)
|
823 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
824 |
+
if "time_ids" not in added_cond_kwargs:
|
825 |
+
raise ValueError(
|
826 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
827 |
+
)
|
828 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
829 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
830 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
831 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
832 |
+
add_embeds = add_embeds.to(emb.dtype)
|
833 |
+
aug_emb = self.add_embedding(add_embeds)
|
834 |
+
elif self.config.addition_embed_type == "image":
|
835 |
+
# Kandinsky 2.2 - style
|
836 |
+
if "image_embeds" not in added_cond_kwargs:
|
837 |
+
raise ValueError(
|
838 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
839 |
+
)
|
840 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
841 |
+
aug_emb = self.add_embedding(image_embs)
|
842 |
+
elif self.config.addition_embed_type == "image_hint":
|
843 |
+
# Kandinsky 2.2 - style
|
844 |
+
if (
|
845 |
+
"image_embeds" not in added_cond_kwargs
|
846 |
+
or "hint" not in added_cond_kwargs
|
847 |
+
):
|
848 |
+
raise ValueError(
|
849 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
850 |
+
)
|
851 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
852 |
+
hint = added_cond_kwargs.get("hint")
|
853 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
854 |
+
sample = torch.cat([sample, hint], dim=1)
|
855 |
+
|
856 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
857 |
+
|
858 |
+
if self.time_embed_act is not None:
|
859 |
+
emb = self.time_embed_act(emb)
|
860 |
+
|
861 |
+
if (
|
862 |
+
self.encoder_hid_proj is not None
|
863 |
+
and self.config.encoder_hid_dim_type == "text_proj"
|
864 |
+
):
|
865 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
866 |
+
elif (
|
867 |
+
self.encoder_hid_proj is not None
|
868 |
+
and self.config.encoder_hid_dim_type == "text_image_proj"
|
869 |
+
):
|
870 |
+
# Kadinsky 2.1 - style
|
871 |
+
if "image_embeds" not in added_cond_kwargs:
|
872 |
+
raise ValueError(
|
873 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
874 |
+
)
|
875 |
+
|
876 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
877 |
+
encoder_hidden_states = self.encoder_hid_proj(
|
878 |
+
encoder_hidden_states, image_embeds
|
879 |
+
)
|
880 |
+
elif (
|
881 |
+
self.encoder_hid_proj is not None
|
882 |
+
and self.config.encoder_hid_dim_type == "image_proj"
|
883 |
+
):
|
884 |
+
# Kandinsky 2.2 - style
|
885 |
+
if "image_embeds" not in added_cond_kwargs:
|
886 |
+
raise ValueError(
|
887 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
888 |
+
)
|
889 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
890 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
891 |
+
elif (
|
892 |
+
self.encoder_hid_proj is not None
|
893 |
+
and self.config.encoder_hid_dim_type == "ip_image_proj"
|
894 |
+
):
|
895 |
+
if "image_embeds" not in added_cond_kwargs:
|
896 |
+
raise ValueError(
|
897 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
898 |
+
)
|
899 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
900 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(
|
901 |
+
encoder_hidden_states.dtype
|
902 |
+
)
|
903 |
+
encoder_hidden_states = torch.cat(
|
904 |
+
[encoder_hidden_states, image_embeds], dim=1
|
905 |
+
)
|
906 |
+
|
907 |
+
# need_self_attn_block_embs
|
908 |
+
# 初始化
|
909 |
+
# 或在unet中运算中会不断 append self_attn_blocks_embs,用完需要清理,
|
910 |
+
if self.need_self_attn_block_embs:
|
911 |
+
self_attn_block_embs = [None] * self.self_attn_num
|
912 |
+
else:
|
913 |
+
self_attn_block_embs = None
|
914 |
+
# 2. pre-process
|
915 |
+
sample = self.conv_in(sample)
|
916 |
+
if self.print_idx == 0:
|
917 |
+
logger.debug(f"after conv in sample={sample.mean()}")
|
918 |
+
# 2.5 GLIGEN position net
|
919 |
+
if (
|
920 |
+
cross_attention_kwargs is not None
|
921 |
+
and cross_attention_kwargs.get("gligen", None) is not None
|
922 |
+
):
|
923 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
924 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
925 |
+
cross_attention_kwargs["gligen"] = {
|
926 |
+
"objs": self.position_net(**gligen_args)
|
927 |
+
}
|
928 |
+
|
929 |
+
# 3. down
|
930 |
+
lora_scale = (
|
931 |
+
cross_attention_kwargs.get("scale", 1.0)
|
932 |
+
if cross_attention_kwargs is not None
|
933 |
+
else 1.0
|
934 |
+
)
|
935 |
+
if USE_PEFT_BACKEND:
|
936 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
937 |
+
scale_lora_layers(self, lora_scale)
|
938 |
+
|
939 |
+
is_controlnet = (
|
940 |
+
mid_block_additional_residual is not None
|
941 |
+
and down_block_additional_residuals is not None
|
942 |
+
)
|
943 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
944 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
945 |
+
# maintain backward compatibility for legacy usage, where
|
946 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
947 |
+
# but can only use one or the other
|
948 |
+
if (
|
949 |
+
not is_adapter
|
950 |
+
and mid_block_additional_residual is None
|
951 |
+
and down_block_additional_residuals is not None
|
952 |
+
):
|
953 |
+
deprecate(
|
954 |
+
"T2I should not use down_block_additional_residuals",
|
955 |
+
"1.3.0",
|
956 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
957 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
958 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
959 |
+
standard_warn=False,
|
960 |
+
)
|
961 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
962 |
+
is_adapter = True
|
963 |
+
|
964 |
+
down_block_res_samples = (sample,)
|
965 |
+
for i_downsample_block, downsample_block in enumerate(self.down_blocks):
|
966 |
+
if (
|
967 |
+
hasattr(downsample_block, "has_cross_attention")
|
968 |
+
and downsample_block.has_cross_attention
|
969 |
+
):
|
970 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
971 |
+
additional_residuals = {}
|
972 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
973 |
+
additional_residuals[
|
974 |
+
"additional_residuals"
|
975 |
+
] = down_intrablock_additional_residuals.pop(0)
|
976 |
+
if self.print_idx == 0:
|
977 |
+
logger.debug(
|
978 |
+
f"downsample_block {i_downsample_block} sample={sample.mean()}"
|
979 |
+
)
|
980 |
+
sample, res_samples = downsample_block(
|
981 |
+
hidden_states=sample,
|
982 |
+
temb=emb,
|
983 |
+
encoder_hidden_states=encoder_hidden_states,
|
984 |
+
attention_mask=attention_mask,
|
985 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
986 |
+
encoder_attention_mask=encoder_attention_mask,
|
987 |
+
**additional_residuals,
|
988 |
+
self_attn_block_embs=self_attn_block_embs,
|
989 |
+
)
|
990 |
+
else:
|
991 |
+
sample, res_samples = downsample_block(
|
992 |
+
hidden_states=sample,
|
993 |
+
temb=emb,
|
994 |
+
scale=lora_scale,
|
995 |
+
self_attn_block_embs=self_attn_block_embs,
|
996 |
+
)
|
997 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
998 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
999 |
+
|
1000 |
+
down_block_res_samples += res_samples
|
1001 |
+
|
1002 |
+
if is_controlnet:
|
1003 |
+
new_down_block_res_samples = ()
|
1004 |
+
|
1005 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1006 |
+
down_block_res_samples, down_block_additional_residuals
|
1007 |
+
):
|
1008 |
+
down_block_res_sample = (
|
1009 |
+
down_block_res_sample + down_block_additional_residual
|
1010 |
+
)
|
1011 |
+
new_down_block_res_samples = new_down_block_res_samples + (
|
1012 |
+
down_block_res_sample,
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
down_block_res_samples = new_down_block_res_samples
|
1016 |
+
|
1017 |
+
# update code start
|
1018 |
+
def reshape_return_emb(tmp_emb):
|
1019 |
+
if return_ndim == 4:
|
1020 |
+
return tmp_emb
|
1021 |
+
elif return_ndim == 5:
|
1022 |
+
return rearrange(tmp_emb, "(b t) c h w-> b c t h w", t=num_frames)
|
1023 |
+
else:
|
1024 |
+
raise ValueError(
|
1025 |
+
f"reshape_emb only support 4, 5 but given {return_ndim}"
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
if self.need_block_embs:
|
1029 |
+
return_down_block_res_samples = [
|
1030 |
+
reshape_return_emb(tmp_emb) for tmp_emb in down_block_res_samples
|
1031 |
+
]
|
1032 |
+
else:
|
1033 |
+
return_down_block_res_samples = None
|
1034 |
+
# update code end
|
1035 |
+
|
1036 |
+
# 4. mid
|
1037 |
+
if self.mid_block is not None:
|
1038 |
+
if (
|
1039 |
+
hasattr(self.mid_block, "has_cross_attention")
|
1040 |
+
and self.mid_block.has_cross_attention
|
1041 |
+
):
|
1042 |
+
sample = self.mid_block(
|
1043 |
+
sample,
|
1044 |
+
emb,
|
1045 |
+
encoder_hidden_states=encoder_hidden_states,
|
1046 |
+
attention_mask=attention_mask,
|
1047 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1048 |
+
encoder_attention_mask=encoder_attention_mask,
|
1049 |
+
self_attn_block_embs=self_attn_block_embs,
|
1050 |
+
)
|
1051 |
+
else:
|
1052 |
+
sample = self.mid_block(sample, emb)
|
1053 |
+
|
1054 |
+
# To support T2I-Adapter-XL
|
1055 |
+
if (
|
1056 |
+
is_adapter
|
1057 |
+
and len(down_intrablock_additional_residuals) > 0
|
1058 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1059 |
+
):
|
1060 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1061 |
+
|
1062 |
+
if is_controlnet:
|
1063 |
+
sample = sample + mid_block_additional_residual
|
1064 |
+
|
1065 |
+
if self.need_block_embs:
|
1066 |
+
return_mid_block_res_samples = reshape_return_emb(sample)
|
1067 |
+
logger.debug(
|
1068 |
+
f"return_mid_block_res_samples, is_leaf={return_mid_block_res_samples.is_leaf}, requires_grad={return_mid_block_res_samples.requires_grad}"
|
1069 |
+
)
|
1070 |
+
else:
|
1071 |
+
return_mid_block_res_samples = None
|
1072 |
+
|
1073 |
+
if self.up_blocks is not None:
|
1074 |
+
# update code end
|
1075 |
+
|
1076 |
+
# 5. up
|
1077 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1078 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1079 |
+
|
1080 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1081 |
+
down_block_res_samples = down_block_res_samples[
|
1082 |
+
: -len(upsample_block.resnets)
|
1083 |
+
]
|
1084 |
+
|
1085 |
+
# if we have not reached the final block and need to forward the
|
1086 |
+
# upsample size, we do it here
|
1087 |
+
if not is_final_block and forward_upsample_size:
|
1088 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1089 |
+
|
1090 |
+
if (
|
1091 |
+
hasattr(upsample_block, "has_cross_attention")
|
1092 |
+
and upsample_block.has_cross_attention
|
1093 |
+
):
|
1094 |
+
sample = upsample_block(
|
1095 |
+
hidden_states=sample,
|
1096 |
+
temb=emb,
|
1097 |
+
res_hidden_states_tuple=res_samples,
|
1098 |
+
encoder_hidden_states=encoder_hidden_states,
|
1099 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1100 |
+
upsample_size=upsample_size,
|
1101 |
+
attention_mask=attention_mask,
|
1102 |
+
encoder_attention_mask=encoder_attention_mask,
|
1103 |
+
self_attn_block_embs=self_attn_block_embs,
|
1104 |
+
)
|
1105 |
+
else:
|
1106 |
+
sample = upsample_block(
|
1107 |
+
hidden_states=sample,
|
1108 |
+
temb=emb,
|
1109 |
+
res_hidden_states_tuple=res_samples,
|
1110 |
+
upsample_size=upsample_size,
|
1111 |
+
scale=lora_scale,
|
1112 |
+
self_attn_block_embs=self_attn_block_embs,
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
# update code start
|
1116 |
+
if self.need_block_embs or self.need_self_attn_block_embs:
|
1117 |
+
if self_attn_block_embs is not None:
|
1118 |
+
self_attn_block_embs = [
|
1119 |
+
reshape_return_emb(tmp_emb=tmp_emb)
|
1120 |
+
for tmp_emb in self_attn_block_embs
|
1121 |
+
]
|
1122 |
+
self.print_idx += 1
|
1123 |
+
return (
|
1124 |
+
return_down_block_res_samples,
|
1125 |
+
return_mid_block_res_samples,
|
1126 |
+
self_attn_block_embs,
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
if not self.need_block_embs and not self.need_self_attn_block_embs:
|
1130 |
+
# 6. post-process
|
1131 |
+
if self.conv_norm_out:
|
1132 |
+
sample = self.conv_norm_out(sample)
|
1133 |
+
sample = self.conv_act(sample)
|
1134 |
+
sample = self.conv_out(sample)
|
1135 |
+
|
1136 |
+
if USE_PEFT_BACKEND:
|
1137 |
+
# remove `lora_scale` from each PEFT layer
|
1138 |
+
unscale_lora_layers(self, lora_scale)
|
1139 |
+
self.print_idx += 1
|
1140 |
+
if not return_dict:
|
1141 |
+
return (sample,)
|
1142 |
+
|
1143 |
+
return UNet2DConditionOutput(sample=sample)
|
1144 |
+
|
1145 |
+
def insert_spatial_self_attn_idx(self):
|
1146 |
+
attns, basic_transformers = self.spatial_self_attns
|
1147 |
+
self.self_attn_num = len(attns)
|
1148 |
+
for i, (name, layer) in enumerate(attns):
|
1149 |
+
logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
|
1150 |
+
if layer is not None:
|
1151 |
+
layer.spatial_self_attn_idx = i
|
1152 |
+
for i, (name, layer) in enumerate(basic_transformers):
|
1153 |
+
logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
|
1154 |
+
if layer is not None:
|
1155 |
+
layer.spatial_self_attn_idx = i
|
1156 |
+
|
1157 |
+
@property
|
1158 |
+
def spatial_self_attns(
|
1159 |
+
self,
|
1160 |
+
) -> List[Tuple[str, Attention]]:
|
1161 |
+
attns, spatial_transformers = self.get_self_attns(
|
1162 |
+
include="attentions", exclude="temp_attentions"
|
1163 |
+
)
|
1164 |
+
attns = sorted(attns)
|
1165 |
+
spatial_transformers = sorted(spatial_transformers)
|
1166 |
+
return attns, spatial_transformers
|
1167 |
+
|
1168 |
+
def get_self_attns(
|
1169 |
+
self, include: str = None, exclude: str = None
|
1170 |
+
) -> List[Tuple[str, Attention]]:
|
1171 |
+
r"""
|
1172 |
+
Returns:
|
1173 |
+
`dict` of attention attns: A dictionary containing all attention attns used in the model with
|
1174 |
+
indexed by its weight name.
|
1175 |
+
"""
|
1176 |
+
# set recursively
|
1177 |
+
attns = []
|
1178 |
+
spatial_transformers = []
|
1179 |
+
|
1180 |
+
def fn_recursive_add_attns(
|
1181 |
+
name: str,
|
1182 |
+
module: torch.nn.Module,
|
1183 |
+
attns: List[Tuple[str, Attention]],
|
1184 |
+
spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
|
1185 |
+
):
|
1186 |
+
is_target = False
|
1187 |
+
if isinstance(module, BasicTransformerBlock) and hasattr(module, "attn1"):
|
1188 |
+
is_target = True
|
1189 |
+
if include is not None:
|
1190 |
+
is_target = include in name
|
1191 |
+
if exclude is not None:
|
1192 |
+
is_target = exclude not in name
|
1193 |
+
if is_target:
|
1194 |
+
attns.append([f"{name}.attn1", module.attn1])
|
1195 |
+
spatial_transformers.append([f"{name}", module])
|
1196 |
+
for sub_name, child in module.named_children():
|
1197 |
+
fn_recursive_add_attns(
|
1198 |
+
f"{name}.{sub_name}", child, attns, spatial_transformers
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
return attns
|
1202 |
+
|
1203 |
+
for name, module in self.named_children():
|
1204 |
+
fn_recursive_add_attns(name, module, attns, spatial_transformers)
|
1205 |
+
|
1206 |
+
return attns, spatial_transformers
|
1207 |
+
|
1208 |
+
|
1209 |
+
class ReferenceNet3D(UNet3DConditionModel):
|
1210 |
+
"""继承 UNet3DConditionModel, 用于提取中间emb用于后续作用。
|
1211 |
+
Inherit Unet3DConditionModel, used to extract the middle emb for subsequent actions.
|
1212 |
+
Args:
|
1213 |
+
UNet3DConditionModel (_type_): _description_
|
1214 |
+
"""
|
1215 |
+
|
1216 |
+
pass
|
musev/models/referencenet_loader.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, Iterable, Union
|
3 |
+
import PIL
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
import logging
|
9 |
+
import inspect
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
from pprint import pprint
|
15 |
+
from collections import OrderedDict
|
16 |
+
from dataclasses import dataclass
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from omegaconf import SCMode
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
import pandas as pd
|
29 |
+
import h5py
|
30 |
+
from diffusers.models.modeling_utils import load_state_dict
|
31 |
+
from diffusers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
|
36 |
+
from .referencenet import ReferenceNet2D
|
37 |
+
from .unet_loader import update_unet_with_sd
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
def load_referencenet(
|
44 |
+
sd_referencenet_model: Tuple[str, nn.Module],
|
45 |
+
sd_model: nn.Module = None,
|
46 |
+
need_self_attn_block_embs: bool = False,
|
47 |
+
need_block_embs: bool = False,
|
48 |
+
dtype: torch.dtype = torch.float16,
|
49 |
+
cross_attention_dim: int = 768,
|
50 |
+
subfolder: str = "unet",
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Loads the ReferenceNet model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model.
|
57 |
+
sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None.
|
58 |
+
need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False.
|
59 |
+
need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False.
|
60 |
+
dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16.
|
61 |
+
cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768.
|
62 |
+
subfolder (str, optional): The subfolder of the model. Defaults to "unet".
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
nn.Module: The loaded ReferenceNet model.
|
66 |
+
"""
|
67 |
+
|
68 |
+
if isinstance(sd_referencenet_model, str):
|
69 |
+
referencenet = ReferenceNet2D.from_pretrained(
|
70 |
+
sd_referencenet_model,
|
71 |
+
subfolder=subfolder,
|
72 |
+
need_self_attn_block_embs=need_self_attn_block_embs,
|
73 |
+
need_block_embs=need_block_embs,
|
74 |
+
torch_dtype=dtype,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
)
|
77 |
+
elif isinstance(sd_referencenet_model, nn.Module):
|
78 |
+
referencenet = sd_referencenet_model
|
79 |
+
if sd_model is not None:
|
80 |
+
referencenet = update_unet_with_sd(referencenet, sd_model)
|
81 |
+
return referencenet
|
82 |
+
|
83 |
+
|
84 |
+
def load_referencenet_by_name(
|
85 |
+
model_name: str,
|
86 |
+
sd_referencenet_model: Tuple[str, nn.Module],
|
87 |
+
sd_model: nn.Module = None,
|
88 |
+
cross_attention_dim: int = 768,
|
89 |
+
dtype: torch.dtype = torch.float16,
|
90 |
+
) -> nn.Module:
|
91 |
+
"""通过模型名字 初始化 referencenet,载入预训练参数,
|
92 |
+
如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
|
93 |
+
init referencenet with model_name.
|
94 |
+
if you want to use pretrained model with simple name, you need to define it here.
|
95 |
+
Args:
|
96 |
+
model_name (str): _description_
|
97 |
+
sd_unet_model (Tuple[str, nn.Module]): _description_
|
98 |
+
sd_model (Tuple[str, nn.Module]): _description_
|
99 |
+
cross_attention_dim (int, optional): _description_. Defaults to 768.
|
100 |
+
dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
ValueError: _description_
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
nn.Module: _description_
|
107 |
+
"""
|
108 |
+
if model_name in [
|
109 |
+
"musev_referencenet",
|
110 |
+
]:
|
111 |
+
unet = load_referencenet(
|
112 |
+
sd_referencenet_model=sd_referencenet_model,
|
113 |
+
sd_model=sd_model,
|
114 |
+
cross_attention_dim=cross_attention_dim,
|
115 |
+
dtype=dtype,
|
116 |
+
need_self_attn_block_embs=False,
|
117 |
+
need_block_embs=True,
|
118 |
+
subfolder="referencenet",
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
raise ValueError(
|
122 |
+
f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16"
|
123 |
+
)
|
124 |
+
return unet
|
musev/models/resnet.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
from functools import partial
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from einops import rearrange, repeat
|
26 |
+
|
27 |
+
from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
|
28 |
+
from ..data.data_util import batch_index_fill, batch_index_select
|
29 |
+
from . import Model_Register
|
30 |
+
|
31 |
+
|
32 |
+
@Model_Register.register
|
33 |
+
class TemporalConvLayer(nn.Module):
|
34 |
+
"""
|
35 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
36 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_dim,
|
42 |
+
out_dim=None,
|
43 |
+
dropout=0.0,
|
44 |
+
keep_content_condition: bool = False,
|
45 |
+
femb_channels: Optional[int] = None,
|
46 |
+
need_temporal_weight: bool = True,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
out_dim = out_dim or in_dim
|
50 |
+
self.in_dim = in_dim
|
51 |
+
self.out_dim = out_dim
|
52 |
+
self.keep_content_condition = keep_content_condition
|
53 |
+
self.femb_channels = femb_channels
|
54 |
+
self.need_temporal_weight = need_temporal_weight
|
55 |
+
# conv layers
|
56 |
+
self.conv1 = nn.Sequential(
|
57 |
+
nn.GroupNorm(32, in_dim),
|
58 |
+
nn.SiLU(),
|
59 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
|
60 |
+
)
|
61 |
+
self.conv2 = nn.Sequential(
|
62 |
+
nn.GroupNorm(32, out_dim),
|
63 |
+
nn.SiLU(),
|
64 |
+
nn.Dropout(dropout),
|
65 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
66 |
+
)
|
67 |
+
self.conv3 = nn.Sequential(
|
68 |
+
nn.GroupNorm(32, out_dim),
|
69 |
+
nn.SiLU(),
|
70 |
+
nn.Dropout(dropout),
|
71 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
72 |
+
)
|
73 |
+
self.conv4 = nn.Sequential(
|
74 |
+
nn.GroupNorm(32, out_dim),
|
75 |
+
nn.SiLU(),
|
76 |
+
nn.Dropout(dropout),
|
77 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
78 |
+
)
|
79 |
+
|
80 |
+
# zero out the last layer params,so the conv block is identity
|
81 |
+
# nn.init.zeros_(self.conv4[-1].weight)
|
82 |
+
# nn.init.zeros_(self.conv4[-1].bias)
|
83 |
+
self.temporal_weight = nn.Parameter(
|
84 |
+
torch.tensor(
|
85 |
+
[
|
86 |
+
1e-5,
|
87 |
+
]
|
88 |
+
)
|
89 |
+
) # initialize parameter with 0
|
90 |
+
# zero out the last layer params,so the conv block is identity
|
91 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
92 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
93 |
+
self.skip_temporal_layers = False # Whether to skip temporal layer
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
hidden_states,
|
98 |
+
num_frames=1,
|
99 |
+
sample_index: torch.LongTensor = None,
|
100 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
101 |
+
femb: torch.Tensor = None,
|
102 |
+
):
|
103 |
+
if self.skip_temporal_layers is True:
|
104 |
+
return hidden_states
|
105 |
+
hidden_states_dtype = hidden_states.dtype
|
106 |
+
hidden_states = rearrange(
|
107 |
+
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
|
108 |
+
)
|
109 |
+
identity = hidden_states
|
110 |
+
hidden_states = self.conv1(hidden_states)
|
111 |
+
hidden_states = self.conv2(hidden_states)
|
112 |
+
hidden_states = self.conv3(hidden_states)
|
113 |
+
hidden_states = self.conv4(hidden_states)
|
114 |
+
# 保留condition对应的frames,便于保持前序内容帧,提升一致性
|
115 |
+
if self.keep_content_condition:
|
116 |
+
mask = torch.ones_like(hidden_states, device=hidden_states.device)
|
117 |
+
mask = batch_index_fill(
|
118 |
+
mask, dim=2, index=vision_conditon_frames_sample_index, value=0
|
119 |
+
)
|
120 |
+
if self.need_temporal_weight:
|
121 |
+
hidden_states = (
|
122 |
+
identity + torch.abs(self.temporal_weight) * mask * hidden_states
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
hidden_states = identity + mask * hidden_states
|
126 |
+
else:
|
127 |
+
if self.need_temporal_weight:
|
128 |
+
hidden_states = (
|
129 |
+
identity + torch.abs(self.temporal_weight) * hidden_states
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
hidden_states = identity + hidden_states
|
133 |
+
hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
|
134 |
+
hidden_states = hidden_states.to(dtype=hidden_states_dtype)
|
135 |
+
return hidden_states
|
musev/models/super_model.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from typing import Any, Dict, Tuple, Union, Optional
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
12 |
+
|
13 |
+
from ..data.data_util import align_repeat_tensor_single_dim
|
14 |
+
|
15 |
+
from .unet_3d_condition import UNet3DConditionModel
|
16 |
+
from .referencenet import ReferenceNet2D
|
17 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class SuperUNet3DConditionModel(nn.Module):
|
23 |
+
"""封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
|
24 |
+
主要作用
|
25 |
+
1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
|
26 |
+
2. 便于 accelerator 的分布式训练;
|
27 |
+
|
28 |
+
wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
|
29 |
+
1. support controlnet, referencenet, etc.
|
30 |
+
2. support accelerator distributed training
|
31 |
+
"""
|
32 |
+
|
33 |
+
_supports_gradient_checkpointing = True
|
34 |
+
print_idx = 0
|
35 |
+
|
36 |
+
# @register_to_config
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
unet: nn.Module,
|
40 |
+
referencenet: nn.Module = None,
|
41 |
+
controlnet: nn.Module = None,
|
42 |
+
vae: nn.Module = None,
|
43 |
+
text_encoder: nn.Module = None,
|
44 |
+
tokenizer: nn.Module = None,
|
45 |
+
text_emb_extractor: nn.Module = None,
|
46 |
+
clip_vision_extractor: nn.Module = None,
|
47 |
+
ip_adapter_image_proj: nn.Module = None,
|
48 |
+
) -> None:
|
49 |
+
"""_summary_
|
50 |
+
|
51 |
+
Args:
|
52 |
+
unet (nn.Module): _description_
|
53 |
+
referencenet (nn.Module, optional): _description_. Defaults to None.
|
54 |
+
controlnet (nn.Module, optional): _description_. Defaults to None.
|
55 |
+
vae (nn.Module, optional): _description_. Defaults to None.
|
56 |
+
text_encoder (nn.Module, optional): _description_. Defaults to None.
|
57 |
+
tokenizer (nn.Module, optional): _description_. Defaults to None.
|
58 |
+
text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
|
59 |
+
clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
|
60 |
+
"""
|
61 |
+
super().__init__()
|
62 |
+
self.unet = unet
|
63 |
+
self.referencenet = referencenet
|
64 |
+
self.controlnet = controlnet
|
65 |
+
self.vae = vae
|
66 |
+
self.text_encoder = text_encoder
|
67 |
+
self.tokenizer = tokenizer
|
68 |
+
self.text_emb_extractor = text_emb_extractor
|
69 |
+
self.clip_vision_extractor = clip_vision_extractor
|
70 |
+
self.ip_adapter_image_proj = ip_adapter_image_proj
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
unet_params: Dict,
|
75 |
+
encoder_hidden_states: torch.Tensor,
|
76 |
+
referencenet_params: Dict = None,
|
77 |
+
controlnet_params: Dict = None,
|
78 |
+
controlnet_scale: float = 1.0,
|
79 |
+
vision_clip_emb: Union[torch.Tensor, None] = None,
|
80 |
+
prompt_only_use_image_prompt: bool = False,
|
81 |
+
):
|
82 |
+
"""_summary_
|
83 |
+
|
84 |
+
Args:
|
85 |
+
unet_params (Dict): _description_
|
86 |
+
encoder_hidden_states (torch.Tensor): b t n d
|
87 |
+
referencenet_params (Dict, optional): _description_. Defaults to None.
|
88 |
+
controlnet_params (Dict, optional): _description_. Defaults to None.
|
89 |
+
controlnet_scale (float, optional): _description_. Defaults to 1.0.
|
90 |
+
vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
|
91 |
+
prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
_type_: _description_
|
95 |
+
"""
|
96 |
+
batch_size = unet_params["sample"].shape[0]
|
97 |
+
time_size = unet_params["sample"].shape[2]
|
98 |
+
|
99 |
+
# ip_adapter_cross_attn, prepare image prompt
|
100 |
+
if vision_clip_emb is not None:
|
101 |
+
# b t n d -> b t n d
|
102 |
+
if self.print_idx == 0:
|
103 |
+
logger.debug(
|
104 |
+
f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
105 |
+
)
|
106 |
+
if vision_clip_emb.ndim == 3:
|
107 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
|
108 |
+
if self.ip_adapter_image_proj is not None:
|
109 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
|
110 |
+
vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
|
111 |
+
if self.print_idx == 0:
|
112 |
+
logger.debug(
|
113 |
+
f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
114 |
+
)
|
115 |
+
if vision_clip_emb.ndim == 2:
|
116 |
+
vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
|
117 |
+
vision_clip_emb = rearrange(
|
118 |
+
vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
|
119 |
+
)
|
120 |
+
vision_clip_emb = align_repeat_tensor_single_dim(
|
121 |
+
vision_clip_emb, target_length=time_size, dim=1
|
122 |
+
)
|
123 |
+
if self.print_idx == 0:
|
124 |
+
logger.debug(
|
125 |
+
f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
|
126 |
+
)
|
127 |
+
|
128 |
+
if vision_clip_emb is None and encoder_hidden_states is not None:
|
129 |
+
vision_clip_emb = encoder_hidden_states
|
130 |
+
if vision_clip_emb is not None and encoder_hidden_states is None:
|
131 |
+
encoder_hidden_states = vision_clip_emb
|
132 |
+
# 当 prompt_only_use_image_prompt 为True时,
|
133 |
+
# 1. referencenet 都使用 vision_clip_emb
|
134 |
+
# 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
|
135 |
+
# 3. controlnet 当前使用 text_prompt
|
136 |
+
|
137 |
+
# when prompt_only_use_image_prompt True,
|
138 |
+
# 1. referencenet use vision_clip_emb
|
139 |
+
# 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
|
140 |
+
# 3. controlnet use text_prompt
|
141 |
+
|
142 |
+
# extract referencenet emb
|
143 |
+
if self.referencenet is not None and referencenet_params is not None:
|
144 |
+
referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
|
145 |
+
vision_clip_emb,
|
146 |
+
target_length=referencenet_params["num_frames"],
|
147 |
+
dim=1,
|
148 |
+
)
|
149 |
+
referencenet_params["encoder_hidden_states"] = rearrange(
|
150 |
+
referencenet_encoder_hidden_states, "b t n d->(b t) n d"
|
151 |
+
)
|
152 |
+
referencenet_out = self.referencenet(**referencenet_params)
|
153 |
+
(
|
154 |
+
down_block_refer_embs,
|
155 |
+
mid_block_refer_emb,
|
156 |
+
refer_self_attn_emb,
|
157 |
+
) = referencenet_out
|
158 |
+
if down_block_refer_embs is not None:
|
159 |
+
if self.print_idx == 0:
|
160 |
+
logger.debug(
|
161 |
+
f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
|
162 |
+
)
|
163 |
+
for i, down_emb in enumerate(down_block_refer_embs):
|
164 |
+
if self.print_idx == 0:
|
165 |
+
logger.debug(
|
166 |
+
f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
if self.print_idx == 0:
|
170 |
+
logger.debug(f"down_block_refer_embs is None")
|
171 |
+
if mid_block_refer_emb is not None:
|
172 |
+
if self.print_idx == 0:
|
173 |
+
logger.debug(
|
174 |
+
f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
if self.print_idx == 0:
|
178 |
+
logger.debug(f"mid_block_refer_emb is None")
|
179 |
+
if refer_self_attn_emb is not None:
|
180 |
+
if self.print_idx == 0:
|
181 |
+
logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
|
182 |
+
for i, self_attn_emb in enumerate(refer_self_attn_emb):
|
183 |
+
if self.print_idx == 0:
|
184 |
+
logger.debug(
|
185 |
+
f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
if self.print_idx == 0:
|
189 |
+
logger.debug(f"refer_self_attn_emb is None")
|
190 |
+
else:
|
191 |
+
down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
|
192 |
+
None,
|
193 |
+
None,
|
194 |
+
None,
|
195 |
+
)
|
196 |
+
|
197 |
+
# extract controlnet emb
|
198 |
+
if self.controlnet is not None and controlnet_params is not None:
|
199 |
+
controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
|
200 |
+
encoder_hidden_states,
|
201 |
+
target_length=unet_params["sample"].shape[2],
|
202 |
+
dim=1,
|
203 |
+
)
|
204 |
+
controlnet_params["encoder_hidden_states"] = rearrange(
|
205 |
+
controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
|
206 |
+
)
|
207 |
+
(
|
208 |
+
down_block_additional_residuals,
|
209 |
+
mid_block_additional_residual,
|
210 |
+
) = self.controlnet(**controlnet_params)
|
211 |
+
if controlnet_scale != 1.0:
|
212 |
+
down_block_additional_residuals = [
|
213 |
+
x * controlnet_scale for x in down_block_additional_residuals
|
214 |
+
]
|
215 |
+
mid_block_additional_residual = (
|
216 |
+
mid_block_additional_residual * controlnet_scale
|
217 |
+
)
|
218 |
+
for i, down_block_additional_residual in enumerate(
|
219 |
+
down_block_additional_residuals
|
220 |
+
):
|
221 |
+
if self.print_idx == 0:
|
222 |
+
logger.debug(
|
223 |
+
f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
|
224 |
+
)
|
225 |
+
|
226 |
+
if self.print_idx == 0:
|
227 |
+
logger.debug(
|
228 |
+
f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
down_block_additional_residuals = None
|
232 |
+
mid_block_additional_residual = None
|
233 |
+
|
234 |
+
if prompt_only_use_image_prompt and vision_clip_emb is not None:
|
235 |
+
encoder_hidden_states = vision_clip_emb
|
236 |
+
|
237 |
+
# run unet
|
238 |
+
out = self.unet(
|
239 |
+
**unet_params,
|
240 |
+
down_block_refer_embs=down_block_refer_embs,
|
241 |
+
mid_block_refer_emb=mid_block_refer_emb,
|
242 |
+
refer_self_attn_emb=refer_self_attn_emb,
|
243 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
244 |
+
mid_block_additional_residual=mid_block_additional_residual,
|
245 |
+
encoder_hidden_states=encoder_hidden_states,
|
246 |
+
vision_clip_emb=vision_clip_emb,
|
247 |
+
)
|
248 |
+
self.print_idx += 1
|
249 |
+
return out
|
250 |
+
|
251 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
252 |
+
if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
|
253 |
+
module.gradient_checkpointing = value
|
musev/models/temporal_transformer.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py
|
16 |
+
from __future__ import annotations
|
17 |
+
from copy import deepcopy
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import List, Literal, Optional
|
20 |
+
import logging
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
from einops import rearrange, repeat
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
29 |
+
from diffusers.models.transformer_temporal import (
|
30 |
+
TransformerTemporalModelOutput,
|
31 |
+
TransformerTemporalModel as DiffusersTransformerTemporalModel,
|
32 |
+
)
|
33 |
+
from diffusers.models.attention_processor import AttnProcessor
|
34 |
+
|
35 |
+
from mmcm.utils.gpu_util import get_gpu_status
|
36 |
+
from ..data.data_util import (
|
37 |
+
batch_concat_two_tensor_with_index,
|
38 |
+
batch_index_fill,
|
39 |
+
batch_index_select,
|
40 |
+
concat_two_tensor,
|
41 |
+
align_repeat_tensor_single_dim,
|
42 |
+
)
|
43 |
+
from ..utils.attention_util import generate_sparse_causcal_attn_mask
|
44 |
+
from .attention import BasicTransformerBlock
|
45 |
+
from .attention_processor import (
|
46 |
+
BaseIPAttnProcessor,
|
47 |
+
)
|
48 |
+
from . import Model_Register
|
49 |
+
|
50 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
51 |
+
# 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉
|
52 |
+
# if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal
|
53 |
+
logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
|
56 |
+
@Model_Register.register
|
57 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
58 |
+
"""
|
59 |
+
Transformer model for video-like data.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
63 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
64 |
+
in_channels (`int`, *optional*):
|
65 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
66 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
67 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
68 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
69 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
70 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
71 |
+
`ImagePositionalEmbeddings`.
|
72 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
73 |
+
attention_bias (`bool`, *optional*):
|
74 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
75 |
+
double_self_attention (`bool`, *optional*):
|
76 |
+
Configure if each TransformerBlock should contain two self-attention layers
|
77 |
+
"""
|
78 |
+
|
79 |
+
@register_to_config
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
num_attention_heads: int = 16,
|
83 |
+
attention_head_dim: int = 88,
|
84 |
+
in_channels: Optional[int] = None,
|
85 |
+
out_channels: Optional[int] = None,
|
86 |
+
num_layers: int = 1,
|
87 |
+
femb_channels: Optional[int] = None,
|
88 |
+
dropout: float = 0.0,
|
89 |
+
norm_num_groups: int = 32,
|
90 |
+
cross_attention_dim: Optional[int] = None,
|
91 |
+
attention_bias: bool = False,
|
92 |
+
sample_size: Optional[int] = None,
|
93 |
+
activation_fn: str = "geglu",
|
94 |
+
norm_elementwise_affine: bool = True,
|
95 |
+
double_self_attention: bool = True,
|
96 |
+
allow_xformers: bool = False,
|
97 |
+
only_cross_attention: bool = False,
|
98 |
+
keep_content_condition: bool = False,
|
99 |
+
need_spatial_position_emb: bool = False,
|
100 |
+
need_temporal_weight: bool = True,
|
101 |
+
self_attn_mask: str = None,
|
102 |
+
# TODO: 运行参数,有待改到forward里面去
|
103 |
+
# TODO: running parameters, need to be moved to forward
|
104 |
+
image_scale: float = 1.0,
|
105 |
+
processor: AttnProcessor | None = None,
|
106 |
+
remove_femb_non_linear: bool = False,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.num_attention_heads = num_attention_heads
|
111 |
+
self.attention_head_dim = attention_head_dim
|
112 |
+
|
113 |
+
inner_dim = num_attention_heads * attention_head_dim
|
114 |
+
self.inner_dim = inner_dim
|
115 |
+
self.in_channels = in_channels
|
116 |
+
|
117 |
+
self.norm = torch.nn.GroupNorm(
|
118 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
119 |
+
)
|
120 |
+
|
121 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
122 |
+
|
123 |
+
# 2. Define temporal positional embedding
|
124 |
+
self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
|
125 |
+
self.remove_femb_non_linear = remove_femb_non_linear
|
126 |
+
if not remove_femb_non_linear:
|
127 |
+
self.nonlinearity = nn.SiLU()
|
128 |
+
|
129 |
+
# spatial_position_emb 使用femb_的参数配置
|
130 |
+
self.need_spatial_position_emb = need_spatial_position_emb
|
131 |
+
if need_spatial_position_emb:
|
132 |
+
self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
|
133 |
+
# 3. Define transformers blocks
|
134 |
+
# TODO: 该实现方式不好,待优化
|
135 |
+
# TODO: bad implementation, need to be optimized
|
136 |
+
self.need_ipadapter = False
|
137 |
+
self.cross_attn_temporal_cond = False
|
138 |
+
self.allow_xformers = allow_xformers
|
139 |
+
if processor is not None and isinstance(processor, BaseIPAttnProcessor):
|
140 |
+
self.cross_attn_temporal_cond = True
|
141 |
+
self.allow_xformers = False
|
142 |
+
if "NonParam" not in processor.__class__.__name__:
|
143 |
+
self.need_ipadapter = True
|
144 |
+
|
145 |
+
self.transformer_blocks = nn.ModuleList(
|
146 |
+
[
|
147 |
+
BasicTransformerBlock(
|
148 |
+
inner_dim,
|
149 |
+
num_attention_heads,
|
150 |
+
attention_head_dim,
|
151 |
+
dropout=dropout,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
activation_fn=activation_fn,
|
154 |
+
attention_bias=attention_bias,
|
155 |
+
double_self_attention=double_self_attention,
|
156 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
157 |
+
allow_xformers=allow_xformers,
|
158 |
+
only_cross_attention=only_cross_attention,
|
159 |
+
cross_attn_temporal_cond=self.need_ipadapter,
|
160 |
+
image_scale=image_scale,
|
161 |
+
processor=processor,
|
162 |
+
)
|
163 |
+
for d in range(num_layers)
|
164 |
+
]
|
165 |
+
)
|
166 |
+
|
167 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
168 |
+
|
169 |
+
self.need_temporal_weight = need_temporal_weight
|
170 |
+
if need_temporal_weight:
|
171 |
+
self.temporal_weight = nn.Parameter(
|
172 |
+
torch.tensor(
|
173 |
+
[
|
174 |
+
1e-5,
|
175 |
+
]
|
176 |
+
)
|
177 |
+
) # initialize parameter with 0
|
178 |
+
self.skip_temporal_layers = False # Whether to skip temporal layer
|
179 |
+
self.keep_content_condition = keep_content_condition
|
180 |
+
self.self_attn_mask = self_attn_mask
|
181 |
+
self.only_cross_attention = only_cross_attention
|
182 |
+
self.double_self_attention = double_self_attention
|
183 |
+
self.cross_attention_dim = cross_attention_dim
|
184 |
+
self.image_scale = image_scale
|
185 |
+
# zero out the last layer params,so the conv block is identity
|
186 |
+
nn.init.zeros_(self.proj_out.weight)
|
187 |
+
nn.init.zeros_(self.proj_out.bias)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
hidden_states,
|
192 |
+
femb,
|
193 |
+
encoder_hidden_states=None,
|
194 |
+
timestep=None,
|
195 |
+
class_labels=None,
|
196 |
+
num_frames=1,
|
197 |
+
cross_attention_kwargs=None,
|
198 |
+
sample_index: torch.LongTensor = None,
|
199 |
+
vision_conditon_frames_sample_index: torch.LongTensor = None,
|
200 |
+
spatial_position_emb: torch.Tensor = None,
|
201 |
+
return_dict: bool = True,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Args:
|
205 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
206 |
+
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
207 |
+
hidden_states
|
208 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
209 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
210 |
+
self-attention.
|
211 |
+
timestep ( `torch.long`, *optional*):
|
212 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
213 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
214 |
+
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
215 |
+
conditioning.
|
216 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
217 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
[`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
|
221 |
+
[`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
|
222 |
+
When returning a tuple, the first element is the sample tensor.
|
223 |
+
"""
|
224 |
+
if self.skip_temporal_layers is True:
|
225 |
+
if not return_dict:
|
226 |
+
return (hidden_states,)
|
227 |
+
|
228 |
+
return TransformerTemporalModelOutput(sample=hidden_states)
|
229 |
+
|
230 |
+
# 1. Input
|
231 |
+
batch_frames, channel, height, width = hidden_states.shape
|
232 |
+
batch_size = batch_frames // num_frames
|
233 |
+
|
234 |
+
hidden_states = rearrange(
|
235 |
+
hidden_states, "(b t) c h w -> b c t h w", b=batch_size
|
236 |
+
)
|
237 |
+
residual = hidden_states
|
238 |
+
|
239 |
+
hidden_states = self.norm(hidden_states)
|
240 |
+
|
241 |
+
hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c")
|
242 |
+
|
243 |
+
hidden_states = self.proj_in(hidden_states)
|
244 |
+
|
245 |
+
# 2 Positional embedding
|
246 |
+
# adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574
|
247 |
+
if not self.remove_femb_non_linear:
|
248 |
+
femb = self.nonlinearity(femb)
|
249 |
+
femb = self.frame_emb_proj(femb)
|
250 |
+
femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0)
|
251 |
+
hidden_states = hidden_states + femb
|
252 |
+
|
253 |
+
# 3. Blocks
|
254 |
+
if (
|
255 |
+
(self.only_cross_attention or not self.double_self_attention)
|
256 |
+
and self.cross_attention_dim is not None
|
257 |
+
and encoder_hidden_states is not None
|
258 |
+
):
|
259 |
+
encoder_hidden_states = align_repeat_tensor_single_dim(
|
260 |
+
encoder_hidden_states,
|
261 |
+
hidden_states.shape[0],
|
262 |
+
dim=0,
|
263 |
+
n_src_base_length=batch_size,
|
264 |
+
)
|
265 |
+
|
266 |
+
for i, block in enumerate(self.transformer_blocks):
|
267 |
+
hidden_states = block(
|
268 |
+
hidden_states,
|
269 |
+
encoder_hidden_states=encoder_hidden_states,
|
270 |
+
timestep=timestep,
|
271 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
272 |
+
class_labels=class_labels,
|
273 |
+
)
|
274 |
+
|
275 |
+
# 4. Output
|
276 |
+
hidden_states = self.proj_out(hidden_states)
|
277 |
+
hidden_states = rearrange(
|
278 |
+
hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width
|
279 |
+
).contiguous()
|
280 |
+
|
281 |
+
# 保留condition对应的frames,便于保持前序内容帧,提升一致性
|
282 |
+
# keep the frames corresponding to the condition to maintain the previous content frames and improve consistency
|
283 |
+
if (
|
284 |
+
vision_conditon_frames_sample_index is not None
|
285 |
+
and self.keep_content_condition
|
286 |
+
):
|
287 |
+
mask = torch.ones_like(hidden_states, device=hidden_states.device)
|
288 |
+
mask = batch_index_fill(
|
289 |
+
mask, dim=2, index=vision_conditon_frames_sample_index, value=0
|
290 |
+
)
|
291 |
+
if self.need_temporal_weight:
|
292 |
+
output = (
|
293 |
+
residual + torch.abs(self.temporal_weight) * mask * hidden_states
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
output = residual + mask * hidden_states
|
297 |
+
else:
|
298 |
+
if self.need_temporal_weight:
|
299 |
+
output = residual + torch.abs(self.temporal_weight) * hidden_states
|
300 |
+
else:
|
301 |
+
output = residual + mask * hidden_states
|
302 |
+
|
303 |
+
# output = torch.abs(self.temporal_weight) * hidden_states + residual
|
304 |
+
output = rearrange(output, "b c t h w -> (b t) c h w")
|
305 |
+
if not return_dict:
|
306 |
+
return (output,)
|
307 |
+
|
308 |
+
return TransformerTemporalModelOutput(sample=output)
|
musev/models/text_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class TextEmbExtractor(nn.Module):
|
6 |
+
def __init__(self, tokenizer, text_encoder) -> None:
|
7 |
+
super(TextEmbExtractor, self).__init__()
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.text_encoder = text_encoder
|
10 |
+
|
11 |
+
def forward(
|
12 |
+
self,
|
13 |
+
texts,
|
14 |
+
text_params: Dict = None,
|
15 |
+
):
|
16 |
+
if text_params is None:
|
17 |
+
text_params = {}
|
18 |
+
special_prompt_input = self.tokenizer(
|
19 |
+
texts,
|
20 |
+
max_length=self.tokenizer.model_max_length,
|
21 |
+
padding="max_length",
|
22 |
+
truncation=True,
|
23 |
+
return_tensors="pt",
|
24 |
+
)
|
25 |
+
if (
|
26 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
27 |
+
and self.text_encoder.config.use_attention_mask
|
28 |
+
):
|
29 |
+
attention_mask = special_prompt_input.attention_mask.to(
|
30 |
+
self.text_encoder.device
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
attention_mask = None
|
34 |
+
|
35 |
+
embeddings = self.text_encoder(
|
36 |
+
special_prompt_input.input_ids.to(self.text_encoder.device),
|
37 |
+
attention_mask=attention_mask,
|
38 |
+
**text_params
|
39 |
+
)
|
40 |
+
return embeddings
|
musev/models/transformer_2d.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from __future__ import annotations
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Literal, Optional
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from diffusers.models.transformer_2d import (
|
26 |
+
Transformer2DModelOutput,
|
27 |
+
Transformer2DModel as DiffusersTransformer2DModel,
|
28 |
+
)
|
29 |
+
|
30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
31 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
32 |
+
from diffusers.utils import BaseOutput, deprecate
|
33 |
+
from diffusers.models.attention import (
|
34 |
+
BasicTransformerBlock as DiffusersBasicTransformerBlock,
|
35 |
+
)
|
36 |
+
from diffusers.models.embeddings import PatchEmbed
|
37 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
38 |
+
from diffusers.models.modeling_utils import ModelMixin
|
39 |
+
from diffusers.utils.constants import USE_PEFT_BACKEND
|
40 |
+
|
41 |
+
from .attention import BasicTransformerBlock
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
# 本部分 与 diffusers/models/transformer_2d.py 几乎一样
|
46 |
+
# 更新部分
|
47 |
+
# 1. 替换自定义 BasicTransformerBlock 类
|
48 |
+
# 2. 在forward 里增加了 self_attn_block_embs 用于 提取 self_attn 中的emb
|
49 |
+
|
50 |
+
# this module is same as diffusers/models/transformer_2d.py. The update part is
|
51 |
+
# 1 redefine BasicTransformerBlock
|
52 |
+
# 2. add self_attn_block_embs in forward to extract emb from self_attn
|
53 |
+
|
54 |
+
|
55 |
+
class Transformer2DModel(DiffusersTransformer2DModel):
|
56 |
+
"""
|
57 |
+
A 2D Transformer model for image-like data.
|
58 |
+
|
59 |
+
Parameters:
|
60 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
61 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
62 |
+
in_channels (`int`, *optional*):
|
63 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
64 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
65 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
66 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
67 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
68 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
69 |
+
num_vector_embeds (`int`, *optional*):
|
70 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
71 |
+
Includes the class for the masked latent pixel.
|
72 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
73 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
74 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
75 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
76 |
+
added to the hidden states.
|
77 |
+
|
78 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
79 |
+
attention_bias (`bool`, *optional*):
|
80 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
81 |
+
"""
|
82 |
+
|
83 |
+
@register_to_config
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
num_attention_heads: int = 16,
|
87 |
+
attention_head_dim: int = 88,
|
88 |
+
in_channels: int | None = None,
|
89 |
+
out_channels: int | None = None,
|
90 |
+
num_layers: int = 1,
|
91 |
+
dropout: float = 0,
|
92 |
+
norm_num_groups: int = 32,
|
93 |
+
cross_attention_dim: int | None = None,
|
94 |
+
attention_bias: bool = False,
|
95 |
+
sample_size: int | None = None,
|
96 |
+
num_vector_embeds: int | None = None,
|
97 |
+
patch_size: int | None = None,
|
98 |
+
activation_fn: str = "geglu",
|
99 |
+
num_embeds_ada_norm: int | None = None,
|
100 |
+
use_linear_projection: bool = False,
|
101 |
+
only_cross_attention: bool = False,
|
102 |
+
double_self_attention: bool = False,
|
103 |
+
upcast_attention: bool = False,
|
104 |
+
norm_type: str = "layer_norm",
|
105 |
+
norm_elementwise_affine: bool = True,
|
106 |
+
attention_type: str = "default",
|
107 |
+
cross_attn_temporal_cond: bool = False,
|
108 |
+
ip_adapter_cross_attn: bool = False,
|
109 |
+
need_t2i_facein: bool = False,
|
110 |
+
need_t2i_ip_adapter_face: bool = False,
|
111 |
+
image_scale: float = 1.0,
|
112 |
+
):
|
113 |
+
super().__init__(
|
114 |
+
num_attention_heads,
|
115 |
+
attention_head_dim,
|
116 |
+
in_channels,
|
117 |
+
out_channels,
|
118 |
+
num_layers,
|
119 |
+
dropout,
|
120 |
+
norm_num_groups,
|
121 |
+
cross_attention_dim,
|
122 |
+
attention_bias,
|
123 |
+
sample_size,
|
124 |
+
num_vector_embeds,
|
125 |
+
patch_size,
|
126 |
+
activation_fn,
|
127 |
+
num_embeds_ada_norm,
|
128 |
+
use_linear_projection,
|
129 |
+
only_cross_attention,
|
130 |
+
double_self_attention,
|
131 |
+
upcast_attention,
|
132 |
+
norm_type,
|
133 |
+
norm_elementwise_affine,
|
134 |
+
attention_type,
|
135 |
+
)
|
136 |
+
inner_dim = num_attention_heads * attention_head_dim
|
137 |
+
self.transformer_blocks = nn.ModuleList(
|
138 |
+
[
|
139 |
+
BasicTransformerBlock(
|
140 |
+
inner_dim,
|
141 |
+
num_attention_heads,
|
142 |
+
attention_head_dim,
|
143 |
+
dropout=dropout,
|
144 |
+
cross_attention_dim=cross_attention_dim,
|
145 |
+
activation_fn=activation_fn,
|
146 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
147 |
+
attention_bias=attention_bias,
|
148 |
+
only_cross_attention=only_cross_attention,
|
149 |
+
double_self_attention=double_self_attention,
|
150 |
+
upcast_attention=upcast_attention,
|
151 |
+
norm_type=norm_type,
|
152 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
153 |
+
attention_type=attention_type,
|
154 |
+
cross_attn_temporal_cond=cross_attn_temporal_cond,
|
155 |
+
ip_adapter_cross_attn=ip_adapter_cross_attn,
|
156 |
+
need_t2i_facein=need_t2i_facein,
|
157 |
+
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
|
158 |
+
image_scale=image_scale,
|
159 |
+
)
|
160 |
+
for d in range(num_layers)
|
161 |
+
]
|
162 |
+
)
|
163 |
+
self.num_layers = num_layers
|
164 |
+
self.cross_attn_temporal_cond = cross_attn_temporal_cond
|
165 |
+
self.ip_adapter_cross_attn = ip_adapter_cross_attn
|
166 |
+
|
167 |
+
self.need_t2i_facein = need_t2i_facein
|
168 |
+
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
|
169 |
+
self.image_scale = image_scale
|
170 |
+
self.print_idx = 0
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
hidden_states: torch.Tensor,
|
175 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
176 |
+
timestep: Optional[torch.LongTensor] = None,
|
177 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
178 |
+
class_labels: Optional[torch.LongTensor] = None,
|
179 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
180 |
+
attention_mask: Optional[torch.Tensor] = None,
|
181 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
182 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
183 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
184 |
+
return_dict: bool = True,
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
The [`Transformer2DModel`] forward method.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
191 |
+
Input `hidden_states`.
|
192 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
193 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
194 |
+
self-attention.
|
195 |
+
timestep ( `torch.LongTensor`, *optional*):
|
196 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
197 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
198 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
199 |
+
`AdaLayerZeroNorm`.
|
200 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
201 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
202 |
+
`self.processor` in
|
203 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
204 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
205 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
206 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
207 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
208 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
209 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
210 |
+
|
211 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
212 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
213 |
+
|
214 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
215 |
+
above. This bias will be added to the cross-attention scores.
|
216 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
217 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
218 |
+
tuple.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
222 |
+
`tuple` where the first element is the sample tensor.
|
223 |
+
"""
|
224 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
225 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
226 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
227 |
+
# expects mask of shape:
|
228 |
+
# [batch, key_tokens]
|
229 |
+
# adds singleton query_tokens dimension:
|
230 |
+
# [batch, 1, key_tokens]
|
231 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
232 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
233 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
234 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
235 |
+
# assume that mask is expressed as:
|
236 |
+
# (1 = keep, 0 = discard)
|
237 |
+
# convert mask into a bias that can be added to attention scores:
|
238 |
+
# (keep = +0, discard = -10000.0)
|
239 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
240 |
+
attention_mask = attention_mask.unsqueeze(1)
|
241 |
+
|
242 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
243 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
244 |
+
encoder_attention_mask = (
|
245 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
246 |
+
) * -10000.0
|
247 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
248 |
+
|
249 |
+
# Retrieve lora scale.
|
250 |
+
lora_scale = (
|
251 |
+
cross_attention_kwargs.get("scale", 1.0)
|
252 |
+
if cross_attention_kwargs is not None
|
253 |
+
else 1.0
|
254 |
+
)
|
255 |
+
|
256 |
+
# 1. Input
|
257 |
+
if self.is_input_continuous:
|
258 |
+
batch, _, height, width = hidden_states.shape
|
259 |
+
residual = hidden_states
|
260 |
+
|
261 |
+
hidden_states = self.norm(hidden_states)
|
262 |
+
if not self.use_linear_projection:
|
263 |
+
hidden_states = (
|
264 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
265 |
+
if not USE_PEFT_BACKEND
|
266 |
+
else self.proj_in(hidden_states)
|
267 |
+
)
|
268 |
+
inner_dim = hidden_states.shape[1]
|
269 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
270 |
+
batch, height * width, inner_dim
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
inner_dim = hidden_states.shape[1]
|
274 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
275 |
+
batch, height * width, inner_dim
|
276 |
+
)
|
277 |
+
hidden_states = (
|
278 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
279 |
+
if not USE_PEFT_BACKEND
|
280 |
+
else self.proj_in(hidden_states)
|
281 |
+
)
|
282 |
+
|
283 |
+
elif self.is_input_vectorized:
|
284 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
285 |
+
elif self.is_input_patches:
|
286 |
+
height, width = (
|
287 |
+
hidden_states.shape[-2] // self.patch_size,
|
288 |
+
hidden_states.shape[-1] // self.patch_size,
|
289 |
+
)
|
290 |
+
hidden_states = self.pos_embed(hidden_states)
|
291 |
+
|
292 |
+
if self.adaln_single is not None:
|
293 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
294 |
+
raise ValueError(
|
295 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
296 |
+
)
|
297 |
+
batch_size = hidden_states.shape[0]
|
298 |
+
timestep, embedded_timestep = self.adaln_single(
|
299 |
+
timestep,
|
300 |
+
added_cond_kwargs,
|
301 |
+
batch_size=batch_size,
|
302 |
+
hidden_dtype=hidden_states.dtype,
|
303 |
+
)
|
304 |
+
|
305 |
+
# 2. Blocks
|
306 |
+
if self.caption_projection is not None:
|
307 |
+
batch_size = hidden_states.shape[0]
|
308 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
309 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
310 |
+
batch_size, -1, hidden_states.shape[-1]
|
311 |
+
)
|
312 |
+
|
313 |
+
for block in self.transformer_blocks:
|
314 |
+
if self.training and self.gradient_checkpointing:
|
315 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
316 |
+
block,
|
317 |
+
hidden_states,
|
318 |
+
attention_mask,
|
319 |
+
encoder_hidden_states,
|
320 |
+
encoder_attention_mask,
|
321 |
+
timestep,
|
322 |
+
cross_attention_kwargs,
|
323 |
+
class_labels,
|
324 |
+
self_attn_block_embs,
|
325 |
+
self_attn_block_embs_mode,
|
326 |
+
use_reentrant=False,
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
hidden_states = block(
|
330 |
+
hidden_states,
|
331 |
+
attention_mask=attention_mask,
|
332 |
+
encoder_hidden_states=encoder_hidden_states,
|
333 |
+
encoder_attention_mask=encoder_attention_mask,
|
334 |
+
timestep=timestep,
|
335 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
336 |
+
class_labels=class_labels,
|
337 |
+
self_attn_block_embs=self_attn_block_embs,
|
338 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
339 |
+
)
|
340 |
+
# 将 转换 self_attn_emb的尺寸
|
341 |
+
if (
|
342 |
+
self_attn_block_embs is not None
|
343 |
+
and self_attn_block_embs_mode.lower() == "write"
|
344 |
+
):
|
345 |
+
self_attn_idx = block.spatial_self_attn_idx
|
346 |
+
if self.print_idx == 0:
|
347 |
+
logger.debug(
|
348 |
+
f"self_attn_block_embs, num={len(self_attn_block_embs)}, before, shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
|
349 |
+
)
|
350 |
+
self_attn_block_embs[self_attn_idx] = rearrange(
|
351 |
+
self_attn_block_embs[self_attn_idx],
|
352 |
+
"bt (h w) c->bt c h w",
|
353 |
+
h=height,
|
354 |
+
w=width,
|
355 |
+
)
|
356 |
+
if self.print_idx == 0:
|
357 |
+
logger.debug(
|
358 |
+
f"self_attn_block_embs, num={len(self_attn_block_embs)}, after ,shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
|
359 |
+
)
|
360 |
+
|
361 |
+
if self.proj_out is None:
|
362 |
+
return hidden_states
|
363 |
+
|
364 |
+
# 3. Output
|
365 |
+
if self.is_input_continuous:
|
366 |
+
if not self.use_linear_projection:
|
367 |
+
hidden_states = (
|
368 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
369 |
+
.permute(0, 3, 1, 2)
|
370 |
+
.contiguous()
|
371 |
+
)
|
372 |
+
hidden_states = (
|
373 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
374 |
+
if not USE_PEFT_BACKEND
|
375 |
+
else self.proj_out(hidden_states)
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
hidden_states = (
|
379 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
380 |
+
if not USE_PEFT_BACKEND
|
381 |
+
else self.proj_out(hidden_states)
|
382 |
+
)
|
383 |
+
hidden_states = (
|
384 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
385 |
+
.permute(0, 3, 1, 2)
|
386 |
+
.contiguous()
|
387 |
+
)
|
388 |
+
|
389 |
+
output = hidden_states + residual
|
390 |
+
elif self.is_input_vectorized:
|
391 |
+
hidden_states = self.norm_out(hidden_states)
|
392 |
+
logits = self.out(hidden_states)
|
393 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
394 |
+
logits = logits.permute(0, 2, 1)
|
395 |
+
|
396 |
+
# log(p(x_0))
|
397 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
398 |
+
|
399 |
+
if self.is_input_patches:
|
400 |
+
if self.config.norm_type != "ada_norm_single":
|
401 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
402 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
403 |
+
)
|
404 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
405 |
+
hidden_states = (
|
406 |
+
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
407 |
+
)
|
408 |
+
hidden_states = self.proj_out_2(hidden_states)
|
409 |
+
elif self.config.norm_type == "ada_norm_single":
|
410 |
+
shift, scale = (
|
411 |
+
self.scale_shift_table[None] + embedded_timestep[:, None]
|
412 |
+
).chunk(2, dim=1)
|
413 |
+
hidden_states = self.norm_out(hidden_states)
|
414 |
+
# Modulation
|
415 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
416 |
+
hidden_states = self.proj_out(hidden_states)
|
417 |
+
hidden_states = hidden_states.squeeze(1)
|
418 |
+
|
419 |
+
# unpatchify
|
420 |
+
if self.adaln_single is None:
|
421 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
422 |
+
hidden_states = hidden_states.reshape(
|
423 |
+
shape=(
|
424 |
+
-1,
|
425 |
+
height,
|
426 |
+
width,
|
427 |
+
self.patch_size,
|
428 |
+
self.patch_size,
|
429 |
+
self.out_channels,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
433 |
+
output = hidden_states.reshape(
|
434 |
+
shape=(
|
435 |
+
-1,
|
436 |
+
self.out_channels,
|
437 |
+
height * self.patch_size,
|
438 |
+
width * self.patch_size,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
self.print_idx += 1
|
442 |
+
if not return_dict:
|
443 |
+
return (output,)
|
444 |
+
|
445 |
+
return Transformer2DModelOutput(sample=output)
|
musev/models/unet_2d_blocks.py
ADDED
@@ -0,0 +1,1537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Literal, Optional, Tuple, Union, List
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.utils import is_torch_version, logging
|
22 |
+
from diffusers.utils.torch_utils import apply_freeu
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
Attention,
|
26 |
+
AttnAddedKVProcessor,
|
27 |
+
AttnAddedKVProcessor2_0,
|
28 |
+
)
|
29 |
+
from diffusers.models.dual_transformer_2d import DualTransformer2DModel
|
30 |
+
from diffusers.models.normalization import AdaGroupNorm
|
31 |
+
from diffusers.models.resnet import (
|
32 |
+
Downsample2D,
|
33 |
+
FirDownsample2D,
|
34 |
+
FirUpsample2D,
|
35 |
+
KDownsample2D,
|
36 |
+
KUpsample2D,
|
37 |
+
ResnetBlock2D,
|
38 |
+
Upsample2D,
|
39 |
+
)
|
40 |
+
from diffusers.models.unet_2d_blocks import (
|
41 |
+
AttnDownBlock2D,
|
42 |
+
AttnDownEncoderBlock2D,
|
43 |
+
AttnSkipDownBlock2D,
|
44 |
+
AttnSkipUpBlock2D,
|
45 |
+
AttnUpBlock2D,
|
46 |
+
AttnUpDecoderBlock2D,
|
47 |
+
DownEncoderBlock2D,
|
48 |
+
KCrossAttnDownBlock2D,
|
49 |
+
KCrossAttnUpBlock2D,
|
50 |
+
KDownBlock2D,
|
51 |
+
KUpBlock2D,
|
52 |
+
ResnetDownsampleBlock2D,
|
53 |
+
ResnetUpsampleBlock2D,
|
54 |
+
SimpleCrossAttnDownBlock2D,
|
55 |
+
SimpleCrossAttnUpBlock2D,
|
56 |
+
SkipDownBlock2D,
|
57 |
+
SkipUpBlock2D,
|
58 |
+
UpDecoderBlock2D,
|
59 |
+
)
|
60 |
+
|
61 |
+
from .transformer_2d import Transformer2DModel
|
62 |
+
|
63 |
+
|
64 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
65 |
+
|
66 |
+
|
67 |
+
def get_down_block(
|
68 |
+
down_block_type: str,
|
69 |
+
num_layers: int,
|
70 |
+
in_channels: int,
|
71 |
+
out_channels: int,
|
72 |
+
temb_channels: int,
|
73 |
+
add_downsample: bool,
|
74 |
+
resnet_eps: float,
|
75 |
+
resnet_act_fn: str,
|
76 |
+
transformer_layers_per_block: int = 1,
|
77 |
+
num_attention_heads: Optional[int] = None,
|
78 |
+
resnet_groups: Optional[int] = None,
|
79 |
+
cross_attention_dim: Optional[int] = None,
|
80 |
+
downsample_padding: Optional[int] = None,
|
81 |
+
dual_cross_attention: bool = False,
|
82 |
+
use_linear_projection: bool = False,
|
83 |
+
only_cross_attention: bool = False,
|
84 |
+
upcast_attention: bool = False,
|
85 |
+
resnet_time_scale_shift: str = "default",
|
86 |
+
attention_type: str = "default",
|
87 |
+
resnet_skip_time_act: bool = False,
|
88 |
+
resnet_out_scale_factor: float = 1.0,
|
89 |
+
cross_attention_norm: Optional[str] = None,
|
90 |
+
attention_head_dim: Optional[int] = None,
|
91 |
+
downsample_type: Optional[str] = None,
|
92 |
+
dropout: float = 0.0,
|
93 |
+
):
|
94 |
+
# If attn head dim is not defined, we default it to the number of heads
|
95 |
+
if attention_head_dim is None:
|
96 |
+
logger.warn(
|
97 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
98 |
+
)
|
99 |
+
attention_head_dim = num_attention_heads
|
100 |
+
|
101 |
+
down_block_type = (
|
102 |
+
down_block_type[7:]
|
103 |
+
if down_block_type.startswith("UNetRes")
|
104 |
+
else down_block_type
|
105 |
+
)
|
106 |
+
if down_block_type == "DownBlock2D":
|
107 |
+
return DownBlock2D(
|
108 |
+
num_layers=num_layers,
|
109 |
+
in_channels=in_channels,
|
110 |
+
out_channels=out_channels,
|
111 |
+
temb_channels=temb_channels,
|
112 |
+
dropout=dropout,
|
113 |
+
add_downsample=add_downsample,
|
114 |
+
resnet_eps=resnet_eps,
|
115 |
+
resnet_act_fn=resnet_act_fn,
|
116 |
+
resnet_groups=resnet_groups,
|
117 |
+
downsample_padding=downsample_padding,
|
118 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
119 |
+
)
|
120 |
+
elif down_block_type == "ResnetDownsampleBlock2D":
|
121 |
+
return ResnetDownsampleBlock2D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
temb_channels=temb_channels,
|
126 |
+
dropout=dropout,
|
127 |
+
add_downsample=add_downsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
skip_time_act=resnet_skip_time_act,
|
133 |
+
output_scale_factor=resnet_out_scale_factor,
|
134 |
+
)
|
135 |
+
elif down_block_type == "AttnDownBlock2D":
|
136 |
+
if add_downsample is False:
|
137 |
+
downsample_type = None
|
138 |
+
else:
|
139 |
+
downsample_type = downsample_type or "conv" # default to 'conv'
|
140 |
+
return AttnDownBlock2D(
|
141 |
+
num_layers=num_layers,
|
142 |
+
in_channels=in_channels,
|
143 |
+
out_channels=out_channels,
|
144 |
+
temb_channels=temb_channels,
|
145 |
+
dropout=dropout,
|
146 |
+
resnet_eps=resnet_eps,
|
147 |
+
resnet_act_fn=resnet_act_fn,
|
148 |
+
resnet_groups=resnet_groups,
|
149 |
+
downsample_padding=downsample_padding,
|
150 |
+
attention_head_dim=attention_head_dim,
|
151 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
152 |
+
downsample_type=downsample_type,
|
153 |
+
)
|
154 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
155 |
+
if cross_attention_dim is None:
|
156 |
+
raise ValueError(
|
157 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock2D"
|
158 |
+
)
|
159 |
+
return CrossAttnDownBlock2D(
|
160 |
+
num_layers=num_layers,
|
161 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
162 |
+
in_channels=in_channels,
|
163 |
+
out_channels=out_channels,
|
164 |
+
temb_channels=temb_channels,
|
165 |
+
dropout=dropout,
|
166 |
+
add_downsample=add_downsample,
|
167 |
+
resnet_eps=resnet_eps,
|
168 |
+
resnet_act_fn=resnet_act_fn,
|
169 |
+
resnet_groups=resnet_groups,
|
170 |
+
downsample_padding=downsample_padding,
|
171 |
+
cross_attention_dim=cross_attention_dim,
|
172 |
+
num_attention_heads=num_attention_heads,
|
173 |
+
dual_cross_attention=dual_cross_attention,
|
174 |
+
use_linear_projection=use_linear_projection,
|
175 |
+
only_cross_attention=only_cross_attention,
|
176 |
+
upcast_attention=upcast_attention,
|
177 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
178 |
+
attention_type=attention_type,
|
179 |
+
)
|
180 |
+
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
181 |
+
if cross_attention_dim is None:
|
182 |
+
raise ValueError(
|
183 |
+
"cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
|
184 |
+
)
|
185 |
+
return SimpleCrossAttnDownBlock2D(
|
186 |
+
num_layers=num_layers,
|
187 |
+
in_channels=in_channels,
|
188 |
+
out_channels=out_channels,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
dropout=dropout,
|
191 |
+
add_downsample=add_downsample,
|
192 |
+
resnet_eps=resnet_eps,
|
193 |
+
resnet_act_fn=resnet_act_fn,
|
194 |
+
resnet_groups=resnet_groups,
|
195 |
+
cross_attention_dim=cross_attention_dim,
|
196 |
+
attention_head_dim=attention_head_dim,
|
197 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
198 |
+
skip_time_act=resnet_skip_time_act,
|
199 |
+
output_scale_factor=resnet_out_scale_factor,
|
200 |
+
only_cross_attention=only_cross_attention,
|
201 |
+
cross_attention_norm=cross_attention_norm,
|
202 |
+
)
|
203 |
+
elif down_block_type == "SkipDownBlock2D":
|
204 |
+
return SkipDownBlock2D(
|
205 |
+
num_layers=num_layers,
|
206 |
+
in_channels=in_channels,
|
207 |
+
out_channels=out_channels,
|
208 |
+
temb_channels=temb_channels,
|
209 |
+
dropout=dropout,
|
210 |
+
add_downsample=add_downsample,
|
211 |
+
resnet_eps=resnet_eps,
|
212 |
+
resnet_act_fn=resnet_act_fn,
|
213 |
+
downsample_padding=downsample_padding,
|
214 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
215 |
+
)
|
216 |
+
elif down_block_type == "AttnSkipDownBlock2D":
|
217 |
+
return AttnSkipDownBlock2D(
|
218 |
+
num_layers=num_layers,
|
219 |
+
in_channels=in_channels,
|
220 |
+
out_channels=out_channels,
|
221 |
+
temb_channels=temb_channels,
|
222 |
+
dropout=dropout,
|
223 |
+
add_downsample=add_downsample,
|
224 |
+
resnet_eps=resnet_eps,
|
225 |
+
resnet_act_fn=resnet_act_fn,
|
226 |
+
attention_head_dim=attention_head_dim,
|
227 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
228 |
+
)
|
229 |
+
elif down_block_type == "DownEncoderBlock2D":
|
230 |
+
return DownEncoderBlock2D(
|
231 |
+
num_layers=num_layers,
|
232 |
+
in_channels=in_channels,
|
233 |
+
out_channels=out_channels,
|
234 |
+
dropout=dropout,
|
235 |
+
add_downsample=add_downsample,
|
236 |
+
resnet_eps=resnet_eps,
|
237 |
+
resnet_act_fn=resnet_act_fn,
|
238 |
+
resnet_groups=resnet_groups,
|
239 |
+
downsample_padding=downsample_padding,
|
240 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
241 |
+
)
|
242 |
+
elif down_block_type == "AttnDownEncoderBlock2D":
|
243 |
+
return AttnDownEncoderBlock2D(
|
244 |
+
num_layers=num_layers,
|
245 |
+
in_channels=in_channels,
|
246 |
+
out_channels=out_channels,
|
247 |
+
dropout=dropout,
|
248 |
+
add_downsample=add_downsample,
|
249 |
+
resnet_eps=resnet_eps,
|
250 |
+
resnet_act_fn=resnet_act_fn,
|
251 |
+
resnet_groups=resnet_groups,
|
252 |
+
downsample_padding=downsample_padding,
|
253 |
+
attention_head_dim=attention_head_dim,
|
254 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
255 |
+
)
|
256 |
+
elif down_block_type == "KDownBlock2D":
|
257 |
+
return KDownBlock2D(
|
258 |
+
num_layers=num_layers,
|
259 |
+
in_channels=in_channels,
|
260 |
+
out_channels=out_channels,
|
261 |
+
temb_channels=temb_channels,
|
262 |
+
dropout=dropout,
|
263 |
+
add_downsample=add_downsample,
|
264 |
+
resnet_eps=resnet_eps,
|
265 |
+
resnet_act_fn=resnet_act_fn,
|
266 |
+
)
|
267 |
+
elif down_block_type == "KCrossAttnDownBlock2D":
|
268 |
+
return KCrossAttnDownBlock2D(
|
269 |
+
num_layers=num_layers,
|
270 |
+
in_channels=in_channels,
|
271 |
+
out_channels=out_channels,
|
272 |
+
temb_channels=temb_channels,
|
273 |
+
dropout=dropout,
|
274 |
+
add_downsample=add_downsample,
|
275 |
+
resnet_eps=resnet_eps,
|
276 |
+
resnet_act_fn=resnet_act_fn,
|
277 |
+
cross_attention_dim=cross_attention_dim,
|
278 |
+
attention_head_dim=attention_head_dim,
|
279 |
+
add_self_attention=True if not add_downsample else False,
|
280 |
+
)
|
281 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
282 |
+
|
283 |
+
|
284 |
+
def get_up_block(
|
285 |
+
up_block_type: str,
|
286 |
+
num_layers: int,
|
287 |
+
in_channels: int,
|
288 |
+
out_channels: int,
|
289 |
+
prev_output_channel: int,
|
290 |
+
temb_channels: int,
|
291 |
+
add_upsample: bool,
|
292 |
+
resnet_eps: float,
|
293 |
+
resnet_act_fn: str,
|
294 |
+
resolution_idx: Optional[int] = None,
|
295 |
+
transformer_layers_per_block: int = 1,
|
296 |
+
num_attention_heads: Optional[int] = None,
|
297 |
+
resnet_groups: Optional[int] = None,
|
298 |
+
cross_attention_dim: Optional[int] = None,
|
299 |
+
dual_cross_attention: bool = False,
|
300 |
+
use_linear_projection: bool = False,
|
301 |
+
only_cross_attention: bool = False,
|
302 |
+
upcast_attention: bool = False,
|
303 |
+
resnet_time_scale_shift: str = "default",
|
304 |
+
attention_type: str = "default",
|
305 |
+
resnet_skip_time_act: bool = False,
|
306 |
+
resnet_out_scale_factor: float = 1.0,
|
307 |
+
cross_attention_norm: Optional[str] = None,
|
308 |
+
attention_head_dim: Optional[int] = None,
|
309 |
+
upsample_type: Optional[str] = None,
|
310 |
+
dropout: float = 0.0,
|
311 |
+
) -> nn.Module:
|
312 |
+
# If attn head dim is not defined, we default it to the number of heads
|
313 |
+
if attention_head_dim is None:
|
314 |
+
logger.warn(
|
315 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
316 |
+
)
|
317 |
+
attention_head_dim = num_attention_heads
|
318 |
+
|
319 |
+
up_block_type = (
|
320 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
321 |
+
)
|
322 |
+
if up_block_type == "UpBlock2D":
|
323 |
+
return UpBlock2D(
|
324 |
+
num_layers=num_layers,
|
325 |
+
in_channels=in_channels,
|
326 |
+
out_channels=out_channels,
|
327 |
+
prev_output_channel=prev_output_channel,
|
328 |
+
temb_channels=temb_channels,
|
329 |
+
resolution_idx=resolution_idx,
|
330 |
+
dropout=dropout,
|
331 |
+
add_upsample=add_upsample,
|
332 |
+
resnet_eps=resnet_eps,
|
333 |
+
resnet_act_fn=resnet_act_fn,
|
334 |
+
resnet_groups=resnet_groups,
|
335 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
336 |
+
)
|
337 |
+
elif up_block_type == "ResnetUpsampleBlock2D":
|
338 |
+
return ResnetUpsampleBlock2D(
|
339 |
+
num_layers=num_layers,
|
340 |
+
in_channels=in_channels,
|
341 |
+
out_channels=out_channels,
|
342 |
+
prev_output_channel=prev_output_channel,
|
343 |
+
temb_channels=temb_channels,
|
344 |
+
resolution_idx=resolution_idx,
|
345 |
+
dropout=dropout,
|
346 |
+
add_upsample=add_upsample,
|
347 |
+
resnet_eps=resnet_eps,
|
348 |
+
resnet_act_fn=resnet_act_fn,
|
349 |
+
resnet_groups=resnet_groups,
|
350 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
351 |
+
skip_time_act=resnet_skip_time_act,
|
352 |
+
output_scale_factor=resnet_out_scale_factor,
|
353 |
+
)
|
354 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
355 |
+
if cross_attention_dim is None:
|
356 |
+
raise ValueError(
|
357 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock2D"
|
358 |
+
)
|
359 |
+
return CrossAttnUpBlock2D(
|
360 |
+
num_layers=num_layers,
|
361 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
362 |
+
in_channels=in_channels,
|
363 |
+
out_channels=out_channels,
|
364 |
+
prev_output_channel=prev_output_channel,
|
365 |
+
temb_channels=temb_channels,
|
366 |
+
resolution_idx=resolution_idx,
|
367 |
+
dropout=dropout,
|
368 |
+
add_upsample=add_upsample,
|
369 |
+
resnet_eps=resnet_eps,
|
370 |
+
resnet_act_fn=resnet_act_fn,
|
371 |
+
resnet_groups=resnet_groups,
|
372 |
+
cross_attention_dim=cross_attention_dim,
|
373 |
+
num_attention_heads=num_attention_heads,
|
374 |
+
dual_cross_attention=dual_cross_attention,
|
375 |
+
use_linear_projection=use_linear_projection,
|
376 |
+
only_cross_attention=only_cross_attention,
|
377 |
+
upcast_attention=upcast_attention,
|
378 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
379 |
+
attention_type=attention_type,
|
380 |
+
)
|
381 |
+
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
382 |
+
if cross_attention_dim is None:
|
383 |
+
raise ValueError(
|
384 |
+
"cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
|
385 |
+
)
|
386 |
+
return SimpleCrossAttnUpBlock2D(
|
387 |
+
num_layers=num_layers,
|
388 |
+
in_channels=in_channels,
|
389 |
+
out_channels=out_channels,
|
390 |
+
prev_output_channel=prev_output_channel,
|
391 |
+
temb_channels=temb_channels,
|
392 |
+
resolution_idx=resolution_idx,
|
393 |
+
dropout=dropout,
|
394 |
+
add_upsample=add_upsample,
|
395 |
+
resnet_eps=resnet_eps,
|
396 |
+
resnet_act_fn=resnet_act_fn,
|
397 |
+
resnet_groups=resnet_groups,
|
398 |
+
cross_attention_dim=cross_attention_dim,
|
399 |
+
attention_head_dim=attention_head_dim,
|
400 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
401 |
+
skip_time_act=resnet_skip_time_act,
|
402 |
+
output_scale_factor=resnet_out_scale_factor,
|
403 |
+
only_cross_attention=only_cross_attention,
|
404 |
+
cross_attention_norm=cross_attention_norm,
|
405 |
+
)
|
406 |
+
elif up_block_type == "AttnUpBlock2D":
|
407 |
+
if add_upsample is False:
|
408 |
+
upsample_type = None
|
409 |
+
else:
|
410 |
+
upsample_type = upsample_type or "conv" # default to 'conv'
|
411 |
+
|
412 |
+
return AttnUpBlock2D(
|
413 |
+
num_layers=num_layers,
|
414 |
+
in_channels=in_channels,
|
415 |
+
out_channels=out_channels,
|
416 |
+
prev_output_channel=prev_output_channel,
|
417 |
+
temb_channels=temb_channels,
|
418 |
+
resolution_idx=resolution_idx,
|
419 |
+
dropout=dropout,
|
420 |
+
resnet_eps=resnet_eps,
|
421 |
+
resnet_act_fn=resnet_act_fn,
|
422 |
+
resnet_groups=resnet_groups,
|
423 |
+
attention_head_dim=attention_head_dim,
|
424 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
425 |
+
upsample_type=upsample_type,
|
426 |
+
)
|
427 |
+
elif up_block_type == "SkipUpBlock2D":
|
428 |
+
return SkipUpBlock2D(
|
429 |
+
num_layers=num_layers,
|
430 |
+
in_channels=in_channels,
|
431 |
+
out_channels=out_channels,
|
432 |
+
prev_output_channel=prev_output_channel,
|
433 |
+
temb_channels=temb_channels,
|
434 |
+
resolution_idx=resolution_idx,
|
435 |
+
dropout=dropout,
|
436 |
+
add_upsample=add_upsample,
|
437 |
+
resnet_eps=resnet_eps,
|
438 |
+
resnet_act_fn=resnet_act_fn,
|
439 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
440 |
+
)
|
441 |
+
elif up_block_type == "AttnSkipUpBlock2D":
|
442 |
+
return AttnSkipUpBlock2D(
|
443 |
+
num_layers=num_layers,
|
444 |
+
in_channels=in_channels,
|
445 |
+
out_channels=out_channels,
|
446 |
+
prev_output_channel=prev_output_channel,
|
447 |
+
temb_channels=temb_channels,
|
448 |
+
resolution_idx=resolution_idx,
|
449 |
+
dropout=dropout,
|
450 |
+
add_upsample=add_upsample,
|
451 |
+
resnet_eps=resnet_eps,
|
452 |
+
resnet_act_fn=resnet_act_fn,
|
453 |
+
attention_head_dim=attention_head_dim,
|
454 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
455 |
+
)
|
456 |
+
elif up_block_type == "UpDecoderBlock2D":
|
457 |
+
return UpDecoderBlock2D(
|
458 |
+
num_layers=num_layers,
|
459 |
+
in_channels=in_channels,
|
460 |
+
out_channels=out_channels,
|
461 |
+
resolution_idx=resolution_idx,
|
462 |
+
dropout=dropout,
|
463 |
+
add_upsample=add_upsample,
|
464 |
+
resnet_eps=resnet_eps,
|
465 |
+
resnet_act_fn=resnet_act_fn,
|
466 |
+
resnet_groups=resnet_groups,
|
467 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
468 |
+
temb_channels=temb_channels,
|
469 |
+
)
|
470 |
+
elif up_block_type == "AttnUpDecoderBlock2D":
|
471 |
+
return AttnUpDecoderBlock2D(
|
472 |
+
num_layers=num_layers,
|
473 |
+
in_channels=in_channels,
|
474 |
+
out_channels=out_channels,
|
475 |
+
resolution_idx=resolution_idx,
|
476 |
+
dropout=dropout,
|
477 |
+
add_upsample=add_upsample,
|
478 |
+
resnet_eps=resnet_eps,
|
479 |
+
resnet_act_fn=resnet_act_fn,
|
480 |
+
resnet_groups=resnet_groups,
|
481 |
+
attention_head_dim=attention_head_dim,
|
482 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
483 |
+
temb_channels=temb_channels,
|
484 |
+
)
|
485 |
+
elif up_block_type == "KUpBlock2D":
|
486 |
+
return KUpBlock2D(
|
487 |
+
num_layers=num_layers,
|
488 |
+
in_channels=in_channels,
|
489 |
+
out_channels=out_channels,
|
490 |
+
temb_channels=temb_channels,
|
491 |
+
resolution_idx=resolution_idx,
|
492 |
+
dropout=dropout,
|
493 |
+
add_upsample=add_upsample,
|
494 |
+
resnet_eps=resnet_eps,
|
495 |
+
resnet_act_fn=resnet_act_fn,
|
496 |
+
)
|
497 |
+
elif up_block_type == "KCrossAttnUpBlock2D":
|
498 |
+
return KCrossAttnUpBlock2D(
|
499 |
+
num_layers=num_layers,
|
500 |
+
in_channels=in_channels,
|
501 |
+
out_channels=out_channels,
|
502 |
+
temb_channels=temb_channels,
|
503 |
+
resolution_idx=resolution_idx,
|
504 |
+
dropout=dropout,
|
505 |
+
add_upsample=add_upsample,
|
506 |
+
resnet_eps=resnet_eps,
|
507 |
+
resnet_act_fn=resnet_act_fn,
|
508 |
+
cross_attention_dim=cross_attention_dim,
|
509 |
+
attention_head_dim=attention_head_dim,
|
510 |
+
)
|
511 |
+
|
512 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
513 |
+
|
514 |
+
|
515 |
+
class UNetMidBlock2D(nn.Module):
|
516 |
+
"""
|
517 |
+
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
in_channels (`int`): The number of input channels.
|
521 |
+
temb_channels (`int`): The number of temporal embedding channels.
|
522 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
523 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
524 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
525 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
526 |
+
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
527 |
+
model on tasks with long-range temporal dependencies.
|
528 |
+
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
529 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
530 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
531 |
+
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
532 |
+
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
533 |
+
Whether to use pre-normalization for the resnet blocks.
|
534 |
+
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
535 |
+
attention_head_dim (`int`, *optional*, defaults to 1):
|
536 |
+
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
537 |
+
the number of input channels.
|
538 |
+
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
542 |
+
in_channels, height, width)`.
|
543 |
+
|
544 |
+
"""
|
545 |
+
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
in_channels: int,
|
549 |
+
temb_channels: int,
|
550 |
+
dropout: float = 0.0,
|
551 |
+
num_layers: int = 1,
|
552 |
+
resnet_eps: float = 1e-6,
|
553 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
554 |
+
resnet_act_fn: str = "swish",
|
555 |
+
resnet_groups: int = 32,
|
556 |
+
attn_groups: Optional[int] = None,
|
557 |
+
resnet_pre_norm: bool = True,
|
558 |
+
add_attention: bool = True,
|
559 |
+
attention_head_dim: int = 1,
|
560 |
+
output_scale_factor: float = 1.0,
|
561 |
+
):
|
562 |
+
super().__init__()
|
563 |
+
resnet_groups = (
|
564 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
565 |
+
)
|
566 |
+
self.add_attention = add_attention
|
567 |
+
|
568 |
+
if attn_groups is None:
|
569 |
+
attn_groups = (
|
570 |
+
resnet_groups if resnet_time_scale_shift == "default" else None
|
571 |
+
)
|
572 |
+
|
573 |
+
# there is always at least one resnet
|
574 |
+
resnets = [
|
575 |
+
ResnetBlock2D(
|
576 |
+
in_channels=in_channels,
|
577 |
+
out_channels=in_channels,
|
578 |
+
temb_channels=temb_channels,
|
579 |
+
eps=resnet_eps,
|
580 |
+
groups=resnet_groups,
|
581 |
+
dropout=dropout,
|
582 |
+
time_embedding_norm=resnet_time_scale_shift,
|
583 |
+
non_linearity=resnet_act_fn,
|
584 |
+
output_scale_factor=output_scale_factor,
|
585 |
+
pre_norm=resnet_pre_norm,
|
586 |
+
)
|
587 |
+
]
|
588 |
+
attentions = []
|
589 |
+
|
590 |
+
if attention_head_dim is None:
|
591 |
+
logger.warn(
|
592 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
593 |
+
)
|
594 |
+
attention_head_dim = in_channels
|
595 |
+
|
596 |
+
for _ in range(num_layers):
|
597 |
+
if self.add_attention:
|
598 |
+
attentions.append(
|
599 |
+
Attention(
|
600 |
+
in_channels,
|
601 |
+
heads=in_channels // attention_head_dim,
|
602 |
+
dim_head=attention_head_dim,
|
603 |
+
rescale_output_factor=output_scale_factor,
|
604 |
+
eps=resnet_eps,
|
605 |
+
norm_num_groups=attn_groups,
|
606 |
+
spatial_norm_dim=temb_channels
|
607 |
+
if resnet_time_scale_shift == "spatial"
|
608 |
+
else None,
|
609 |
+
residual_connection=True,
|
610 |
+
bias=True,
|
611 |
+
upcast_softmax=True,
|
612 |
+
_from_deprecated_attn_block=True,
|
613 |
+
)
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
attentions.append(None)
|
617 |
+
|
618 |
+
resnets.append(
|
619 |
+
ResnetBlock2D(
|
620 |
+
in_channels=in_channels,
|
621 |
+
out_channels=in_channels,
|
622 |
+
temb_channels=temb_channels,
|
623 |
+
eps=resnet_eps,
|
624 |
+
groups=resnet_groups,
|
625 |
+
dropout=dropout,
|
626 |
+
time_embedding_norm=resnet_time_scale_shift,
|
627 |
+
non_linearity=resnet_act_fn,
|
628 |
+
output_scale_factor=output_scale_factor,
|
629 |
+
pre_norm=resnet_pre_norm,
|
630 |
+
)
|
631 |
+
)
|
632 |
+
|
633 |
+
self.attentions = nn.ModuleList(attentions)
|
634 |
+
self.resnets = nn.ModuleList(resnets)
|
635 |
+
|
636 |
+
def forward(
|
637 |
+
self,
|
638 |
+
hidden_states: torch.FloatTensor,
|
639 |
+
temb: Optional[torch.FloatTensor] = None,
|
640 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
641 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
642 |
+
) -> torch.FloatTensor:
|
643 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
644 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
645 |
+
if attn is not None:
|
646 |
+
hidden_states = attn(
|
647 |
+
hidden_states,
|
648 |
+
temb=temb,
|
649 |
+
self_attn_block_embs=self_attn_block_embs,
|
650 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
651 |
+
)
|
652 |
+
hidden_states = resnet(hidden_states, temb)
|
653 |
+
|
654 |
+
return hidden_states
|
655 |
+
|
656 |
+
|
657 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
658 |
+
def __init__(
|
659 |
+
self,
|
660 |
+
in_channels: int,
|
661 |
+
temb_channels: int,
|
662 |
+
dropout: float = 0.0,
|
663 |
+
num_layers: int = 1,
|
664 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
665 |
+
resnet_eps: float = 1e-6,
|
666 |
+
resnet_time_scale_shift: str = "default",
|
667 |
+
resnet_act_fn: str = "swish",
|
668 |
+
resnet_groups: int = 32,
|
669 |
+
resnet_pre_norm: bool = True,
|
670 |
+
num_attention_heads: int = 1,
|
671 |
+
output_scale_factor: float = 1.0,
|
672 |
+
cross_attention_dim: int = 1280,
|
673 |
+
dual_cross_attention: bool = False,
|
674 |
+
use_linear_projection: bool = False,
|
675 |
+
upcast_attention: bool = False,
|
676 |
+
attention_type: str = "default",
|
677 |
+
):
|
678 |
+
super().__init__()
|
679 |
+
|
680 |
+
self.has_cross_attention = True
|
681 |
+
self.num_attention_heads = num_attention_heads
|
682 |
+
resnet_groups = (
|
683 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
684 |
+
)
|
685 |
+
|
686 |
+
# support for variable transformer layers per block
|
687 |
+
if isinstance(transformer_layers_per_block, int):
|
688 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
689 |
+
|
690 |
+
# there is always at least one resnet
|
691 |
+
resnets = [
|
692 |
+
ResnetBlock2D(
|
693 |
+
in_channels=in_channels,
|
694 |
+
out_channels=in_channels,
|
695 |
+
temb_channels=temb_channels,
|
696 |
+
eps=resnet_eps,
|
697 |
+
groups=resnet_groups,
|
698 |
+
dropout=dropout,
|
699 |
+
time_embedding_norm=resnet_time_scale_shift,
|
700 |
+
non_linearity=resnet_act_fn,
|
701 |
+
output_scale_factor=output_scale_factor,
|
702 |
+
pre_norm=resnet_pre_norm,
|
703 |
+
)
|
704 |
+
]
|
705 |
+
attentions = []
|
706 |
+
|
707 |
+
for i in range(num_layers):
|
708 |
+
if not dual_cross_attention:
|
709 |
+
attentions.append(
|
710 |
+
Transformer2DModel(
|
711 |
+
num_attention_heads,
|
712 |
+
in_channels // num_attention_heads,
|
713 |
+
in_channels=in_channels,
|
714 |
+
num_layers=transformer_layers_per_block[i],
|
715 |
+
cross_attention_dim=cross_attention_dim,
|
716 |
+
norm_num_groups=resnet_groups,
|
717 |
+
use_linear_projection=use_linear_projection,
|
718 |
+
upcast_attention=upcast_attention,
|
719 |
+
attention_type=attention_type,
|
720 |
+
)
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
attentions.append(
|
724 |
+
DualTransformer2DModel(
|
725 |
+
num_attention_heads,
|
726 |
+
in_channels // num_attention_heads,
|
727 |
+
in_channels=in_channels,
|
728 |
+
num_layers=1,
|
729 |
+
cross_attention_dim=cross_attention_dim,
|
730 |
+
norm_num_groups=resnet_groups,
|
731 |
+
)
|
732 |
+
)
|
733 |
+
resnets.append(
|
734 |
+
ResnetBlock2D(
|
735 |
+
in_channels=in_channels,
|
736 |
+
out_channels=in_channels,
|
737 |
+
temb_channels=temb_channels,
|
738 |
+
eps=resnet_eps,
|
739 |
+
groups=resnet_groups,
|
740 |
+
dropout=dropout,
|
741 |
+
time_embedding_norm=resnet_time_scale_shift,
|
742 |
+
non_linearity=resnet_act_fn,
|
743 |
+
output_scale_factor=output_scale_factor,
|
744 |
+
pre_norm=resnet_pre_norm,
|
745 |
+
)
|
746 |
+
)
|
747 |
+
|
748 |
+
self.attentions = nn.ModuleList(attentions)
|
749 |
+
self.resnets = nn.ModuleList(resnets)
|
750 |
+
|
751 |
+
self.gradient_checkpointing = False
|
752 |
+
|
753 |
+
def forward(
|
754 |
+
self,
|
755 |
+
hidden_states: torch.FloatTensor,
|
756 |
+
temb: Optional[torch.FloatTensor] = None,
|
757 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
758 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
759 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
760 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
761 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
762 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
763 |
+
) -> torch.FloatTensor:
|
764 |
+
lora_scale = (
|
765 |
+
cross_attention_kwargs.get("scale", 1.0)
|
766 |
+
if cross_attention_kwargs is not None
|
767 |
+
else 1.0
|
768 |
+
)
|
769 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
770 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
771 |
+
if self.training and self.gradient_checkpointing:
|
772 |
+
|
773 |
+
def create_custom_forward(module, return_dict=None):
|
774 |
+
def custom_forward(*inputs):
|
775 |
+
if return_dict is not None:
|
776 |
+
return module(*inputs, return_dict=return_dict)
|
777 |
+
else:
|
778 |
+
return module(*inputs)
|
779 |
+
|
780 |
+
return custom_forward
|
781 |
+
|
782 |
+
ckpt_kwargs: Dict[str, Any] = (
|
783 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
784 |
+
)
|
785 |
+
hidden_states = attn(
|
786 |
+
hidden_states,
|
787 |
+
encoder_hidden_states=encoder_hidden_states,
|
788 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
789 |
+
attention_mask=attention_mask,
|
790 |
+
encoder_attention_mask=encoder_attention_mask,
|
791 |
+
return_dict=False,
|
792 |
+
self_attn_block_embs=self_attn_block_embs,
|
793 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
794 |
+
)[0]
|
795 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
796 |
+
create_custom_forward(resnet),
|
797 |
+
hidden_states,
|
798 |
+
temb,
|
799 |
+
**ckpt_kwargs,
|
800 |
+
)
|
801 |
+
else:
|
802 |
+
hidden_states = attn(
|
803 |
+
hidden_states,
|
804 |
+
encoder_hidden_states=encoder_hidden_states,
|
805 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
806 |
+
attention_mask=attention_mask,
|
807 |
+
encoder_attention_mask=encoder_attention_mask,
|
808 |
+
return_dict=False,
|
809 |
+
self_attn_block_embs=self_attn_block_embs,
|
810 |
+
)[0]
|
811 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
812 |
+
|
813 |
+
return hidden_states
|
814 |
+
|
815 |
+
|
816 |
+
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
817 |
+
def __init__(
|
818 |
+
self,
|
819 |
+
in_channels: int,
|
820 |
+
temb_channels: int,
|
821 |
+
dropout: float = 0.0,
|
822 |
+
num_layers: int = 1,
|
823 |
+
resnet_eps: float = 1e-6,
|
824 |
+
resnet_time_scale_shift: str = "default",
|
825 |
+
resnet_act_fn: str = "swish",
|
826 |
+
resnet_groups: int = 32,
|
827 |
+
resnet_pre_norm: bool = True,
|
828 |
+
attention_head_dim: int = 1,
|
829 |
+
output_scale_factor: float = 1.0,
|
830 |
+
cross_attention_dim: int = 1280,
|
831 |
+
skip_time_act: bool = False,
|
832 |
+
only_cross_attention: bool = False,
|
833 |
+
cross_attention_norm: Optional[str] = None,
|
834 |
+
):
|
835 |
+
super().__init__()
|
836 |
+
|
837 |
+
self.has_cross_attention = True
|
838 |
+
|
839 |
+
self.attention_head_dim = attention_head_dim
|
840 |
+
resnet_groups = (
|
841 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
842 |
+
)
|
843 |
+
|
844 |
+
self.num_heads = in_channels // self.attention_head_dim
|
845 |
+
|
846 |
+
# there is always at least one resnet
|
847 |
+
resnets = [
|
848 |
+
ResnetBlock2D(
|
849 |
+
in_channels=in_channels,
|
850 |
+
out_channels=in_channels,
|
851 |
+
temb_channels=temb_channels,
|
852 |
+
eps=resnet_eps,
|
853 |
+
groups=resnet_groups,
|
854 |
+
dropout=dropout,
|
855 |
+
time_embedding_norm=resnet_time_scale_shift,
|
856 |
+
non_linearity=resnet_act_fn,
|
857 |
+
output_scale_factor=output_scale_factor,
|
858 |
+
pre_norm=resnet_pre_norm,
|
859 |
+
skip_time_act=skip_time_act,
|
860 |
+
)
|
861 |
+
]
|
862 |
+
attentions = []
|
863 |
+
|
864 |
+
for _ in range(num_layers):
|
865 |
+
processor = (
|
866 |
+
AttnAddedKVProcessor2_0()
|
867 |
+
if hasattr(F, "scaled_dot_product_attention")
|
868 |
+
else AttnAddedKVProcessor()
|
869 |
+
)
|
870 |
+
|
871 |
+
attentions.append(
|
872 |
+
Attention(
|
873 |
+
query_dim=in_channels,
|
874 |
+
cross_attention_dim=in_channels,
|
875 |
+
heads=self.num_heads,
|
876 |
+
dim_head=self.attention_head_dim,
|
877 |
+
added_kv_proj_dim=cross_attention_dim,
|
878 |
+
norm_num_groups=resnet_groups,
|
879 |
+
bias=True,
|
880 |
+
upcast_softmax=True,
|
881 |
+
only_cross_attention=only_cross_attention,
|
882 |
+
cross_attention_norm=cross_attention_norm,
|
883 |
+
processor=processor,
|
884 |
+
)
|
885 |
+
)
|
886 |
+
resnets.append(
|
887 |
+
ResnetBlock2D(
|
888 |
+
in_channels=in_channels,
|
889 |
+
out_channels=in_channels,
|
890 |
+
temb_channels=temb_channels,
|
891 |
+
eps=resnet_eps,
|
892 |
+
groups=resnet_groups,
|
893 |
+
dropout=dropout,
|
894 |
+
time_embedding_norm=resnet_time_scale_shift,
|
895 |
+
non_linearity=resnet_act_fn,
|
896 |
+
output_scale_factor=output_scale_factor,
|
897 |
+
pre_norm=resnet_pre_norm,
|
898 |
+
skip_time_act=skip_time_act,
|
899 |
+
)
|
900 |
+
)
|
901 |
+
|
902 |
+
self.attentions = nn.ModuleList(attentions)
|
903 |
+
self.resnets = nn.ModuleList(resnets)
|
904 |
+
|
905 |
+
def forward(
|
906 |
+
self,
|
907 |
+
hidden_states: torch.FloatTensor,
|
908 |
+
temb: Optional[torch.FloatTensor] = None,
|
909 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
910 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
911 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
912 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
913 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
914 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
915 |
+
) -> torch.FloatTensor:
|
916 |
+
cross_attention_kwargs = (
|
917 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
918 |
+
)
|
919 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
920 |
+
|
921 |
+
if attention_mask is None:
|
922 |
+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
923 |
+
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
924 |
+
else:
|
925 |
+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
926 |
+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
927 |
+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
928 |
+
# then we can simplify this whole if/else block to:
|
929 |
+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
930 |
+
mask = attention_mask
|
931 |
+
|
932 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
933 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
934 |
+
# attn
|
935 |
+
hidden_states = attn(
|
936 |
+
hidden_states,
|
937 |
+
encoder_hidden_states=encoder_hidden_states,
|
938 |
+
attention_mask=mask,
|
939 |
+
**cross_attention_kwargs,
|
940 |
+
self_attn_block_embs=self_attn_block_embs,
|
941 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
942 |
+
)
|
943 |
+
|
944 |
+
# resnet
|
945 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
946 |
+
|
947 |
+
return hidden_states
|
948 |
+
|
949 |
+
|
950 |
+
class CrossAttnDownBlock2D(nn.Module):
|
951 |
+
print_idx = 0
|
952 |
+
|
953 |
+
def __init__(
|
954 |
+
self,
|
955 |
+
in_channels: int,
|
956 |
+
out_channels: int,
|
957 |
+
temb_channels: int,
|
958 |
+
dropout: float = 0.0,
|
959 |
+
num_layers: int = 1,
|
960 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
961 |
+
resnet_eps: float = 1e-6,
|
962 |
+
resnet_time_scale_shift: str = "default",
|
963 |
+
resnet_act_fn: str = "swish",
|
964 |
+
resnet_groups: int = 32,
|
965 |
+
resnet_pre_norm: bool = True,
|
966 |
+
num_attention_heads: int = 1,
|
967 |
+
cross_attention_dim: int = 1280,
|
968 |
+
output_scale_factor: float = 1.0,
|
969 |
+
downsample_padding: int = 1,
|
970 |
+
add_downsample: bool = True,
|
971 |
+
dual_cross_attention: bool = False,
|
972 |
+
use_linear_projection: bool = False,
|
973 |
+
only_cross_attention: bool = False,
|
974 |
+
upcast_attention: bool = False,
|
975 |
+
attention_type: str = "default",
|
976 |
+
):
|
977 |
+
super().__init__()
|
978 |
+
resnets = []
|
979 |
+
attentions = []
|
980 |
+
|
981 |
+
self.has_cross_attention = True
|
982 |
+
self.num_attention_heads = num_attention_heads
|
983 |
+
if isinstance(transformer_layers_per_block, int):
|
984 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
985 |
+
|
986 |
+
for i in range(num_layers):
|
987 |
+
in_channels = in_channels if i == 0 else out_channels
|
988 |
+
resnets.append(
|
989 |
+
ResnetBlock2D(
|
990 |
+
in_channels=in_channels,
|
991 |
+
out_channels=out_channels,
|
992 |
+
temb_channels=temb_channels,
|
993 |
+
eps=resnet_eps,
|
994 |
+
groups=resnet_groups,
|
995 |
+
dropout=dropout,
|
996 |
+
time_embedding_norm=resnet_time_scale_shift,
|
997 |
+
non_linearity=resnet_act_fn,
|
998 |
+
output_scale_factor=output_scale_factor,
|
999 |
+
pre_norm=resnet_pre_norm,
|
1000 |
+
)
|
1001 |
+
)
|
1002 |
+
if not dual_cross_attention:
|
1003 |
+
attentions.append(
|
1004 |
+
Transformer2DModel(
|
1005 |
+
num_attention_heads,
|
1006 |
+
out_channels // num_attention_heads,
|
1007 |
+
in_channels=out_channels,
|
1008 |
+
num_layers=transformer_layers_per_block[i],
|
1009 |
+
cross_attention_dim=cross_attention_dim,
|
1010 |
+
norm_num_groups=resnet_groups,
|
1011 |
+
use_linear_projection=use_linear_projection,
|
1012 |
+
only_cross_attention=only_cross_attention,
|
1013 |
+
upcast_attention=upcast_attention,
|
1014 |
+
attention_type=attention_type,
|
1015 |
+
)
|
1016 |
+
)
|
1017 |
+
else:
|
1018 |
+
attentions.append(
|
1019 |
+
DualTransformer2DModel(
|
1020 |
+
num_attention_heads,
|
1021 |
+
out_channels // num_attention_heads,
|
1022 |
+
in_channels=out_channels,
|
1023 |
+
num_layers=1,
|
1024 |
+
cross_attention_dim=cross_attention_dim,
|
1025 |
+
norm_num_groups=resnet_groups,
|
1026 |
+
)
|
1027 |
+
)
|
1028 |
+
self.attentions = nn.ModuleList(attentions)
|
1029 |
+
self.resnets = nn.ModuleList(resnets)
|
1030 |
+
|
1031 |
+
if add_downsample:
|
1032 |
+
self.downsamplers = nn.ModuleList(
|
1033 |
+
[
|
1034 |
+
Downsample2D(
|
1035 |
+
out_channels,
|
1036 |
+
use_conv=True,
|
1037 |
+
out_channels=out_channels,
|
1038 |
+
padding=downsample_padding,
|
1039 |
+
name="op",
|
1040 |
+
)
|
1041 |
+
]
|
1042 |
+
)
|
1043 |
+
else:
|
1044 |
+
self.downsamplers = None
|
1045 |
+
|
1046 |
+
self.gradient_checkpointing = False
|
1047 |
+
|
1048 |
+
def forward(
|
1049 |
+
self,
|
1050 |
+
hidden_states: torch.FloatTensor,
|
1051 |
+
temb: Optional[torch.FloatTensor] = None,
|
1052 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1053 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1054 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1055 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1056 |
+
additional_residuals: Optional[torch.FloatTensor] = None,
|
1057 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1058 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1059 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1060 |
+
output_states = ()
|
1061 |
+
|
1062 |
+
lora_scale = (
|
1063 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1064 |
+
if cross_attention_kwargs is not None
|
1065 |
+
else 1.0
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
blocks = list(zip(self.resnets, self.attentions))
|
1069 |
+
|
1070 |
+
for i, (resnet, attn) in enumerate(blocks):
|
1071 |
+
if self.training and self.gradient_checkpointing:
|
1072 |
+
|
1073 |
+
def create_custom_forward(module, return_dict=None):
|
1074 |
+
def custom_forward(*inputs):
|
1075 |
+
if return_dict is not None:
|
1076 |
+
return module(*inputs, return_dict=return_dict)
|
1077 |
+
else:
|
1078 |
+
return module(*inputs)
|
1079 |
+
|
1080 |
+
return custom_forward
|
1081 |
+
|
1082 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1083 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1084 |
+
)
|
1085 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1086 |
+
create_custom_forward(resnet),
|
1087 |
+
hidden_states,
|
1088 |
+
temb,
|
1089 |
+
**ckpt_kwargs,
|
1090 |
+
)
|
1091 |
+
if self.print_idx == 0:
|
1092 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
1093 |
+
|
1094 |
+
hidden_states = attn(
|
1095 |
+
hidden_states,
|
1096 |
+
encoder_hidden_states=encoder_hidden_states,
|
1097 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1098 |
+
attention_mask=attention_mask,
|
1099 |
+
encoder_attention_mask=encoder_attention_mask,
|
1100 |
+
return_dict=False,
|
1101 |
+
self_attn_block_embs=self_attn_block_embs,
|
1102 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1103 |
+
)[0]
|
1104 |
+
else:
|
1105 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1106 |
+
if self.print_idx == 0:
|
1107 |
+
logger.debug(f"unet3d after resnet {hidden_states.mean()}")
|
1108 |
+
hidden_states = attn(
|
1109 |
+
hidden_states,
|
1110 |
+
encoder_hidden_states=encoder_hidden_states,
|
1111 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1112 |
+
attention_mask=attention_mask,
|
1113 |
+
encoder_attention_mask=encoder_attention_mask,
|
1114 |
+
return_dict=False,
|
1115 |
+
self_attn_block_embs=self_attn_block_embs,
|
1116 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1117 |
+
)[0]
|
1118 |
+
|
1119 |
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
1120 |
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
1121 |
+
hidden_states = hidden_states + additional_residuals
|
1122 |
+
|
1123 |
+
output_states = output_states + (hidden_states,)
|
1124 |
+
|
1125 |
+
if self.downsamplers is not None:
|
1126 |
+
for downsampler in self.downsamplers:
|
1127 |
+
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
1128 |
+
|
1129 |
+
output_states = output_states + (hidden_states,)
|
1130 |
+
|
1131 |
+
self.print_idx += 1
|
1132 |
+
return hidden_states, output_states
|
1133 |
+
|
1134 |
+
|
1135 |
+
class DownBlock2D(nn.Module):
|
1136 |
+
def __init__(
|
1137 |
+
self,
|
1138 |
+
in_channels: int,
|
1139 |
+
out_channels: int,
|
1140 |
+
temb_channels: int,
|
1141 |
+
dropout: float = 0.0,
|
1142 |
+
num_layers: int = 1,
|
1143 |
+
resnet_eps: float = 1e-6,
|
1144 |
+
resnet_time_scale_shift: str = "default",
|
1145 |
+
resnet_act_fn: str = "swish",
|
1146 |
+
resnet_groups: int = 32,
|
1147 |
+
resnet_pre_norm: bool = True,
|
1148 |
+
output_scale_factor: float = 1.0,
|
1149 |
+
add_downsample: bool = True,
|
1150 |
+
downsample_padding: int = 1,
|
1151 |
+
):
|
1152 |
+
super().__init__()
|
1153 |
+
resnets = []
|
1154 |
+
|
1155 |
+
for i in range(num_layers):
|
1156 |
+
in_channels = in_channels if i == 0 else out_channels
|
1157 |
+
resnets.append(
|
1158 |
+
ResnetBlock2D(
|
1159 |
+
in_channels=in_channels,
|
1160 |
+
out_channels=out_channels,
|
1161 |
+
temb_channels=temb_channels,
|
1162 |
+
eps=resnet_eps,
|
1163 |
+
groups=resnet_groups,
|
1164 |
+
dropout=dropout,
|
1165 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1166 |
+
non_linearity=resnet_act_fn,
|
1167 |
+
output_scale_factor=output_scale_factor,
|
1168 |
+
pre_norm=resnet_pre_norm,
|
1169 |
+
)
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
self.resnets = nn.ModuleList(resnets)
|
1173 |
+
|
1174 |
+
if add_downsample:
|
1175 |
+
self.downsamplers = nn.ModuleList(
|
1176 |
+
[
|
1177 |
+
Downsample2D(
|
1178 |
+
out_channels,
|
1179 |
+
use_conv=True,
|
1180 |
+
out_channels=out_channels,
|
1181 |
+
padding=downsample_padding,
|
1182 |
+
name="op",
|
1183 |
+
)
|
1184 |
+
]
|
1185 |
+
)
|
1186 |
+
else:
|
1187 |
+
self.downsamplers = None
|
1188 |
+
|
1189 |
+
self.gradient_checkpointing = False
|
1190 |
+
|
1191 |
+
def forward(
|
1192 |
+
self,
|
1193 |
+
hidden_states: torch.FloatTensor,
|
1194 |
+
temb: Optional[torch.FloatTensor] = None,
|
1195 |
+
scale: float = 1.0,
|
1196 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1197 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1198 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1199 |
+
output_states = ()
|
1200 |
+
|
1201 |
+
for resnet in self.resnets:
|
1202 |
+
if self.training and self.gradient_checkpointing:
|
1203 |
+
|
1204 |
+
def create_custom_forward(module):
|
1205 |
+
def custom_forward(*inputs):
|
1206 |
+
return module(*inputs)
|
1207 |
+
|
1208 |
+
return custom_forward
|
1209 |
+
|
1210 |
+
if is_torch_version(">=", "1.11.0"):
|
1211 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1212 |
+
create_custom_forward(resnet),
|
1213 |
+
hidden_states,
|
1214 |
+
temb,
|
1215 |
+
use_reentrant=False,
|
1216 |
+
)
|
1217 |
+
else:
|
1218 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1219 |
+
create_custom_forward(resnet), hidden_states, temb
|
1220 |
+
)
|
1221 |
+
else:
|
1222 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1223 |
+
|
1224 |
+
output_states = output_states + (hidden_states,)
|
1225 |
+
|
1226 |
+
if self.downsamplers is not None:
|
1227 |
+
for downsampler in self.downsamplers:
|
1228 |
+
hidden_states = downsampler(hidden_states, scale=scale)
|
1229 |
+
|
1230 |
+
output_states = output_states + (hidden_states,)
|
1231 |
+
|
1232 |
+
return hidden_states, output_states
|
1233 |
+
|
1234 |
+
|
1235 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1236 |
+
def __init__(
|
1237 |
+
self,
|
1238 |
+
in_channels: int,
|
1239 |
+
out_channels: int,
|
1240 |
+
prev_output_channel: int,
|
1241 |
+
temb_channels: int,
|
1242 |
+
resolution_idx: Optional[int] = None,
|
1243 |
+
dropout: float = 0.0,
|
1244 |
+
num_layers: int = 1,
|
1245 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
1246 |
+
resnet_eps: float = 1e-6,
|
1247 |
+
resnet_time_scale_shift: str = "default",
|
1248 |
+
resnet_act_fn: str = "swish",
|
1249 |
+
resnet_groups: int = 32,
|
1250 |
+
resnet_pre_norm: bool = True,
|
1251 |
+
num_attention_heads: int = 1,
|
1252 |
+
cross_attention_dim: int = 1280,
|
1253 |
+
output_scale_factor: float = 1.0,
|
1254 |
+
add_upsample: bool = True,
|
1255 |
+
dual_cross_attention: bool = False,
|
1256 |
+
use_linear_projection: bool = False,
|
1257 |
+
only_cross_attention: bool = False,
|
1258 |
+
upcast_attention: bool = False,
|
1259 |
+
attention_type: str = "default",
|
1260 |
+
):
|
1261 |
+
super().__init__()
|
1262 |
+
resnets = []
|
1263 |
+
attentions = []
|
1264 |
+
|
1265 |
+
self.has_cross_attention = True
|
1266 |
+
self.num_attention_heads = num_attention_heads
|
1267 |
+
|
1268 |
+
if isinstance(transformer_layers_per_block, int):
|
1269 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
1270 |
+
|
1271 |
+
for i in range(num_layers):
|
1272 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1273 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1274 |
+
|
1275 |
+
resnets.append(
|
1276 |
+
ResnetBlock2D(
|
1277 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1278 |
+
out_channels=out_channels,
|
1279 |
+
temb_channels=temb_channels,
|
1280 |
+
eps=resnet_eps,
|
1281 |
+
groups=resnet_groups,
|
1282 |
+
dropout=dropout,
|
1283 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1284 |
+
non_linearity=resnet_act_fn,
|
1285 |
+
output_scale_factor=output_scale_factor,
|
1286 |
+
pre_norm=resnet_pre_norm,
|
1287 |
+
)
|
1288 |
+
)
|
1289 |
+
if not dual_cross_attention:
|
1290 |
+
attentions.append(
|
1291 |
+
Transformer2DModel(
|
1292 |
+
num_attention_heads,
|
1293 |
+
out_channels // num_attention_heads,
|
1294 |
+
in_channels=out_channels,
|
1295 |
+
num_layers=transformer_layers_per_block[i],
|
1296 |
+
cross_attention_dim=cross_attention_dim,
|
1297 |
+
norm_num_groups=resnet_groups,
|
1298 |
+
use_linear_projection=use_linear_projection,
|
1299 |
+
only_cross_attention=only_cross_attention,
|
1300 |
+
upcast_attention=upcast_attention,
|
1301 |
+
attention_type=attention_type,
|
1302 |
+
)
|
1303 |
+
)
|
1304 |
+
else:
|
1305 |
+
attentions.append(
|
1306 |
+
DualTransformer2DModel(
|
1307 |
+
num_attention_heads,
|
1308 |
+
out_channels // num_attention_heads,
|
1309 |
+
in_channels=out_channels,
|
1310 |
+
num_layers=1,
|
1311 |
+
cross_attention_dim=cross_attention_dim,
|
1312 |
+
norm_num_groups=resnet_groups,
|
1313 |
+
)
|
1314 |
+
)
|
1315 |
+
self.attentions = nn.ModuleList(attentions)
|
1316 |
+
self.resnets = nn.ModuleList(resnets)
|
1317 |
+
|
1318 |
+
if add_upsample:
|
1319 |
+
self.upsamplers = nn.ModuleList(
|
1320 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1321 |
+
)
|
1322 |
+
else:
|
1323 |
+
self.upsamplers = None
|
1324 |
+
|
1325 |
+
self.gradient_checkpointing = False
|
1326 |
+
self.resolution_idx = resolution_idx
|
1327 |
+
|
1328 |
+
def forward(
|
1329 |
+
self,
|
1330 |
+
hidden_states: torch.FloatTensor,
|
1331 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1332 |
+
temb: Optional[torch.FloatTensor] = None,
|
1333 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1334 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1335 |
+
upsample_size: Optional[int] = None,
|
1336 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1337 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1338 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1339 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1340 |
+
) -> torch.FloatTensor:
|
1341 |
+
lora_scale = (
|
1342 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1343 |
+
if cross_attention_kwargs is not None
|
1344 |
+
else 1.0
|
1345 |
+
)
|
1346 |
+
is_freeu_enabled = (
|
1347 |
+
getattr(self, "s1", None)
|
1348 |
+
and getattr(self, "s2", None)
|
1349 |
+
and getattr(self, "b1", None)
|
1350 |
+
and getattr(self, "b2", None)
|
1351 |
+
)
|
1352 |
+
|
1353 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1354 |
+
# pop res hidden states
|
1355 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1356 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1357 |
+
|
1358 |
+
# FreeU: Only operate on the first two stages
|
1359 |
+
if is_freeu_enabled:
|
1360 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1361 |
+
self.resolution_idx,
|
1362 |
+
hidden_states,
|
1363 |
+
res_hidden_states,
|
1364 |
+
s1=self.s1,
|
1365 |
+
s2=self.s2,
|
1366 |
+
b1=self.b1,
|
1367 |
+
b2=self.b2,
|
1368 |
+
)
|
1369 |
+
|
1370 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1371 |
+
|
1372 |
+
if self.training and self.gradient_checkpointing:
|
1373 |
+
|
1374 |
+
def create_custom_forward(module, return_dict=None):
|
1375 |
+
def custom_forward(*inputs):
|
1376 |
+
if return_dict is not None:
|
1377 |
+
return module(*inputs, return_dict=return_dict)
|
1378 |
+
else:
|
1379 |
+
return module(*inputs)
|
1380 |
+
|
1381 |
+
return custom_forward
|
1382 |
+
|
1383 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1384 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1385 |
+
)
|
1386 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1387 |
+
create_custom_forward(resnet),
|
1388 |
+
hidden_states,
|
1389 |
+
temb,
|
1390 |
+
**ckpt_kwargs,
|
1391 |
+
)
|
1392 |
+
hidden_states = attn(
|
1393 |
+
hidden_states,
|
1394 |
+
encoder_hidden_states=encoder_hidden_states,
|
1395 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1396 |
+
attention_mask=attention_mask,
|
1397 |
+
encoder_attention_mask=encoder_attention_mask,
|
1398 |
+
return_dict=False,
|
1399 |
+
self_attn_block_embs=self_attn_block_embs,
|
1400 |
+
self_attn_block_embs_mode=self_attn_block_embs_mode,
|
1401 |
+
)[0]
|
1402 |
+
else:
|
1403 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1404 |
+
hidden_states = attn(
|
1405 |
+
hidden_states,
|
1406 |
+
encoder_hidden_states=encoder_hidden_states,
|
1407 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1408 |
+
attention_mask=attention_mask,
|
1409 |
+
encoder_attention_mask=encoder_attention_mask,
|
1410 |
+
return_dict=False,
|
1411 |
+
self_attn_block_embs=self_attn_block_embs,
|
1412 |
+
)[0]
|
1413 |
+
|
1414 |
+
if self.upsamplers is not None:
|
1415 |
+
for upsampler in self.upsamplers:
|
1416 |
+
hidden_states = upsampler(
|
1417 |
+
hidden_states, upsample_size, scale=lora_scale
|
1418 |
+
)
|
1419 |
+
|
1420 |
+
return hidden_states
|
1421 |
+
|
1422 |
+
|
1423 |
+
class UpBlock2D(nn.Module):
|
1424 |
+
def __init__(
|
1425 |
+
self,
|
1426 |
+
in_channels: int,
|
1427 |
+
prev_output_channel: int,
|
1428 |
+
out_channels: int,
|
1429 |
+
temb_channels: int,
|
1430 |
+
resolution_idx: Optional[int] = None,
|
1431 |
+
dropout: float = 0.0,
|
1432 |
+
num_layers: int = 1,
|
1433 |
+
resnet_eps: float = 1e-6,
|
1434 |
+
resnet_time_scale_shift: str = "default",
|
1435 |
+
resnet_act_fn: str = "swish",
|
1436 |
+
resnet_groups: int = 32,
|
1437 |
+
resnet_pre_norm: bool = True,
|
1438 |
+
output_scale_factor: float = 1.0,
|
1439 |
+
add_upsample: bool = True,
|
1440 |
+
):
|
1441 |
+
super().__init__()
|
1442 |
+
resnets = []
|
1443 |
+
|
1444 |
+
for i in range(num_layers):
|
1445 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1446 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1447 |
+
|
1448 |
+
resnets.append(
|
1449 |
+
ResnetBlock2D(
|
1450 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1451 |
+
out_channels=out_channels,
|
1452 |
+
temb_channels=temb_channels,
|
1453 |
+
eps=resnet_eps,
|
1454 |
+
groups=resnet_groups,
|
1455 |
+
dropout=dropout,
|
1456 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1457 |
+
non_linearity=resnet_act_fn,
|
1458 |
+
output_scale_factor=output_scale_factor,
|
1459 |
+
pre_norm=resnet_pre_norm,
|
1460 |
+
)
|
1461 |
+
)
|
1462 |
+
|
1463 |
+
self.resnets = nn.ModuleList(resnets)
|
1464 |
+
|
1465 |
+
if add_upsample:
|
1466 |
+
self.upsamplers = nn.ModuleList(
|
1467 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1468 |
+
)
|
1469 |
+
else:
|
1470 |
+
self.upsamplers = None
|
1471 |
+
|
1472 |
+
self.gradient_checkpointing = False
|
1473 |
+
self.resolution_idx = resolution_idx
|
1474 |
+
|
1475 |
+
def forward(
|
1476 |
+
self,
|
1477 |
+
hidden_states: torch.FloatTensor,
|
1478 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1479 |
+
temb: Optional[torch.FloatTensor] = None,
|
1480 |
+
upsample_size: Optional[int] = None,
|
1481 |
+
scale: float = 1.0,
|
1482 |
+
self_attn_block_embs: Optional[List[torch.Tensor]] = None,
|
1483 |
+
self_attn_block_embs_mode: Literal["read", "write"] = "write",
|
1484 |
+
) -> torch.FloatTensor:
|
1485 |
+
is_freeu_enabled = (
|
1486 |
+
getattr(self, "s1", None)
|
1487 |
+
and getattr(self, "s2", None)
|
1488 |
+
and getattr(self, "b1", None)
|
1489 |
+
and getattr(self, "b2", None)
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
for resnet in self.resnets:
|
1493 |
+
# pop res hidden states
|
1494 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1495 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1496 |
+
|
1497 |
+
# FreeU: Only operate on the first two stages
|
1498 |
+
if is_freeu_enabled:
|
1499 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1500 |
+
self.resolution_idx,
|
1501 |
+
hidden_states,
|
1502 |
+
res_hidden_states,
|
1503 |
+
s1=self.s1,
|
1504 |
+
s2=self.s2,
|
1505 |
+
b1=self.b1,
|
1506 |
+
b2=self.b2,
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1510 |
+
|
1511 |
+
if self.training and self.gradient_checkpointing:
|
1512 |
+
|
1513 |
+
def create_custom_forward(module):
|
1514 |
+
def custom_forward(*inputs):
|
1515 |
+
return module(*inputs)
|
1516 |
+
|
1517 |
+
return custom_forward
|
1518 |
+
|
1519 |
+
if is_torch_version(">=", "1.11.0"):
|
1520 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1521 |
+
create_custom_forward(resnet),
|
1522 |
+
hidden_states,
|
1523 |
+
temb,
|
1524 |
+
use_reentrant=False,
|
1525 |
+
)
|
1526 |
+
else:
|
1527 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1528 |
+
create_custom_forward(resnet), hidden_states, temb
|
1529 |
+
)
|
1530 |
+
else:
|
1531 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1532 |
+
|
1533 |
+
if self.upsamplers is not None:
|
1534 |
+
for upsampler in self.upsamplers:
|
1535 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
1536 |
+
|
1537 |
+
return hidden_states
|