jmanhype commited on
Commit
06e9d12
·
0 Parent(s):

Initial commit without binary files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +65 -0
  2. .gitmodules +11 -0
  3. CHANGES +5 -0
  4. Dockerfile +18 -0
  5. LICENSE +175 -0
  6. MMCM +1 -0
  7. README-zh.md +465 -0
  8. README.md +37 -0
  9. configs/model/T2I_all_model.py +15 -0
  10. configs/model/ip_adapter.py +66 -0
  11. configs/model/lcm_model.py +17 -0
  12. configs/model/motion_model.py +22 -0
  13. configs/model/negative_prompt.py +32 -0
  14. configs/model/referencenet.py +14 -0
  15. configs/tasks/example.yaml +210 -0
  16. controlnet_aux +1 -0
  17. data/models/musev_structure.png +0 -0
  18. data/models/parallel_denoise.png +0 -0
  19. diffusers +1 -0
  20. environment.yml +312 -0
  21. musev/__init__.py +9 -0
  22. musev/auto_prompt/__init__.py +0 -0
  23. musev/auto_prompt/attributes/__init__.py +8 -0
  24. musev/auto_prompt/attributes/attr2template.py +127 -0
  25. musev/auto_prompt/attributes/attributes.py +227 -0
  26. musev/auto_prompt/attributes/human.py +424 -0
  27. musev/auto_prompt/attributes/render.py +33 -0
  28. musev/auto_prompt/attributes/style.py +12 -0
  29. musev/auto_prompt/human.py +40 -0
  30. musev/auto_prompt/load_template.py +37 -0
  31. musev/auto_prompt/util.py +25 -0
  32. musev/data/__init__.py +0 -0
  33. musev/data/data_util.py +681 -0
  34. musev/logging.conf +32 -0
  35. musev/models/__init__.py +3 -0
  36. musev/models/attention.py +431 -0
  37. musev/models/attention_processor.py +750 -0
  38. musev/models/controlnet.py +399 -0
  39. musev/models/embeddings.py +87 -0
  40. musev/models/facein_loader.py +120 -0
  41. musev/models/ip_adapter_face_loader.py +179 -0
  42. musev/models/ip_adapter_loader.py +340 -0
  43. musev/models/referencenet.py +1216 -0
  44. musev/models/referencenet_loader.py +124 -0
  45. musev/models/resnet.py +135 -0
  46. musev/models/super_model.py +253 -0
  47. musev/models/temporal_transformer.py +308 -0
  48. musev/models/text_model.py +40 -0
  49. musev/models/transformer_2d.py +445 -0
  50. musev/models/unet_2d_blocks.py +1537 -0
.gitignore ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Checkpoints
24
+ checkpoints/
25
+
26
+ # Logs
27
+ *.log
28
+ logs/
29
+ tensorboard/
30
+
31
+ # Environment
32
+ .env
33
+ .venv
34
+ env/
35
+ venv/
36
+ ENV/
37
+
38
+ # IDE
39
+ .idea/
40
+ .vscode/
41
+ *.swp
42
+ *.swo
43
+
44
+ # OS
45
+ .DS_Store
46
+ Thumbs.db
47
+
48
+ # Large files
49
+ data/result_video/
50
+ *.mp4
51
+ *.png
52
+ *.jpg
53
+ *.jpeg
54
+ *.gif
55
+ *.webp
56
+ *.avi
57
+ *.mov
58
+ *.mkv
59
+ *.flv
60
+ *.wmv
61
+
62
+ # Demo and source files
63
+ data/demo/
64
+ data/source_video/
65
+ data/images/
.gitmodules ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "MMCM"]
2
+ path = MMCM
3
+ url = https://github.com/TMElyralab/MMCM.git
4
+ [submodule "controlnet_aux"]
5
+ path = controlnet_aux
6
+ url = https://github.com/TMElyralab/controlnet_aux.git
7
+ branch = tme
8
+ [submodule "diffusers"]
9
+ path = diffusers
10
+ url = https://github.com/TMElyralab/diffusers.git
11
+ branch = tme
CHANGES ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Version 1.0.0 (2024.03.27)
2
+
3
+ * init musev, support video generation with text and image
4
+ * controlnet_aux: enrich interface and function of dwpose.
5
+ * diffusers: controlnet support latent instead of images only.
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM anchorxia/musev:1.0.0
2
+
3
+ #MAINTAINER 维护者信息
4
+ LABEL MAINTAINER="anchorxia"
5
+ LABEL Email="[email protected]"
6
+ LABEL Description="musev gpu runtime image, base docker is pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel"
7
+ ARG DEBIAN_FRONTEND=noninteractive
8
+
9
+ USER root
10
+
11
+ SHELL ["/bin/bash", "--login", "-c"]
12
+
13
+ RUN . /opt/conda/etc/profile.d/conda.sh \
14
+ && echo "source activate musev" >> ~/.bashrc \
15
+ && conda activate musev \
16
+ && conda env list \
17
+ && pip --no-cache-dir install cuid gradio==4.12 spaces
18
+ USER root
LICENSE ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ MIT License
3
+
4
+ Copyright (c) 2024 Tencent Music Entertainment Group
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+
24
+
25
+ Other dependencies and licenses:
26
+
27
+
28
+ Open Source Software Licensed under the MIT License:
29
+ --------------------------------------------------------------------
30
+ 1. BriVL-BUA-applications
31
+ Files:https://github.com/chuhaojin/BriVL-BUA-applications
32
+ License:MIT License
33
+ Copyright (c) 2021 chuhaojin
34
+ For details:https://github.com/chuhaojin/BriVL-BUA-applications/blob/master/LICENSE
35
+
36
+ 2.deep-person-reid
37
+ Files:https://github.com/KaiyangZhou/deep-person-reid
38
+ License:MIT License
39
+ Copyright (c) 2018 Kaiyang Zhou
40
+ For details:https://github.com/KaiyangZhou/deep-person-reid/blob/master/LICENSE
41
+
42
+
43
+ Open Source Software Licensed under the Apache License Version 2.0:
44
+ --------------------------------------------------------------------
45
+ 1. diffusers
46
+ Files:https://github.com/huggingface/diffusers
47
+ License:Apache License 2.0
48
+ Copyright 2024 The HuggingFace Team. All rights reserved.
49
+ For details:https://github.com/huggingface/diffusers/blob/main/LICENSE
50
+ https://github.com/huggingface/diffusers/blob/main/setup.py
51
+
52
+
53
+ 2. controlnet_aux
54
+ Files: https://github.com/huggingface/controlnet_aux
55
+ License: Apache License 2.0
56
+ Copyright 2023 The HuggingFace Team. All rights reserved.
57
+ For details: https://github.com/huggingface/controlnet_aux/blob/master/LICENSE.txt
58
+ https://github.com/huggingface/controlnet_aux/blob/master/setup.py
59
+
60
+ 3. decord
61
+ Files:https://github.com/dmlc/decord
62
+ License:Apache License 2.0
63
+ For details:https://github.com/dmlc/decord/blob/master/LICENSE
64
+
65
+
66
+ Terms of the Apache License Version 2.0:
67
+ --------------------------------------------------------------------
68
+ Apache License
69
+
70
+ Version 2.0, January 2004
71
+
72
+ http://www.apache.org/licenses/
73
+
74
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
75
+ 1. Definitions.
76
+
77
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
78
+
79
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
80
+
81
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
82
+
83
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
84
+
85
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
86
+
87
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
88
+
89
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
90
+
91
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
92
+
93
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
94
+
95
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
96
+
97
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
98
+
99
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
100
+
101
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
102
+
103
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
104
+
105
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
106
+
107
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
108
+
109
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
110
+
111
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
112
+
113
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
114
+
115
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
116
+
117
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
118
+
119
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
120
+
121
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
122
+
123
+ END OF TERMS AND CONDITIONS
124
+
125
+
126
+
127
+ Open Source Software Licensed under the BSD 3-Clause License:
128
+ --------------------------------------------------------------------
129
+ 1. pynvml
130
+ Files:https://github.com/gpuopenanalytics/pynvml/tree/master
131
+ License:BSD 3-Clause
132
+ Copyright (c) 2011-2021, NVIDIA Corporation.
133
+ All rights reserved.
134
+ For details:https://github.com/gpuopenanalytics/pynvml/blob/master/LICENSE.txt
135
+
136
+
137
+ Terms of the BSD 3-Clause License:
138
+ --------------------------------------------------------------------
139
+ Redistribution and use in source and binary forms, with or without
140
+ modification, are permitted provided that the following conditions are met:
141
+
142
+ * Redistributions of source code must retain the above copyright notice, this
143
+ list of conditions and the following disclaimer.
144
+
145
+ * Redistributions in binary form must reproduce the above copyright notice,
146
+ this list of conditions and the following disclaimer in the documentation
147
+ and/or other materials provided with the distribution.
148
+
149
+ * Neither the name of the copyright holder nor the names of its
150
+ contributors may be used to endorse or promote products derived from
151
+ this software without specific prior written permission.
152
+
153
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
154
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
155
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
156
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
157
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
158
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
159
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
160
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
161
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
162
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
163
+
164
+
165
+ Other Open Source Software:
166
+ --------------------------------------------------------------------
167
+ 1.SceneSeg
168
+ Files:https://github.com/AnyiRao/SceneSeg/tree/master
169
+
170
+
171
+
172
+
173
+
174
+
175
+
MMCM ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1e2b6e6a848f0f116e8acaf0621c2ee64d3642ce
README-zh.md ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MuseV [English](README.md) [中文](README-zh.md)
2
+
3
+ <font size=5>MuseV:基于视觉条件并行去噪的无限长度和高保真虚拟人视频生成。
4
+ </br>
5
+ Zhiqiang Xia <sup>\*</sup>,
6
+ Zhaokang Chen<sup>\*</sup>,
7
+ Bin Wu<sup>†</sup>,
8
+ Chao Li,
9
+ Kwok-Wai Hung,
10
+ Chao Zhan,
11
+ Yingjie He,
12
+ Wenjiang Zhou
13
+ (<sup>*</sup>co-first author, <sup>†</sup>Corresponding Author, [email protected])
14
+ </font>
15
+
16
+ **[github](https://github.com/TMElyralab/MuseV)** **[huggingface](https://huggingface.co/TMElyralab/MuseV)** **[HuggingfaceSpace](https://huggingface.co/spaces/AnchorFake/MuseVDemo)** **[project](https://tmelyralab.github.io/MuseV_Page/)** **Technical report (comming soon)**
17
+
18
+
19
+ 我们在2023年3月相信扩散模型可以模拟世界,也开始基于扩散模型研发世界视觉模拟器。`MuseV`是在 2023 年 7 月左右实现的一个里程碑。受到 Sora 进展的启发,我们决定开源 MuseV。MuseV 站在开源的肩膀上成长,也希望能够借此反馈社区。接下来,我们将转向有前景的扩散+变换器方案。
20
+
21
+ 我们已经发布 <a href="https://github.com/TMElyralab/MuseTalk" style="font-size:24px; color:red;">MuseTalk</a>. `MuseTalk`是一个实时高质量的唇同步模型,可与 `MuseV` 一起构建完整的`虚拟人生成解决方案`。请保持关注!
22
+
23
+ :new: 我们新发布了<a href="https://github.com/TMElyralab/MusePose" style="font-size:24px; color:red;">MusePose</a>。 MusePose是一个用于虚拟人物的图像到视频生成框架,它可以根据控制信号(姿态)生成视频。结合 MuseV 和 MuseTalk,我们希望社区能够加入我们,一起迈向一个愿景:能够端到端生成具有全身运动和交互能力的虚拟人物。
24
+
25
+ # 概述
26
+
27
+ `MuseV` 是基于扩散模型的虚拟人视频生成框架,具有以下特点:
28
+
29
+ 1. 支持使用新颖的视觉条件并行去噪方案进行无限长度生成,不会再有误差累计的问题,尤其适用于固定相机位的场景。
30
+ 1. 提供了基于人物类型数据集训练的虚拟人视频生成预训练模型。
31
+ 1. 支持图像到视频、文本到图像到视频、视频到视频的生成。
32
+ 1. 兼容 `Stable Diffusion` 文图生成生态系统,包括 `base_model`、`lora`、`controlnet` 等。
33
+ 1. 支持多参考图像技术,包括 `IPAdapter`、`ReferenceOnly`、`ReferenceNet`、`IPAdapterFaceID`。
34
+ 1. 我们后面也会推出训练代码。
35
+
36
+ # 重要更新
37
+ 1. `musev_referencenet_pose`: `unet`, `ip_adapter` 的模型名字指定错误,请使用 `musev_referencenet_pose`而不是`musev_referencenet`,请使用最新的main分支。
38
+
39
+ # 进展
40
+ - [2024年3月27日] 发布 `MuseV` 项目和训练好的模型 `musev`、`muse_referencenet`、`muse_referencenet_pose`。
41
+ - [03/30/2024] 在 huggingface space 上新增 [gui](https://huggingface.co/spaces/AnchorFake/MuseVDemo) 交互方式来生成视频.
42
+
43
+ ## 模型
44
+ ### 模型结构示意图
45
+ ![model_structure](./data/models/musev_structure.png)
46
+ ### 并行去噪算法示意图
47
+ ![parallel_denoise](./data//models/parallel_denoise.png)
48
+
49
+ ## 测试用例
50
+ 生成结果的所有帧直接由`MuseV`生成,没有时序超分辨、空间超分辨等任何后处理。
51
+ 更多测试结果请看[MuseVPage]()
52
+
53
+ <!-- # TODO: // use youtu video link? -->
54
+ 以下所有测试用例都维护在 `configs/tasks/example.yaml`,可以直接运行复现。
55
+ **[project](https://tmelyralab.github.io/)** 有更多测试用例,包括直接生成的、一两分钟的长视频。
56
+
57
+ ### 输入文本、图像的视频生成
58
+ #### 人类
59
+ <table class="center">
60
+ <tr style="font-weight: bolder;text-align:center;">
61
+ <td width="50%">image</td>
62
+ <td width="45%">video </td>
63
+ <td width="5%">prompt</td>
64
+ </tr>
65
+
66
+ <tr>
67
+ <td>
68
+ <img src=./data/images/yongen.jpeg width="400">
69
+ </td>
70
+ <td >
71
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/732cf1fd-25e7-494e-b462-969c9425d277" width="100" controls preload></video>
72
+ </td>
73
+ <td>(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
74
+ </td>
75
+ </tr>
76
+
77
+ <tr>
78
+ <td>
79
+ <img src=./data/images/seaside4.jpeg width="400">
80
+ </td>
81
+ <td>
82
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/9b75a46c-f4e6-45ef-ad02-05729f091c8f" width="100" controls preload></video>
83
+ </td>
84
+ <td>
85
+ (masterpiece, best quality, highres:1), peaceful beautiful sea scene
86
+ </td>
87
+ </tr>
88
+ <tr>
89
+ <td>
90
+ <img src=./data/images/seaside_girl.jpeg width="400">
91
+ </td>
92
+ <td>
93
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/d0f3b401-09bf-4018-81c3-569ec24a4de9" width="100" controls preload></video>
94
+ </td>
95
+ <td>
96
+ (masterpiece, best quality, highres:1), peaceful beautiful sea scene
97
+ </td>
98
+ </tr>
99
+ <!-- guitar -->
100
+ <tr>
101
+ <td>
102
+ <img src=./data/images/boy_play_guitar.jpeg width="400">
103
+ </td>
104
+ <td>
105
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/61bf955e-7161-44c8-a498-8811c4f4eb4f" width="100" controls preload></video>
106
+ </td>
107
+ <td>
108
+ (masterpiece, best quality, highres:1), playing guitar
109
+ </td>
110
+ </tr>
111
+ <tr>
112
+ <td>
113
+ <img src=./data/images/girl_play_guitar2.jpeg width="400">
114
+ </td>
115
+ <td>
116
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/40982aa7-9f6a-4e44-8ef6-3f185d284e6a" width="100" controls preload></video>
117
+ </td>
118
+ <td>
119
+ (masterpiece, best quality, highres:1), playing guitar
120
+ </td>
121
+ </tr>
122
+ <!-- famous people -->
123
+ <tr>
124
+ <td>
125
+ <img src=./data/images/dufu.jpeg width="400">
126
+ </td>
127
+ <td>
128
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/28294baa-b996-420f-b1fb-046542adf87d" width="100" controls preload></video>
129
+ </td>
130
+ <td>
131
+ (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
132
+ </td>
133
+ </tr>
134
+
135
+ <tr>
136
+ <td>
137
+ <img src=./data/images/Mona_Lisa.jpg width="400">
138
+ </td>
139
+ <td>
140
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/1ce11da6-14c6-4dcd-b7f9-7a5f060d71fb" width="100" controls preload></video>
141
+ </td>
142
+ <td>
143
+ (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
144
+ soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
145
+ </td>
146
+ </tr>
147
+ </table >
148
+
149
+ #### 场景
150
+ <table class="center">
151
+ <tr style="font-weight: bolder;text-align:center;">
152
+ <td width="35%">image</td>
153
+ <td width="50%">video</td>
154
+ <td width="15%">prompt</td>
155
+ </tr>
156
+
157
+ <tr>
158
+ <td>
159
+ <img src=./data/images/waterfall4.jpeg width="400">
160
+ </td>
161
+ <td>
162
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/852daeb6-6b58-4931-81f9-0dddfa1b4ea5" width="100" controls preload></video>
163
+ </td>
164
+ <td>
165
+ (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
166
+ endless waterfall
167
+ </td>
168
+ </tr>
169
+
170
+ <tr>
171
+ <td>
172
+ <img src=./data/images/seaside2.jpeg width="400">
173
+ </td>
174
+ <td>
175
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/4a4d527a-6203-411f-afe9-31c992d26816" width="100" controls preload></video>
176
+ </td>
177
+ <td>(masterpiece, best quality, highres:1), peaceful beautiful sea scene
178
+ </td>
179
+ </tr>
180
+ </table >
181
+
182
+ ### 输入视频条件的视频生成
183
+ 当前生成模式下,需要参考视频的首帧条件和参考图像的首帧条件对齐,不然会破坏首帧的信息,效果会更差。所以一般生成流程是
184
+ 1. 确定参考视频;
185
+ 2. 用参考视频的首帧走图生图、controlnet流程,可以使用`MJ`等各种平台;
186
+ 3. 拿2中的生成图、参考视频用MuseV生成视频;
187
+ 4.
188
+ **pose2video**
189
+
190
+ `duffy` 的测试用例中,视觉条件帧的姿势与控制视频的第一帧不对齐。需要`posealign` 将解决这个问题。
191
+
192
+ <table class="center">
193
+ <tr style="font-weight: bolder;text-align:center;">
194
+ <td width="25%">image</td>
195
+ <td width="65%">video</td>
196
+ <td width="10%">prompt</td>
197
+ </tr>
198
+ <tr>
199
+ <td>
200
+ <img src=./data/images/spark_girl.png width="200">
201
+ <img src=./data/images/cyber_girl.png width="200">
202
+ </td>
203
+ <td>
204
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/484cc69d-c316-4464-a55b-3df929780a8e" width="400" controls preload></video>
205
+ </td>
206
+ <td>
207
+ (masterpiece, best quality, highres:1) , a girl is dancing, animation
208
+ </td>
209
+ </tr>
210
+ <tr>
211
+ <td>
212
+ <img src=./data/images/duffy.png width="400">
213
+ </td>
214
+ <td>
215
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/c44682e6-aafc-4730-8fc1-72825c1bacf2" width="400" controls preload></video>
216
+ </td>
217
+ <td>
218
+ (masterpiece, best quality, highres:1), is dancing, animation
219
+ </td>
220
+ </tr>
221
+ </table >
222
+
223
+ ### MuseTalk
224
+
225
+ `talk`的角色`孙昕荧`著名的网络大V,可以在 [抖音](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8) 关注。
226
+
227
+ <table class="center">
228
+ <tr style="font-weight: bolder;">
229
+ <td width="35%">name</td>
230
+ <td width="50%">video</td>
231
+ </tr>
232
+
233
+ <tr>
234
+ <td>
235
+ talk
236
+ </td>
237
+ <td>
238
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/951188d1-4731-4e7f-bf40-03cacba17f2f" width="100" controls preload></video>
239
+ </td>
240
+ <tr>
241
+ <td>
242
+ sing
243
+ </td>
244
+ <td>
245
+ <video src="https://github.com/TMElyralab/MuseV/assets/163980830/50b8ffab-9307-4836-99e5-947e6ce7d112" width="100" controls preload></video>
246
+ </td>
247
+ </tr>
248
+ </table >
249
+
250
+
251
+ # 待办事项:
252
+ - [ ] 技术报告(即将推出)。
253
+ - [ ] 训练代码。
254
+ - [ ] 扩散变换生成框架。
255
+ - [ ] `posealign` 模块。
256
+
257
+ # 快速入门
258
+ 准备 Python 环境并安装额外的包,如 `diffusers`、`controlnet_aux`、`mmcm`。
259
+
260
+ ## 第三方整合版
261
+ 一些第三方的整合,方便大家安装、使用,感谢第三方的工作。
262
+ 同时也希望注意,我们没有对第���方的支持做验证、维护和后续更新,具体效果请以本项目为准。
263
+ ### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseV)
264
+ ### [windows整合包](https://www.bilibili.com/video/BV1ux4y1v7pF/?vd_source=fe03b064abab17b79e22a692551405c3)
265
+ netdisk:https://www.123pan.com/s/Pf5Yjv-Bb9W3.html
266
+ code: glut
267
+
268
+ ## 准备环境
269
+ 建议您优先使用 `docker` 来准备 Python 环境。
270
+
271
+ ### 准备 Python 环境
272
+ **注意**:我们只测试了 Docker,使用 conda 或其他环境可能会遇到问题。我们将尽力解决。但依然请优先使用 `docker`。
273
+
274
+ #### 方法 1:使用 Docker
275
+ 1. 拉取 Docker 镜像
276
+ ```bash
277
+ docker pull anchorxia/musev:latest
278
+ ```
279
+ 2. 运行 Docker 容器
280
+ ```bash
281
+ docker run --gpus all -it --entrypoint /bin/bash anchorxia/musev:latest
282
+ ```
283
+ docker启动后默认的 conda 环境是 `musev`。
284
+
285
+ #### 方法 2:使用 conda
286
+ 从 environment.yaml 创建 conda 环境
287
+ ```
288
+ conda env create --name musev --file ./environment.yml
289
+ ```
290
+ #### 方法 3:使用 pip requirements
291
+ ```
292
+ pip install -r requirements.txt
293
+ ```
294
+ #### 准备 [openmmlab](https://openmmlab.com/) 包
295
+ 如果不使用 Docker方式,还需要额外安装 mmlab 包。
296
+ ```bash
297
+ pip install --no-cache-dir -U openmim
298
+ mim install mmengine
299
+ mim install "mmcv>=2.0.1"
300
+ mim install "mmdet>=3.1.0"
301
+ mim install "mmpose>=1.1.0"
302
+ ```
303
+
304
+ ### 准备我们开发的包
305
+ #### 下载
306
+ ```bash
307
+ git clone --recursive https://github.com/TMElyralab/MuseV.git
308
+ ```
309
+ #### 准备 PYTHONPATH
310
+ ```bash
311
+ current_dir=$(pwd)
312
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV
313
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/MMCM
314
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/diffusers/src
315
+ export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/controlnet_aux/src
316
+ cd MuseV
317
+ ```
318
+
319
+ 1. `MMCM`:多媒体、跨模态处理包。
320
+ 1. `diffusers`:基于 [diffusers](https://github.com/huggingface/diffusers) 修改的 diffusers 包。
321
+ 1. `controlnet_aux`:基于 [controlnet_aux](https://github.com/TMElyralab/controlnet_aux) 修改的包。
322
+
323
+
324
+ ## 下载模型
325
+ ```bash
326
+ git clone https://huggingface.co/TMElyralab/MuseV ./checkpoints
327
+ ```
328
+ - `motion`:多个版本的视频生成模型。使用小数据集 `ucf101` 和小 `webvid` 数据子集进行训练,约 60K 个视频文本对。GPU 内存消耗测试在 `resolution` $=512*512,`time_size=12`。
329
+ - `musev/unet`:这个版本 仅训练 `unet` 运动模块。推断 `GPU 内存消耗` $\approx 8G$。
330
+ - `musev_referencenet`:这个版本训练 `unet` 运动模块、`referencenet`、`IPAdapter`。推断 `GPU 内存消耗` $\approx 12G$。
331
+ - `unet`:`motion` 模块,具有 `Attention` 层中的 `to_k`、`to_v`,参考 `IPAdapter`。
332
+ - `referencenet`:类似于 `AnimateAnyone`。
333
+ - `ip_adapter_image_proj.bin`:图像特征变换层,参考 `IPAdapter`。
334
+ - `musev_referencenet_pose`:这个版本基于 `musev_referencenet`,固定 `referencenet` 和 `controlnet_pose`,训练 `unet motion` 和 `IPAdapter`。推断 `GPU 内存消耗` $\approx 12G$。
335
+ - `t2i/sd1.5`:text2image 模型,在训练运动模块时参数被冻结。
336
+ - `majicmixRealv6Fp16`:示例,可以替换为其他 t2i 基础。从 [majicmixRealv6Fp16](https://civitai.com/models/43331/majicmix-realistic) 下载。
337
+ - `fantasticmix_v10`: 可在 [fantasticmix_v10](https://civitai.com/models/22402?modelVersionId=26744) 下载。
338
+ - `IP-Adapter/models`:从 [IPAdapter](https://huggingface.co/h94/IP-Adapter/tree/main) 下载。
339
+ - `image_encoder`:视觉特征抽取模型。
340
+ - `ip-adapter_sd15.bin`:原始 IPAdapter 模型预训练权重。
341
+ - `ip-adapter-faceid_sd15.bin`:原始 IPAdapter 模型预训练权重。
342
+
343
+ ## 推理
344
+
345
+ ### 准备模型路径
346
+ 当使用示例推断命令运行示例任务时,可以跳过此步骤。
347
+ 该模块主要是在配置文件中设置模型路径和缩写,以在推断脚本中使用简单缩写而不是完整路径。
348
+ - T2I SD:参考 `musev/configs/model/T2I_all_model.py`
349
+ - 运动 Unet:参考 `musev/configs/model/motion_model.py`
350
+ - 任务:参考 `musev/configs/tasks/example.yaml`
351
+
352
+ ### musev_referencenet
353
+ #### 输入文本、图像的视频生成
354
+ ```bash
355
+ python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --time_size 12 --fps 12
356
+ ```
357
+ **通用参数**:
358
+ - `test_data_path`:测试用例任务路径
359
+ - `target_datas`:如果 `test_data_path` 中的 `name` 在 `target_datas` 中,则只运行这些子任务。`sep` 是 `,`;
360
+ - `sd_model_cfg_path`:T2I sd 模型路径,模型配置路径或模型路径。
361
+ - `sd_model_name`:sd 模型名称,用于在 `sd_model_cfg_path` 中选择完整模型��径。使用 `,` 分隔的多个模型名称,或 `all`。
362
+ - `unet_model_cfg_path`:运动 unet 模型配置路径或模型路径。
363
+ - `unet_model_name`:unet 模型名称,用于获取 `unet_model_cfg_path` 中的模型路径,并在 `musev/models/unet_loader.py` 中初始化 unet 类实例。使用 `,` 分隔的多个模型名称,或 `all`。如果 `unet_model_cfg_path` 是模型路径,则 `unet_name` 必须在 `musev/models/unet_loader.py` 中支持。
364
+ - `time_size`:扩散模型每次生成一个片段,这里是一个片段的帧数。默认为 `12`。
365
+ - `n_batch`:首尾相连方式生成总片段数,$total\_frames=n\_batch * time\_size + n\_viscond$,默认为 `1`。
366
+ - `context_frames`: 并行去噪子窗口一次生成的帧数。如果 `time_size` > `context_frame`,则会启动并行去噪逻辑, `time_size` 窗口会分成多个子窗口进行并行去噪。默认为 `12`。
367
+
368
+ 生成**长视频**,有两种方法,可以共同使用:
369
+ 1. `视觉条件并行去噪`:设置 `n_batch=1`,`time_size` = 想要的所有帧。
370
+ 2. `传统的首尾相连方式`:设置 `time_size` = `context_frames` = 一次片段的帧数 (`12`),`context_overlap` = 0。会首尾相连方式生成`n_batch`片段数,首尾相连存在误差累计,当`n_batch`越大,最后的结果越差。
371
+
372
+
373
+ **模型参数**:
374
+ 支持 `referencenet`、`IPAdapter`、`IPAdapterFaceID`、`Facein`。
375
+ - `referencenet_model_name`:`referencenet` 模型名称。
376
+ - `ImageClipVisionFeatureExtractor`:`ImageEmbExtractor` 名称,在 `IPAdapter` 中提取视觉特征。
377
+ - `vision_clip_model_path`:`ImageClipVisionFeatureExtractor` 模型路径。
378
+ - `ip_adapter_model_name`:来自 `IPAdapter` 的,它是 `ImagePromptEmbProj`,与 `ImageEmbExtractor` 一起使用。
379
+ - `ip_adapter_face_model_name`:`IPAdapterFaceID`,来自 `IPAdapter`,应该设置 `face_image_path`。
380
+
381
+ **一些影响运动范围和生成结果的参数**:
382
+ - `video_guidance_scale`:类似于 text2image,控制 cond 和 uncond 之间的影响,影响较大,默认为 `3.5`。
383
+ - `use_condition_image`:是否使用给定的第一帧进行视频生成, 默认 `True`。
384
+ - `redraw_condition_image`:是否重新绘制给定的第一帧图像。
385
+ - `video_negative_prompt`:配置文件中全 `negative_prompt` 的缩写。默认为 `V2`。
386
+
387
+
388
+ #### 输入视频的视频生成
389
+ ```bash
390
+ python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
391
+ ```
392
+ **一些重要参数**
393
+
394
+ 大多数参数与 `musev_text2video` 相同。`video2video` 的特殊参数有:
395
+ 1. 需要在 `test_data` 中设置 `video_path`。现在支持 `rgb video` 和 `controlnet_middle_video`。
396
+ - `which2video`: 参考视频类型。 如果是 `video_middle`,则只使用类似`pose`、`depth`的 `video_middle`;如果是 `video`, 视频本身也会参与视频噪声初始化,类似于`img2imge`。
397
+ - `controlnet_name`:是否使用 `controlnet condition`,例如 `dwpose,depth`, pose的话 优先建议使用`dwpose_body_hand`。
398
+ - `video_is_middle`:`video_path` 是 `rgb video` 还是 `controlnet_middle_video`。可以为 `test_data_path` 中的每个 `test_data` 设置。
399
+ - `video_has_condition`:condtion_images 是否与 video_path 的第一帧对齐。如果不是,则首先生成 `condition_images`,然后与参考视频拼接对齐。 目前仅支持参考视频是`video_is_middle=True`,可`test_data` 设置。
400
+
401
+ 所有 `controlnet_names` 维护在 [mmcm](https://github.com/TMElyralab/MMCM/blob/main/mmcm/vision/feature_extractor/controlnet.py#L513)
402
+ ```python
403
+ ['pose', 'pose_body', 'pose_hand', 'pose_face', 'pose_hand_body', 'pose_hand_face', 'dwpose', 'dwpose_face', 'dwpose_hand', 'dwpose_body', 'dwpose_body_hand', 'canny', 'tile', 'hed', 'hed_scribble', 'depth', 'pidi', 'normal_bae', 'lineart', 'lineart_anime', 'zoe', 'sam', 'mobile_sam', 'leres', 'content', 'face_detector']
404
+ ```
405
+
406
+ ### musev_referencenet_pose
407
+ 仅用于 `pose2video`
408
+ 基于 `musev_referencenet` 训练,固定 `referencenet`、`pose-controlnet` 和 `T2I`,训练 `motion` 模块和 `IPAdapter`。
409
+ ```bash
410
+ python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet_pose --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet_pose -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
411
+ ```
412
+
413
+ ### musev
414
+ 仅有动作模块,没有 referencenet,需要更少的 GPU 内存。
415
+ #### 文本到视频
416
+ ```bash
417
+ python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --time_size 12 --fps 12
418
+ ```
419
+ #### 视频到视频
420
+ ```bash
421
+ python scripts/inference/video2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12
422
+ ```
423
+
424
+ ### Gradio 演示
425
+ MuseV 提供 gradio 脚本,可在本地机器上生成 GUI,方便生成视频。
426
+
427
+ ```bash
428
+ cd scripts/gradio
429
+ python app.py
430
+ ```
431
+
432
+ # 致谢
433
+ 1. MuseV 开发过程中参考学习了很多开源工作 [TuneAVideo](https://github.com/showlab/Tune-A-Video)、[diffusers](https://github.com/huggingface/diffusers)、[Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone/tree/master/src/pipelines)、[animatediff](https://github.com/guoyww/AnimateDiff)、[IP-Adapter](https://github.com/tencent-ailab/IP-Adapter)、[AnimateAnyone](https://arxiv.org/abs/2311.17117)、[VideoFusion](https://arxiv.org/abs/2303.08320) 和 [insightface](https://github.com/deepinsight/insightface)。
434
+ 2. MuseV 基于 `ucf101` 和 `webvid` 数据集构建。
435
+
436
+ 感谢开源社区的贡献!
437
+
438
+ # 限制
439
+
440
+ `MuseV` 仍然存在很多待优化项,包括:
441
+
442
+ 1. 缺乏泛化能力。对视觉条件帧敏感,有些视觉条件图像表现良好,有些表现不佳。有些预训练的 t2i 模型表现良好,有些表现不佳。
443
+ 1. 有限的视频生成类型和有限的动作范围,部分原因是训练数据类型有限。发布的 `MuseV` 已经在大约 6 万对分辨率为 `512*320` 的人类文本视频对上进行了训练。`MuseV` 在较低分辨率下具有更大的动作范围,但视频质量较低。`MuseV` 在高分辨率下画质很好、但动作范围较小。在更大、更高分辨率、更高质量的文本视频数据集上进行训练可能会使 `MuseV` 更好。
444
+ 1. 因为使用 `webvid` 训练会有水印问题。使用没有水印的、更干净的数据集可能会解决这个问题。
445
+ 1. 有限类型的长视频生成。视觉条件并行去噪可以解决视频生成的累积误差,但当前的方法只适用于相对固定的摄像机场景。
446
+ 1. referencenet 和 IP-Adapter 训练不足,因为时间有限和资源有限。
447
+ 1. 代码结构不够完善。`MuseV` 支持丰富而动态的功能,但代码复杂且未经过重构。熟悉需要时间。
448
+
449
+
450
+ <!-- # Contribution 暂时不需要组织开源共建 -->
451
+ # 引用
452
+ ```bib
453
+ @article{musev,
454
+ title={MuseV: 基于视觉条件的并行去噪的无限长度和高保真虚拟人视频生成},
455
+ author={Xia, Zhiqiang and Chen, Zhaokang and Wu, Bin and Li, Chao and Hung, Kwok-Wai and Zhan, Chao and He, Yingjie and Zhou, Wenjiang},
456
+ journal={arxiv},
457
+ year={2024}
458
+ }
459
+ ```
460
+ # 免责声明/许可
461
+ 1. `代码`:`MuseV` 的代码采用 `MIT` 许可证发布,学术用途和商业用途都可以。
462
+ 1. `模型`:训练好的模型仅供非商业研究目的使用。
463
+ 1. `其他开源模型`:使用的其他开源模型必须遵守他们的许可证,如 `insightface`、`IP-Adapter`、`ft-mse-vae` 等。
464
+ 1. 测试数据收集自互联网,仅供非商业研究目的使用。
465
+ 1. `AIGC`:本项目旨在积极影响基于人工智能的视频生成领域。用户被授予使用此工具创建视频的自由,但他们应该遵守当地法律,并负责任地使用。开发人员不对用户可能的不当使用承担任何责任。
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MuseV Demo
3
+ emoji: 🎥
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ app_file: scripts/gradio/app_gradio_space.py
9
+ pinned: false
10
+ ---
11
+
12
+ # MuseV Demo
13
+
14
+ This is a Hugging Face Space for MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising.
15
+
16
+ ## Features
17
+
18
+ - Text-to-Video generation
19
+ - Visual condition support
20
+ - High-quality video generation
21
+
22
+ For more details, visit the [GitHub repository](https://github.com/TMElyralab/MuseV).
23
+
24
+ ## Usage
25
+
26
+ 1. Enter your prompt describing the video you want to generate
27
+ 2. Upload a reference image
28
+ 3. Adjust parameters like seed, FPS, dimensions, etc.
29
+ 4. Click generate and wait for the results
30
+
31
+ ## Model Details
32
+
33
+ The model will be automatically downloaded when you first run the demo.
34
+
35
+ ## Credits
36
+
37
+ Created by Lyra Lab, Tencent Music Entertainment
configs/model/T2I_all_model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ T2IDir = os.path.join(
5
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "t2i"
6
+ )
7
+
8
+ MODEL_CFG = {
9
+ "majicmixRealv6Fp16": {
10
+ "sd": os.path.join(T2IDir, "sd1.5/majicmixRealv6Fp16"),
11
+ },
12
+ "fantasticmix_v10": {
13
+ "sd": os.path.join(T2IDir, "sd1.5/fantasticmix_v10"),
14
+ },
15
+ }
configs/model/ip_adapter.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ IPAdapterModelDir = os.path.join(
4
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "IP-Adapter"
5
+ )
6
+
7
+
8
+ MotionDir = os.path.join(
9
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
10
+ )
11
+
12
+
13
+ MODEL_CFG = {
14
+ "IPAdapter": {
15
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "models/image_encoder"),
16
+ "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter_sd15.bin"),
17
+ "ip_scale": 1.0,
18
+ "clip_extra_context_tokens": 4,
19
+ "clip_embeddings_dim": 1024,
20
+ "desp": "",
21
+ },
22
+ "IPAdapterPlus": {
23
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
24
+ "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus_sd15.bin"),
25
+ "ip_scale": 1.0,
26
+ "clip_extra_context_tokens": 16,
27
+ "clip_embeddings_dim": 1024,
28
+ "desp": "",
29
+ },
30
+ "IPAdapterPlus-face": {
31
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
32
+ "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus-face_sd15.bin"),
33
+ "ip_scale": 1.0,
34
+ "clip_extra_context_tokens": 16,
35
+ "clip_embeddings_dim": 1024,
36
+ "desp": "",
37
+ },
38
+ "IPAdapterFaceID": {
39
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
40
+ "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-faceid_sd15.bin"),
41
+ "ip_scale": 1.0,
42
+ "clip_extra_context_tokens": 4,
43
+ "clip_embeddings_dim": 512,
44
+ "desp": "",
45
+ },
46
+ "musev_referencenet": {
47
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
48
+ "ip_ckpt": os.path.join(
49
+ MotionDir, "musev_referencenet/ip_adapter_image_proj.bin"
50
+ ),
51
+ "ip_scale": 1.0,
52
+ "clip_extra_context_tokens": 4,
53
+ "clip_embeddings_dim": 1024,
54
+ "desp": "",
55
+ },
56
+ "musev_referencenet_pose": {
57
+ "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"),
58
+ "ip_ckpt": os.path.join(
59
+ MotionDir, "musev_referencenet_pose/ip_adapter_image_proj.bin"
60
+ ),
61
+ "ip_scale": 1.0,
62
+ "clip_extra_context_tokens": 4,
63
+ "clip_embeddings_dim": 1024,
64
+ "desp": "",
65
+ },
66
+ }
configs/model/lcm_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ LCMDir = os.path.join(
5
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "lcm"
6
+ )
7
+
8
+
9
+ MODEL_CFG = {
10
+ "lcm": {
11
+ os.path.join(LCMDir, "lcm-lora-sdv1-5/pytorch_lora_weights.safetensors"): {
12
+ "strength": 1.0,
13
+ "lora_block_weight": "ALL",
14
+ "strength_offset": 0,
15
+ },
16
+ },
17
+ }
configs/model/motion_model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ MotionDIr = os.path.join(
5
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
6
+ )
7
+
8
+
9
+ MODEL_CFG = {
10
+ "musev": {
11
+ "unet": os.path.join(MotionDIr, "musev"),
12
+ "desp": "only train unet motion module, fix t2i",
13
+ },
14
+ "musev_referencenet": {
15
+ "unet": os.path.join(MotionDIr, "musev_referencenet"),
16
+ "desp": "train referencenet, IPAdapter and unet motion module, fix t2i",
17
+ },
18
+ "musev_referencenet_pose": {
19
+ "unet": os.path.join(MotionDIr, "musev_referencenet_pose"),
20
+ "desp": "train unet motion module and IPAdapter, fix t2i and referencenet",
21
+ },
22
+ }
configs/model/negative_prompt.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Negative_Prompt_CFG = {
2
+ "Empty": {
3
+ "base_model": "",
4
+ "prompt": "",
5
+ "refer": "",
6
+ },
7
+ "V1": {
8
+ "base_model": "",
9
+ "prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, tail, watermarks",
10
+ "refer": "",
11
+ },
12
+ "V2": {
13
+ "base_model": "",
14
+ "prompt": "badhandv4, ng_deepnegative_v1_75t, (((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)",
15
+ "refer": "Weiban",
16
+ },
17
+ "V3": {
18
+ "base_model": "",
19
+ "prompt": "badhandv4, ng_deepnegative_v1_75t, bad quality",
20
+ "refer": "",
21
+ },
22
+ "V4": {
23
+ "base_model": "",
24
+ "prompt": "badhandv4,ng_deepnegative_v1_75t,EasyNegativeV2,bad_prompt_version2-neg,bad quality",
25
+ "refer": "",
26
+ },
27
+ "V5": {
28
+ "base_model": "",
29
+ "prompt": "(((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)",
30
+ "refer": "Weiban",
31
+ },
32
+ }
configs/model/referencenet.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ MotionDIr = os.path.join(
5
+ os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion"
6
+ )
7
+
8
+
9
+ MODEL_CFG = {
10
+ "musev_referencenet": {
11
+ "net": os.path.join(MotionDIr, "musev_referencenet"),
12
+ "desp": "",
13
+ },
14
+ }
configs/tasks/example.yaml ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - name: task_name
2
+ # condition_images: vision condition images path
3
+ # video_path: str, default null, used for video2video
4
+ # prompt: text to guide image generation
5
+ # ipadapter_image: image_path for IP-Apdater
6
+ # refer_image: image_path for referencenet, generally speaking, same as ipadapter_image
7
+ # height: int # The shorter the image size, the larger the motion amplitude, and the lower video quality.
8
+ # width: int # The longer the W&H, the smaller the motion amplitude, and the higher video quality.
9
+ # img_length_ratio: float, generation video size is (height, width) * img_length_ratio
10
+
11
+ # text/image2video
12
+ - condition_images: ./data/images/yongen.jpeg
13
+ eye_blinks_factor: 1.8
14
+ height: 1308
15
+ img_length_ratio: 0.957
16
+ ipadapter_image: ${.condition_images}
17
+ name: yongen
18
+ prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)
19
+ refer_image: ${.condition_images}
20
+ video_path: null
21
+ width: 736
22
+ - condition_images: ./data/images/jinkesi2.jpeg
23
+ eye_blinks_factor: 1.8
24
+ height: 714
25
+ img_length_ratio: 1.25
26
+ ipadapter_image: ${.condition_images}
27
+ name: jinkesi2
28
+ prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
29
+ soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
30
+ refer_image: ${.condition_images}
31
+ video_path: null
32
+ width: 563
33
+ - condition_images: ./data/images/seaside4.jpeg
34
+ eye_blinks_factor: 1.8
35
+ height: 317
36
+ img_length_ratio: 2.221
37
+ ipadapter_image: ${.condition_images}
38
+ name: seaside4
39
+ prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
40
+ refer_image: ${.condition_images}
41
+ video_path: null
42
+ width: 564
43
+ - condition_images: ./data/images/seaside_girl.jpeg
44
+ eye_blinks_factor: 1.8
45
+ height: 736
46
+ img_length_ratio: 0.957
47
+ ipadapter_image: ${.condition_images}
48
+ name: seaside_girl
49
+ prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
50
+ refer_image: ${.condition_images}
51
+ video_path: null
52
+ width: 736
53
+ - condition_images: ./data/images/boy_play_guitar.jpeg
54
+ eye_blinks_factor: 1.8
55
+ height: 846
56
+ img_length_ratio: 1.248
57
+ ipadapter_image: ${.condition_images}
58
+ name: boy_play_guitar
59
+ prompt: (masterpiece, best quality, highres:1), playing guitar
60
+ refer_image: ${.condition_images}
61
+ video_path: null
62
+ width: 564
63
+ - condition_images: ./data/images/girl_play_guitar2.jpeg
64
+ eye_blinks_factor: 1.8
65
+ height: 1002
66
+ img_length_ratio: 1.248
67
+ ipadapter_image: ${.condition_images}
68
+ name: girl_play_guitar2
69
+ prompt: (masterpiece, best quality, highres:1), playing guitar
70
+ refer_image: ${.condition_images}
71
+ video_path: null
72
+ width: 564
73
+ - condition_images: ./data/images/boy_play_guitar2.jpeg
74
+ eye_blinks_factor: 1.8
75
+ height: 630
76
+ img_length_ratio: 1.676
77
+ ipadapter_image: ${.condition_images}
78
+ name: boy_play_guitar2
79
+ prompt: (masterpiece, best quality, highres:1), playing guitar
80
+ refer_image: ${.condition_images}
81
+ video_path: null
82
+ width: 420
83
+ - condition_images: ./data/images/girl_play_guitar4.jpeg
84
+ eye_blinks_factor: 1.8
85
+ height: 846
86
+ img_length_ratio: 1.248
87
+ ipadapter_image: ${.condition_images}
88
+ name: girl_play_guitar4
89
+ prompt: (masterpiece, best quality, highres:1), playing guitar
90
+ refer_image: ${.condition_images}
91
+ video_path: null
92
+ width: 564
93
+ - condition_images: ./data/images/dufu.jpeg
94
+ eye_blinks_factor: 1.8
95
+ height: 500
96
+ img_length_ratio: 1.495
97
+ ipadapter_image: ${.condition_images}
98
+ name: dufu
99
+ prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style
100
+ refer_image: ${.condition_images}
101
+ video_path: null
102
+ width: 471
103
+ - condition_images: ./data/images/Mona_Lisa..jpg
104
+ eye_blinks_factor: 1.8
105
+ height: 894
106
+ img_length_ratio: 1.173
107
+ ipadapter_image: ${.condition_images}
108
+ name: Mona_Lisa.
109
+ prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face,
110
+ soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3)
111
+ refer_image: ${.condition_images}
112
+ video_path: null
113
+ width: 600
114
+ - condition_images: ./data/images/Portrait-of-Dr.-Gachet.jpg
115
+ eye_blinks_factor: 1.8
116
+ height: 985
117
+ img_length_ratio: 0.88
118
+ ipadapter_image: ${.condition_images}
119
+ name: Portrait-of-Dr.-Gachet
120
+ prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)
121
+ refer_image: ${.condition_images}
122
+ video_path: null
123
+ width: 800
124
+ - condition_images: ./data/images/Self-Portrait-with-Cropped-Hair.jpg
125
+ eye_blinks_factor: 1.8
126
+ height: 565
127
+ img_length_ratio: 1.246
128
+ ipadapter_image: ${.condition_images}
129
+ name: Self-Portrait-with-Cropped-Hair
130
+ prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3), animate
131
+ refer_image: ${.condition_images}
132
+ video_path: null
133
+ width: 848
134
+ - condition_images: ./data/images/The-Laughing-Cavalier.jpg
135
+ eye_blinks_factor: 1.8
136
+ height: 1462
137
+ img_length_ratio: 0.587
138
+ ipadapter_image: ${.condition_images}
139
+ name: The-Laughing-Cavalier
140
+ prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)
141
+ refer_image: ${.condition_images}
142
+ video_path: null
143
+ width: 1200
144
+
145
+ # scene
146
+ - condition_images: ./data/images/waterfall4.jpeg
147
+ eye_blinks_factor: 1.8
148
+ height: 846
149
+ img_length_ratio: 1.248
150
+ ipadapter_image: ${.condition_images}
151
+ name: waterfall4
152
+ prompt: (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an
153
+ endless waterfall
154
+ refer_image: ${.condition_images}
155
+ video_path: null
156
+ width: 564
157
+ - condition_images: ./data/images/river.jpeg
158
+ eye_blinks_factor: 1.8
159
+ height: 736
160
+ img_length_ratio: 0.957
161
+ ipadapter_image: ${.condition_images}
162
+ name: river
163
+ prompt: (masterpiece, best quality, highres:1), peaceful beautiful river
164
+ refer_image: ${.condition_images}
165
+ video_path: null
166
+ width: 736
167
+ - condition_images: ./data/images/seaside2.jpeg
168
+ eye_blinks_factor: 1.8
169
+ height: 1313
170
+ img_length_ratio: 0.957
171
+ ipadapter_image: ${.condition_images}
172
+ name: seaside2
173
+ prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene
174
+ refer_image: ${.condition_images}
175
+ video_path: null
176
+ width: 736
177
+
178
+ # video2video
179
+ - name: "dance1"
180
+ prompt: "(masterpiece, best quality, highres:1) , a girl is dancing, wearing a dress made of stars, animation"
181
+ video_path: ./data/source_video/video1_girl_poseseq.mp4
182
+ condition_images: ./data/images/spark_girl.png
183
+ refer_image: ${.condition_images}
184
+ ipadapter_image: ${.condition_images}
185
+ height: 960
186
+ width: 512
187
+ img_length_ratio: 1.0
188
+ video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
189
+
190
+ - name: "dance2"
191
+ prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper"
192
+ video_path: ./data/source_video/video1_girl_poseseq.mp4
193
+ condition_images: ./data/images/cyber_girl.png
194
+ refer_image: ${.condition_images}
195
+ ipadapter_image: ${.condition_images}
196
+ height: 960
197
+ width: 512
198
+ img_length_ratio: 1.0
199
+ video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
200
+
201
+ - name: "duffy"
202
+ prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper"
203
+ video_path: ./data/source_video/pose-for-Duffy-4.mp4
204
+ condition_images: ./data/images/duffy.png
205
+ refer_image: ${.condition_images}
206
+ ipadapter_image: ${.condition_images}
207
+ height: 1280
208
+ width: 704
209
+ img_length_ratio: 1.0
210
+ video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video
controlnet_aux ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 54c6c49baf68bff290679f5bb896715f25932133
data/models/musev_structure.png ADDED
data/models/parallel_denoise.png ADDED
diffusers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit abf2f8bf698a895cecc30a73c6ff4abb92fdce1c
environment.yml ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: musev
2
+ channels:
3
+ - https://repo.anaconda.com/pkgs/main
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - bzip2=1.0.8=h7b6447c_0
9
+ - ca-certificates=2023.12.12=h06a4308_0
10
+ - ld_impl_linux-64=2.38=h1181459_1
11
+ - libffi=3.3=he6710b0_2
12
+ - libgcc-ng=11.2.0=h1234567_1
13
+ - libgomp=11.2.0=h1234567_1
14
+ - libstdcxx-ng=11.2.0=h1234567_1
15
+ - libuuid=1.41.5=h5eee18b_0
16
+ - ncurses=6.4=h6a678d5_0
17
+ - openssl=1.1.1w=h7f8727e_0
18
+ - python=3.10.6=haa1d7c7_1
19
+ - readline=8.2=h5eee18b_0
20
+ - sqlite=3.41.2=h5eee18b_0
21
+ - tk=8.6.12=h1ccaba5_0
22
+ - xz=5.4.5=h5eee18b_0
23
+ - zlib=1.2.13=h5eee18b_0
24
+ - pip:
25
+ - absl-py==2.1.0
26
+ - accelerate==0.22.0
27
+ - addict==2.4.0
28
+ - aiofiles==23.2.1
29
+ - aiohttp==3.9.1
30
+ - aiosignal==1.3.1
31
+ - albumentations==1.3.1
32
+ - aliyun-python-sdk-core==2.14.0
33
+ - aliyun-python-sdk-kms==2.16.2
34
+ - altair==5.2.0
35
+ - antlr4-python3-runtime==4.9.3
36
+ - anyio==4.2.0
37
+ - appdirs==1.4.4
38
+ - argparse==1.4.0
39
+ - asttokens==2.4.1
40
+ - astunparse==1.6.3
41
+ - async-timeout==4.0.3
42
+ - attrs==23.2.0
43
+ - audioread==3.0.1
44
+ - basicsr==1.4.2
45
+ - beautifulsoup4==4.12.2
46
+ - bitsandbytes==0.41.1
47
+ - black==23.12.1
48
+ - blinker==1.7.0
49
+ - braceexpand==0.1.7
50
+ - cachetools==5.3.2
51
+ - certifi==2023.11.17
52
+ - cffi==1.16.0
53
+ - charset-normalizer==3.3.2
54
+ - chumpy==0.70
55
+ - click==8.1.7
56
+ - cmake==3.28.1
57
+ - colorama==0.4.6
58
+ - coloredlogs==15.0.1
59
+ - comm==0.2.1
60
+ - contourpy==1.2.0
61
+ - cos-python-sdk-v5==1.9.22
62
+ - coscmd==1.8.6.30
63
+ - crcmod==1.7
64
+ - cryptography==41.0.7
65
+ - cycler==0.12.1
66
+ - cython==3.0.2
67
+ - datetime==5.4
68
+ - debugpy==1.8.0
69
+ - decorator==4.4.2
70
+ - decord==0.6.0
71
+ - dill==0.3.7
72
+ - docker-pycreds==0.4.0
73
+ - dulwich==0.21.7
74
+ - easydict==1.11
75
+ - einops==0.7.0
76
+ - exceptiongroup==1.2.0
77
+ - executing==2.0.1
78
+ - fastapi==0.109.0
79
+ - ffmpeg==1.4
80
+ - ffmpeg-python==0.2.0
81
+ - ffmpy==0.3.1
82
+ - filelock==3.13.1
83
+ - flatbuffers==23.5.26
84
+ - fonttools==4.47.2
85
+ - frozenlist==1.4.1
86
+ - fsspec==2023.12.2
87
+ - ftfy==6.1.1
88
+ - future==0.18.3
89
+ - fuzzywuzzy==0.18.0
90
+ - fvcore==0.1.5.post20221221
91
+ - gast==0.4.0
92
+ - gdown==4.5.3
93
+ - gitdb==4.0.11
94
+ - gitpython==3.1.41
95
+ - google-auth==2.26.2
96
+ - google-auth-oauthlib==0.4.6
97
+ - google-pasta==0.2.0
98
+ - gradio==3.43.2
99
+ - gradio-client==0.5.0
100
+ - grpcio==1.60.0
101
+ - h11==0.14.0
102
+ - h5py==3.10.0
103
+ - httpcore==1.0.2
104
+ - httpx==0.26.0
105
+ - huggingface-hub==0.20.2
106
+ - humanfriendly==10.0
107
+ - idna==3.6
108
+ - imageio==2.31.1
109
+ - imageio-ffmpeg==0.4.8
110
+ - importlib-metadata==7.0.1
111
+ - importlib-resources==6.1.1
112
+ - iniconfig==2.0.0
113
+ - insightface==0.7.3
114
+ - invisible-watermark==0.1.5
115
+ - iopath==0.1.10
116
+ - ip-adapter==0.1.0
117
+ - iprogress==0.4
118
+ - ipykernel==6.29.0
119
+ - ipython==8.20.0
120
+ - ipywidgets==8.0.3
121
+ - jax==0.4.23
122
+ - jedi==0.19.1
123
+ - jinja2==3.1.3
124
+ - jmespath==0.10.0
125
+ - joblib==1.3.2
126
+ - json-tricks==3.17.3
127
+ - jsonschema==4.21.0
128
+ - jsonschema-specifications==2023.12.1
129
+ - jupyter-client==8.6.0
130
+ - jupyter-core==5.7.1
131
+ - jupyterlab-widgets==3.0.9
132
+ - keras==2.12.0
133
+ - kiwisolver==1.4.5
134
+ - kornia==0.7.0
135
+ - lazy-loader==0.3
136
+ - libclang==16.0.6
137
+ - librosa==0.10.1
138
+ - lightning-utilities==0.10.0
139
+ - lit==17.0.6
140
+ - llvmlite==0.41.1
141
+ - lmdb==1.4.1
142
+ - loguru==0.6.0
143
+ - markdown==3.5.2
144
+ - markdown-it-py==3.0.0
145
+ - markupsafe==2.0.1
146
+ - matplotlib==3.6.2
147
+ - matplotlib-inline==0.1.6
148
+ - mdurl==0.1.2
149
+ - mediapipe==0.10.3
150
+ - ml-dtypes==0.3.2
151
+ - model-index==0.1.11
152
+ - modelcards==0.1.6
153
+ - moviepy==1.0.3
154
+ - mpmath==1.3.0
155
+ - msgpack==1.0.7
156
+ - multidict==6.0.4
157
+ - munkres==1.1.4
158
+ - mypy-extensions==1.0.0
159
+ - nest-asyncio==1.5.9
160
+ - networkx==3.2.1
161
+ - ninja==1.11.1
162
+ - numba==0.58.1
163
+ - numpy==1.23.5
164
+ - oauthlib==3.2.2
165
+ - omegaconf==2.3.0
166
+ - onnx==1.14.1
167
+ - onnxruntime==1.15.1
168
+ - onnxsim==0.4.33
169
+ - open-clip-torch==2.20.0
170
+ - opencv-contrib-python==4.8.0.76
171
+ - opencv-python==4.9.0.80
172
+ - opencv-python-headless==4.9.0.80
173
+ - opendatalab==0.0.10
174
+ - openmim==0.3.9
175
+ - openxlab==0.0.34
176
+ - opt-einsum==3.3.0
177
+ - ordered-set==4.1.0
178
+ - orjson==3.9.10
179
+ - oss2==2.17.0
180
+ - packaging==23.2
181
+ - pandas==2.1.4
182
+ - parso==0.8.3
183
+ - pathspec==0.12.1
184
+ - pathtools==0.1.2
185
+ - pexpect==4.9.0
186
+ - pillow==10.2.0
187
+ - pip==23.3.1
188
+ - platformdirs==4.1.0
189
+ - pluggy==1.3.0
190
+ - pooch==1.8.0
191
+ - portalocker==2.8.2
192
+ - prettytable==3.9.0
193
+ - proglog==0.1.10
194
+ - prompt-toolkit==3.0.43
195
+ - protobuf==3.20.3
196
+ - psutil==5.9.7
197
+ - ptyprocess==0.7.0
198
+ - pure-eval==0.2.2
199
+ - pyarrow==14.0.2
200
+ - pyasn1==0.5.1
201
+ - pyasn1-modules==0.3.0
202
+ - pycocotools==2.0.7
203
+ - pycparser==2.21
204
+ - pycryptodome==3.20.0
205
+ - pydantic==1.10.2
206
+ - pydeck==0.8.1b0
207
+ - pydub==0.25.1
208
+ - pygments==2.17.2
209
+ - pynvml==11.5.0
210
+ - pyparsing==3.1.1
211
+ - pysocks==1.7.1
212
+ - pytest==7.4.4
213
+ - python-dateutil==2.8.2
214
+ - python-dotenv==1.0.0
215
+ - python-multipart==0.0.6
216
+ - pytorch-lightning==2.0.8
217
+ - pytube==15.0.0
218
+ - pytz==2023.3.post1
219
+ - pywavelets==1.5.0
220
+ - pyyaml==6.0.1
221
+ - pyzmq==25.1.2
222
+ - qudida==0.0.4
223
+ - redis==4.5.1
224
+ - referencing==0.32.1
225
+ - regex==2023.12.25
226
+ - requests==2.28.2
227
+ - requests-oauthlib==1.3.1
228
+ - rich==13.4.2
229
+ - rpds-py==0.17.1
230
+ - rsa==4.9
231
+ - safetensors==0.3.3
232
+ - scikit-image==0.22.0
233
+ - scikit-learn==1.3.2
234
+ - scipy==1.11.4
235
+ - semantic-version==2.10.0
236
+ - sentencepiece==0.1.99
237
+ - sentry-sdk==1.39.2
238
+ - setproctitle==1.3.3
239
+ - setuptools==60.2.0
240
+ - shapely==2.0.2
241
+ - six==1.16.0
242
+ - smmap==5.0.1
243
+ - sniffio==1.3.0
244
+ - sounddevice==0.4.6
245
+ - soundfile==0.12.1
246
+ - soupsieve==2.5
247
+ - soxr==0.3.7
248
+ - stack-data==0.6.3
249
+ - starlette==0.35.1
250
+ - streamlit==1.30.0
251
+ - streamlit-drawable-canvas==0.9.3
252
+ - sympy==1.12
253
+ - tabulate==0.9.0
254
+ - tb-nightly==2.11.0a20220906
255
+ - tenacity==8.2.3
256
+ - tensorboard==2.12.0
257
+ - tensorboard-data-server==0.6.1
258
+ - tensorboard-plugin-wit==1.8.1
259
+ - tensorflow==2.12.0
260
+ - tensorflow-estimator==2.12.0
261
+ - tensorflow-io-gcs-filesystem==0.35.0
262
+ - termcolor==2.4.0
263
+ - terminaltables==3.1.10
264
+ - test-tube==0.7.5
265
+ - threadpoolctl==3.2.0
266
+ - tifffile==2023.12.9
267
+ - timm==0.9.12
268
+ - tokenizers==0.13.3
269
+ - toml==0.10.2
270
+ - tomli==2.0.1
271
+ - toolz==0.12.0
272
+ - torch==2.0.1+cu118
273
+ - torch-tb-profiler==0.4.1
274
+ - torchmetrics==1.1.1
275
+ - torchvision==0.15.2+cu118
276
+ - tornado==6.4
277
+ - tqdm==4.65.2
278
+ - traitlets==5.14.1
279
+ - transformers==4.33.1
280
+ - triton==2.0.0
281
+ - typing-extensions==4.9.0
282
+ - tzdata==2023.4
283
+ - tzlocal==5.2
284
+ - urllib3==1.26.18
285
+ - urwid==2.4.2
286
+ - uvicorn==0.26.0
287
+ - validators==0.22.0
288
+ - wandb==0.15.10
289
+ - watchdog==3.0.0
290
+ - wcwidth==0.2.13
291
+ - webdataset==0.2.86
292
+ - webp==0.3.0
293
+ - websockets==11.0.3
294
+ - werkzeug==3.0.1
295
+ - wget==3.2
296
+ - wheel==0.41.2
297
+ - widgetsnbextension==4.0.9
298
+ - wrapt==1.14.1
299
+ - xformers==0.0.21
300
+ - xmltodict==0.13.0
301
+ - xtcocotools==1.14.3
302
+ - yacs==0.1.8
303
+ - yapf==0.40.2
304
+ - yarl==1.9.4
305
+ - zipp==3.17.0
306
+ - zope-interface==6.1
307
+ - fire==0.6.0
308
+ - cuid
309
+ - git+https://github.com/tencent-ailab/IP-Adapter.git@main
310
+ - git+https://github.com/openai/CLIP.git@main
311
+ prefix: /data/miniconda3/envs/musev
312
+
musev/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import logging.config
4
+
5
+ # 读取日志配置文件内容
6
+ logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf"))
7
+
8
+ # 创建一个日志器logger
9
+ logger = logging.getLogger("musev")
musev/auto_prompt/__init__.py ADDED
File without changes
musev/auto_prompt/attributes/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ...utils.register import Register
2
+
3
+ AttrRegister = Register(registry_name="attributes")
4
+
5
+ # must import like bellow to ensure that each class is registered with AttrRegister:
6
+ from .human import *
7
+ from .render import *
8
+ from .style import *
musev/auto_prompt/attributes/attr2template.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ 中文
3
+ 该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。
4
+ 提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。
5
+ 该模块主要有三种类,分别是:
6
+ 1. `BaseAttribute2Text`: 单属性文本转换类
7
+ 2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。
8
+ 3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。
9
+ 1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`;
10
+ 2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来;
11
+ 3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序;
12
+
13
+ English
14
+ This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency.
15
+
16
+ Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming.
17
+
18
+ This module mainly consists of three classes:
19
+
20
+ BaseAttribute2Text: A class for converting single attribute text.
21
+ MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate.
22
+ MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input.
23
+ If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network.
24
+ If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table.
25
+ If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate.
26
+ """
27
+
28
+ from typing import List, Tuple, Union
29
+
30
+ from mmcm.utils.str_util import (
31
+ has_key_brace,
32
+ merge_near_same_char,
33
+ get_word_from_key_brace_string,
34
+ )
35
+
36
+ from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText
37
+ from . import AttrRegister
38
+
39
+
40
+ class MultiAttr2PromptTemplate(object):
41
+ """
42
+ 将多属性转化为模型输入文本的实际类
43
+ The actual class that converts multiple attributes into model input text is
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ template: str,
49
+ attr2text: MultiAttr2Text,
50
+ name: str,
51
+ ) -> None:
52
+ """
53
+ Args:
54
+ template (str): 提词模板, prompt template.
55
+ 如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key
56
+ 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list.
57
+ attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt.
58
+ name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name
59
+ """
60
+ self.attr2text = attr2text
61
+ self.name = name
62
+ if template == "":
63
+ template = "{}"
64
+ self.template = template
65
+ self.template_has_key_brace = has_key_brace(template)
66
+
67
+ def __call__(self, attributes: dict) -> Union[str, List[str]]:
68
+ texts = self.attr2text(attributes)
69
+ if not isinstance(texts, list):
70
+ texts = [texts]
71
+ prompts = [merge_multi_attrtext(text, self.template) for text in texts]
72
+ prompts = [merge_near_same_char(prompt) for prompt in prompts]
73
+ if len(prompts) == 1:
74
+ prompts = prompts[0]
75
+ return prompts
76
+
77
+
78
+ class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate):
79
+ def __init__(self, template: str, name: str = "keywords") -> None:
80
+ """关键词模板属性2文本转化类
81
+ 1. 获取关键词模板字符串中的关键词属性;
82
+ 2. 从import * 存储在locals()中变量中获取对应的类;
83
+ 3. 将集成了多属性转换类的`MultiAttr2Text`
84
+ Args:
85
+ template (str): 含有{key}的模板字符串
86
+ name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords".
87
+
88
+ class for converting keyword template attributes to text
89
+ 1. Get the keyword attributes in the keyword template string;
90
+ 2. Get the corresponding class from the variables stored in locals() by import *;
91
+ 3. The `MultiAttr2Text` integrated with multiple attribute conversion classes
92
+ Args:
93
+ template (str): template string containing {key}
94
+ name (str, optional): the name of the template string, no actual use. Defaults to "keywords".
95
+ """
96
+ assert has_key_brace(
97
+ template
98
+ ), "template should have key brace, but given {}".format(template)
99
+ keywords = get_word_from_key_brace_string(template)
100
+ funcs = []
101
+ for word in keywords:
102
+ if word in AttrRegister:
103
+ func = AttrRegister[word](name=word)
104
+ else:
105
+ func = AttriributeIsText(name=word)
106
+ funcs.append(func)
107
+ attr2text = MultiAttr2Text(funcs, name=name)
108
+ super().__init__(template, attr2text, name)
109
+
110
+
111
+ class OnlySpacePromptTemplate(MultiAttr2PromptTemplate):
112
+ def __init__(self, template: str, name: str = "space_prompt") -> None:
113
+ """纯空模板,无论输入啥,都只返回空格字符串作为prompt。
114
+ Args:
115
+ template (str): 符合只输出空格字符串的模板,
116
+ name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt".
117
+
118
+ Pure empty template, no matter what the input is, it will only return a space string as the prompt.
119
+ Args:
120
+ template (str): template that only outputs a space string,
121
+ name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt".
122
+ """
123
+ attr2text = None
124
+ super().__init__(template, attr2text, name)
125
+
126
+ def __call__(self, attributes: dict) -> Union[str, List[str]]:
127
+ return ""
musev/auto_prompt/attributes/attributes.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List, Tuple, Dict
3
+
4
+ from mmcm.utils.str_util import has_key_brace
5
+
6
+
7
+ class BaseAttribute2Text(object):
8
+ """
9
+ 属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。
10
+ Base class for converting attributes to text which converts attributes to prompt text.
11
+ """
12
+
13
+ name = "base_attribute"
14
+
15
+ def __init__(self, name: str = None) -> None:
16
+ """这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。
17
+ Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters.
18
+
19
+ Args:
20
+ name (str, optional): _description_. Defaults to None.
21
+ """
22
+ if name is not None:
23
+ self.name = name
24
+
25
+ def __call__(self, attributes) -> str:
26
+ raise NotImplementedError
27
+
28
+
29
+ class AttributeIsTextAndName(BaseAttribute2Text):
30
+ """
31
+ 属性文本转换功能类,将key和value拼接在一起作为文本.
32
+ class for converting attributes to text which concatenates the key and value together as text.
33
+ """
34
+
35
+ name = "attribute_is_text_name"
36
+
37
+ def __call__(self, attributes) -> str:
38
+ if attributes == "" or attributes is None:
39
+ return ""
40
+ attributes = attributes.split(",")
41
+ text = ", ".join(
42
+ [
43
+ "{} {}".format(attr, self.name) if attr != "" else ""
44
+ for attr in attributes
45
+ ]
46
+ )
47
+ return text
48
+
49
+
50
+ class AttriributeIsText(BaseAttribute2Text):
51
+ """
52
+ 属性文本转换功能类,将value作为文本.
53
+ class for converting attributes to text which only uses the value as text.
54
+ """
55
+
56
+ name = "attribute_is_text"
57
+
58
+ def __call__(self, attributes: str) -> str:
59
+ if attributes == "" or attributes is None:
60
+ return ""
61
+ attributes = str(attributes)
62
+ attributes = attributes.split(",")
63
+ text = ", ".join(["{}".format(attr) for attr in attributes])
64
+ return text
65
+
66
+
67
+ class MultiAttr2Text(object):
68
+ """将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号
69
+ class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol.
70
+
71
+ Args:
72
+ object (_type_): _description_
73
+ """
74
+
75
+ def __init__(self, funcs: list, name) -> None:
76
+ """
77
+ Args:
78
+ funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class.
79
+ name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about.
80
+ """
81
+ if not isinstance(funcs, list):
82
+ funcs = [funcs]
83
+ self.funcs = funcs
84
+ self.name = name
85
+
86
+ def __call__(
87
+ self, dct: dict, ignored_blank_str: bool = False
88
+ ) -> List[Tuple[str, str]]:
89
+ """
90
+ 有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。
91
+ sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product.
92
+ Args:
93
+ dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本.
94
+ Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text.
95
+ ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`.
96
+ If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`.
97
+ Returns:
98
+ Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries.
99
+ """
100
+ attrs_lst = [[]]
101
+ for func in self.funcs:
102
+ if func.name in dct:
103
+ attrs = func(dct[func.name])
104
+ if isinstance(attrs, str):
105
+ for i in range(len(attrs_lst)):
106
+ attrs_lst[i].append((func.name, attrs))
107
+ else:
108
+ # 一个属性可能会返回多个文本
109
+ n_attrs = len(attrs)
110
+ new_attrs_lst = []
111
+ for n in range(n_attrs):
112
+ attrs_lst_cp = deepcopy(attrs_lst)
113
+ for i in range(len(attrs_lst_cp)):
114
+ attrs_lst_cp[i].append((func.name, attrs[n]))
115
+ new_attrs_lst.extend(attrs_lst_cp)
116
+ attrs_lst = new_attrs_lst
117
+
118
+ texts = [
119
+ [
120
+ (attr, text)
121
+ for (attr, text) in attrs
122
+ if not (text == "" and ignored_blank_str)
123
+ ]
124
+ for attrs in attrs_lst
125
+ ]
126
+ return texts
127
+
128
+
129
+ def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str:
130
+ """使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本
131
+ concatenate multiple attribute text tuples using a template containing "{}" to form a new text
132
+ Args:
133
+ template (str):
134
+ texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples
135
+
136
+ Returns:
137
+ str: 拼接后的新文本, merged new text
138
+ """
139
+ merged_text = ", ".join([text[1] for text in texts if text[1] != ""])
140
+ merged_text = template.format(merged_text)
141
+ return merged_text
142
+
143
+
144
+ def format_dct_texts(template: str, texts: Dict[str, str]) -> str:
145
+ """使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本
146
+ concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text
147
+ Args:
148
+ template (str):
149
+ texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries
150
+
151
+ Returns:
152
+ str: 拼接后的新文本, merged new text
153
+ """
154
+ merged_text = template.format(**texts)
155
+ return merged_text
156
+
157
+
158
+ def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str:
159
+ """对多属性文本元组进行拼接,形成新文本。
160
+ 如果`template`含有{key},则根据key来取值;
161
+ 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。
162
+
163
+ concatenate multiple attribute text tuples to form a new text.
164
+ if `template` contains {key}, the value is taken according to the key;
165
+ if `template` contains only one {}, the values in texts are concatenated in order.
166
+ Args:
167
+ texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本.
168
+ Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion.
169
+ template (str, optional): template . Defaults to None.
170
+
171
+ Returns:
172
+ str: 拼接后的新文本, merged new text
173
+ """
174
+ if not isinstance(texts, List):
175
+ texts = [texts]
176
+ if template is None or template == "":
177
+ template = "{}"
178
+ if has_key_brace(template):
179
+ texts = {k: v for k, v in texts}
180
+ merged_text = format_dct_texts(template, texts)
181
+ else:
182
+ merged_text = format_tuple_texts(template, texts)
183
+ return merged_text
184
+
185
+
186
+ class PresetMultiAttr2Text(MultiAttr2Text):
187
+ """预置了多种关注属性转换的类,方便维护
188
+ class for multiple attribute conversion with multiple attention attributes preset for easy maintenance
189
+
190
+ """
191
+
192
+ preset_attributes = []
193
+
194
+ def __init__(
195
+ self, funcs: List = None, use_preset: bool = True, name: str = "preset"
196
+ ) -> None:
197
+ """虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。
198
+ 注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。
199
+
200
+ Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance.
201
+ Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions.
202
+
203
+ Args:
204
+ funcs (List, optional): list of funcs . Defaults to None.
205
+ use_preset (bool, optional): _description_. Defaults to True.
206
+ name (str, optional): _description_. Defaults to "preset".
207
+ """
208
+ if use_preset:
209
+ preset_funcs = self.preset()
210
+ else:
211
+ preset_funcs = []
212
+ if funcs is None:
213
+ funcs = []
214
+ if not isinstance(funcs, list):
215
+ funcs = [funcs]
216
+ funcs_names = [func.name for func in funcs]
217
+ preset_funcs = [
218
+ preset_func
219
+ for preset_func in preset_funcs
220
+ if preset_func.name not in funcs_names
221
+ ]
222
+ funcs = funcs + preset_funcs
223
+ super().__init__(funcs, name)
224
+
225
+ def preset(self):
226
+ funcs = [cls() for cls in self.preset_attributes]
227
+ return funcs
musev/auto_prompt/attributes/human.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ import random
4
+ import json
5
+
6
+ from .attributes import (
7
+ MultiAttr2Text,
8
+ AttriributeIsText,
9
+ AttributeIsTextAndName,
10
+ PresetMultiAttr2Text,
11
+ )
12
+ from .style import Style
13
+ from .render import Render
14
+ from . import AttrRegister
15
+
16
+
17
+ __all__ = [
18
+ "Age",
19
+ "Sex",
20
+ "Singing",
21
+ "Country",
22
+ "Lighting",
23
+ "Headwear",
24
+ "Eyes",
25
+ "Irises",
26
+ "Hair",
27
+ "Skin",
28
+ "Face",
29
+ "Smile",
30
+ "Expression",
31
+ "Clothes",
32
+ "Nose",
33
+ "Mouth",
34
+ "Beard",
35
+ "Necklace",
36
+ "KeyWords",
37
+ "InsightFace",
38
+ "Caption",
39
+ "Env",
40
+ "Decoration",
41
+ "Festival",
42
+ "SpringHeadwear",
43
+ "SpringClothes",
44
+ "Animal",
45
+ ]
46
+
47
+
48
+ @AttrRegister.register
49
+ class Sex(AttriributeIsText):
50
+ name = "sex"
51
+
52
+ def __init__(self, name: str = None) -> None:
53
+ super().__init__(name)
54
+
55
+
56
+ @AttrRegister.register
57
+ class Headwear(AttriributeIsText):
58
+ name = "headwear"
59
+
60
+ def __init__(self, name: str = None) -> None:
61
+ super().__init__(name)
62
+
63
+
64
+ @AttrRegister.register
65
+ class Expression(AttriributeIsText):
66
+ name = "expression"
67
+
68
+ def __init__(self, name: str = None) -> None:
69
+ super().__init__(name)
70
+
71
+
72
+ @AttrRegister.register
73
+ class KeyWords(AttriributeIsText):
74
+ name = "keywords"
75
+
76
+ def __init__(self, name: str = None) -> None:
77
+ super().__init__(name)
78
+
79
+
80
+ @AttrRegister.register
81
+ class Singing(AttriributeIsText):
82
+ def __init__(self, name: str = "singing") -> None:
83
+ super().__init__(name)
84
+
85
+
86
+ @AttrRegister.register
87
+ class Country(AttriributeIsText):
88
+ name = "country"
89
+
90
+ def __init__(self, name: str = None) -> None:
91
+ super().__init__(name)
92
+
93
+
94
+ @AttrRegister.register
95
+ class Clothes(AttriributeIsText):
96
+ name = "clothes"
97
+
98
+ def __init__(self, name: str = None) -> None:
99
+ super().__init__(name)
100
+
101
+
102
+ @AttrRegister.register
103
+ class Age(AttributeIsTextAndName):
104
+ name = "age"
105
+
106
+ def __init__(self, name: str = None) -> None:
107
+ super().__init__(name)
108
+
109
+ def __call__(self, attributes: str) -> str:
110
+ if not isinstance(attributes, str):
111
+ attributes = str(attributes)
112
+ attributes = attributes.split(",")
113
+ text = ", ".join(
114
+ ["{}-year-old".format(attr) if attr != "" else "" for attr in attributes]
115
+ )
116
+ return text
117
+
118
+
119
+ @AttrRegister.register
120
+ class Eyes(AttributeIsTextAndName):
121
+ name = "eyes"
122
+
123
+ def __init__(self, name: str = None) -> None:
124
+ super().__init__(name)
125
+
126
+
127
+ @AttrRegister.register
128
+ class Hair(AttributeIsTextAndName):
129
+ name = "hair"
130
+
131
+ def __init__(self, name: str = None) -> None:
132
+ super().__init__(name)
133
+
134
+
135
+ @AttrRegister.register
136
+ class Background(AttributeIsTextAndName):
137
+ name = "background"
138
+
139
+ def __init__(self, name: str = None) -> None:
140
+ super().__init__(name)
141
+
142
+
143
+ @AttrRegister.register
144
+ class Skin(AttributeIsTextAndName):
145
+ name = "skin"
146
+
147
+ def __init__(self, name: str = None) -> None:
148
+ super().__init__(name)
149
+
150
+
151
+ @AttrRegister.register
152
+ class Face(AttributeIsTextAndName):
153
+ name = "face"
154
+
155
+ def __init__(self, name: str = None) -> None:
156
+ super().__init__(name)
157
+
158
+
159
+ @AttrRegister.register
160
+ class Smile(AttributeIsTextAndName):
161
+ name = "smile"
162
+
163
+ def __init__(self, name: str = None) -> None:
164
+ super().__init__(name)
165
+
166
+
167
+ @AttrRegister.register
168
+ class Nose(AttributeIsTextAndName):
169
+ name = "nose"
170
+
171
+ def __init__(self, name: str = None) -> None:
172
+ super().__init__(name)
173
+
174
+
175
+ @AttrRegister.register
176
+ class Mouth(AttributeIsTextAndName):
177
+ name = "mouth"
178
+
179
+ def __init__(self, name: str = None) -> None:
180
+ super().__init__(name)
181
+
182
+
183
+ @AttrRegister.register
184
+ class Beard(AttriributeIsText):
185
+ name = "beard"
186
+
187
+ def __init__(self, name: str = None) -> None:
188
+ super().__init__(name)
189
+
190
+
191
+ @AttrRegister.register
192
+ class Necklace(AttributeIsTextAndName):
193
+ name = "necklace"
194
+
195
+ def __init__(self, name: str = None) -> None:
196
+ super().__init__(name)
197
+
198
+
199
+ @AttrRegister.register
200
+ class Irises(AttributeIsTextAndName):
201
+ name = "irises"
202
+
203
+ def __init__(self, name: str = None) -> None:
204
+ super().__init__(name)
205
+
206
+
207
+ @AttrRegister.register
208
+ class Lighting(AttributeIsTextAndName):
209
+ name = "lighting"
210
+
211
+ def __init__(self, name: str = None) -> None:
212
+ super().__init__(name)
213
+
214
+
215
+ PresetPortraitAttributes = [
216
+ Age,
217
+ Sex,
218
+ Singing,
219
+ Country,
220
+ Lighting,
221
+ Headwear,
222
+ Eyes,
223
+ Irises,
224
+ Hair,
225
+ Skin,
226
+ Face,
227
+ Smile,
228
+ Expression,
229
+ Clothes,
230
+ Nose,
231
+ Mouth,
232
+ Beard,
233
+ Necklace,
234
+ Style,
235
+ KeyWords,
236
+ Render,
237
+ ]
238
+
239
+
240
+ class PortraitMultiAttr2Text(PresetMultiAttr2Text):
241
+ preset_attributes = PresetPortraitAttributes
242
+
243
+ def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None:
244
+ super().__init__(funcs, use_preset, name)
245
+
246
+
247
+ @AttrRegister.register
248
+ class InsightFace(AttriributeIsText):
249
+ name = "insight_face"
250
+ face_render_dict = {
251
+ "boy": "handsome,elegant",
252
+ "girl": "gorgeous,kawaii,colorful",
253
+ }
254
+ key_words = "delicate face,beautiful eyes"
255
+
256
+ def __call__(self, attributes: str) -> str:
257
+ """将insight faces 检测的结果转化成prompt
258
+ convert the results of insight faces detection to prompt
259
+ Args:
260
+ face_list (_type_): _description_
261
+
262
+ Returns:
263
+ _type_: _description_
264
+ """
265
+ attributes = json.loads(attributes)
266
+ face_list = attributes["info"]
267
+ if len(face_list) == 0:
268
+ return ""
269
+
270
+ if attributes["image_type"] == "body":
271
+ for face in face_list:
272
+ if "black" in face and face["black"]:
273
+ return "african,dark skin"
274
+ return ""
275
+
276
+ gender_dict = {"girl": 0, "boy": 0}
277
+ face_render_list = []
278
+ black = False
279
+
280
+ for face in face_list:
281
+ if face["ratio"] < 0.02:
282
+ continue
283
+
284
+ if face["gender"] == 0:
285
+ gender_dict["girl"] += 1
286
+ face_render_list.append(self.face_render_dict["girl"])
287
+ else:
288
+ gender_dict["boy"] += 1
289
+ face_render_list.append(self.face_render_dict["boy"])
290
+
291
+ if "black" in face and face["black"]:
292
+ black = True
293
+
294
+ if len(face_render_list) == 0:
295
+ return ""
296
+ elif len(face_render_list) == 1:
297
+ solo = True
298
+ else:
299
+ solo = False
300
+
301
+ gender = ""
302
+ for g, num in gender_dict.items():
303
+ if num > 0:
304
+ if gender:
305
+ gender += ", "
306
+ gender += "{}{}".format(num, g)
307
+ if num > 1:
308
+ gender += "s"
309
+
310
+ face_render_list = ",".join(face_render_list)
311
+ face_render_list = face_render_list.split(",")
312
+ face_render = list(set(face_render_list))
313
+ face_render.sort(key=face_render_list.index)
314
+ face_render = ",".join(face_render)
315
+ if gender_dict["girl"] == 0:
316
+ face_render = "male focus," + face_render
317
+
318
+ insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words)
319
+
320
+ if solo:
321
+ insightface_prompt += ",solo"
322
+ if black:
323
+ insightface_prompt = "african,dark skin," + insightface_prompt
324
+
325
+ return insightface_prompt
326
+
327
+
328
+ @AttrRegister.register
329
+ class Caption(AttriributeIsText):
330
+ name = "caption"
331
+
332
+
333
+ @AttrRegister.register
334
+ class Env(AttriributeIsText):
335
+ name = "env"
336
+ envs_list = [
337
+ "east asian architecture",
338
+ "fireworks",
339
+ "snow, snowflakes",
340
+ "snowing, snowflakes",
341
+ ]
342
+
343
+ def __call__(self, attributes: str = None) -> str:
344
+ if attributes != "" and attributes != " " and attributes is not None:
345
+ return attributes
346
+ else:
347
+ return random.choice(self.envs_list)
348
+
349
+
350
+ @AttrRegister.register
351
+ class Decoration(AttriributeIsText):
352
+ name = "decoration"
353
+
354
+ def __init__(self, name: str = None) -> None:
355
+ self.decoration_list = [
356
+ "chinese knot",
357
+ "flowers",
358
+ "food",
359
+ "lanterns",
360
+ "red envelop",
361
+ ]
362
+ super().__init__(name)
363
+
364
+ def __call__(self, attributes: str = None) -> str:
365
+ if attributes != "" and attributes != " " and attributes is not None:
366
+ return attributes
367
+ else:
368
+ return random.choice(self.decoration_list)
369
+
370
+
371
+ @AttrRegister.register
372
+ class Festival(AttriributeIsText):
373
+ name = "festival"
374
+ festival_list = ["new year"]
375
+
376
+ def __init__(self, name: str = None) -> None:
377
+ super().__init__(name)
378
+
379
+ def __call__(self, attributes: str = None) -> str:
380
+ if attributes != "" and attributes != " " and attributes is not None:
381
+ return attributes
382
+ else:
383
+ return random.choice(self.festival_list)
384
+
385
+
386
+ @AttrRegister.register
387
+ class SpringHeadwear(AttriributeIsText):
388
+ name = "spring_headwear"
389
+ headwear_list = ["rabbit ears", "rabbit ears, fur hat"]
390
+
391
+ def __call__(self, attributes: str = None) -> str:
392
+ if attributes != "" and attributes != " " and attributes is not None:
393
+ return attributes
394
+ else:
395
+ return random.choice(self.headwear_list)
396
+
397
+
398
+ @AttrRegister.register
399
+ class SpringClothes(AttriributeIsText):
400
+ name = "spring_clothes"
401
+ clothes_list = [
402
+ "mittens,chinese clothes",
403
+ "mittens,fur trim",
404
+ "mittens,red scarf",
405
+ "mittens,winter clothes",
406
+ ]
407
+
408
+ def __call__(self, attributes: str = None) -> str:
409
+ if attributes != "" and attributes != " " and attributes is not None:
410
+ return attributes
411
+ else:
412
+ return random.choice(self.clothes_list)
413
+
414
+
415
+ @AttrRegister.register
416
+ class Animal(AttriributeIsText):
417
+ name = "animal"
418
+ animal_list = ["rabbit", "holding rabbits"]
419
+
420
+ def __call__(self, attributes: str = None) -> str:
421
+ if attributes != "" and attributes != " " and attributes is not None:
422
+ return attributes
423
+ else:
424
+ return random.choice(self.animal_list)
musev/auto_prompt/attributes/render.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcm.utils.util import flatten
2
+
3
+ from .attributes import BaseAttribute2Text
4
+ from . import AttrRegister
5
+
6
+ __all__ = ["Render"]
7
+
8
+ RenderMap = {
9
+ "Epic": "artstation, epic environment, highly detailed, 8k, HD",
10
+ "HD": "8k, highly detailed",
11
+ "EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k",
12
+ "Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation",
13
+ "Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k",
14
+ "Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k",
15
+ }
16
+
17
+
18
+ @AttrRegister.register
19
+ class Render(BaseAttribute2Text):
20
+ name = "render"
21
+
22
+ def __init__(self, name: str = None) -> None:
23
+ super().__init__(name)
24
+
25
+ def __call__(self, attributes: str) -> str:
26
+ if attributes == "" or attributes is None:
27
+ return ""
28
+ attributes = attributes.split(",")
29
+ render = [RenderMap[attr] for attr in attributes if attr in RenderMap]
30
+ render = flatten(render, ignored_iterable_types=[str])
31
+ if len(render) == 1:
32
+ render = render[0]
33
+ return render
musev/auto_prompt/attributes/style.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attributes import AttriributeIsText
2
+ from . import AttrRegister
3
+
4
+ __all__ = ["Style"]
5
+
6
+
7
+ @AttrRegister.register
8
+ class Style(AttriributeIsText):
9
+ name = "style"
10
+
11
+ def __init__(self, name: str = None) -> None:
12
+ super().__init__(name)
musev/auto_prompt/human.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """负责按照人相关的属性转化成提词
2
+ """
3
+ from typing import List
4
+
5
+ from .attributes.human import PortraitMultiAttr2Text
6
+ from .attributes.attributes import BaseAttribute2Text
7
+ from .attributes.attr2template import MultiAttr2PromptTemplate
8
+
9
+
10
+ class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate):
11
+ """可以将任务字典转化为形象提词模板类
12
+ template class for converting task dictionaries into image prompt templates
13
+ Args:
14
+ MultiAttr2PromptTemplate (_type_): _description_
15
+ """
16
+
17
+ templates = "a portrait of {}"
18
+
19
+ def __init__(
20
+ self, templates: str = None, attr2text: List = None, name: str = "portrait"
21
+ ) -> None:
22
+ """
23
+
24
+ Args:
25
+ templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None.
26
+ portrait prompt template, if None, the default class attribute is used.
27
+ attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None.
28
+ the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used.
29
+ name (str, optional): 该形象类的名字. Defaults to "portrait".
30
+ class name of this class instance
31
+ """
32
+ if (
33
+ attr2text is None
34
+ or isinstance(attr2text, list)
35
+ or isinstance(attr2text, BaseAttribute2Text)
36
+ ):
37
+ attr2text = PortraitMultiAttr2Text(funcs=attr2text)
38
+ if templates is None:
39
+ templates = self.templates
40
+ super().__init__(templates, attr2text, name=name)
musev/auto_prompt/load_template.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcm.utils.str_util import has_key_brace
2
+
3
+ from .human import PortraitAttr2PromptTemplate
4
+ from .attributes.attr2template import (
5
+ KeywordMultiAttr2PromptTemplate,
6
+ OnlySpacePromptTemplate,
7
+ )
8
+
9
+
10
+ def get_template_by_name(template: str, name: str = None):
11
+ """根据 template_name 确定 prompt 生成器类
12
+ choose prompt generator class according to template_name
13
+ Args:
14
+ name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference
15
+
16
+ Raises:
17
+ ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported.
18
+
19
+ Returns:
20
+ MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function
21
+
22
+ """
23
+ if template == "" or template is None:
24
+ template = OnlySpacePromptTemplate(template=template)
25
+ elif has_key_brace(template):
26
+ # if has_key_brace(template):
27
+ template = KeywordMultiAttr2PromptTemplate(template=template)
28
+ else:
29
+ if name == "portrait":
30
+ template = PortraitAttr2PromptTemplate(templates=template)
31
+ else:
32
+ raise ValueError(
33
+ "PresetAttr2PromptTemplate only support one of [portrait], but given {}".format(
34
+ name
35
+ )
36
+ )
37
+ return template
musev/auto_prompt/util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Dict, List
3
+
4
+ from .load_template import get_template_by_name
5
+
6
+
7
+ def generate_prompts(tasks: List[Dict]) -> List[Dict]:
8
+ new_tasks = []
9
+ for task in tasks:
10
+ task["origin_prompt"] = deepcopy(task["prompt"])
11
+ # 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值
12
+ if "{" not in task["prompt"] and len(task["prompt"]) != 0:
13
+ new_tasks.append(task)
14
+ else:
15
+ template = get_template_by_name(
16
+ template=task["prompt"], name=task.get("template_name", None)
17
+ )
18
+ prompts = template(task)
19
+ if not isinstance(prompts, list) and isinstance(prompts, str):
20
+ prompts = [prompts]
21
+ for prompt in prompts:
22
+ task_cp = deepcopy(task)
23
+ task_cp["prompt"] = prompt
24
+ new_tasks.append(task_cp)
25
+ return new_tasks
musev/data/__init__.py ADDED
File without changes
musev/data/data_util.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Literal, Union, Tuple
2
+ import os
3
+ import string
4
+ import logging
5
+
6
+ import torch
7
+ import numpy as np
8
+ from einops import rearrange, repeat
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def generate_tasks_of_dir(
14
+ path: str,
15
+ output_dir: str,
16
+ exts: Tuple[str],
17
+ same_dir_name: bool = False,
18
+ **kwargs,
19
+ ) -> List[Dict]:
20
+ """covert video directory into tasks
21
+
22
+ Args:
23
+ path (str): _description_
24
+ output_dir (str): _description_
25
+ exts (Tuple[str]): _description_
26
+ same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False.
27
+ whether keep the same parent dir name as the source video
28
+ Returns:
29
+ List[Dict]: _description_
30
+ """
31
+ tasks = []
32
+ for rootdir, dirs, files in os.walk(path):
33
+ for basename in files:
34
+ if basename.lower().endswith(exts):
35
+ video_path = os.path.join(rootdir, basename)
36
+ filename, ext = basename.split(".")
37
+ rootdir_name = os.path.basename(rootdir)
38
+ if same_dir_name:
39
+ save_path = os.path.join(
40
+ output_dir, rootdir_name, f"{filename}.h5py"
41
+ )
42
+ save_dir = os.path.join(output_dir, rootdir_name)
43
+ else:
44
+ save_path = os.path.join(output_dir, f"{filename}.h5py")
45
+ save_dir = output_dir
46
+ task = {
47
+ "video_path": video_path,
48
+ "output_path": save_path,
49
+ "output_dir": save_dir,
50
+ "filename": filename,
51
+ "ext": ext,
52
+ }
53
+ task.update(kwargs)
54
+ tasks.append(task)
55
+ return tasks
56
+
57
+
58
+ def sample_by_idx(
59
+ T: int,
60
+ n_sample: int,
61
+ sample_rate: int,
62
+ sample_start_idx: int = None,
63
+ change_sample_rate: bool = False,
64
+ seed: int = None,
65
+ whether_random: bool = True,
66
+ n_independent: int = 0,
67
+ ) -> List[int]:
68
+ """given a int to represent candidate list, sample n_sample with sample_rate from the candidate list
69
+
70
+ Args:
71
+ T (int): _description_
72
+ n_sample (int): 目标采样数目. sample number
73
+ sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number
74
+ sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0.
75
+ change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False.
76
+ whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False.
77
+
78
+ Raises:
79
+ ValueError: T / sample_rate should be larger than n_sample
80
+ Returns:
81
+ List[int]: 采样的索引位置. sampled index position
82
+ """
83
+ if T < n_sample:
84
+ raise ValueError(f"T({T}) < n_sample({n_sample})")
85
+ else:
86
+ if T / sample_rate < n_sample:
87
+ if not change_sample_rate:
88
+ raise ValueError(
89
+ f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})"
90
+ )
91
+ else:
92
+ while T / sample_rate < n_sample:
93
+ sample_rate -= 1
94
+ logger.error(
95
+ f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}"
96
+ )
97
+ if sample_rate == 0:
98
+ raise ValueError("T / sample_rate < n_sample")
99
+
100
+ if sample_start_idx is None:
101
+ if whether_random:
102
+ sample_start_idx_candidates = np.arange(T - n_sample * sample_rate)
103
+ if seed is not None:
104
+ np.random.seed(seed)
105
+ sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0]
106
+
107
+ else:
108
+ sample_start_idx = 0
109
+ sample_end_idx = sample_start_idx + sample_rate * n_sample
110
+ sample = list(range(sample_start_idx, sample_end_idx, sample_rate))
111
+ if n_independent == 0:
112
+ n_independent_sample = None
113
+ else:
114
+ left_candidate = np.array(
115
+ list(range(0, sample_start_idx)) + list(range(sample_end_idx, T))
116
+ )
117
+ if len(left_candidate) >= n_independent:
118
+ # 使用两端的剩余空间采样, use the left space to sample
119
+ n_independent_sample = np.random.choice(left_candidate, n_independent)
120
+ else:
121
+ # 当两端没有剩余采样空间时,使用任意不是sample中的帧
122
+ # if no enough space to sample, use any frame not in sample
123
+ left_candidate = np.array(list(set(range(T) - set(sample))))
124
+ n_independent_sample = np.random.choice(left_candidate, n_independent)
125
+
126
+ return sample, sample_rate, n_independent_sample
127
+
128
+
129
+ def sample_tensor_by_idx(
130
+ tensor: Union[torch.Tensor, np.ndarray],
131
+ n_sample: int,
132
+ sample_rate: int,
133
+ sample_start_idx: int = 0,
134
+ change_sample_rate: bool = False,
135
+ seed: int = None,
136
+ dim: int = 0,
137
+ return_type: Literal["numpy", "torch"] = "torch",
138
+ whether_random: bool = True,
139
+ n_independent: int = 0,
140
+ ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]:
141
+ """sample sub_tensor
142
+
143
+ Args:
144
+ tensor (Union[torch.Tensor, np.ndarray]): _description_
145
+ n_sample (int): _description_
146
+ sample_rate (int): _description_
147
+ sample_start_idx (int, optional): _description_. Defaults to 0.
148
+ change_sample_rate (bool, optional): _description_. Defaults to False.
149
+ seed (int, optional): _description_. Defaults to None.
150
+ dim (int, optional): _description_. Defaults to 0.
151
+ return_type (Literal[&quot;numpy&quot;, &quot;torch&quot;], optional): _description_. Defaults to "torch".
152
+ whether_random (bool, optional): _description_. Defaults to True.
153
+ n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0.
154
+ n_independent sample number that is independent of n_sample
155
+
156
+ Returns:
157
+ Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor
158
+ """
159
+ if isinstance(tensor, np.ndarray):
160
+ tensor = torch.from_numpy(tensor)
161
+ T = tensor.shape[dim]
162
+ sample_idx, sample_rate, independent_sample_idx = sample_by_idx(
163
+ T,
164
+ n_sample,
165
+ sample_rate,
166
+ sample_start_idx,
167
+ change_sample_rate,
168
+ seed,
169
+ whether_random=whether_random,
170
+ n_independent=n_independent,
171
+ )
172
+ sample_idx = torch.LongTensor(sample_idx)
173
+ sample = torch.index_select(tensor, dim, sample_idx)
174
+ if independent_sample_idx is not None:
175
+ independent_sample_idx = torch.LongTensor(independent_sample_idx)
176
+ independent_sample = torch.index_select(tensor, dim, independent_sample_idx)
177
+ else:
178
+ independent_sample = None
179
+ independent_sample_idx = None
180
+ if return_type == "numpy":
181
+ sample = sample.cpu().numpy()
182
+ return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx
183
+
184
+
185
+ def concat_two_tensor(
186
+ data1: torch.Tensor,
187
+ data2: torch.Tensor,
188
+ dim: int,
189
+ method: Literal[
190
+ "first_in_first_out", "first_in_last_out", "intertwine", "index"
191
+ ] = "first_in_first_out",
192
+ data1_index: torch.long = None,
193
+ data2_index: torch.long = None,
194
+ return_index: bool = False,
195
+ ):
196
+ """concat two tensor along dim with given method
197
+
198
+ Args:
199
+ data1 (torch.Tensor): first in data
200
+ data2 (torch.Tensor): last in data
201
+ dim (int): _description_
202
+ method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot; ], optional): _description_. Defaults to "first_in_first_out".
203
+
204
+ Raises:
205
+ NotImplementedError: unsupported method
206
+ ValueError: unsupported method
207
+
208
+ Returns:
209
+ _type_: _description_
210
+ """
211
+ len_data1 = data1.shape[dim]
212
+ len_data2 = data2.shape[dim]
213
+
214
+ if method == "first_in_first_out":
215
+ res = torch.concat([data1, data2], dim=dim)
216
+ data1_index = range(len_data1)
217
+ data2_index = [len_data1 + x for x in range(len_data2)]
218
+ elif method == "first_in_last_out":
219
+ res = torch.concat([data2, data1], dim=dim)
220
+ data2_index = range(len_data2)
221
+ data1_index = [len_data2 + x for x in range(len_data1)]
222
+ elif method == "intertwine":
223
+ raise NotImplementedError("intertwine")
224
+ elif method == "index":
225
+ res = concat_two_tensor_with_index(
226
+ data1=data1,
227
+ data1_index=data1_index,
228
+ data2=data2,
229
+ data2_index=data2_index,
230
+ dim=dim,
231
+ )
232
+ else:
233
+ raise ValueError(
234
+ "only support first_in_first_out, first_in_last_out, intertwine, index"
235
+ )
236
+ if return_index:
237
+ return res, data1_index, data2_index
238
+ else:
239
+ return res
240
+
241
+
242
+ def concat_two_tensor_with_index(
243
+ data1: torch.Tensor,
244
+ data1_index: torch.LongTensor,
245
+ data2: torch.Tensor,
246
+ data2_index: torch.LongTensor,
247
+ dim: int,
248
+ ) -> torch.Tensor:
249
+ """_summary_
250
+
251
+ Args:
252
+ data1 (torch.Tensor): b1*c1*h1*w1*...
253
+ data1_index (torch.LongTensor): N, if dim=1, N=c1
254
+ data2 (torch.Tensor): b2*c2*h2*w2*...
255
+ data2_index (torch.LongTensor): M, if dim=1, M=c2
256
+ dim (int): int
257
+
258
+ Returns:
259
+ torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,...
260
+ """
261
+ shape1 = list(data1.shape)
262
+ shape2 = list(data2.shape)
263
+ target_shape = list(shape1)
264
+ target_shape[dim] = shape1[dim] + shape2[dim]
265
+ target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
266
+ target = batch_index_copy(target, dim=dim, index=data1_index, source=data1)
267
+ target = batch_index_copy(target, dim=dim, index=data2_index, source=data2)
268
+ return target
269
+
270
+
271
+ def repeat_index_to_target_size(
272
+ index: torch.LongTensor, target_size: int
273
+ ) -> torch.LongTensor:
274
+ if len(index.shape) == 1:
275
+ index = repeat(index, "n -> b n", b=target_size)
276
+ if len(index.shape) == 2:
277
+ remainder = target_size % index.shape[0]
278
+ assert (
279
+ remainder == 0
280
+ ), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}"
281
+ index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0]))
282
+ return index
283
+
284
+
285
+ def batch_concat_two_tensor_with_index(
286
+ data1: torch.Tensor,
287
+ data1_index: torch.LongTensor,
288
+ data2: torch.Tensor,
289
+ data2_index: torch.LongTensor,
290
+ dim: int,
291
+ ) -> torch.Tensor:
292
+ return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim)
293
+
294
+
295
+ def interwine_two_tensor(
296
+ data1: torch.Tensor,
297
+ data2: torch.Tensor,
298
+ dim: int,
299
+ return_index: bool = False,
300
+ ) -> torch.Tensor:
301
+ shape1 = list(data1.shape)
302
+ shape2 = list(data2.shape)
303
+ target_shape = list(shape1)
304
+ target_shape[dim] = shape1[dim] + shape2[dim]
305
+ target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
306
+ data1_reshape = torch.swapaxes(data1, 0, dim)
307
+ data2_reshape = torch.swapaxes(data2, 0, dim)
308
+ target = torch.swapaxes(target, 0, dim)
309
+ total_index = set(range(target_shape[dim]))
310
+ data1_index = range(0, 2 * shape1[dim], 2)
311
+ data2_index = sorted(list(set(total_index) - set(data1_index)))
312
+ data1_index = torch.LongTensor(data1_index)
313
+ data2_index = torch.LongTensor(data2_index)
314
+ target[data1_index, ...] = data1_reshape
315
+ target[data2_index, ...] = data2_reshape
316
+ target = torch.swapaxes(target, 0, dim)
317
+ if return_index:
318
+ return target, data1_index, data2_index
319
+ else:
320
+ return target
321
+
322
+
323
+ def split_index(
324
+ indexs: torch.Tensor,
325
+ n_first: int = None,
326
+ n_last: int = None,
327
+ method: Literal[
328
+ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
329
+ ] = "first_in_first_out",
330
+ ):
331
+ """_summary_
332
+
333
+ Args:
334
+ indexs (List): _description_
335
+ n_first (int): _description_
336
+ n_last (int): _description_
337
+ method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot;, &quot;index&quot; ], optional): _description_. Defaults to "first_in_first_out".
338
+
339
+ Raises:
340
+ NotImplementedError: _description_
341
+
342
+ Returns:
343
+ first_index: _description_
344
+ last_index:
345
+ """
346
+ # assert (
347
+ # n_first is None and n_last is None
348
+ # ), "must assign one value for n_first or n_last"
349
+ n_total = len(indexs)
350
+ if n_first is None:
351
+ n_first = n_total - n_last
352
+ if n_last is None:
353
+ n_last = n_total - n_first
354
+ assert len(indexs) == n_first + n_last
355
+ if method == "first_in_first_out":
356
+ first_index = indexs[:n_first]
357
+ last_index = indexs[n_first:]
358
+ elif method == "first_in_last_out":
359
+ first_index = indexs[n_last:]
360
+ last_index = indexs[:n_last]
361
+ elif method == "intertwine":
362
+ raise NotImplementedError
363
+ elif method == "random":
364
+ idx_ = torch.randperm(len(indexs))
365
+ first_index = indexs[idx_[:n_first]]
366
+ last_index = indexs[idx_[n_first:]]
367
+ return first_index, last_index
368
+
369
+
370
+ def split_tensor(
371
+ tensor: torch.Tensor,
372
+ dim: int,
373
+ n_first=None,
374
+ n_last=None,
375
+ method: Literal[
376
+ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
377
+ ] = "first_in_first_out",
378
+ need_return_index: bool = False,
379
+ ):
380
+ device = tensor.device
381
+ total = tensor.shape[dim]
382
+ if n_first is None:
383
+ n_first = total - n_last
384
+ if n_last is None:
385
+ n_last = total - n_first
386
+ indexs = torch.arange(
387
+ total,
388
+ dtype=torch.long,
389
+ device=device,
390
+ )
391
+ (
392
+ first_index,
393
+ last_index,
394
+ ) = split_index(
395
+ indexs=indexs,
396
+ n_first=n_first,
397
+ method=method,
398
+ )
399
+ first_tensor = torch.index_select(tensor, dim=dim, index=first_index)
400
+ last_tensor = torch.index_select(tensor, dim=dim, index=last_index)
401
+ if need_return_index:
402
+ return (
403
+ first_tensor,
404
+ last_tensor,
405
+ first_index,
406
+ last_index,
407
+ )
408
+ else:
409
+ return (first_tensor, last_tensor)
410
+
411
+
412
+ # TODO: 待确定batch_index_select的优化
413
+ def batch_index_select(
414
+ tensor: torch.Tensor, index: torch.LongTensor, dim: int
415
+ ) -> torch.Tensor:
416
+ """_summary_
417
+
418
+ Args:
419
+ tensor (torch.Tensor): D1*D2*D3*D4...
420
+ index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim]
421
+ dim (int): dim to select
422
+
423
+ Returns:
424
+ torch.Tensor: D1*...*N*...
425
+ """
426
+ # TODO: now only support N same for every d1
427
+ if len(index.shape) == 1:
428
+ return torch.index_select(tensor, dim=dim, index=index)
429
+ else:
430
+ index = repeat_index_to_target_size(index, tensor.shape[0])
431
+ out = []
432
+ for i in torch.arange(tensor.shape[0]):
433
+ sub_tensor = tensor[i]
434
+ sub_index = index[i]
435
+ d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index)
436
+ out.append(d)
437
+ return torch.stack(out).to(dtype=tensor.dtype)
438
+
439
+
440
+ def batch_index_copy(
441
+ tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor
442
+ ) -> torch.Tensor:
443
+ """_summary_
444
+
445
+ Args:
446
+ tensor (torch.Tensor): b*c*h
447
+ dim (int):
448
+ index (torch.LongTensor): b*d,
449
+ source (torch.Tensor):
450
+ b*d*h*..., if dim=1
451
+ b*c*d*..., if dim=2
452
+
453
+ Returns:
454
+ torch.Tensor: b*c*d*...
455
+ """
456
+ if len(index.shape) == 1:
457
+ tensor.index_copy_(dim=dim, index=index, source=source)
458
+ else:
459
+ index = repeat_index_to_target_size(index, tensor.shape[0])
460
+
461
+ batch_size = tensor.shape[0]
462
+ for b in torch.arange(batch_size):
463
+ sub_index = index[b]
464
+ sub_source = source[b]
465
+ sub_tensor = tensor[b]
466
+ sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source)
467
+ tensor[b] = sub_tensor
468
+ return tensor
469
+
470
+
471
+ def batch_index_fill(
472
+ tensor: torch.Tensor,
473
+ dim: int,
474
+ index: torch.LongTensor,
475
+ value: Literal[torch.Tensor, torch.float],
476
+ ) -> torch.Tensor:
477
+ """_summary_
478
+
479
+ Args:
480
+ tensor (torch.Tensor): b*c*h
481
+ dim (int):
482
+ index (torch.LongTensor): b*d,
483
+ value (torch.Tensor): b
484
+
485
+ Returns:
486
+ torch.Tensor: b*c*d*...
487
+ """
488
+ index = repeat_index_to_target_size(index, tensor.shape[0])
489
+ batch_size = tensor.shape[0]
490
+ for b in torch.arange(batch_size):
491
+ sub_index = index[b]
492
+ sub_value = value[b] if isinstance(value, torch.Tensor) else value
493
+ sub_tensor = tensor[b]
494
+ sub_tensor.index_fill_(dim - 1, sub_index, sub_value)
495
+ tensor[b] = sub_tensor
496
+ return tensor
497
+
498
+
499
+ def adaptive_instance_normalization(
500
+ src: torch.Tensor,
501
+ dst: torch.Tensor,
502
+ eps: float = 1e-6,
503
+ ):
504
+ """
505
+ Args:
506
+ src (torch.Tensor): b c t h w
507
+ dst (torch.Tensor): b c t h w
508
+ """
509
+ ndim = src.ndim
510
+ if ndim == 5:
511
+ dim = (2, 3, 4)
512
+ elif ndim == 4:
513
+ dim = (2, 3)
514
+ elif ndim == 3:
515
+ dim = 2
516
+ else:
517
+ raise ValueError("only support ndim in [3,4,5], but given {ndim}")
518
+ var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0)
519
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
520
+ dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0)
521
+ mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0)
522
+ # mean_acc = sum(mean_acc) / float(len(mean_acc))
523
+ # var_acc = sum(var_acc) / float(len(var_acc))
524
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
525
+ src = (((src - mean) / std) * std_acc) + mean_acc
526
+ return src
527
+
528
+
529
+ def adaptive_instance_normalization_with_ref(
530
+ src: torch.LongTensor,
531
+ dst: torch.LongTensor,
532
+ style_fidelity: float = 0.5,
533
+ do_classifier_free_guidance: bool = True,
534
+ ):
535
+ # logger.debug(
536
+ # f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n"
537
+ # f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}"
538
+ # )
539
+ batch_size = src.shape[0] // 2
540
+ uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool()
541
+ src_uc = adaptive_instance_normalization(src, dst)
542
+ src_c = src_uc.clone()
543
+ # TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True
544
+ if do_classifier_free_guidance and style_fidelity > 0:
545
+ src_c[uc_mask] = src[uc_mask]
546
+ src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc
547
+ return src
548
+
549
+
550
+ def batch_adain_conditioned_tensor(
551
+ tensor: torch.Tensor,
552
+ src_index: torch.LongTensor,
553
+ dst_index: torch.LongTensor,
554
+ keep_dim: bool = True,
555
+ num_frames: int = None,
556
+ dim: int = 2,
557
+ style_fidelity: float = 0.5,
558
+ do_classifier_free_guidance: bool = True,
559
+ need_style_fidelity: bool = False,
560
+ ):
561
+ """_summary_
562
+
563
+ Args:
564
+ tensor (torch.Tensor): b c t h w
565
+ src_index (torch.LongTensor): _description_
566
+ dst_index (torch.LongTensor): _description_
567
+ keep_dim (bool, optional): _description_. Defaults to True.
568
+
569
+ Returns:
570
+ _type_: _description_
571
+ """
572
+ ndim = tensor.ndim
573
+ dtype = tensor.dtype
574
+ if ndim == 4 and num_frames is not None:
575
+ tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames)
576
+ src = batch_index_select(tensor, dim=dim, index=src_index).contiguous()
577
+ dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous()
578
+ if need_style_fidelity:
579
+ src = adaptive_instance_normalization_with_ref(
580
+ src=src,
581
+ dst=dst,
582
+ style_fidelity=style_fidelity,
583
+ do_classifier_free_guidance=do_classifier_free_guidance,
584
+ need_style_fidelity=need_style_fidelity,
585
+ )
586
+ else:
587
+ src = adaptive_instance_normalization(
588
+ src=src,
589
+ dst=dst,
590
+ )
591
+ if keep_dim:
592
+ src = batch_concat_two_tensor_with_index(
593
+ src.to(dtype=dtype),
594
+ src_index,
595
+ dst.to(dtype=dtype),
596
+ dst_index,
597
+ dim=dim,
598
+ )
599
+
600
+ if ndim == 4 and num_frames is not None:
601
+ src = rearrange(tensor, "b c t h w ->(b t) c h w")
602
+ return src
603
+
604
+
605
+ def align_repeat_tensor_single_dim(
606
+ src: torch.Tensor,
607
+ target_length: int,
608
+ dim: int = 0,
609
+ n_src_base_length: int = 1,
610
+ src_base_index: List[int] = None,
611
+ ) -> torch.Tensor:
612
+ """沿着 dim 纬度, 补齐 src 的长度到目标 target_length。
613
+ 当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length
614
+
615
+ align length of src to target_length along dim
616
+ when src length is less than target_length, take the first n_src_base_length and repeat to target_length
617
+
618
+ Args:
619
+ src (torch.Tensor): 输入 tensor, input tensor
620
+ target_length (int): 目标长度, target_length
621
+ dim (int, optional): 处理纬度, target dim . Defaults to 0.
622
+ n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1.
623
+
624
+ Returns:
625
+ torch.Tensor: _description_
626
+ """
627
+ src_dim_length = src.shape[dim]
628
+ if target_length > src_dim_length:
629
+ if target_length % src_dim_length == 0:
630
+ new = src.repeat_interleave(
631
+ repeats=target_length // src_dim_length, dim=dim
632
+ )
633
+ else:
634
+ if src_base_index is None and n_src_base_length is not None:
635
+ src_base_index = torch.arange(n_src_base_length)
636
+
637
+ new = src.index_select(
638
+ dim=dim,
639
+ index=torch.LongTensor(src_base_index).to(device=src.device),
640
+ )
641
+ new = new.repeat_interleave(
642
+ repeats=target_length // len(src_base_index),
643
+ dim=dim,
644
+ )
645
+ elif target_length < src_dim_length:
646
+ new = src.index_select(
647
+ dim=dim,
648
+ index=torch.LongTensor(torch.arange(target_length)).to(device=src.device),
649
+ )
650
+ else:
651
+ new = src
652
+ return new
653
+
654
+
655
+ def fuse_part_tensor(
656
+ src: torch.Tensor,
657
+ dst: torch.Tensor,
658
+ overlap: int,
659
+ weight: float = 0.5,
660
+ skip_step: int = 0,
661
+ ) -> torch.Tensor:
662
+ """fuse overstep tensor with weight of src into dst
663
+ out = src_fused_part * weight + dst * (1-weight) for overlap
664
+
665
+ Args:
666
+ src (torch.Tensor): b c t h w
667
+ dst (torch.Tensor): b c t h w
668
+ overlap (int): 1
669
+ weight (float, optional): weight of src tensor part. Defaults to 0.5.
670
+
671
+ Returns:
672
+ torch.Tensor: fused tensor
673
+ """
674
+ if overlap == 0:
675
+ return dst
676
+ else:
677
+ dst[:, :, skip_step : skip_step + overlap] = (
678
+ weight * src[:, :, -overlap:]
679
+ + (1 - weight) * dst[:, :, skip_step : skip_step + overlap]
680
+ )
681
+ return dst
musev/logging.conf ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root,musev
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=musevFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ # logger level 尽量设置低一点
15
+ [logger_musev]
16
+ level=DEBUG
17
+ handlers=consoleHandler
18
+ qualname=musev
19
+ propagate=0
20
+
21
+ # handler level 设置比 logger level高
22
+ [handler_consoleHandler]
23
+ class=StreamHandler
24
+ level=DEBUG
25
+ # level=INFO
26
+
27
+ formatter=musevFormatter
28
+ args=(sys.stdout,)
29
+
30
+ [formatter_musevFormatter]
31
+ format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s
32
+ datefmt=
musev/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ..utils.register import Register
2
+
3
+ Model_Register = Register(registry_name="torch_model")
musev/models/attention.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/huggingface/diffusers/blob/64bf5d33b7ef1b1deac256bed7bd99b55020c4e0/src/diffusers/models/attention.py
16
+ from __future__ import annotations
17
+ from copy import deepcopy
18
+
19
+ from typing import Any, Dict, List, Literal, Optional, Callable, Tuple
20
+ import logging
21
+ from einops import rearrange
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
28
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
29
+ from diffusers.models.attention_processor import Attention as DiffusersAttention
30
+ from diffusers.models.attention import (
31
+ BasicTransformerBlock as DiffusersBasicTransformerBlock,
32
+ AdaLayerNormZero,
33
+ AdaLayerNorm,
34
+ FeedForward,
35
+ )
36
+ from diffusers.models.attention_processor import AttnProcessor
37
+
38
+ from .attention_processor import IPAttention, BaseIPAttnProcessor
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def not_use_xformers_anyway(
45
+ use_memory_efficient_attention_xformers: bool,
46
+ attention_op: Optional[Callable] = None,
47
+ ):
48
+ return None
49
+
50
+
51
+ @maybe_allow_in_graph
52
+ class BasicTransformerBlock(DiffusersBasicTransformerBlock):
53
+ print_idx = 0
54
+
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ num_attention_heads: int,
59
+ attention_head_dim: int,
60
+ dropout=0,
61
+ cross_attention_dim: int | None = None,
62
+ activation_fn: str = "geglu",
63
+ num_embeds_ada_norm: int | None = None,
64
+ attention_bias: bool = False,
65
+ only_cross_attention: bool = False,
66
+ double_self_attention: bool = False,
67
+ upcast_attention: bool = False,
68
+ norm_elementwise_affine: bool = True,
69
+ norm_type: str = "layer_norm",
70
+ final_dropout: bool = False,
71
+ attention_type: str = "default",
72
+ allow_xformers: bool = True,
73
+ cross_attn_temporal_cond: bool = False,
74
+ image_scale: float = 1.0,
75
+ processor: AttnProcessor | None = None,
76
+ ip_adapter_cross_attn: bool = False,
77
+ need_t2i_facein: bool = False,
78
+ need_t2i_ip_adapter_face: bool = False,
79
+ ):
80
+ if not only_cross_attention and double_self_attention:
81
+ cross_attention_dim = None
82
+ super().__init__(
83
+ dim,
84
+ num_attention_heads,
85
+ attention_head_dim,
86
+ dropout,
87
+ cross_attention_dim,
88
+ activation_fn,
89
+ num_embeds_ada_norm,
90
+ attention_bias,
91
+ only_cross_attention,
92
+ double_self_attention,
93
+ upcast_attention,
94
+ norm_elementwise_affine,
95
+ norm_type,
96
+ final_dropout,
97
+ attention_type,
98
+ )
99
+
100
+ self.attn1 = IPAttention(
101
+ query_dim=dim,
102
+ heads=num_attention_heads,
103
+ dim_head=attention_head_dim,
104
+ dropout=dropout,
105
+ bias=attention_bias,
106
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
107
+ upcast_attention=upcast_attention,
108
+ cross_attn_temporal_cond=cross_attn_temporal_cond,
109
+ image_scale=image_scale,
110
+ ip_adapter_dim=cross_attention_dim
111
+ if only_cross_attention
112
+ else attention_head_dim,
113
+ facein_dim=cross_attention_dim
114
+ if only_cross_attention
115
+ else attention_head_dim,
116
+ processor=processor,
117
+ )
118
+ # 2. Cross-Attn
119
+ if cross_attention_dim is not None or double_self_attention:
120
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
121
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
122
+ # the second cross attention block.
123
+ self.norm2 = (
124
+ AdaLayerNorm(dim, num_embeds_ada_norm)
125
+ if self.use_ada_layer_norm
126
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
127
+ )
128
+
129
+ self.attn2 = IPAttention(
130
+ query_dim=dim,
131
+ cross_attention_dim=cross_attention_dim
132
+ if not double_self_attention
133
+ else None,
134
+ heads=num_attention_heads,
135
+ dim_head=attention_head_dim,
136
+ dropout=dropout,
137
+ bias=attention_bias,
138
+ upcast_attention=upcast_attention,
139
+ cross_attn_temporal_cond=ip_adapter_cross_attn,
140
+ need_t2i_facein=need_t2i_facein,
141
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
142
+ image_scale=image_scale,
143
+ ip_adapter_dim=cross_attention_dim
144
+ if not double_self_attention
145
+ else attention_head_dim,
146
+ facein_dim=cross_attention_dim
147
+ if not double_self_attention
148
+ else attention_head_dim,
149
+ ip_adapter_face_dim=cross_attention_dim
150
+ if not double_self_attention
151
+ else attention_head_dim,
152
+ processor=processor,
153
+ ) # is self-attn if encoder_hidden_states is none
154
+ else:
155
+ self.norm2 = None
156
+ self.attn2 = None
157
+ if self.attn1 is not None:
158
+ if not allow_xformers:
159
+ self.attn1.set_use_memory_efficient_attention_xformers = (
160
+ not_use_xformers_anyway
161
+ )
162
+ if self.attn2 is not None:
163
+ if not allow_xformers:
164
+ self.attn2.set_use_memory_efficient_attention_xformers = (
165
+ not_use_xformers_anyway
166
+ )
167
+ self.double_self_attention = double_self_attention
168
+ self.only_cross_attention = only_cross_attention
169
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
170
+ self.image_scale = image_scale
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states: torch.FloatTensor,
175
+ attention_mask: Optional[torch.FloatTensor] = None,
176
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
177
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
178
+ timestep: Optional[torch.LongTensor] = None,
179
+ cross_attention_kwargs: Dict[str, Any] = None,
180
+ class_labels: Optional[torch.LongTensor] = None,
181
+ self_attn_block_embs: Optional[Tuple[List[torch.Tensor], List[None]]] = None,
182
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
183
+ ) -> torch.FloatTensor:
184
+ # Notice that normalization is always applied before the real computation in the following blocks.
185
+ # 0. Self-Attention
186
+ if self.use_ada_layer_norm:
187
+ norm_hidden_states = self.norm1(hidden_states, timestep)
188
+ elif self.use_ada_layer_norm_zero:
189
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
190
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
191
+ )
192
+ else:
193
+ norm_hidden_states = self.norm1(hidden_states)
194
+
195
+ # 1. Retrieve lora scale.
196
+ lora_scale = (
197
+ cross_attention_kwargs.get("scale", 1.0)
198
+ if cross_attention_kwargs is not None
199
+ else 1.0
200
+ )
201
+
202
+ if cross_attention_kwargs is None:
203
+ cross_attention_kwargs = {}
204
+ # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
205
+ # special AttnProcessor needs input parameters in cross_attention_kwargs
206
+ original_cross_attention_kwargs = {
207
+ k: v
208
+ for k, v in cross_attention_kwargs.items()
209
+ if k
210
+ not in [
211
+ "num_frames",
212
+ "sample_index",
213
+ "vision_conditon_frames_sample_index",
214
+ "vision_cond",
215
+ "vision_clip_emb",
216
+ "ip_adapter_scale",
217
+ "face_emb",
218
+ "facein_scale",
219
+ "ip_adapter_face_emb",
220
+ "ip_adapter_face_scale",
221
+ "do_classifier_free_guidance",
222
+ ]
223
+ }
224
+
225
+ if "do_classifier_free_guidance" in cross_attention_kwargs:
226
+ do_classifier_free_guidance = cross_attention_kwargs[
227
+ "do_classifier_free_guidance"
228
+ ]
229
+ else:
230
+ do_classifier_free_guidance = False
231
+
232
+ # 2. Prepare GLIGEN inputs
233
+ original_cross_attention_kwargs = (
234
+ original_cross_attention_kwargs.copy()
235
+ if original_cross_attention_kwargs is not None
236
+ else {}
237
+ )
238
+ gligen_kwargs = original_cross_attention_kwargs.pop("gligen", None)
239
+
240
+ # 返回self_attn的结果,适用于referencenet的输出给其他Unet来使用
241
+ # return the result of self_attn, which is suitable for the output of referencenet to be used by other Unet
242
+ if (
243
+ self_attn_block_embs is not None
244
+ and self_attn_block_embs_mode.lower() == "write"
245
+ ):
246
+ # self_attn_block_emb = self.attn1.head_to_batch_dim(attn_output, out_dim=4)
247
+ self_attn_block_emb = norm_hidden_states
248
+ if not hasattr(self, "spatial_self_attn_idx"):
249
+ raise ValueError(
250
+ "must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
251
+ )
252
+ basick_transformer_idx = self.spatial_self_attn_idx
253
+ if self.print_idx == 0:
254
+ logger.debug(
255
+ f"self_attn_block_embs, self_attn_block_embs_mode={self_attn_block_embs_mode}, "
256
+ f"basick_transformer_idx={basick_transformer_idx}, length={len(self_attn_block_embs)}, shape={self_attn_block_emb.shape}, "
257
+ # f"attn1 processor, {type(self.attn1.processor)}"
258
+ )
259
+ self_attn_block_embs[basick_transformer_idx] = self_attn_block_emb
260
+
261
+ # read and put referencenet emb into cross_attention_kwargs, which would be fused into attn_processor
262
+ if (
263
+ self_attn_block_embs is not None
264
+ and self_attn_block_embs_mode.lower() == "read"
265
+ ):
266
+ basick_transformer_idx = self.spatial_self_attn_idx
267
+ if not hasattr(self, "spatial_self_attn_idx"):
268
+ raise ValueError(
269
+ "must call unet.insert_spatial_self_attn_idx to generate spatial attn index"
270
+ )
271
+ if self.print_idx == 0:
272
+ logger.debug(
273
+ f"refer_self_attn_emb: , self_attn_block_embs_mode={self_attn_block_embs_mode}, "
274
+ f"length={len(self_attn_block_embs)}, idx={basick_transformer_idx}, "
275
+ # f"attn1 processor, {type(self.attn1.processor)}, "
276
+ )
277
+ ref_emb = self_attn_block_embs[basick_transformer_idx]
278
+ cross_attention_kwargs["refer_emb"] = ref_emb
279
+ if self.print_idx == 0:
280
+ logger.debug(
281
+ f"unet attention read, {self.spatial_self_attn_idx}",
282
+ )
283
+ # ------------------------------warning-----------------------
284
+ # 这两行由于使用了ref_emb会导致和checkpoint_train相关的训练错误,具体未知,留在这里作为警示
285
+ # bellow annoated code will cause training error, keep it here as a warning
286
+ # logger.debug(f"ref_emb shape,{ref_emb.shape}, {ref_emb.mean()}")
287
+ # logger.debug(
288
+ # f"norm_hidden_states shape, {norm_hidden_states.shape}, {norm_hidden_states.mean()}",
289
+ # )
290
+ if self.attn1 is None:
291
+ self.print_idx += 1
292
+ return norm_hidden_states
293
+ attn_output = self.attn1(
294
+ norm_hidden_states,
295
+ encoder_hidden_states=encoder_hidden_states
296
+ if self.only_cross_attention
297
+ else None,
298
+ attention_mask=attention_mask,
299
+ **(
300
+ cross_attention_kwargs
301
+ if isinstance(self.attn1.processor, BaseIPAttnProcessor)
302
+ else original_cross_attention_kwargs
303
+ ),
304
+ )
305
+
306
+ if self.use_ada_layer_norm_zero:
307
+ attn_output = gate_msa.unsqueeze(1) * attn_output
308
+ hidden_states = attn_output + hidden_states
309
+
310
+ # 推断的时候,对于uncondition_部分独立生成,排除掉 refer_emb,
311
+ # 首帧等的影响,避免生成参考了refer_emb、首帧等,又在uncond上去除了
312
+ # in inference stage, eliminate influence of refer_emb, vis_cond on unconditionpart
313
+ # to avoid use that, and then eliminate in pipeline
314
+ # refer to moore-animate anyone
315
+
316
+ # do_classifier_free_guidance = False
317
+ if self.print_idx == 0:
318
+ logger.debug(f"do_classifier_free_guidance={do_classifier_free_guidance},")
319
+ if do_classifier_free_guidance:
320
+ hidden_states_c = attn_output.clone()
321
+ _uc_mask = (
322
+ torch.Tensor(
323
+ [1] * (norm_hidden_states.shape[0] // 2)
324
+ + [0] * (norm_hidden_states.shape[0] // 2)
325
+ )
326
+ .to(norm_hidden_states.device)
327
+ .bool()
328
+ )
329
+ hidden_states_c[_uc_mask] = self.attn1(
330
+ norm_hidden_states[_uc_mask],
331
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
332
+ attention_mask=attention_mask,
333
+ )
334
+ attn_output = hidden_states_c.clone()
335
+
336
+ if "refer_emb" in cross_attention_kwargs:
337
+ del cross_attention_kwargs["refer_emb"]
338
+
339
+ # 2.5 GLIGEN Control
340
+ if gligen_kwargs is not None:
341
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
342
+ # 2.5 ends
343
+
344
+ # 3. Cross-Attention
345
+ if self.attn2 is not None:
346
+ norm_hidden_states = (
347
+ self.norm2(hidden_states, timestep)
348
+ if self.use_ada_layer_norm
349
+ else self.norm2(hidden_states)
350
+ )
351
+
352
+ # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备
353
+ # special AttnProcessor needs input parameters in cross_attention_kwargs
354
+ attn_output = self.attn2(
355
+ norm_hidden_states,
356
+ encoder_hidden_states=encoder_hidden_states
357
+ if not self.double_self_attention
358
+ else None,
359
+ attention_mask=encoder_attention_mask,
360
+ **(
361
+ original_cross_attention_kwargs
362
+ if not isinstance(self.attn2.processor, BaseIPAttnProcessor)
363
+ else cross_attention_kwargs
364
+ ),
365
+ )
366
+ if self.print_idx == 0:
367
+ logger.debug(
368
+ f"encoder_hidden_states, type={type(encoder_hidden_states)}"
369
+ )
370
+ if encoder_hidden_states is not None:
371
+ logger.debug(
372
+ f"encoder_hidden_states, ={encoder_hidden_states.shape}"
373
+ )
374
+
375
+ # encoder_hidden_states_tmp = (
376
+ # encoder_hidden_states
377
+ # if not self.double_self_attention
378
+ # else norm_hidden_states
379
+ # )
380
+ # if do_classifier_free_guidance:
381
+ # hidden_states_c = attn_output.clone()
382
+ # _uc_mask = (
383
+ # torch.Tensor(
384
+ # [1] * (norm_hidden_states.shape[0] // 2)
385
+ # + [0] * (norm_hidden_states.shape[0] // 2)
386
+ # )
387
+ # .to(norm_hidden_states.device)
388
+ # .bool()
389
+ # )
390
+ # hidden_states_c[_uc_mask] = self.attn2(
391
+ # norm_hidden_states[_uc_mask],
392
+ # encoder_hidden_states=encoder_hidden_states_tmp[_uc_mask],
393
+ # attention_mask=attention_mask,
394
+ # )
395
+ # attn_output = hidden_states_c.clone()
396
+ hidden_states = attn_output + hidden_states
397
+ # 4. Feed-forward
398
+ if self.norm3 is not None and self.ff is not None:
399
+ norm_hidden_states = self.norm3(hidden_states)
400
+ if self.use_ada_layer_norm_zero:
401
+ norm_hidden_states = (
402
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
403
+ )
404
+ if self._chunk_size is not None:
405
+ # "feed_forward_chunk_size" can be used to save memory
406
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
407
+ raise ValueError(
408
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
409
+ )
410
+
411
+ num_chunks = (
412
+ norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
413
+ )
414
+ ff_output = torch.cat(
415
+ [
416
+ self.ff(hid_slice, scale=lora_scale)
417
+ for hid_slice in norm_hidden_states.chunk(
418
+ num_chunks, dim=self._chunk_dim
419
+ )
420
+ ],
421
+ dim=self._chunk_dim,
422
+ )
423
+ else:
424
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
425
+
426
+ if self.use_ada_layer_norm_zero:
427
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
428
+
429
+ hidden_states = ff_output + hidden_states
430
+ self.print_idx += 1
431
+ return hidden_states
musev/models/attention_processor.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """该模型是自定义的attn_processor,实现特殊功能的 Attn功能。
16
+ 相对而言,开源代码经常会重新定义Attention 类,
17
+
18
+ This module implements special AttnProcessor function with custom attn_processor class.
19
+ While other open source code always modify Attention class.
20
+ """
21
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
22
+ from __future__ import annotations
23
+
24
+ import time
25
+ from typing import Any, Callable, Optional
26
+ import logging
27
+
28
+ from einops import rearrange, repeat
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import xformers
33
+ from diffusers.models.lora import LoRACompatibleLinear
34
+
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.attention_processor import (
37
+ Attention as DiffusersAttention,
38
+ AttnProcessor,
39
+ AttnProcessor2_0,
40
+ )
41
+ from ..data.data_util import (
42
+ batch_concat_two_tensor_with_index,
43
+ batch_index_select,
44
+ align_repeat_tensor_single_dim,
45
+ batch_adain_conditioned_tensor,
46
+ )
47
+
48
+ from . import Model_Register
49
+
50
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ @maybe_allow_in_graph
54
+ class IPAttention(DiffusersAttention):
55
+ r"""
56
+ Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v,
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ query_dim: int,
62
+ cross_attention_dim: int | None = None,
63
+ heads: int = 8,
64
+ dim_head: int = 64,
65
+ dropout: float = 0,
66
+ bias=False,
67
+ upcast_attention: bool = False,
68
+ upcast_softmax: bool = False,
69
+ cross_attention_norm: str | None = None,
70
+ cross_attention_norm_num_groups: int = 32,
71
+ added_kv_proj_dim: int | None = None,
72
+ norm_num_groups: int | None = None,
73
+ spatial_norm_dim: int | None = None,
74
+ out_bias: bool = True,
75
+ scale_qk: bool = True,
76
+ only_cross_attention: bool = False,
77
+ eps: float = 0.00001,
78
+ rescale_output_factor: float = 1,
79
+ residual_connection: bool = False,
80
+ _from_deprecated_attn_block=False,
81
+ processor: AttnProcessor | None = None,
82
+ cross_attn_temporal_cond: bool = False,
83
+ image_scale: float = 1.0,
84
+ ip_adapter_dim: int = None,
85
+ need_t2i_facein: bool = False,
86
+ facein_dim: int = None,
87
+ need_t2i_ip_adapter_face: bool = False,
88
+ ip_adapter_face_dim: int = None,
89
+ ):
90
+ super().__init__(
91
+ query_dim,
92
+ cross_attention_dim,
93
+ heads,
94
+ dim_head,
95
+ dropout,
96
+ bias,
97
+ upcast_attention,
98
+ upcast_softmax,
99
+ cross_attention_norm,
100
+ cross_attention_norm_num_groups,
101
+ added_kv_proj_dim,
102
+ norm_num_groups,
103
+ spatial_norm_dim,
104
+ out_bias,
105
+ scale_qk,
106
+ only_cross_attention,
107
+ eps,
108
+ rescale_output_factor,
109
+ residual_connection,
110
+ _from_deprecated_attn_block,
111
+ processor,
112
+ )
113
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
114
+ self.image_scale = image_scale
115
+ # 面向首帧的 ip_adapter
116
+ # ip_apdater
117
+ if cross_attn_temporal_cond:
118
+ self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
119
+ self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
120
+ # facein
121
+ self.need_t2i_facein = need_t2i_facein
122
+ self.facein_dim = facein_dim
123
+ if need_t2i_facein:
124
+ raise NotImplementedError("facein")
125
+
126
+ # ip_adapter_face
127
+ self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
128
+ self.ip_adapter_face_dim = ip_adapter_face_dim
129
+ if need_t2i_ip_adapter_face:
130
+ self.ip_adapter_face_to_k_ip = LoRACompatibleLinear(
131
+ ip_adapter_face_dim, query_dim, bias=False
132
+ )
133
+ self.ip_adapter_face_to_v_ip = LoRACompatibleLinear(
134
+ ip_adapter_face_dim, query_dim, bias=False
135
+ )
136
+
137
+ def set_use_memory_efficient_attention_xformers(
138
+ self,
139
+ use_memory_efficient_attention_xformers: bool,
140
+ attention_op: Callable[..., Any] | None = None,
141
+ ):
142
+ if (
143
+ "XFormers" in self.processor.__class__.__name__
144
+ or "IP" in self.processor.__class__.__name__
145
+ ):
146
+ pass
147
+ else:
148
+ return super().set_use_memory_efficient_attention_xformers(
149
+ use_memory_efficient_attention_xformers, attention_op
150
+ )
151
+
152
+
153
+ @Model_Register.register
154
+ class BaseIPAttnProcessor(nn.Module):
155
+ print_idx = 0
156
+
157
+ def __init__(self, *args, **kwargs) -> None:
158
+ super().__init__(*args, **kwargs)
159
+
160
+
161
+ @Model_Register.register
162
+ class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor):
163
+ r"""
164
+ 面向 ref_image的 self_attn的 IPAdapter
165
+ """
166
+ print_idx = 0
167
+
168
+ def __init__(
169
+ self,
170
+ attention_op: Optional[Callable] = None,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.attention_op = attention_op
175
+
176
+ def __call__(
177
+ self,
178
+ attn: IPAttention,
179
+ hidden_states: torch.FloatTensor,
180
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ temb: Optional[torch.FloatTensor] = None,
183
+ scale: float = 1.0,
184
+ num_frames: int = None,
185
+ sample_index: torch.LongTensor = None,
186
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
187
+ refer_emb: torch.Tensor = None,
188
+ vision_clip_emb: torch.Tensor = None,
189
+ ip_adapter_scale: float = 1.0,
190
+ face_emb: torch.Tensor = None,
191
+ facein_scale: float = 1.0,
192
+ ip_adapter_face_emb: torch.Tensor = None,
193
+ ip_adapter_face_scale: float = 1.0,
194
+ do_classifier_free_guidance: bool = False,
195
+ ):
196
+ residual = hidden_states
197
+
198
+ if attn.spatial_norm is not None:
199
+ hidden_states = attn.spatial_norm(hidden_states, temb)
200
+
201
+ input_ndim = hidden_states.ndim
202
+
203
+ if input_ndim == 4:
204
+ batch_size, channel, height, width = hidden_states.shape
205
+ hidden_states = hidden_states.view(
206
+ batch_size, channel, height * width
207
+ ).transpose(1, 2)
208
+
209
+ batch_size, key_tokens, _ = (
210
+ hidden_states.shape
211
+ if encoder_hidden_states is None
212
+ else encoder_hidden_states.shape
213
+ )
214
+
215
+ attention_mask = attn.prepare_attention_mask(
216
+ attention_mask, key_tokens, batch_size
217
+ )
218
+ if attention_mask is not None:
219
+ # expand our mask's singleton query_tokens dimension:
220
+ # [batch*heads, 1, key_tokens] ->
221
+ # [batch*heads, query_tokens, key_tokens]
222
+ # so that it can be added as a bias onto the attention scores that xformers computes:
223
+ # [batch*heads, query_tokens, key_tokens]
224
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
225
+ _, query_tokens, _ = hidden_states.shape
226
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
227
+
228
+ if attn.group_norm is not None:
229
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
230
+ 1, 2
231
+ )
232
+
233
+ query = attn.to_q(hidden_states, scale=scale)
234
+
235
+ if encoder_hidden_states is None:
236
+ encoder_hidden_states = hidden_states
237
+ elif attn.norm_cross:
238
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
239
+ encoder_hidden_states
240
+ )
241
+ encoder_hidden_states = align_repeat_tensor_single_dim(
242
+ encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
243
+ )
244
+ key = attn.to_k(encoder_hidden_states, scale=scale)
245
+ value = attn.to_v(encoder_hidden_states, scale=scale)
246
+
247
+ # for facein
248
+ if self.print_idx == 0:
249
+ logger.debug(
250
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}"
251
+ )
252
+ if facein_scale > 0 and face_emb is not None:
253
+ raise NotImplementedError("facein")
254
+
255
+ query = attn.head_to_batch_dim(query).contiguous()
256
+ key = attn.head_to_batch_dim(key).contiguous()
257
+ value = attn.head_to_batch_dim(value).contiguous()
258
+ hidden_states = xformers.ops.memory_efficient_attention(
259
+ query,
260
+ key,
261
+ value,
262
+ attn_bias=attention_mask,
263
+ op=self.attention_op,
264
+ scale=attn.scale,
265
+ )
266
+
267
+ # ip-adapter start
268
+ if self.print_idx == 0:
269
+ logger.debug(
270
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}"
271
+ )
272
+ if ip_adapter_scale > 0 and vision_clip_emb is not None:
273
+ if self.print_idx == 0:
274
+ logger.debug(
275
+ f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
276
+ )
277
+ ip_key = attn.to_k_ip(vision_clip_emb)
278
+ ip_value = attn.to_v_ip(vision_clip_emb)
279
+ ip_key = align_repeat_tensor_single_dim(
280
+ ip_key, target_length=batch_size, dim=0
281
+ )
282
+ ip_value = align_repeat_tensor_single_dim(
283
+ ip_value, target_length=batch_size, dim=0
284
+ )
285
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
286
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
287
+ if self.print_idx == 0:
288
+ logger.debug(
289
+ f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
290
+ )
291
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
292
+ hidden_states_from_ip = xformers.ops.memory_efficient_attention(
293
+ query,
294
+ ip_key,
295
+ ip_value,
296
+ attn_bias=attention_mask,
297
+ op=self.attention_op,
298
+ scale=attn.scale,
299
+ )
300
+ hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip
301
+ # ip-adapter end
302
+
303
+ # ip-adapter face start
304
+ if self.print_idx == 0:
305
+ logger.debug(
306
+ f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}"
307
+ )
308
+ if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None:
309
+ if self.print_idx == 0:
310
+ logger.debug(
311
+ f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
312
+ )
313
+ ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb)
314
+ ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb)
315
+ ip_key = align_repeat_tensor_single_dim(
316
+ ip_key, target_length=batch_size, dim=0
317
+ )
318
+ ip_value = align_repeat_tensor_single_dim(
319
+ ip_value, target_length=batch_size, dim=0
320
+ )
321
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
322
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
323
+ if self.print_idx == 0:
324
+ logger.debug(
325
+ f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
326
+ )
327
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
328
+ hidden_states_from_ip = xformers.ops.memory_efficient_attention(
329
+ query,
330
+ ip_key,
331
+ ip_value,
332
+ attn_bias=attention_mask,
333
+ op=self.attention_op,
334
+ scale=attn.scale,
335
+ )
336
+ hidden_states = (
337
+ hidden_states + ip_adapter_face_scale * hidden_states_from_ip
338
+ )
339
+ # ip-adapter face end
340
+
341
+ hidden_states = hidden_states.to(query.dtype)
342
+ hidden_states = attn.batch_to_head_dim(hidden_states)
343
+
344
+ # linear proj
345
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
346
+ # dropout
347
+ hidden_states = attn.to_out[1](hidden_states)
348
+
349
+ if input_ndim == 4:
350
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
351
+ batch_size, channel, height, width
352
+ )
353
+
354
+ if attn.residual_connection:
355
+ hidden_states = hidden_states + residual
356
+
357
+ hidden_states = hidden_states / attn.rescale_output_factor
358
+ self.print_idx += 1
359
+ return hidden_states
360
+
361
+
362
+ @Model_Register.register
363
+ class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor):
364
+ r"""
365
+ 面向首帧的 referenceonly attn,适用于 T2I的 self_attn
366
+ referenceonly with vis_cond as key, value, in t2i self_attn.
367
+ """
368
+ print_idx = 0
369
+
370
+ def __init__(
371
+ self,
372
+ attention_op: Optional[Callable] = None,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.attention_op = attention_op
377
+
378
+ def __call__(
379
+ self,
380
+ attn: IPAttention,
381
+ hidden_states: torch.FloatTensor,
382
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
383
+ attention_mask: Optional[torch.FloatTensor] = None,
384
+ temb: Optional[torch.FloatTensor] = None,
385
+ scale: float = 1.0,
386
+ num_frames: int = None,
387
+ sample_index: torch.LongTensor = None,
388
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
389
+ refer_emb: torch.Tensor = None,
390
+ face_emb: torch.Tensor = None,
391
+ vision_clip_emb: torch.Tensor = None,
392
+ ip_adapter_scale: float = 1.0,
393
+ facein_scale: float = 1.0,
394
+ ip_adapter_face_emb: torch.Tensor = None,
395
+ ip_adapter_face_scale: float = 1.0,
396
+ do_classifier_free_guidance: bool = False,
397
+ ):
398
+ residual = hidden_states
399
+
400
+ if attn.spatial_norm is not None:
401
+ hidden_states = attn.spatial_norm(hidden_states, temb)
402
+
403
+ input_ndim = hidden_states.ndim
404
+
405
+ if input_ndim == 4:
406
+ batch_size, channel, height, width = hidden_states.shape
407
+ hidden_states = hidden_states.view(
408
+ batch_size, channel, height * width
409
+ ).transpose(1, 2)
410
+
411
+ batch_size, key_tokens, _ = (
412
+ hidden_states.shape
413
+ if encoder_hidden_states is None
414
+ else encoder_hidden_states.shape
415
+ )
416
+
417
+ attention_mask = attn.prepare_attention_mask(
418
+ attention_mask, key_tokens, batch_size
419
+ )
420
+ if attention_mask is not None:
421
+ # expand our mask's singleton query_tokens dimension:
422
+ # [batch*heads, 1, key_tokens] ->
423
+ # [batch*heads, query_tokens, key_tokens]
424
+ # so that it can be added as a bias onto the attention scores that xformers computes:
425
+ # [batch*heads, query_tokens, key_tokens]
426
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
427
+ _, query_tokens, _ = hidden_states.shape
428
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
429
+
430
+ # vision_cond in same unet attn start
431
+ if (
432
+ vision_conditon_frames_sample_index is not None and num_frames > 1
433
+ ) or refer_emb is not None:
434
+ batchsize_timesize = hidden_states.shape[0]
435
+ if self.print_idx == 0:
436
+ logger.debug(
437
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}"
438
+ )
439
+ encoder_hidden_states = rearrange(
440
+ hidden_states, "(b t) hw c -> b t hw c", t=num_frames
441
+ )
442
+ # if False:
443
+ if vision_conditon_frames_sample_index is not None and num_frames > 1:
444
+ ip_hidden_states = batch_index_select(
445
+ encoder_hidden_states,
446
+ dim=1,
447
+ index=vision_conditon_frames_sample_index,
448
+ ).contiguous()
449
+ if self.print_idx == 0:
450
+ logger.debug(
451
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
452
+ )
453
+ #
454
+ ip_hidden_states = rearrange(
455
+ ip_hidden_states, "b t hw c -> b 1 (t hw) c"
456
+ )
457
+ ip_hidden_states = align_repeat_tensor_single_dim(
458
+ ip_hidden_states,
459
+ dim=1,
460
+ target_length=num_frames,
461
+ )
462
+ # b t hw c -> b t hw + hw c
463
+ if self.print_idx == 0:
464
+ logger.debug(
465
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
466
+ )
467
+ encoder_hidden_states = torch.concat(
468
+ [encoder_hidden_states, ip_hidden_states], dim=2
469
+ )
470
+ if self.print_idx == 0:
471
+ logger.debug(
472
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
473
+ )
474
+ # if False:
475
+ if refer_emb is not None: # and num_frames > 1:
476
+ refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c")
477
+ refer_emb = align_repeat_tensor_single_dim(
478
+ refer_emb, target_length=num_frames, dim=1
479
+ )
480
+ if self.print_idx == 0:
481
+ logger.debug(
482
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
483
+ )
484
+ encoder_hidden_states = torch.concat(
485
+ [encoder_hidden_states, refer_emb], dim=2
486
+ )
487
+ if self.print_idx == 0:
488
+ logger.debug(
489
+ f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
490
+ )
491
+ encoder_hidden_states = rearrange(
492
+ encoder_hidden_states, "b t hw c -> (b t) hw c"
493
+ )
494
+ # vision_cond in same unet attn end
495
+
496
+ if attn.group_norm is not None:
497
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
498
+ 1, 2
499
+ )
500
+
501
+ query = attn.to_q(hidden_states, scale=scale)
502
+
503
+ if encoder_hidden_states is None:
504
+ encoder_hidden_states = hidden_states
505
+ elif attn.norm_cross:
506
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
507
+ encoder_hidden_states
508
+ )
509
+ encoder_hidden_states = align_repeat_tensor_single_dim(
510
+ encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
511
+ )
512
+ key = attn.to_k(encoder_hidden_states, scale=scale)
513
+ value = attn.to_v(encoder_hidden_states, scale=scale)
514
+
515
+ query = attn.head_to_batch_dim(query).contiguous()
516
+ key = attn.head_to_batch_dim(key).contiguous()
517
+ value = attn.head_to_batch_dim(value).contiguous()
518
+
519
+ hidden_states = xformers.ops.memory_efficient_attention(
520
+ query,
521
+ key,
522
+ value,
523
+ attn_bias=attention_mask,
524
+ op=self.attention_op,
525
+ scale=attn.scale,
526
+ )
527
+ hidden_states = hidden_states.to(query.dtype)
528
+ hidden_states = attn.batch_to_head_dim(hidden_states)
529
+
530
+ # linear proj
531
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
532
+ # dropout
533
+ hidden_states = attn.to_out[1](hidden_states)
534
+
535
+ if input_ndim == 4:
536
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
537
+ batch_size, channel, height, width
538
+ )
539
+
540
+ if attn.residual_connection:
541
+ hidden_states = hidden_states + residual
542
+
543
+ hidden_states = hidden_states / attn.rescale_output_factor
544
+ self.print_idx += 1
545
+
546
+ return hidden_states
547
+
548
+
549
+ @Model_Register.register
550
+ class NonParamReferenceIPXFormersAttnProcessor(
551
+ NonParamT2ISelfReferenceXFormersAttnProcessor
552
+ ):
553
+ def __init__(self, attention_op: Callable[..., Any] | None = None):
554
+ super().__init__(attention_op)
555
+
556
+
557
+ @maybe_allow_in_graph
558
+ class ReferEmbFuseAttention(IPAttention):
559
+ """使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中
560
+ # TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合
561
+ residual_connection: bool = True, 默认, 从不产生影响开始学习
562
+
563
+ use attention to fuse referencenet emb into unet latents
564
+ # TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c
565
+ residual_connection: bool = True, default, start from no effect
566
+
567
+ Args:
568
+ IPAttention (_type_): _description_
569
+ """
570
+
571
+ print_idx = 0
572
+
573
+ def __init__(
574
+ self,
575
+ query_dim: int,
576
+ cross_attention_dim: int | None = None,
577
+ heads: int = 8,
578
+ dim_head: int = 64,
579
+ dropout: float = 0,
580
+ bias=False,
581
+ upcast_attention: bool = False,
582
+ upcast_softmax: bool = False,
583
+ cross_attention_norm: str | None = None,
584
+ cross_attention_norm_num_groups: int = 32,
585
+ added_kv_proj_dim: int | None = None,
586
+ norm_num_groups: int | None = None,
587
+ spatial_norm_dim: int | None = None,
588
+ out_bias: bool = True,
589
+ scale_qk: bool = True,
590
+ only_cross_attention: bool = False,
591
+ eps: float = 0.00001,
592
+ rescale_output_factor: float = 1,
593
+ residual_connection: bool = True,
594
+ _from_deprecated_attn_block=False,
595
+ processor: AttnProcessor | None = None,
596
+ cross_attn_temporal_cond: bool = False,
597
+ image_scale: float = 1,
598
+ ):
599
+ super().__init__(
600
+ query_dim,
601
+ cross_attention_dim,
602
+ heads,
603
+ dim_head,
604
+ dropout,
605
+ bias,
606
+ upcast_attention,
607
+ upcast_softmax,
608
+ cross_attention_norm,
609
+ cross_attention_norm_num_groups,
610
+ added_kv_proj_dim,
611
+ norm_num_groups,
612
+ spatial_norm_dim,
613
+ out_bias,
614
+ scale_qk,
615
+ only_cross_attention,
616
+ eps,
617
+ rescale_output_factor,
618
+ residual_connection,
619
+ _from_deprecated_attn_block,
620
+ processor,
621
+ cross_attn_temporal_cond,
622
+ image_scale,
623
+ )
624
+ self.processor = None
625
+ # 配合residual,使一开始不影响之前结果
626
+ nn.init.zeros_(self.to_out[0].weight)
627
+ nn.init.zeros_(self.to_out[0].bias)
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.FloatTensor,
632
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
633
+ attention_mask: Optional[torch.FloatTensor] = None,
634
+ temb: Optional[torch.FloatTensor] = None,
635
+ scale: float = 1.0,
636
+ num_frames: int = None,
637
+ ) -> torch.Tensor:
638
+ """fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn
639
+ refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor
640
+
641
+ Args:
642
+ hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1
643
+ encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None.
644
+ attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
645
+ temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
646
+ scale (float, optional): _description_. Defaults to 1.0.
647
+ num_frames (int, optional): _description_. Defaults to None.
648
+
649
+ Returns:
650
+ torch.Tensor: _description_
651
+ """
652
+ residual = hidden_states
653
+ # start
654
+ hidden_states = rearrange(
655
+ hidden_states, "(b t) c h w -> b c t h w", t=num_frames
656
+ )
657
+ batch_size, channel, t1, height, width = hidden_states.shape
658
+ if self.print_idx == 0:
659
+ logger.debug(
660
+ f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}"
661
+ )
662
+ # concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c
663
+ encoder_hidden_states = rearrange(
664
+ encoder_hidden_states, " b c t2 h w-> b (t2 h w) c"
665
+ )
666
+ encoder_hidden_states = repeat(
667
+ encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1
668
+ )
669
+ hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c")
670
+ # bt (t2+1)hw d
671
+ encoder_hidden_states = torch.concat(
672
+ [encoder_hidden_states, hidden_states], dim=1
673
+ )
674
+ # encoder_hidden_states = align_repeat_tensor_single_dim(
675
+ # encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
676
+ # )
677
+ # end
678
+
679
+ if self.spatial_norm is not None:
680
+ hidden_states = self.spatial_norm(hidden_states, temb)
681
+
682
+ _, key_tokens, _ = (
683
+ hidden_states.shape
684
+ if encoder_hidden_states is None
685
+ else encoder_hidden_states.shape
686
+ )
687
+
688
+ attention_mask = self.prepare_attention_mask(
689
+ attention_mask, key_tokens, batch_size
690
+ )
691
+ if attention_mask is not None:
692
+ # expand our mask's singleton query_tokens dimension:
693
+ # [batch*heads, 1, key_tokens] ->
694
+ # [batch*heads, query_tokens, key_tokens]
695
+ # so that it can be added as a bias onto the attention scores that xformers computes:
696
+ # [batch*heads, query_tokens, key_tokens]
697
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
698
+ _, query_tokens, _ = hidden_states.shape
699
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
700
+
701
+ if self.group_norm is not None:
702
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
703
+ 1, 2
704
+ )
705
+
706
+ query = self.to_q(hidden_states, scale=scale)
707
+
708
+ if encoder_hidden_states is None:
709
+ encoder_hidden_states = hidden_states
710
+ elif self.norm_cross:
711
+ encoder_hidden_states = self.norm_encoder_hidden_states(
712
+ encoder_hidden_states
713
+ )
714
+
715
+ key = self.to_k(encoder_hidden_states, scale=scale)
716
+ value = self.to_v(encoder_hidden_states, scale=scale)
717
+
718
+ query = self.head_to_batch_dim(query).contiguous()
719
+ key = self.head_to_batch_dim(key).contiguous()
720
+ value = self.head_to_batch_dim(value).contiguous()
721
+
722
+ # query: b t hw d
723
+ # key/value: bt (t1+1)hw d
724
+ hidden_states = xformers.ops.memory_efficient_attention(
725
+ query,
726
+ key,
727
+ value,
728
+ attn_bias=attention_mask,
729
+ scale=self.scale,
730
+ )
731
+ hidden_states = hidden_states.to(query.dtype)
732
+ hidden_states = self.batch_to_head_dim(hidden_states)
733
+
734
+ # linear proj
735
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
736
+ # dropout
737
+ hidden_states = self.to_out[1](hidden_states)
738
+
739
+ hidden_states = rearrange(
740
+ hidden_states,
741
+ "bt (h w) c-> bt c h w",
742
+ h=height,
743
+ w=width,
744
+ )
745
+ if self.residual_connection:
746
+ hidden_states = hidden_states + residual
747
+
748
+ hidden_states = hidden_states / self.rescale_output_factor
749
+ self.print_idx += 1
750
+ return hidden_states
musev/models/controlnet.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2
+ import warnings
3
+ import os
4
+
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ import PIL
10
+ from einops import rearrange, repeat
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.init as init
14
+ from diffusers.models.controlnet import ControlNetModel
15
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
16
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
17
+ from diffusers.utils.torch_utils import is_compiled_module
18
+
19
+
20
+ class ControlnetPredictor(object):
21
+ def __init__(self, controlnet_model_path: str, *args, **kwargs):
22
+ """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取
23
+ Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training
24
+ Args:
25
+ controlnet_model_path (str): controlnet 模型路径. controlnet model path.
26
+ """
27
+ super(ControlnetPredictor, self).__init__(*args, **kwargs)
28
+ self.controlnet = ControlNetModel.from_pretrained(
29
+ controlnet_model_path,
30
+ )
31
+
32
+ def prepare_image(
33
+ self,
34
+ image, # b c t h w
35
+ width,
36
+ height,
37
+ batch_size,
38
+ num_images_per_prompt,
39
+ device,
40
+ dtype,
41
+ do_classifier_free_guidance=False,
42
+ guess_mode=False,
43
+ ):
44
+ if height is None:
45
+ height = image.shape[-2]
46
+ if width is None:
47
+ width = image.shape[-1]
48
+ width, height = (
49
+ x - x % self.control_image_processor.vae_scale_factor
50
+ for x in (width, height)
51
+ )
52
+ image = rearrange(image, "b c t h w-> (b t) c h w")
53
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
54
+
55
+ image = (
56
+ torch.nn.functional.interpolate(
57
+ image,
58
+ size=(height, width),
59
+ mode="bilinear",
60
+ ),
61
+ )
62
+
63
+ do_normalize = self.control_image_processor.config.do_normalize
64
+ if image.min() < 0:
65
+ warnings.warn(
66
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
67
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
68
+ FutureWarning,
69
+ )
70
+ do_normalize = False
71
+
72
+ if do_normalize:
73
+ image = self.control_image_processor.normalize(image)
74
+
75
+ image_batch_size = image.shape[0]
76
+
77
+ if image_batch_size == 1:
78
+ repeat_by = batch_size
79
+ else:
80
+ # image batch size is the same as prompt batch size
81
+ repeat_by = num_images_per_prompt
82
+
83
+ image = image.repeat_interleave(repeat_by, dim=0)
84
+
85
+ image = image.to(device=device, dtype=dtype)
86
+
87
+ if do_classifier_free_guidance and not guess_mode:
88
+ image = torch.cat([image] * 2)
89
+
90
+ return image
91
+
92
+ @torch.no_grad()
93
+ def __call__(
94
+ self,
95
+ batch_size: int,
96
+ device: str,
97
+ dtype: torch.dtype,
98
+ timesteps: List[float],
99
+ i: int,
100
+ scheduler: KarrasDiffusionSchedulers,
101
+ prompt_embeds: torch.Tensor,
102
+ do_classifier_free_guidance: bool = False,
103
+ # 2b co t ho wo
104
+ latent_model_input: torch.Tensor = None,
105
+ # b co t ho wo
106
+ latents: torch.Tensor = None,
107
+ # b c t h w
108
+ image: Union[
109
+ torch.FloatTensor,
110
+ PIL.Image.Image,
111
+ np.ndarray,
112
+ List[torch.FloatTensor],
113
+ List[PIL.Image.Image],
114
+ List[np.ndarray],
115
+ ] = None,
116
+ # b c t(1) hi wi
117
+ controlnet_condition_frames: Optional[torch.FloatTensor] = None,
118
+ # b c t ho wo
119
+ controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None,
120
+ # b c t(1) ho wo
121
+ controlnet_condition_latents: Optional[torch.FloatTensor] = None,
122
+ height: Optional[int] = None,
123
+ width: Optional[int] = None,
124
+ num_videos_per_prompt: Optional[int] = 1,
125
+ return_dict: bool = True,
126
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
127
+ guess_mode: bool = False,
128
+ control_guidance_start: Union[float, List[float]] = 0.0,
129
+ control_guidance_end: Union[float, List[float]] = 1.0,
130
+ latent_index: torch.LongTensor = None,
131
+ vision_condition_latent_index: torch.LongTensor = None,
132
+ **kwargs,
133
+ ):
134
+ assert (
135
+ image is None and controlnet_latents is None
136
+ ), "should set one of image and controlnet_latents"
137
+
138
+ controlnet = (
139
+ self.controlnet._orig_mod
140
+ if is_compiled_module(self.controlnet)
141
+ else self.controlnet
142
+ )
143
+
144
+ # align format for control guidance
145
+ if not isinstance(control_guidance_start, list) and isinstance(
146
+ control_guidance_end, list
147
+ ):
148
+ control_guidance_start = len(control_guidance_end) * [
149
+ control_guidance_start
150
+ ]
151
+ elif not isinstance(control_guidance_end, list) and isinstance(
152
+ control_guidance_start, list
153
+ ):
154
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
155
+ elif not isinstance(control_guidance_start, list) and not isinstance(
156
+ control_guidance_end, list
157
+ ):
158
+ mult = (
159
+ len(controlnet.nets)
160
+ if isinstance(controlnet, MultiControlNetModel)
161
+ else 1
162
+ )
163
+ control_guidance_start, control_guidance_end = mult * [
164
+ control_guidance_start
165
+ ], mult * [control_guidance_end]
166
+
167
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
168
+ controlnet_conditioning_scale, float
169
+ ):
170
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
171
+ controlnet.nets
172
+ )
173
+
174
+ global_pool_conditions = (
175
+ controlnet.config.global_pool_conditions
176
+ if isinstance(controlnet, ControlNetModel)
177
+ else controlnet.nets[0].config.global_pool_conditions
178
+ )
179
+ guess_mode = guess_mode or global_pool_conditions
180
+
181
+ # 4. Prepare image
182
+ if isinstance(controlnet, ControlNetModel):
183
+ if (
184
+ controlnet_latents is not None
185
+ and controlnet_condition_latents is not None
186
+ ):
187
+ if isinstance(controlnet_latents, np.ndarray):
188
+ controlnet_latents = torch.from_numpy(controlnet_latents)
189
+ if isinstance(controlnet_condition_latents, np.ndarray):
190
+ controlnet_condition_latents = torch.from_numpy(
191
+ controlnet_condition_latents
192
+ )
193
+ # TODO:使用index进行concat
194
+ controlnet_latents = torch.concat(
195
+ [controlnet_condition_latents, controlnet_latents], dim=2
196
+ )
197
+ if not guess_mode and do_classifier_free_guidance:
198
+ controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0)
199
+ controlnet_latents = rearrange(
200
+ controlnet_latents, "b c t h w->(b t) c h w"
201
+ )
202
+ controlnet_latents = controlnet_latents.to(device=device, dtype=dtype)
203
+ else:
204
+ # TODO:使用index进行concat
205
+ # TODO: concat with index
206
+ if controlnet_condition_frames is not None:
207
+ if isinstance(controlnet_condition_frames, np.ndarray):
208
+ image = np.concatenate(
209
+ [controlnet_condition_frames, image], axis=2
210
+ )
211
+ image = self.prepare_image(
212
+ image=image,
213
+ width=width,
214
+ height=height,
215
+ batch_size=batch_size * num_videos_per_prompt,
216
+ num_images_per_prompt=num_videos_per_prompt,
217
+ device=device,
218
+ dtype=controlnet.dtype,
219
+ do_classifier_free_guidance=do_classifier_free_guidance,
220
+ guess_mode=guess_mode,
221
+ )
222
+ height, width = image.shape[-2:]
223
+ elif isinstance(controlnet, MultiControlNetModel):
224
+ images = []
225
+ # TODO: 支持直接使用controlnet_latent而不是frames
226
+ # TODO: support using controlnet_latent directly instead of frames
227
+ if controlnet_latents is not None:
228
+ raise NotImplementedError
229
+ else:
230
+ for i, image_ in enumerate(image):
231
+ if controlnet_condition_frames is not None and isinstance(
232
+ controlnet_condition_frames, list
233
+ ):
234
+ if isinstance(controlnet_condition_frames[i], np.ndarray):
235
+ image_ = np.concatenate(
236
+ [controlnet_condition_frames[i], image_], axis=2
237
+ )
238
+ image_ = self.prepare_image(
239
+ image=image_,
240
+ width=width,
241
+ height=height,
242
+ batch_size=batch_size * num_videos_per_prompt,
243
+ num_images_per_prompt=num_videos_per_prompt,
244
+ device=device,
245
+ dtype=controlnet.dtype,
246
+ do_classifier_free_guidance=do_classifier_free_guidance,
247
+ guess_mode=guess_mode,
248
+ )
249
+
250
+ images.append(image_)
251
+
252
+ image = images
253
+ height, width = image[0].shape[-2:]
254
+ else:
255
+ assert False
256
+
257
+ # 7.1 Create tensor stating which controlnets to keep
258
+ controlnet_keep = []
259
+ for i in range(len(timesteps)):
260
+ keeps = [
261
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
262
+ for s, e in zip(control_guidance_start, control_guidance_end)
263
+ ]
264
+ controlnet_keep.append(
265
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
266
+ )
267
+
268
+ t = timesteps[i]
269
+
270
+ # controlnet(s) inference
271
+ if guess_mode and do_classifier_free_guidance:
272
+ # Infer ControlNet only for the conditional batch.
273
+ control_model_input = latents
274
+ control_model_input = scheduler.scale_model_input(control_model_input, t)
275
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
276
+ else:
277
+ control_model_input = latent_model_input
278
+ controlnet_prompt_embeds = prompt_embeds
279
+ if isinstance(controlnet_keep[i], list):
280
+ cond_scale = [
281
+ c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
282
+ ]
283
+ else:
284
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
285
+ control_model_input_reshape = rearrange(
286
+ control_model_input, "b c t h w -> (b t) c h w"
287
+ )
288
+ encoder_hidden_states_repeat = repeat(
289
+ controlnet_prompt_embeds,
290
+ "b n q->(b t) n q",
291
+ t=control_model_input.shape[2],
292
+ )
293
+
294
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
295
+ control_model_input_reshape,
296
+ t,
297
+ encoder_hidden_states_repeat,
298
+ controlnet_cond=image,
299
+ controlnet_cond_latents=controlnet_latents,
300
+ conditioning_scale=cond_scale,
301
+ guess_mode=guess_mode,
302
+ return_dict=False,
303
+ )
304
+
305
+ return down_block_res_samples, mid_block_res_sample
306
+
307
+
308
+ class InflatedConv3d(nn.Conv2d):
309
+ def forward(self, x):
310
+ video_length = x.shape[2]
311
+
312
+ x = rearrange(x, "b c f h w -> (b f) c h w")
313
+ x = super().forward(x)
314
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
315
+
316
+ return x
317
+
318
+
319
+ def zero_module(module):
320
+ # Zero out the parameters of a module and return it.
321
+ for p in module.parameters():
322
+ p.detach().zero_()
323
+ return module
324
+
325
+
326
+ class PoseGuider(ModelMixin):
327
+ def __init__(
328
+ self,
329
+ conditioning_embedding_channels: int,
330
+ conditioning_channels: int = 3,
331
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
332
+ ):
333
+ super().__init__()
334
+ self.conv_in = InflatedConv3d(
335
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
336
+ )
337
+
338
+ self.blocks = nn.ModuleList([])
339
+
340
+ for i in range(len(block_out_channels) - 1):
341
+ channel_in = block_out_channels[i]
342
+ channel_out = block_out_channels[i + 1]
343
+ self.blocks.append(
344
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
345
+ )
346
+ self.blocks.append(
347
+ InflatedConv3d(
348
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
349
+ )
350
+ )
351
+
352
+ self.conv_out = zero_module(
353
+ InflatedConv3d(
354
+ block_out_channels[-1],
355
+ conditioning_embedding_channels,
356
+ kernel_size=3,
357
+ padding=1,
358
+ )
359
+ )
360
+
361
+ def forward(self, conditioning):
362
+ embedding = self.conv_in(conditioning)
363
+ embedding = F.silu(embedding)
364
+
365
+ for block in self.blocks:
366
+ embedding = block(embedding)
367
+ embedding = F.silu(embedding)
368
+
369
+ embedding = self.conv_out(embedding)
370
+
371
+ return embedding
372
+
373
+ @classmethod
374
+ def from_pretrained(
375
+ cls,
376
+ pretrained_model_path,
377
+ conditioning_embedding_channels: int,
378
+ conditioning_channels: int = 3,
379
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
380
+ ):
381
+ if not os.path.exists(pretrained_model_path):
382
+ print(f"There is no model file in {pretrained_model_path}")
383
+ print(
384
+ f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..."
385
+ )
386
+
387
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
388
+ model = PoseGuider(
389
+ conditioning_embedding_channels=conditioning_embedding_channels,
390
+ conditioning_channels=conditioning_channels,
391
+ block_out_channels=block_out_channels,
392
+ )
393
+
394
+ m, u = model.load_state_dict(state_dict, strict=False)
395
+ # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
396
+ params = [p.numel() for n, p in model.named_parameters()]
397
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
398
+
399
+ return model
musev/models/embeddings.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from einops import rearrange
16
+ import torch
17
+ from torch.nn import functional as F
18
+ import numpy as np
19
+
20
+ from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid
21
+
22
+
23
+ # ref diffusers.models.embeddings.get_2d_sincos_pos_embed
24
+ def get_2d_sincos_pos_embed(
25
+ embed_dim,
26
+ grid_size_w,
27
+ grid_size_h,
28
+ cls_token=False,
29
+ extra_tokens=0,
30
+ norm_length: bool = False,
31
+ max_length: float = 2048,
32
+ ):
33
+ """
34
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
35
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
36
+ """
37
+ if norm_length and grid_size_h <= max_length and grid_size_w <= max_length:
38
+ grid_h = np.linspace(0, max_length, grid_size_h)
39
+ grid_w = np.linspace(0, max_length, grid_size_w)
40
+ else:
41
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
42
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
43
+ grid = np.meshgrid(grid_h, grid_w) # here h goes first
44
+ grid = np.stack(grid, axis=0)
45
+
46
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
47
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
48
+ if cls_token and extra_tokens > 0:
49
+ pos_embed = np.concatenate(
50
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
51
+ )
52
+ return pos_embed
53
+
54
+
55
+ def resize_spatial_position_emb(
56
+ emb: torch.Tensor,
57
+ height: int,
58
+ width: int,
59
+ scale: float = None,
60
+ target_height: int = None,
61
+ target_width: int = None,
62
+ ) -> torch.Tensor:
63
+ """_summary_
64
+
65
+ Args:
66
+ emb (torch.Tensor): b ( h w) d
67
+ height (int): _description_
68
+ width (int): _description_
69
+ scale (float, optional): _description_. Defaults to None.
70
+ target_height (int, optional): _description_. Defaults to None.
71
+ target_width (int, optional): _description_. Defaults to None.
72
+
73
+ Returns:
74
+ torch.Tensor: b (target_height target_width) d
75
+ """
76
+ if scale is not None:
77
+ target_height = int(height * scale)
78
+ target_width = int(width * scale)
79
+ emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1)
80
+ emb = F.interpolate(
81
+ emb,
82
+ size=(target_height, target_width),
83
+ mode="bicubic",
84
+ align_corners=False,
85
+ )
86
+ emb = rearrange(emb, "b d h w-> (h w) (b d)")
87
+ return emb
musev/models/facein_loader.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
37
+ ImageClipVisionFeatureExtractor,
38
+ ImageClipVisionFeatureExtractorV2,
39
+ )
40
+ from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor
41
+
42
+ from ip_adapter.resampler import Resampler
43
+ from ip_adapter.ip_adapter import ImageProjModel
44
+
45
+ from .unet_loader import update_unet_with_sd
46
+ from .unet_3d_condition import UNet3DConditionModel
47
+ from .ip_adapter_loader import ip_adapter_keys_list
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
53
+ unet_keys_list = [
54
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
55
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
56
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
57
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
58
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
59
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
60
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
61
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
62
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
63
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
64
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
65
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
66
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
67
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
68
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
69
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
70
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
71
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
72
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
73
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
74
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
75
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
76
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
77
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
78
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
79
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
80
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
81
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
82
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
83
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
84
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight",
85
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight",
86
+ ]
87
+
88
+
89
+ UNET2IPAadapter_Keys_MAPIING = {
90
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
91
+ }
92
+
93
+
94
+ def load_facein_extractor_and_proj_by_name(
95
+ model_name: str,
96
+ ip_ckpt: Tuple[str, nn.Module],
97
+ ip_image_encoder: Tuple[str, nn.Module] = None,
98
+ cross_attention_dim: int = 768,
99
+ clip_embeddings_dim: int = 512,
100
+ clip_extra_context_tokens: int = 1,
101
+ ip_scale: float = 0.0,
102
+ dtype: torch.dtype = torch.float16,
103
+ device: str = "cuda",
104
+ unet: nn.Module = None,
105
+ ) -> nn.Module:
106
+ pass
107
+
108
+
109
+ def update_unet_facein_cross_attn_param(
110
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
111
+ ) -> None:
112
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
113
+ ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典
114
+
115
+
116
+ Args:
117
+ unet (UNet3DConditionModel): _description_
118
+ ip_adapter_state_dict (Dict): _description_
119
+ """
120
+ pass
musev/models/ip_adapter_face_loader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from ip_adapter.resampler import Resampler
37
+ from ip_adapter.ip_adapter import ImageProjModel
38
+ from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel
39
+
40
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
41
+ ImageClipVisionFeatureExtractor,
42
+ ImageClipVisionFeatureExtractorV2,
43
+ )
44
+ from mmcm.vision.feature_extractor.insight_face_extractor import (
45
+ InsightFaceExtractorNormEmb,
46
+ )
47
+
48
+
49
+ from .unet_loader import update_unet_with_sd
50
+ from .unet_3d_condition import UNet3DConditionModel
51
+ from .ip_adapter_loader import ip_adapter_keys_list
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
57
+ unet_keys_list = [
58
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
59
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
60
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
61
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
62
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
63
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
64
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
65
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
66
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
67
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
68
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
69
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
70
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
71
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
72
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
73
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
74
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
75
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
76
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
77
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
78
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
79
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
80
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
81
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
82
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
83
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
84
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
85
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
86
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
87
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
88
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight",
89
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight",
90
+ ]
91
+
92
+
93
+ UNET2IPAadapter_Keys_MAPIING = {
94
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
95
+ }
96
+
97
+
98
+ def load_ip_adapter_face_extractor_and_proj_by_name(
99
+ model_name: str,
100
+ ip_ckpt: Tuple[str, nn.Module],
101
+ ip_image_encoder: Tuple[str, nn.Module] = None,
102
+ cross_attention_dim: int = 768,
103
+ clip_embeddings_dim: int = 1024,
104
+ clip_extra_context_tokens: int = 4,
105
+ ip_scale: float = 0.0,
106
+ dtype: torch.dtype = torch.float16,
107
+ device: str = "cuda",
108
+ unet: nn.Module = None,
109
+ ) -> nn.Module:
110
+ if model_name == "IPAdapterFaceID":
111
+ if ip_image_encoder is not None:
112
+ ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb(
113
+ pretrained_model_name_or_path=ip_image_encoder,
114
+ dtype=dtype,
115
+ device=device,
116
+ )
117
+ else:
118
+ ip_adapter_face_emb_extractor = None
119
+ ip_adapter_image_proj = MLPProjModel(
120
+ cross_attention_dim=cross_attention_dim,
121
+ id_embeddings_dim=clip_embeddings_dim,
122
+ num_tokens=clip_extra_context_tokens,
123
+ ).to(device, dtype=dtype)
124
+ else:
125
+ raise ValueError(
126
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID"
127
+ )
128
+ ip_adapter_state_dict = torch.load(
129
+ ip_ckpt,
130
+ map_location="cpu",
131
+ )
132
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
133
+ if unet is not None and "ip_adapter" in ip_adapter_state_dict:
134
+ update_unet_ip_adapter_cross_attn_param(
135
+ unet,
136
+ ip_adapter_state_dict["ip_adapter"],
137
+ )
138
+ logger.info(
139
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
140
+ )
141
+ return (
142
+ ip_adapter_face_emb_extractor,
143
+ ip_adapter_image_proj,
144
+ )
145
+
146
+
147
+ def update_unet_ip_adapter_cross_attn_param(
148
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
149
+ ) -> None:
150
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
151
+ ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
152
+
153
+
154
+ Args:
155
+ unet (UNet3DConditionModel): _description_
156
+ ip_adapter_state_dict (Dict): _description_
157
+ """
158
+ unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
159
+ unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
160
+ for i, (unet_key_more, ip_adapter_key) in enumerate(
161
+ UNET2IPAadapter_Keys_MAPIING.items()
162
+ ):
163
+ ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
164
+ unet_key_more_spit = unet_key_more.split(".")
165
+ unet_key = ".".join(unet_key_more_spit[:-3])
166
+ suffix = ".".join(unet_key_more_spit[-3:])
167
+ logger.debug(
168
+ f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
169
+ )
170
+ if ".ip_adapter_face_to_k" in suffix:
171
+ with torch.no_grad():
172
+ unet_spatial_cross_atnns_dct[
173
+ unet_key
174
+ ].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data)
175
+ else:
176
+ with torch.no_grad():
177
+ unet_spatial_cross_atnns_dct[
178
+ unet_key
179
+ ].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data)
musev/models/ip_adapter_loader.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from mmcm.vision.feature_extractor import clip_vision_extractor
37
+ from mmcm.vision.feature_extractor.clip_vision_extractor import (
38
+ ImageClipVisionFeatureExtractor,
39
+ ImageClipVisionFeatureExtractorV2,
40
+ VerstailSDLastHiddenState2ImageEmb,
41
+ )
42
+
43
+ from ip_adapter.resampler import Resampler
44
+ from ip_adapter.ip_adapter import ImageProjModel
45
+
46
+ from .unet_loader import update_unet_with_sd
47
+ from .unet_3d_condition import UNet3DConditionModel
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ def load_vision_clip_encoder_by_name(
53
+ ip_image_encoder: Tuple[str, nn.Module] = None,
54
+ dtype: torch.dtype = torch.float16,
55
+ device: str = "cuda",
56
+ vision_clip_extractor_class_name: str = None,
57
+ ) -> nn.Module:
58
+ if vision_clip_extractor_class_name is not None:
59
+ vision_clip_extractor = getattr(
60
+ clip_vision_extractor, vision_clip_extractor_class_name
61
+ )(
62
+ pretrained_model_name_or_path=ip_image_encoder,
63
+ dtype=dtype,
64
+ device=device,
65
+ )
66
+ else:
67
+ vision_clip_extractor = None
68
+ return vision_clip_extractor
69
+
70
+
71
+ def load_ip_adapter_image_proj_by_name(
72
+ model_name: str,
73
+ ip_ckpt: Tuple[str, nn.Module] = None,
74
+ cross_attention_dim: int = 768,
75
+ clip_embeddings_dim: int = 1024,
76
+ clip_extra_context_tokens: int = 4,
77
+ ip_scale: float = 0.0,
78
+ dtype: torch.dtype = torch.float16,
79
+ device: str = "cuda",
80
+ unet: nn.Module = None,
81
+ vision_clip_extractor_class_name: str = None,
82
+ ip_image_encoder: Tuple[str, nn.Module] = None,
83
+ ) -> nn.Module:
84
+ if model_name in [
85
+ "IPAdapter",
86
+ "musev_referencenet",
87
+ "musev_referencenet_pose",
88
+ ]:
89
+ ip_adapter_image_proj = ImageProjModel(
90
+ cross_attention_dim=cross_attention_dim,
91
+ clip_embeddings_dim=clip_embeddings_dim,
92
+ clip_extra_context_tokens=clip_extra_context_tokens,
93
+ )
94
+
95
+ elif model_name == "IPAdapterPlus":
96
+ vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
97
+ pretrained_model_name_or_path=ip_image_encoder,
98
+ dtype=dtype,
99
+ device=device,
100
+ )
101
+ ip_adapter_image_proj = Resampler(
102
+ dim=cross_attention_dim,
103
+ depth=4,
104
+ dim_head=64,
105
+ heads=12,
106
+ num_queries=clip_extra_context_tokens,
107
+ embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
108
+ output_dim=cross_attention_dim,
109
+ ff_mult=4,
110
+ )
111
+ elif model_name in [
112
+ "VerstailSDLastHiddenState2ImageEmb",
113
+ "OriginLastHiddenState2ImageEmbd",
114
+ "OriginLastHiddenState2Poolout",
115
+ ]:
116
+ ip_adapter_image_proj = getattr(
117
+ clip_vision_extractor, model_name
118
+ ).from_pretrained(ip_image_encoder)
119
+ else:
120
+ raise ValueError(
121
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb"
122
+ )
123
+ if ip_ckpt is not None:
124
+ ip_adapter_state_dict = torch.load(
125
+ ip_ckpt,
126
+ map_location="cpu",
127
+ )
128
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
129
+ if (
130
+ unet is not None
131
+ and unet.ip_adapter_cross_attn
132
+ and "ip_adapter" in ip_adapter_state_dict
133
+ ):
134
+ update_unet_ip_adapter_cross_attn_param(
135
+ unet, ip_adapter_state_dict["ip_adapter"]
136
+ )
137
+ logger.info(
138
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
139
+ )
140
+ return ip_adapter_image_proj
141
+
142
+
143
+ def load_ip_adapter_vision_clip_encoder_by_name(
144
+ model_name: str,
145
+ ip_ckpt: Tuple[str, nn.Module],
146
+ ip_image_encoder: Tuple[str, nn.Module] = None,
147
+ cross_attention_dim: int = 768,
148
+ clip_embeddings_dim: int = 1024,
149
+ clip_extra_context_tokens: int = 4,
150
+ ip_scale: float = 0.0,
151
+ dtype: torch.dtype = torch.float16,
152
+ device: str = "cuda",
153
+ unet: nn.Module = None,
154
+ vision_clip_extractor_class_name: str = None,
155
+ ) -> nn.Module:
156
+ if vision_clip_extractor_class_name is not None:
157
+ vision_clip_extractor = getattr(
158
+ clip_vision_extractor, vision_clip_extractor_class_name
159
+ )(
160
+ pretrained_model_name_or_path=ip_image_encoder,
161
+ dtype=dtype,
162
+ device=device,
163
+ )
164
+ else:
165
+ vision_clip_extractor = None
166
+ if model_name in [
167
+ "IPAdapter",
168
+ "musev_referencenet",
169
+ ]:
170
+ if ip_image_encoder is not None:
171
+ if vision_clip_extractor_class_name is None:
172
+ vision_clip_extractor = ImageClipVisionFeatureExtractor(
173
+ pretrained_model_name_or_path=ip_image_encoder,
174
+ dtype=dtype,
175
+ device=device,
176
+ )
177
+ else:
178
+ vision_clip_extractor = None
179
+ ip_adapter_image_proj = ImageProjModel(
180
+ cross_attention_dim=cross_attention_dim,
181
+ clip_embeddings_dim=clip_embeddings_dim,
182
+ clip_extra_context_tokens=clip_extra_context_tokens,
183
+ )
184
+
185
+ elif model_name == "IPAdapterPlus":
186
+ if ip_image_encoder is not None:
187
+ if vision_clip_extractor_class_name is None:
188
+ vision_clip_extractor = ImageClipVisionFeatureExtractorV2(
189
+ pretrained_model_name_or_path=ip_image_encoder,
190
+ dtype=dtype,
191
+ device=device,
192
+ )
193
+ else:
194
+ vision_clip_extractor = None
195
+ ip_adapter_image_proj = Resampler(
196
+ dim=cross_attention_dim,
197
+ depth=4,
198
+ dim_head=64,
199
+ heads=12,
200
+ num_queries=clip_extra_context_tokens,
201
+ embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size,
202
+ output_dim=cross_attention_dim,
203
+ ff_mult=4,
204
+ ).to(dtype=torch.float16)
205
+ else:
206
+ raise ValueError(
207
+ f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus"
208
+ )
209
+ ip_adapter_state_dict = torch.load(
210
+ ip_ckpt,
211
+ map_location="cpu",
212
+ )
213
+ ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"])
214
+ if (
215
+ unet is not None
216
+ and unet.ip_adapter_cross_attn
217
+ and "ip_adapter" in ip_adapter_state_dict
218
+ ):
219
+ update_unet_ip_adapter_cross_attn_param(
220
+ unet, ip_adapter_state_dict["ip_adapter"]
221
+ )
222
+ logger.info(
223
+ f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}"
224
+ )
225
+ return (
226
+ vision_clip_extractor,
227
+ ip_adapter_image_proj,
228
+ )
229
+
230
+
231
+ # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651
232
+ unet_keys_list = [
233
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
234
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
235
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
236
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
237
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
238
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
239
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
240
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
241
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
242
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
243
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
244
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
245
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
246
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
247
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
248
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
249
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
250
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
251
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
252
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
253
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
254
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
255
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
256
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
257
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
258
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
259
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight",
260
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight",
261
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight",
262
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight",
263
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight",
264
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight",
265
+ ]
266
+
267
+
268
+ ip_adapter_keys_list = [
269
+ "1.to_k_ip.weight",
270
+ "1.to_v_ip.weight",
271
+ "3.to_k_ip.weight",
272
+ "3.to_v_ip.weight",
273
+ "5.to_k_ip.weight",
274
+ "5.to_v_ip.weight",
275
+ "7.to_k_ip.weight",
276
+ "7.to_v_ip.weight",
277
+ "9.to_k_ip.weight",
278
+ "9.to_v_ip.weight",
279
+ "11.to_k_ip.weight",
280
+ "11.to_v_ip.weight",
281
+ "13.to_k_ip.weight",
282
+ "13.to_v_ip.weight",
283
+ "15.to_k_ip.weight",
284
+ "15.to_v_ip.weight",
285
+ "17.to_k_ip.weight",
286
+ "17.to_v_ip.weight",
287
+ "19.to_k_ip.weight",
288
+ "19.to_v_ip.weight",
289
+ "21.to_k_ip.weight",
290
+ "21.to_v_ip.weight",
291
+ "23.to_k_ip.weight",
292
+ "23.to_v_ip.weight",
293
+ "25.to_k_ip.weight",
294
+ "25.to_v_ip.weight",
295
+ "27.to_k_ip.weight",
296
+ "27.to_v_ip.weight",
297
+ "29.to_k_ip.weight",
298
+ "29.to_v_ip.weight",
299
+ "31.to_k_ip.weight",
300
+ "31.to_v_ip.weight",
301
+ ]
302
+
303
+ UNET2IPAadapter_Keys_MAPIING = {
304
+ k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list)
305
+ }
306
+
307
+
308
+ def update_unet_ip_adapter_cross_attn_param(
309
+ unet: UNet3DConditionModel, ip_adapter_state_dict: Dict
310
+ ) -> None:
311
+ """use independent ip_adapter attn 中的 to_k, to_v in unet
312
+ ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']
313
+
314
+
315
+ Args:
316
+ unet (UNet3DConditionModel): _description_
317
+ ip_adapter_state_dict (Dict): _description_
318
+ """
319
+ unet_spatial_cross_atnns = unet.spatial_cross_attns[0]
320
+ unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns}
321
+ for i, (unet_key_more, ip_adapter_key) in enumerate(
322
+ UNET2IPAadapter_Keys_MAPIING.items()
323
+ ):
324
+ ip_adapter_value = ip_adapter_state_dict[ip_adapter_key]
325
+ unet_key_more_spit = unet_key_more.split(".")
326
+ unet_key = ".".join(unet_key_more_spit[:-3])
327
+ suffix = ".".join(unet_key_more_spit[-3:])
328
+ logger.debug(
329
+ f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}",
330
+ )
331
+ if "to_k" in suffix:
332
+ with torch.no_grad():
333
+ unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_(
334
+ ip_adapter_value.data
335
+ )
336
+ else:
337
+ with torch.no_grad():
338
+ unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_(
339
+ ip_adapter_value.data
340
+ )
musev/models/referencenet.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+ import logging
19
+
20
+ import torch
21
+ from diffusers.models.attention_processor import Attention, AttnProcessor
22
+ from einops import rearrange, repeat
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import xformers
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.unet_2d_condition import (
28
+ UNet2DConditionModel,
29
+ UNet2DConditionOutput,
30
+ )
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.utils.constants import USE_PEFT_BACKEND
33
+ from diffusers.utils.deprecation_utils import deprecate
34
+ from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
37
+ from diffusers.loaders import UNet2DConditionLoadersMixin
38
+ from diffusers.utils import (
39
+ USE_PEFT_BACKEND,
40
+ BaseOutput,
41
+ deprecate,
42
+ scale_lora_layers,
43
+ unscale_lora_layers,
44
+ )
45
+ from diffusers.models.activations import get_activation
46
+ from diffusers.models.attention_processor import (
47
+ ADDED_KV_ATTENTION_PROCESSORS,
48
+ CROSS_ATTENTION_PROCESSORS,
49
+ AttentionProcessor,
50
+ AttnAddedKVProcessor,
51
+ AttnProcessor,
52
+ )
53
+ from diffusers.models.embeddings import (
54
+ GaussianFourierProjection,
55
+ ImageHintTimeEmbedding,
56
+ ImageProjection,
57
+ ImageTimeEmbedding,
58
+ PositionNet,
59
+ TextImageProjection,
60
+ TextImageTimeEmbedding,
61
+ TextTimeEmbedding,
62
+ TimestepEmbedding,
63
+ Timesteps,
64
+ )
65
+ from diffusers.models.modeling_utils import ModelMixin
66
+
67
+
68
+ from ..data.data_util import align_repeat_tensor_single_dim
69
+ from .unet_3d_condition import UNet3DConditionModel
70
+ from .attention import BasicTransformerBlock, IPAttention
71
+ from .unet_2d_blocks import (
72
+ UNetMidBlock2D,
73
+ UNetMidBlock2DCrossAttn,
74
+ UNetMidBlock2DSimpleCrossAttn,
75
+ get_down_block,
76
+ get_up_block,
77
+ )
78
+
79
+ from . import Model_Register
80
+
81
+
82
+ logger = logging.getLogger(__name__)
83
+
84
+
85
+ @Model_Register.register
86
+ class ReferenceNet2D(UNet2DConditionModel, nn.Module):
87
+ """继承 UNet2DConditionModel. 新增功能,类似controlnet 返回模型中间特征,用于后续作用
88
+ Inherit Unet2DConditionModel. Add new functions, similar to controlnet, return the intermediate features of the model for subsequent effects
89
+ Args:
90
+ UNet2DConditionModel (_type_): _description_
91
+ """
92
+
93
+ _supports_gradient_checkpointing = True
94
+ print_idx = 0
95
+
96
+ @register_to_config
97
+ def __init__(
98
+ self,
99
+ sample_size: int | None = None,
100
+ in_channels: int = 4,
101
+ out_channels: int = 4,
102
+ center_input_sample: bool = False,
103
+ flip_sin_to_cos: bool = True,
104
+ freq_shift: int = 0,
105
+ down_block_types: Tuple[str] = (
106
+ "CrossAttnDownBlock2D",
107
+ "CrossAttnDownBlock2D",
108
+ "CrossAttnDownBlock2D",
109
+ "DownBlock2D",
110
+ ),
111
+ mid_block_type: str | None = "UNetMidBlock2DCrossAttn",
112
+ up_block_types: Tuple[str] = (
113
+ "UpBlock2D",
114
+ "CrossAttnUpBlock2D",
115
+ "CrossAttnUpBlock2D",
116
+ "CrossAttnUpBlock2D",
117
+ ),
118
+ only_cross_attention: bool | Tuple[bool] = False,
119
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
120
+ layers_per_block: int | Tuple[int] = 2,
121
+ downsample_padding: int = 1,
122
+ mid_block_scale_factor: float = 1,
123
+ dropout: float = 0,
124
+ act_fn: str = "silu",
125
+ norm_num_groups: int | None = 32,
126
+ norm_eps: float = 0.00001,
127
+ cross_attention_dim: int | Tuple[int] = 1280,
128
+ transformer_layers_per_block: int | Tuple[int] | Tuple[Tuple] = 1,
129
+ reverse_transformer_layers_per_block: Tuple[Tuple[int]] | None = None,
130
+ encoder_hid_dim: int | None = None,
131
+ encoder_hid_dim_type: str | None = None,
132
+ attention_head_dim: int | Tuple[int] = 8,
133
+ num_attention_heads: int | Tuple[int] | None = None,
134
+ dual_cross_attention: bool = False,
135
+ use_linear_projection: bool = False,
136
+ class_embed_type: str | None = None,
137
+ addition_embed_type: str | None = None,
138
+ addition_time_embed_dim: int | None = None,
139
+ num_class_embeds: int | None = None,
140
+ upcast_attention: bool = False,
141
+ resnet_time_scale_shift: str = "default",
142
+ resnet_skip_time_act: bool = False,
143
+ resnet_out_scale_factor: int = 1,
144
+ time_embedding_type: str = "positional",
145
+ time_embedding_dim: int | None = None,
146
+ time_embedding_act_fn: str | None = None,
147
+ timestep_post_act: str | None = None,
148
+ time_cond_proj_dim: int | None = None,
149
+ conv_in_kernel: int = 3,
150
+ conv_out_kernel: int = 3,
151
+ projection_class_embeddings_input_dim: int | None = None,
152
+ attention_type: str = "default",
153
+ class_embeddings_concat: bool = False,
154
+ mid_block_only_cross_attention: bool | None = None,
155
+ cross_attention_norm: str | None = None,
156
+ addition_embed_type_num_heads=64,
157
+ need_self_attn_block_embs: bool = False,
158
+ need_block_embs: bool = False,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.sample_size = sample_size
163
+
164
+ if num_attention_heads is not None:
165
+ raise ValueError(
166
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
167
+ )
168
+
169
+ # If `num_attention_heads` is not defined (which is the case for most models)
170
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
171
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
172
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
173
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
174
+ # which is why we correct for the naming here.
175
+ num_attention_heads = num_attention_heads or attention_head_dim
176
+
177
+ # Check inputs
178
+ if len(down_block_types) != len(up_block_types):
179
+ raise ValueError(
180
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
181
+ )
182
+
183
+ if len(block_out_channels) != len(down_block_types):
184
+ raise ValueError(
185
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
186
+ )
187
+
188
+ if not isinstance(only_cross_attention, bool) and len(
189
+ only_cross_attention
190
+ ) != len(down_block_types):
191
+ raise ValueError(
192
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
193
+ )
194
+
195
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
196
+ down_block_types
197
+ ):
198
+ raise ValueError(
199
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
200
+ )
201
+
202
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
203
+ down_block_types
204
+ ):
205
+ raise ValueError(
206
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
207
+ )
208
+
209
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
210
+ down_block_types
211
+ ):
212
+ raise ValueError(
213
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
214
+ )
215
+
216
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
217
+ down_block_types
218
+ ):
219
+ raise ValueError(
220
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
221
+ )
222
+ if (
223
+ isinstance(transformer_layers_per_block, list)
224
+ and reverse_transformer_layers_per_block is None
225
+ ):
226
+ for layer_number_per_block in transformer_layers_per_block:
227
+ if isinstance(layer_number_per_block, list):
228
+ raise ValueError(
229
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
230
+ )
231
+
232
+ # input
233
+ conv_in_padding = (conv_in_kernel - 1) // 2
234
+ self.conv_in = nn.Conv2d(
235
+ in_channels,
236
+ block_out_channels[0],
237
+ kernel_size=conv_in_kernel,
238
+ padding=conv_in_padding,
239
+ )
240
+
241
+ # time
242
+ if time_embedding_type == "fourier":
243
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
244
+ if time_embed_dim % 2 != 0:
245
+ raise ValueError(
246
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
247
+ )
248
+ self.time_proj = GaussianFourierProjection(
249
+ time_embed_dim // 2,
250
+ set_W_to_weight=False,
251
+ log=False,
252
+ flip_sin_to_cos=flip_sin_to_cos,
253
+ )
254
+ timestep_input_dim = time_embed_dim
255
+ elif time_embedding_type == "positional":
256
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
257
+
258
+ self.time_proj = Timesteps(
259
+ block_out_channels[0], flip_sin_to_cos, freq_shift
260
+ )
261
+ timestep_input_dim = block_out_channels[0]
262
+ else:
263
+ raise ValueError(
264
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
265
+ )
266
+
267
+ self.time_embedding = TimestepEmbedding(
268
+ timestep_input_dim,
269
+ time_embed_dim,
270
+ act_fn=act_fn,
271
+ post_act_fn=timestep_post_act,
272
+ cond_proj_dim=time_cond_proj_dim,
273
+ )
274
+
275
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
276
+ encoder_hid_dim_type = "text_proj"
277
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
278
+ logger.info(
279
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
280
+ )
281
+
282
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
283
+ raise ValueError(
284
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
285
+ )
286
+
287
+ if encoder_hid_dim_type == "text_proj":
288
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
289
+ elif encoder_hid_dim_type == "text_image_proj":
290
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
291
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
292
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
293
+ self.encoder_hid_proj = TextImageProjection(
294
+ text_embed_dim=encoder_hid_dim,
295
+ image_embed_dim=cross_attention_dim,
296
+ cross_attention_dim=cross_attention_dim,
297
+ )
298
+ elif encoder_hid_dim_type == "image_proj":
299
+ # Kandinsky 2.2
300
+ self.encoder_hid_proj = ImageProjection(
301
+ image_embed_dim=encoder_hid_dim,
302
+ cross_attention_dim=cross_attention_dim,
303
+ )
304
+ elif encoder_hid_dim_type is not None:
305
+ raise ValueError(
306
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
307
+ )
308
+ else:
309
+ self.encoder_hid_proj = None
310
+
311
+ # class embedding
312
+ if class_embed_type is None and num_class_embeds is not None:
313
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
314
+ elif class_embed_type == "timestep":
315
+ self.class_embedding = TimestepEmbedding(
316
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
317
+ )
318
+ elif class_embed_type == "identity":
319
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
320
+ elif class_embed_type == "projection":
321
+ if projection_class_embeddings_input_dim is None:
322
+ raise ValueError(
323
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
324
+ )
325
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
326
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
327
+ # 2. it projects from an arbitrary input dimension.
328
+ #
329
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
330
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
331
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
332
+ self.class_embedding = TimestepEmbedding(
333
+ projection_class_embeddings_input_dim, time_embed_dim
334
+ )
335
+ elif class_embed_type == "simple_projection":
336
+ if projection_class_embeddings_input_dim is None:
337
+ raise ValueError(
338
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
339
+ )
340
+ self.class_embedding = nn.Linear(
341
+ projection_class_embeddings_input_dim, time_embed_dim
342
+ )
343
+ else:
344
+ self.class_embedding = None
345
+
346
+ if addition_embed_type == "text":
347
+ if encoder_hid_dim is not None:
348
+ text_time_embedding_from_dim = encoder_hid_dim
349
+ else:
350
+ text_time_embedding_from_dim = cross_attention_dim
351
+
352
+ self.add_embedding = TextTimeEmbedding(
353
+ text_time_embedding_from_dim,
354
+ time_embed_dim,
355
+ num_heads=addition_embed_type_num_heads,
356
+ )
357
+ elif addition_embed_type == "text_image":
358
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
359
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
360
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
361
+ self.add_embedding = TextImageTimeEmbedding(
362
+ text_embed_dim=cross_attention_dim,
363
+ image_embed_dim=cross_attention_dim,
364
+ time_embed_dim=time_embed_dim,
365
+ )
366
+ elif addition_embed_type == "text_time":
367
+ self.add_time_proj = Timesteps(
368
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
369
+ )
370
+ self.add_embedding = TimestepEmbedding(
371
+ projection_class_embeddings_input_dim, time_embed_dim
372
+ )
373
+ elif addition_embed_type == "image":
374
+ # Kandinsky 2.2
375
+ self.add_embedding = ImageTimeEmbedding(
376
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
377
+ )
378
+ elif addition_embed_type == "image_hint":
379
+ # Kandinsky 2.2 ControlNet
380
+ self.add_embedding = ImageHintTimeEmbedding(
381
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
382
+ )
383
+ elif addition_embed_type is not None:
384
+ raise ValueError(
385
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
386
+ )
387
+
388
+ if time_embedding_act_fn is None:
389
+ self.time_embed_act = None
390
+ else:
391
+ self.time_embed_act = get_activation(time_embedding_act_fn)
392
+
393
+ self.down_blocks = nn.ModuleList([])
394
+ self.up_blocks = nn.ModuleList([])
395
+
396
+ if isinstance(only_cross_attention, bool):
397
+ if mid_block_only_cross_attention is None:
398
+ mid_block_only_cross_attention = only_cross_attention
399
+
400
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
401
+
402
+ if mid_block_only_cross_attention is None:
403
+ mid_block_only_cross_attention = False
404
+
405
+ if isinstance(num_attention_heads, int):
406
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
407
+
408
+ if isinstance(attention_head_dim, int):
409
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
410
+
411
+ if isinstance(cross_attention_dim, int):
412
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
413
+
414
+ if isinstance(layers_per_block, int):
415
+ layers_per_block = [layers_per_block] * len(down_block_types)
416
+
417
+ if isinstance(transformer_layers_per_block, int):
418
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
419
+ down_block_types
420
+ )
421
+
422
+ if class_embeddings_concat:
423
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
424
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
425
+ # regular time embeddings
426
+ blocks_time_embed_dim = time_embed_dim * 2
427
+ else:
428
+ blocks_time_embed_dim = time_embed_dim
429
+
430
+ # down
431
+ output_channel = block_out_channels[0]
432
+ for i, down_block_type in enumerate(down_block_types):
433
+ input_channel = output_channel
434
+ output_channel = block_out_channels[i]
435
+ is_final_block = i == len(block_out_channels) - 1
436
+
437
+ down_block = get_down_block(
438
+ down_block_type,
439
+ num_layers=layers_per_block[i],
440
+ transformer_layers_per_block=transformer_layers_per_block[i],
441
+ in_channels=input_channel,
442
+ out_channels=output_channel,
443
+ temb_channels=blocks_time_embed_dim,
444
+ add_downsample=not is_final_block,
445
+ resnet_eps=norm_eps,
446
+ resnet_act_fn=act_fn,
447
+ resnet_groups=norm_num_groups,
448
+ cross_attention_dim=cross_attention_dim[i],
449
+ num_attention_heads=num_attention_heads[i],
450
+ downsample_padding=downsample_padding,
451
+ dual_cross_attention=dual_cross_attention,
452
+ use_linear_projection=use_linear_projection,
453
+ only_cross_attention=only_cross_attention[i],
454
+ upcast_attention=upcast_attention,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ attention_type=attention_type,
457
+ resnet_skip_time_act=resnet_skip_time_act,
458
+ resnet_out_scale_factor=resnet_out_scale_factor,
459
+ cross_attention_norm=cross_attention_norm,
460
+ attention_head_dim=attention_head_dim[i]
461
+ if attention_head_dim[i] is not None
462
+ else output_channel,
463
+ dropout=dropout,
464
+ )
465
+ self.down_blocks.append(down_block)
466
+
467
+ # mid
468
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
469
+ self.mid_block = UNetMidBlock2DCrossAttn(
470
+ transformer_layers_per_block=transformer_layers_per_block[-1],
471
+ in_channels=block_out_channels[-1],
472
+ temb_channels=blocks_time_embed_dim,
473
+ dropout=dropout,
474
+ resnet_eps=norm_eps,
475
+ resnet_act_fn=act_fn,
476
+ output_scale_factor=mid_block_scale_factor,
477
+ resnet_time_scale_shift=resnet_time_scale_shift,
478
+ cross_attention_dim=cross_attention_dim[-1],
479
+ num_attention_heads=num_attention_heads[-1],
480
+ resnet_groups=norm_num_groups,
481
+ dual_cross_attention=dual_cross_attention,
482
+ use_linear_projection=use_linear_projection,
483
+ upcast_attention=upcast_attention,
484
+ attention_type=attention_type,
485
+ )
486
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
487
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
488
+ in_channels=block_out_channels[-1],
489
+ temb_channels=blocks_time_embed_dim,
490
+ dropout=dropout,
491
+ resnet_eps=norm_eps,
492
+ resnet_act_fn=act_fn,
493
+ output_scale_factor=mid_block_scale_factor,
494
+ cross_attention_dim=cross_attention_dim[-1],
495
+ attention_head_dim=attention_head_dim[-1],
496
+ resnet_groups=norm_num_groups,
497
+ resnet_time_scale_shift=resnet_time_scale_shift,
498
+ skip_time_act=resnet_skip_time_act,
499
+ only_cross_attention=mid_block_only_cross_attention,
500
+ cross_attention_norm=cross_attention_norm,
501
+ )
502
+ elif mid_block_type == "UNetMidBlock2D":
503
+ self.mid_block = UNetMidBlock2D(
504
+ in_channels=block_out_channels[-1],
505
+ temb_channels=blocks_time_embed_dim,
506
+ dropout=dropout,
507
+ num_layers=0,
508
+ resnet_eps=norm_eps,
509
+ resnet_act_fn=act_fn,
510
+ output_scale_factor=mid_block_scale_factor,
511
+ resnet_groups=norm_num_groups,
512
+ resnet_time_scale_shift=resnet_time_scale_shift,
513
+ add_attention=False,
514
+ )
515
+ elif mid_block_type is None:
516
+ self.mid_block = None
517
+ else:
518
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
519
+
520
+ # count how many layers upsample the images
521
+ self.num_upsamplers = 0
522
+
523
+ # up
524
+ reversed_block_out_channels = list(reversed(block_out_channels))
525
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
526
+ reversed_layers_per_block = list(reversed(layers_per_block))
527
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
528
+ reversed_transformer_layers_per_block = (
529
+ list(reversed(transformer_layers_per_block))
530
+ if reverse_transformer_layers_per_block is None
531
+ else reverse_transformer_layers_per_block
532
+ )
533
+ only_cross_attention = list(reversed(only_cross_attention))
534
+
535
+ output_channel = reversed_block_out_channels[0]
536
+ for i, up_block_type in enumerate(up_block_types):
537
+ is_final_block = i == len(block_out_channels) - 1
538
+
539
+ prev_output_channel = output_channel
540
+ output_channel = reversed_block_out_channels[i]
541
+ input_channel = reversed_block_out_channels[
542
+ min(i + 1, len(block_out_channels) - 1)
543
+ ]
544
+
545
+ # add upsample block for all BUT final layer
546
+ if not is_final_block:
547
+ add_upsample = True
548
+ self.num_upsamplers += 1
549
+ else:
550
+ add_upsample = False
551
+
552
+ up_block = get_up_block(
553
+ up_block_type,
554
+ num_layers=reversed_layers_per_block[i] + 1,
555
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
556
+ in_channels=input_channel,
557
+ out_channels=output_channel,
558
+ prev_output_channel=prev_output_channel,
559
+ temb_channels=blocks_time_embed_dim,
560
+ add_upsample=add_upsample,
561
+ resnet_eps=norm_eps,
562
+ resnet_act_fn=act_fn,
563
+ resolution_idx=i,
564
+ resnet_groups=norm_num_groups,
565
+ cross_attention_dim=reversed_cross_attention_dim[i],
566
+ num_attention_heads=reversed_num_attention_heads[i],
567
+ dual_cross_attention=dual_cross_attention,
568
+ use_linear_projection=use_linear_projection,
569
+ only_cross_attention=only_cross_attention[i],
570
+ upcast_attention=upcast_attention,
571
+ resnet_time_scale_shift=resnet_time_scale_shift,
572
+ attention_type=attention_type,
573
+ resnet_skip_time_act=resnet_skip_time_act,
574
+ resnet_out_scale_factor=resnet_out_scale_factor,
575
+ cross_attention_norm=cross_attention_norm,
576
+ attention_head_dim=attention_head_dim[i]
577
+ if attention_head_dim[i] is not None
578
+ else output_channel,
579
+ dropout=dropout,
580
+ )
581
+ self.up_blocks.append(up_block)
582
+ prev_output_channel = output_channel
583
+
584
+ # out
585
+ if norm_num_groups is not None:
586
+ self.conv_norm_out = nn.GroupNorm(
587
+ num_channels=block_out_channels[0],
588
+ num_groups=norm_num_groups,
589
+ eps=norm_eps,
590
+ )
591
+
592
+ self.conv_act = get_activation(act_fn)
593
+
594
+ else:
595
+ self.conv_norm_out = None
596
+ self.conv_act = None
597
+
598
+ conv_out_padding = (conv_out_kernel - 1) // 2
599
+ self.conv_out = nn.Conv2d(
600
+ block_out_channels[0],
601
+ out_channels,
602
+ kernel_size=conv_out_kernel,
603
+ padding=conv_out_padding,
604
+ )
605
+
606
+ if attention_type in ["gated", "gated-text-image"]:
607
+ positive_len = 768
608
+ if isinstance(cross_attention_dim, int):
609
+ positive_len = cross_attention_dim
610
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
611
+ cross_attention_dim, list
612
+ ):
613
+ positive_len = cross_attention_dim[0]
614
+
615
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
616
+ self.position_net = PositionNet(
617
+ positive_len=positive_len,
618
+ out_dim=cross_attention_dim,
619
+ feature_type=feature_type,
620
+ )
621
+ self.need_block_embs = need_block_embs
622
+ self.need_self_attn_block_embs = need_self_attn_block_embs
623
+
624
+ # only use referencenet soma layers, other layers set None
625
+ self.conv_norm_out = None
626
+ self.conv_act = None
627
+ self.conv_out = None
628
+
629
+ self.up_blocks[-1].attentions[-1].proj_out = None
630
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn1 = None
631
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn2 = None
632
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm2 = None
633
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].ff = None
634
+ self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm3 = None
635
+ if not self.need_self_attn_block_embs:
636
+ self.up_blocks = None
637
+
638
+ self.insert_spatial_self_attn_idx()
639
+
640
+ def forward(
641
+ self,
642
+ sample: torch.FloatTensor,
643
+ timestep: Union[torch.Tensor, float, int],
644
+ encoder_hidden_states: torch.Tensor,
645
+ class_labels: Optional[torch.Tensor] = None,
646
+ timestep_cond: Optional[torch.Tensor] = None,
647
+ attention_mask: Optional[torch.Tensor] = None,
648
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
649
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
650
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
651
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
652
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
653
+ encoder_attention_mask: Optional[torch.Tensor] = None,
654
+ return_dict: bool = True,
655
+ # update new paramestes start
656
+ num_frames: int = None,
657
+ return_ndim: int = 5,
658
+ # update new paramestes end
659
+ ) -> Union[UNet2DConditionOutput, Tuple]:
660
+ r"""
661
+ The [`UNet2DConditionModel`] forward method.
662
+
663
+ Args:
664
+ sample (`torch.FloatTensor`):
665
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
666
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
667
+ encoder_hidden_states (`torch.FloatTensor`):
668
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
669
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
670
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
671
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
672
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
673
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
674
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
675
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
676
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
677
+ negative values to the attention scores corresponding to "discard" tokens.
678
+ cross_attention_kwargs (`dict`, *optional*):
679
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
680
+ `self.processor` in
681
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
682
+ added_cond_kwargs: (`dict`, *optional*):
683
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
684
+ are passed along to the UNet blocks.
685
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
686
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
687
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
688
+ A tensor that if specified is added to the residual of the middle unet block.
689
+ encoder_attention_mask (`torch.Tensor`):
690
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
691
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
692
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
693
+ return_dict (`bool`, *optional*, defaults to `True`):
694
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
695
+ tuple.
696
+ cross_attention_kwargs (`dict`, *optional*):
697
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
698
+ added_cond_kwargs: (`dict`, *optional*):
699
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
700
+ are passed along to the UNet blocks.
701
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
702
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
703
+ example from ControlNet side model(s)
704
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
705
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
706
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
707
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
708
+
709
+ Returns:
710
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
711
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
712
+ a `tuple` is returned where the first element is the sample tensor.
713
+ """
714
+
715
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
716
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
717
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
718
+ # on the fly if necessary.
719
+ default_overall_up_factor = 2**self.num_upsamplers
720
+
721
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
722
+ forward_upsample_size = False
723
+ upsample_size = None
724
+
725
+ for dim in sample.shape[-2:]:
726
+ if dim % default_overall_up_factor != 0:
727
+ # Forward upsample size to force interpolation output size.
728
+ forward_upsample_size = True
729
+ break
730
+
731
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
732
+ # expects mask of shape:
733
+ # [batch, key_tokens]
734
+ # adds singleton query_tokens dimension:
735
+ # [batch, 1, key_tokens]
736
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
737
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
738
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
739
+ if attention_mask is not None:
740
+ # assume that mask is expressed as:
741
+ # (1 = keep, 0 = discard)
742
+ # convert mask into a bias that can be added to attention scores:
743
+ # (keep = +0, discard = -10000.0)
744
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
745
+ attention_mask = attention_mask.unsqueeze(1)
746
+
747
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
748
+ if encoder_attention_mask is not None:
749
+ encoder_attention_mask = (
750
+ 1 - encoder_attention_mask.to(sample.dtype)
751
+ ) * -10000.0
752
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
753
+
754
+ # 0. center input if necessary
755
+ if self.config.center_input_sample:
756
+ sample = 2 * sample - 1.0
757
+
758
+ # 1. time
759
+ timesteps = timestep
760
+ if not torch.is_tensor(timesteps):
761
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
762
+ # This would be a good case for the `match` statement (Python 3.10+)
763
+ is_mps = sample.device.type == "mps"
764
+ if isinstance(timestep, float):
765
+ dtype = torch.float32 if is_mps else torch.float64
766
+ else:
767
+ dtype = torch.int32 if is_mps else torch.int64
768
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
769
+ elif len(timesteps.shape) == 0:
770
+ timesteps = timesteps[None].to(sample.device)
771
+
772
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
773
+ timesteps = timesteps.expand(sample.shape[0])
774
+
775
+ t_emb = self.time_proj(timesteps)
776
+
777
+ # `Timesteps` does not contain any weights and will always return f32 tensors
778
+ # but time_embedding might actually be running in fp16. so we need to cast here.
779
+ # there might be better ways to encapsulate this.
780
+ t_emb = t_emb.to(dtype=sample.dtype)
781
+
782
+ emb = self.time_embedding(t_emb, timestep_cond)
783
+ aug_emb = None
784
+
785
+ if self.class_embedding is not None:
786
+ if class_labels is None:
787
+ raise ValueError(
788
+ "class_labels should be provided when num_class_embeds > 0"
789
+ )
790
+
791
+ if self.config.class_embed_type == "timestep":
792
+ class_labels = self.time_proj(class_labels)
793
+
794
+ # `Timesteps` does not contain any weights and will always return f32 tensors
795
+ # there might be better ways to encapsulate this.
796
+ class_labels = class_labels.to(dtype=sample.dtype)
797
+
798
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
799
+
800
+ if self.config.class_embeddings_concat:
801
+ emb = torch.cat([emb, class_emb], dim=-1)
802
+ else:
803
+ emb = emb + class_emb
804
+
805
+ if self.config.addition_embed_type == "text":
806
+ aug_emb = self.add_embedding(encoder_hidden_states)
807
+ elif self.config.addition_embed_type == "text_image":
808
+ # Kandinsky 2.1 - style
809
+ if "image_embeds" not in added_cond_kwargs:
810
+ raise ValueError(
811
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
812
+ )
813
+
814
+ image_embs = added_cond_kwargs.get("image_embeds")
815
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
816
+ aug_emb = self.add_embedding(text_embs, image_embs)
817
+ elif self.config.addition_embed_type == "text_time":
818
+ # SDXL - style
819
+ if "text_embeds" not in added_cond_kwargs:
820
+ raise ValueError(
821
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
822
+ )
823
+ text_embeds = added_cond_kwargs.get("text_embeds")
824
+ if "time_ids" not in added_cond_kwargs:
825
+ raise ValueError(
826
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
827
+ )
828
+ time_ids = added_cond_kwargs.get("time_ids")
829
+ time_embeds = self.add_time_proj(time_ids.flatten())
830
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
831
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
832
+ add_embeds = add_embeds.to(emb.dtype)
833
+ aug_emb = self.add_embedding(add_embeds)
834
+ elif self.config.addition_embed_type == "image":
835
+ # Kandinsky 2.2 - style
836
+ if "image_embeds" not in added_cond_kwargs:
837
+ raise ValueError(
838
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
839
+ )
840
+ image_embs = added_cond_kwargs.get("image_embeds")
841
+ aug_emb = self.add_embedding(image_embs)
842
+ elif self.config.addition_embed_type == "image_hint":
843
+ # Kandinsky 2.2 - style
844
+ if (
845
+ "image_embeds" not in added_cond_kwargs
846
+ or "hint" not in added_cond_kwargs
847
+ ):
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
850
+ )
851
+ image_embs = added_cond_kwargs.get("image_embeds")
852
+ hint = added_cond_kwargs.get("hint")
853
+ aug_emb, hint = self.add_embedding(image_embs, hint)
854
+ sample = torch.cat([sample, hint], dim=1)
855
+
856
+ emb = emb + aug_emb if aug_emb is not None else emb
857
+
858
+ if self.time_embed_act is not None:
859
+ emb = self.time_embed_act(emb)
860
+
861
+ if (
862
+ self.encoder_hid_proj is not None
863
+ and self.config.encoder_hid_dim_type == "text_proj"
864
+ ):
865
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
866
+ elif (
867
+ self.encoder_hid_proj is not None
868
+ and self.config.encoder_hid_dim_type == "text_image_proj"
869
+ ):
870
+ # Kadinsky 2.1 - style
871
+ if "image_embeds" not in added_cond_kwargs:
872
+ raise ValueError(
873
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
874
+ )
875
+
876
+ image_embeds = added_cond_kwargs.get("image_embeds")
877
+ encoder_hidden_states = self.encoder_hid_proj(
878
+ encoder_hidden_states, image_embeds
879
+ )
880
+ elif (
881
+ self.encoder_hid_proj is not None
882
+ and self.config.encoder_hid_dim_type == "image_proj"
883
+ ):
884
+ # Kandinsky 2.2 - style
885
+ if "image_embeds" not in added_cond_kwargs:
886
+ raise ValueError(
887
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
888
+ )
889
+ image_embeds = added_cond_kwargs.get("image_embeds")
890
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
891
+ elif (
892
+ self.encoder_hid_proj is not None
893
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
894
+ ):
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
898
+ )
899
+ image_embeds = added_cond_kwargs.get("image_embeds")
900
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
901
+ encoder_hidden_states.dtype
902
+ )
903
+ encoder_hidden_states = torch.cat(
904
+ [encoder_hidden_states, image_embeds], dim=1
905
+ )
906
+
907
+ # need_self_attn_block_embs
908
+ # 初始化
909
+ # 或在unet中运算中会不断 append self_attn_blocks_embs,用完需要清理,
910
+ if self.need_self_attn_block_embs:
911
+ self_attn_block_embs = [None] * self.self_attn_num
912
+ else:
913
+ self_attn_block_embs = None
914
+ # 2. pre-process
915
+ sample = self.conv_in(sample)
916
+ if self.print_idx == 0:
917
+ logger.debug(f"after conv in sample={sample.mean()}")
918
+ # 2.5 GLIGEN position net
919
+ if (
920
+ cross_attention_kwargs is not None
921
+ and cross_attention_kwargs.get("gligen", None) is not None
922
+ ):
923
+ cross_attention_kwargs = cross_attention_kwargs.copy()
924
+ gligen_args = cross_attention_kwargs.pop("gligen")
925
+ cross_attention_kwargs["gligen"] = {
926
+ "objs": self.position_net(**gligen_args)
927
+ }
928
+
929
+ # 3. down
930
+ lora_scale = (
931
+ cross_attention_kwargs.get("scale", 1.0)
932
+ if cross_attention_kwargs is not None
933
+ else 1.0
934
+ )
935
+ if USE_PEFT_BACKEND:
936
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
937
+ scale_lora_layers(self, lora_scale)
938
+
939
+ is_controlnet = (
940
+ mid_block_additional_residual is not None
941
+ and down_block_additional_residuals is not None
942
+ )
943
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
944
+ is_adapter = down_intrablock_additional_residuals is not None
945
+ # maintain backward compatibility for legacy usage, where
946
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
947
+ # but can only use one or the other
948
+ if (
949
+ not is_adapter
950
+ and mid_block_additional_residual is None
951
+ and down_block_additional_residuals is not None
952
+ ):
953
+ deprecate(
954
+ "T2I should not use down_block_additional_residuals",
955
+ "1.3.0",
956
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
957
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
958
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
959
+ standard_warn=False,
960
+ )
961
+ down_intrablock_additional_residuals = down_block_additional_residuals
962
+ is_adapter = True
963
+
964
+ down_block_res_samples = (sample,)
965
+ for i_downsample_block, downsample_block in enumerate(self.down_blocks):
966
+ if (
967
+ hasattr(downsample_block, "has_cross_attention")
968
+ and downsample_block.has_cross_attention
969
+ ):
970
+ # For t2i-adapter CrossAttnDownBlock2D
971
+ additional_residuals = {}
972
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
973
+ additional_residuals[
974
+ "additional_residuals"
975
+ ] = down_intrablock_additional_residuals.pop(0)
976
+ if self.print_idx == 0:
977
+ logger.debug(
978
+ f"downsample_block {i_downsample_block} sample={sample.mean()}"
979
+ )
980
+ sample, res_samples = downsample_block(
981
+ hidden_states=sample,
982
+ temb=emb,
983
+ encoder_hidden_states=encoder_hidden_states,
984
+ attention_mask=attention_mask,
985
+ cross_attention_kwargs=cross_attention_kwargs,
986
+ encoder_attention_mask=encoder_attention_mask,
987
+ **additional_residuals,
988
+ self_attn_block_embs=self_attn_block_embs,
989
+ )
990
+ else:
991
+ sample, res_samples = downsample_block(
992
+ hidden_states=sample,
993
+ temb=emb,
994
+ scale=lora_scale,
995
+ self_attn_block_embs=self_attn_block_embs,
996
+ )
997
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
998
+ sample += down_intrablock_additional_residuals.pop(0)
999
+
1000
+ down_block_res_samples += res_samples
1001
+
1002
+ if is_controlnet:
1003
+ new_down_block_res_samples = ()
1004
+
1005
+ for down_block_res_sample, down_block_additional_residual in zip(
1006
+ down_block_res_samples, down_block_additional_residuals
1007
+ ):
1008
+ down_block_res_sample = (
1009
+ down_block_res_sample + down_block_additional_residual
1010
+ )
1011
+ new_down_block_res_samples = new_down_block_res_samples + (
1012
+ down_block_res_sample,
1013
+ )
1014
+
1015
+ down_block_res_samples = new_down_block_res_samples
1016
+
1017
+ # update code start
1018
+ def reshape_return_emb(tmp_emb):
1019
+ if return_ndim == 4:
1020
+ return tmp_emb
1021
+ elif return_ndim == 5:
1022
+ return rearrange(tmp_emb, "(b t) c h w-> b c t h w", t=num_frames)
1023
+ else:
1024
+ raise ValueError(
1025
+ f"reshape_emb only support 4, 5 but given {return_ndim}"
1026
+ )
1027
+
1028
+ if self.need_block_embs:
1029
+ return_down_block_res_samples = [
1030
+ reshape_return_emb(tmp_emb) for tmp_emb in down_block_res_samples
1031
+ ]
1032
+ else:
1033
+ return_down_block_res_samples = None
1034
+ # update code end
1035
+
1036
+ # 4. mid
1037
+ if self.mid_block is not None:
1038
+ if (
1039
+ hasattr(self.mid_block, "has_cross_attention")
1040
+ and self.mid_block.has_cross_attention
1041
+ ):
1042
+ sample = self.mid_block(
1043
+ sample,
1044
+ emb,
1045
+ encoder_hidden_states=encoder_hidden_states,
1046
+ attention_mask=attention_mask,
1047
+ cross_attention_kwargs=cross_attention_kwargs,
1048
+ encoder_attention_mask=encoder_attention_mask,
1049
+ self_attn_block_embs=self_attn_block_embs,
1050
+ )
1051
+ else:
1052
+ sample = self.mid_block(sample, emb)
1053
+
1054
+ # To support T2I-Adapter-XL
1055
+ if (
1056
+ is_adapter
1057
+ and len(down_intrablock_additional_residuals) > 0
1058
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1059
+ ):
1060
+ sample += down_intrablock_additional_residuals.pop(0)
1061
+
1062
+ if is_controlnet:
1063
+ sample = sample + mid_block_additional_residual
1064
+
1065
+ if self.need_block_embs:
1066
+ return_mid_block_res_samples = reshape_return_emb(sample)
1067
+ logger.debug(
1068
+ f"return_mid_block_res_samples, is_leaf={return_mid_block_res_samples.is_leaf}, requires_grad={return_mid_block_res_samples.requires_grad}"
1069
+ )
1070
+ else:
1071
+ return_mid_block_res_samples = None
1072
+
1073
+ if self.up_blocks is not None:
1074
+ # update code end
1075
+
1076
+ # 5. up
1077
+ for i, upsample_block in enumerate(self.up_blocks):
1078
+ is_final_block = i == len(self.up_blocks) - 1
1079
+
1080
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1081
+ down_block_res_samples = down_block_res_samples[
1082
+ : -len(upsample_block.resnets)
1083
+ ]
1084
+
1085
+ # if we have not reached the final block and need to forward the
1086
+ # upsample size, we do it here
1087
+ if not is_final_block and forward_upsample_size:
1088
+ upsample_size = down_block_res_samples[-1].shape[2:]
1089
+
1090
+ if (
1091
+ hasattr(upsample_block, "has_cross_attention")
1092
+ and upsample_block.has_cross_attention
1093
+ ):
1094
+ sample = upsample_block(
1095
+ hidden_states=sample,
1096
+ temb=emb,
1097
+ res_hidden_states_tuple=res_samples,
1098
+ encoder_hidden_states=encoder_hidden_states,
1099
+ cross_attention_kwargs=cross_attention_kwargs,
1100
+ upsample_size=upsample_size,
1101
+ attention_mask=attention_mask,
1102
+ encoder_attention_mask=encoder_attention_mask,
1103
+ self_attn_block_embs=self_attn_block_embs,
1104
+ )
1105
+ else:
1106
+ sample = upsample_block(
1107
+ hidden_states=sample,
1108
+ temb=emb,
1109
+ res_hidden_states_tuple=res_samples,
1110
+ upsample_size=upsample_size,
1111
+ scale=lora_scale,
1112
+ self_attn_block_embs=self_attn_block_embs,
1113
+ )
1114
+
1115
+ # update code start
1116
+ if self.need_block_embs or self.need_self_attn_block_embs:
1117
+ if self_attn_block_embs is not None:
1118
+ self_attn_block_embs = [
1119
+ reshape_return_emb(tmp_emb=tmp_emb)
1120
+ for tmp_emb in self_attn_block_embs
1121
+ ]
1122
+ self.print_idx += 1
1123
+ return (
1124
+ return_down_block_res_samples,
1125
+ return_mid_block_res_samples,
1126
+ self_attn_block_embs,
1127
+ )
1128
+
1129
+ if not self.need_block_embs and not self.need_self_attn_block_embs:
1130
+ # 6. post-process
1131
+ if self.conv_norm_out:
1132
+ sample = self.conv_norm_out(sample)
1133
+ sample = self.conv_act(sample)
1134
+ sample = self.conv_out(sample)
1135
+
1136
+ if USE_PEFT_BACKEND:
1137
+ # remove `lora_scale` from each PEFT layer
1138
+ unscale_lora_layers(self, lora_scale)
1139
+ self.print_idx += 1
1140
+ if not return_dict:
1141
+ return (sample,)
1142
+
1143
+ return UNet2DConditionOutput(sample=sample)
1144
+
1145
+ def insert_spatial_self_attn_idx(self):
1146
+ attns, basic_transformers = self.spatial_self_attns
1147
+ self.self_attn_num = len(attns)
1148
+ for i, (name, layer) in enumerate(attns):
1149
+ logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
1150
+ if layer is not None:
1151
+ layer.spatial_self_attn_idx = i
1152
+ for i, (name, layer) in enumerate(basic_transformers):
1153
+ logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}")
1154
+ if layer is not None:
1155
+ layer.spatial_self_attn_idx = i
1156
+
1157
+ @property
1158
+ def spatial_self_attns(
1159
+ self,
1160
+ ) -> List[Tuple[str, Attention]]:
1161
+ attns, spatial_transformers = self.get_self_attns(
1162
+ include="attentions", exclude="temp_attentions"
1163
+ )
1164
+ attns = sorted(attns)
1165
+ spatial_transformers = sorted(spatial_transformers)
1166
+ return attns, spatial_transformers
1167
+
1168
+ def get_self_attns(
1169
+ self, include: str = None, exclude: str = None
1170
+ ) -> List[Tuple[str, Attention]]:
1171
+ r"""
1172
+ Returns:
1173
+ `dict` of attention attns: A dictionary containing all attention attns used in the model with
1174
+ indexed by its weight name.
1175
+ """
1176
+ # set recursively
1177
+ attns = []
1178
+ spatial_transformers = []
1179
+
1180
+ def fn_recursive_add_attns(
1181
+ name: str,
1182
+ module: torch.nn.Module,
1183
+ attns: List[Tuple[str, Attention]],
1184
+ spatial_transformers: List[Tuple[str, BasicTransformerBlock]],
1185
+ ):
1186
+ is_target = False
1187
+ if isinstance(module, BasicTransformerBlock) and hasattr(module, "attn1"):
1188
+ is_target = True
1189
+ if include is not None:
1190
+ is_target = include in name
1191
+ if exclude is not None:
1192
+ is_target = exclude not in name
1193
+ if is_target:
1194
+ attns.append([f"{name}.attn1", module.attn1])
1195
+ spatial_transformers.append([f"{name}", module])
1196
+ for sub_name, child in module.named_children():
1197
+ fn_recursive_add_attns(
1198
+ f"{name}.{sub_name}", child, attns, spatial_transformers
1199
+ )
1200
+
1201
+ return attns
1202
+
1203
+ for name, module in self.named_children():
1204
+ fn_recursive_add_attns(name, module, attns, spatial_transformers)
1205
+
1206
+ return attns, spatial_transformers
1207
+
1208
+
1209
+ class ReferenceNet3D(UNet3DConditionModel):
1210
+ """继承 UNet3DConditionModel, 用于提取中间emb用于后续作用。
1211
+ Inherit Unet3DConditionModel, used to extract the middle emb for subsequent actions.
1212
+ Args:
1213
+ UNet3DConditionModel (_type_): _description_
1214
+ """
1215
+
1216
+ pass
musev/models/referencenet_loader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, Iterable, Union
3
+ import PIL
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import datetime
8
+ import logging
9
+ import inspect
10
+ import math
11
+ import os
12
+ import shutil
13
+ from typing import Dict, List, Optional, Tuple
14
+ from pprint import pprint
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ import gc
18
+ import time
19
+
20
+ import numpy as np
21
+ from omegaconf import OmegaConf
22
+ from omegaconf import SCMode
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange, repeat
28
+ import pandas as pd
29
+ import h5py
30
+ from diffusers.models.modeling_utils import load_state_dict
31
+ from diffusers.utils import (
32
+ logging,
33
+ )
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+
36
+ from .referencenet import ReferenceNet2D
37
+ from .unet_loader import update_unet_with_sd
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def load_referencenet(
44
+ sd_referencenet_model: Tuple[str, nn.Module],
45
+ sd_model: nn.Module = None,
46
+ need_self_attn_block_embs: bool = False,
47
+ need_block_embs: bool = False,
48
+ dtype: torch.dtype = torch.float16,
49
+ cross_attention_dim: int = 768,
50
+ subfolder: str = "unet",
51
+ ):
52
+ """
53
+ Loads the ReferenceNet model.
54
+
55
+ Args:
56
+ sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model.
57
+ sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None.
58
+ need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False.
59
+ need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False.
60
+ dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16.
61
+ cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768.
62
+ subfolder (str, optional): The subfolder of the model. Defaults to "unet".
63
+
64
+ Returns:
65
+ nn.Module: The loaded ReferenceNet model.
66
+ """
67
+
68
+ if isinstance(sd_referencenet_model, str):
69
+ referencenet = ReferenceNet2D.from_pretrained(
70
+ sd_referencenet_model,
71
+ subfolder=subfolder,
72
+ need_self_attn_block_embs=need_self_attn_block_embs,
73
+ need_block_embs=need_block_embs,
74
+ torch_dtype=dtype,
75
+ cross_attention_dim=cross_attention_dim,
76
+ )
77
+ elif isinstance(sd_referencenet_model, nn.Module):
78
+ referencenet = sd_referencenet_model
79
+ if sd_model is not None:
80
+ referencenet = update_unet_with_sd(referencenet, sd_model)
81
+ return referencenet
82
+
83
+
84
+ def load_referencenet_by_name(
85
+ model_name: str,
86
+ sd_referencenet_model: Tuple[str, nn.Module],
87
+ sd_model: nn.Module = None,
88
+ cross_attention_dim: int = 768,
89
+ dtype: torch.dtype = torch.float16,
90
+ ) -> nn.Module:
91
+ """通过模型名字 初始化 referencenet,载入预训练参数,
92
+ 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义
93
+ init referencenet with model_name.
94
+ if you want to use pretrained model with simple name, you need to define it here.
95
+ Args:
96
+ model_name (str): _description_
97
+ sd_unet_model (Tuple[str, nn.Module]): _description_
98
+ sd_model (Tuple[str, nn.Module]): _description_
99
+ cross_attention_dim (int, optional): _description_. Defaults to 768.
100
+ dtype (torch.dtype, optional): _description_. Defaults to torch.float16.
101
+
102
+ Raises:
103
+ ValueError: _description_
104
+
105
+ Returns:
106
+ nn.Module: _description_
107
+ """
108
+ if model_name in [
109
+ "musev_referencenet",
110
+ ]:
111
+ unet = load_referencenet(
112
+ sd_referencenet_model=sd_referencenet_model,
113
+ sd_model=sd_model,
114
+ cross_attention_dim=cross_attention_dim,
115
+ dtype=dtype,
116
+ need_self_attn_block_embs=False,
117
+ need_block_embs=True,
118
+ subfolder="referencenet",
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16"
123
+ )
124
+ return unet
musev/models/resnet.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
17
+ from __future__ import annotations
18
+
19
+ from functools import partial
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from einops import rearrange, repeat
26
+
27
+ from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
28
+ from ..data.data_util import batch_index_fill, batch_index_select
29
+ from . import Model_Register
30
+
31
+
32
+ @Model_Register.register
33
+ class TemporalConvLayer(nn.Module):
34
+ """
35
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
36
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ in_dim,
42
+ out_dim=None,
43
+ dropout=0.0,
44
+ keep_content_condition: bool = False,
45
+ femb_channels: Optional[int] = None,
46
+ need_temporal_weight: bool = True,
47
+ ):
48
+ super().__init__()
49
+ out_dim = out_dim or in_dim
50
+ self.in_dim = in_dim
51
+ self.out_dim = out_dim
52
+ self.keep_content_condition = keep_content_condition
53
+ self.femb_channels = femb_channels
54
+ self.need_temporal_weight = need_temporal_weight
55
+ # conv layers
56
+ self.conv1 = nn.Sequential(
57
+ nn.GroupNorm(32, in_dim),
58
+ nn.SiLU(),
59
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
60
+ )
61
+ self.conv2 = nn.Sequential(
62
+ nn.GroupNorm(32, out_dim),
63
+ nn.SiLU(),
64
+ nn.Dropout(dropout),
65
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
66
+ )
67
+ self.conv3 = nn.Sequential(
68
+ nn.GroupNorm(32, out_dim),
69
+ nn.SiLU(),
70
+ nn.Dropout(dropout),
71
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
72
+ )
73
+ self.conv4 = nn.Sequential(
74
+ nn.GroupNorm(32, out_dim),
75
+ nn.SiLU(),
76
+ nn.Dropout(dropout),
77
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
78
+ )
79
+
80
+ # zero out the last layer params,so the conv block is identity
81
+ # nn.init.zeros_(self.conv4[-1].weight)
82
+ # nn.init.zeros_(self.conv4[-1].bias)
83
+ self.temporal_weight = nn.Parameter(
84
+ torch.tensor(
85
+ [
86
+ 1e-5,
87
+ ]
88
+ )
89
+ ) # initialize parameter with 0
90
+ # zero out the last layer params,so the conv block is identity
91
+ nn.init.zeros_(self.conv4[-1].weight)
92
+ nn.init.zeros_(self.conv4[-1].bias)
93
+ self.skip_temporal_layers = False # Whether to skip temporal layer
94
+
95
+ def forward(
96
+ self,
97
+ hidden_states,
98
+ num_frames=1,
99
+ sample_index: torch.LongTensor = None,
100
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
101
+ femb: torch.Tensor = None,
102
+ ):
103
+ if self.skip_temporal_layers is True:
104
+ return hidden_states
105
+ hidden_states_dtype = hidden_states.dtype
106
+ hidden_states = rearrange(
107
+ hidden_states, "(b t) c h w -> b c t h w", t=num_frames
108
+ )
109
+ identity = hidden_states
110
+ hidden_states = self.conv1(hidden_states)
111
+ hidden_states = self.conv2(hidden_states)
112
+ hidden_states = self.conv3(hidden_states)
113
+ hidden_states = self.conv4(hidden_states)
114
+ # 保留condition对应的frames,便于保持前序内容帧,提升一致性
115
+ if self.keep_content_condition:
116
+ mask = torch.ones_like(hidden_states, device=hidden_states.device)
117
+ mask = batch_index_fill(
118
+ mask, dim=2, index=vision_conditon_frames_sample_index, value=0
119
+ )
120
+ if self.need_temporal_weight:
121
+ hidden_states = (
122
+ identity + torch.abs(self.temporal_weight) * mask * hidden_states
123
+ )
124
+ else:
125
+ hidden_states = identity + mask * hidden_states
126
+ else:
127
+ if self.need_temporal_weight:
128
+ hidden_states = (
129
+ identity + torch.abs(self.temporal_weight) * hidden_states
130
+ )
131
+ else:
132
+ hidden_states = identity + hidden_states
133
+ hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
134
+ hidden_states = hidden_states.to(dtype=hidden_states_dtype)
135
+ return hidden_states
musev/models/super_model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ from typing import Any, Dict, Tuple, Union, Optional
6
+ from einops import rearrange, repeat
7
+ from torch import nn
8
+ import torch
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
12
+
13
+ from ..data.data_util import align_repeat_tensor_single_dim
14
+
15
+ from .unet_3d_condition import UNet3DConditionModel
16
+ from .referencenet import ReferenceNet2D
17
+ from ip_adapter.ip_adapter import ImageProjModel
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SuperUNet3DConditionModel(nn.Module):
23
+ """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
24
+ 主要作用
25
+ 1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
26
+ 2. 便于 accelerator 的分布式训练;
27
+
28
+ wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
29
+ 1. support controlnet, referencenet, etc.
30
+ 2. support accelerator distributed training
31
+ """
32
+
33
+ _supports_gradient_checkpointing = True
34
+ print_idx = 0
35
+
36
+ # @register_to_config
37
+ def __init__(
38
+ self,
39
+ unet: nn.Module,
40
+ referencenet: nn.Module = None,
41
+ controlnet: nn.Module = None,
42
+ vae: nn.Module = None,
43
+ text_encoder: nn.Module = None,
44
+ tokenizer: nn.Module = None,
45
+ text_emb_extractor: nn.Module = None,
46
+ clip_vision_extractor: nn.Module = None,
47
+ ip_adapter_image_proj: nn.Module = None,
48
+ ) -> None:
49
+ """_summary_
50
+
51
+ Args:
52
+ unet (nn.Module): _description_
53
+ referencenet (nn.Module, optional): _description_. Defaults to None.
54
+ controlnet (nn.Module, optional): _description_. Defaults to None.
55
+ vae (nn.Module, optional): _description_. Defaults to None.
56
+ text_encoder (nn.Module, optional): _description_. Defaults to None.
57
+ tokenizer (nn.Module, optional): _description_. Defaults to None.
58
+ text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
59
+ clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
60
+ """
61
+ super().__init__()
62
+ self.unet = unet
63
+ self.referencenet = referencenet
64
+ self.controlnet = controlnet
65
+ self.vae = vae
66
+ self.text_encoder = text_encoder
67
+ self.tokenizer = tokenizer
68
+ self.text_emb_extractor = text_emb_extractor
69
+ self.clip_vision_extractor = clip_vision_extractor
70
+ self.ip_adapter_image_proj = ip_adapter_image_proj
71
+
72
+ def forward(
73
+ self,
74
+ unet_params: Dict,
75
+ encoder_hidden_states: torch.Tensor,
76
+ referencenet_params: Dict = None,
77
+ controlnet_params: Dict = None,
78
+ controlnet_scale: float = 1.0,
79
+ vision_clip_emb: Union[torch.Tensor, None] = None,
80
+ prompt_only_use_image_prompt: bool = False,
81
+ ):
82
+ """_summary_
83
+
84
+ Args:
85
+ unet_params (Dict): _description_
86
+ encoder_hidden_states (torch.Tensor): b t n d
87
+ referencenet_params (Dict, optional): _description_. Defaults to None.
88
+ controlnet_params (Dict, optional): _description_. Defaults to None.
89
+ controlnet_scale (float, optional): _description_. Defaults to 1.0.
90
+ vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
91
+ prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.
92
+
93
+ Returns:
94
+ _type_: _description_
95
+ """
96
+ batch_size = unet_params["sample"].shape[0]
97
+ time_size = unet_params["sample"].shape[2]
98
+
99
+ # ip_adapter_cross_attn, prepare image prompt
100
+ if vision_clip_emb is not None:
101
+ # b t n d -> b t n d
102
+ if self.print_idx == 0:
103
+ logger.debug(
104
+ f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
105
+ )
106
+ if vision_clip_emb.ndim == 3:
107
+ vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
108
+ if self.ip_adapter_image_proj is not None:
109
+ vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
110
+ vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
111
+ if self.print_idx == 0:
112
+ logger.debug(
113
+ f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
114
+ )
115
+ if vision_clip_emb.ndim == 2:
116
+ vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
117
+ vision_clip_emb = rearrange(
118
+ vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
119
+ )
120
+ vision_clip_emb = align_repeat_tensor_single_dim(
121
+ vision_clip_emb, target_length=time_size, dim=1
122
+ )
123
+ if self.print_idx == 0:
124
+ logger.debug(
125
+ f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
126
+ )
127
+
128
+ if vision_clip_emb is None and encoder_hidden_states is not None:
129
+ vision_clip_emb = encoder_hidden_states
130
+ if vision_clip_emb is not None and encoder_hidden_states is None:
131
+ encoder_hidden_states = vision_clip_emb
132
+ # 当 prompt_only_use_image_prompt 为True时,
133
+ # 1. referencenet 都使用 vision_clip_emb
134
+ # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
135
+ # 3. controlnet 当前使用 text_prompt
136
+
137
+ # when prompt_only_use_image_prompt True,
138
+ # 1. referencenet use vision_clip_emb
139
+ # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
140
+ # 3. controlnet use text_prompt
141
+
142
+ # extract referencenet emb
143
+ if self.referencenet is not None and referencenet_params is not None:
144
+ referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
145
+ vision_clip_emb,
146
+ target_length=referencenet_params["num_frames"],
147
+ dim=1,
148
+ )
149
+ referencenet_params["encoder_hidden_states"] = rearrange(
150
+ referencenet_encoder_hidden_states, "b t n d->(b t) n d"
151
+ )
152
+ referencenet_out = self.referencenet(**referencenet_params)
153
+ (
154
+ down_block_refer_embs,
155
+ mid_block_refer_emb,
156
+ refer_self_attn_emb,
157
+ ) = referencenet_out
158
+ if down_block_refer_embs is not None:
159
+ if self.print_idx == 0:
160
+ logger.debug(
161
+ f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
162
+ )
163
+ for i, down_emb in enumerate(down_block_refer_embs):
164
+ if self.print_idx == 0:
165
+ logger.debug(
166
+ f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
167
+ )
168
+ else:
169
+ if self.print_idx == 0:
170
+ logger.debug(f"down_block_refer_embs is None")
171
+ if mid_block_refer_emb is not None:
172
+ if self.print_idx == 0:
173
+ logger.debug(
174
+ f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
175
+ )
176
+ else:
177
+ if self.print_idx == 0:
178
+ logger.debug(f"mid_block_refer_emb is None")
179
+ if refer_self_attn_emb is not None:
180
+ if self.print_idx == 0:
181
+ logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
182
+ for i, self_attn_emb in enumerate(refer_self_attn_emb):
183
+ if self.print_idx == 0:
184
+ logger.debug(
185
+ f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
186
+ )
187
+ else:
188
+ if self.print_idx == 0:
189
+ logger.debug(f"refer_self_attn_emb is None")
190
+ else:
191
+ down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
192
+ None,
193
+ None,
194
+ None,
195
+ )
196
+
197
+ # extract controlnet emb
198
+ if self.controlnet is not None and controlnet_params is not None:
199
+ controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
200
+ encoder_hidden_states,
201
+ target_length=unet_params["sample"].shape[2],
202
+ dim=1,
203
+ )
204
+ controlnet_params["encoder_hidden_states"] = rearrange(
205
+ controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
206
+ )
207
+ (
208
+ down_block_additional_residuals,
209
+ mid_block_additional_residual,
210
+ ) = self.controlnet(**controlnet_params)
211
+ if controlnet_scale != 1.0:
212
+ down_block_additional_residuals = [
213
+ x * controlnet_scale for x in down_block_additional_residuals
214
+ ]
215
+ mid_block_additional_residual = (
216
+ mid_block_additional_residual * controlnet_scale
217
+ )
218
+ for i, down_block_additional_residual in enumerate(
219
+ down_block_additional_residuals
220
+ ):
221
+ if self.print_idx == 0:
222
+ logger.debug(
223
+ f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
224
+ )
225
+
226
+ if self.print_idx == 0:
227
+ logger.debug(
228
+ f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
229
+ )
230
+ else:
231
+ down_block_additional_residuals = None
232
+ mid_block_additional_residual = None
233
+
234
+ if prompt_only_use_image_prompt and vision_clip_emb is not None:
235
+ encoder_hidden_states = vision_clip_emb
236
+
237
+ # run unet
238
+ out = self.unet(
239
+ **unet_params,
240
+ down_block_refer_embs=down_block_refer_embs,
241
+ mid_block_refer_emb=mid_block_refer_emb,
242
+ refer_self_attn_emb=refer_self_attn_emb,
243
+ down_block_additional_residuals=down_block_additional_residuals,
244
+ mid_block_additional_residual=mid_block_additional_residual,
245
+ encoder_hidden_states=encoder_hidden_states,
246
+ vision_clip_emb=vision_clip_emb,
247
+ )
248
+ self.print_idx += 1
249
+ return out
250
+
251
+ def _set_gradient_checkpointing(self, module, value=False):
252
+ if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
253
+ module.gradient_checkpointing = value
musev/models/temporal_transformer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py
16
+ from __future__ import annotations
17
+ from copy import deepcopy
18
+ from dataclasses import dataclass
19
+ from typing import List, Literal, Optional
20
+ import logging
21
+
22
+ import torch
23
+ from torch import nn
24
+ from einops import rearrange, repeat
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.transformer_temporal import (
30
+ TransformerTemporalModelOutput,
31
+ TransformerTemporalModel as DiffusersTransformerTemporalModel,
32
+ )
33
+ from diffusers.models.attention_processor import AttnProcessor
34
+
35
+ from mmcm.utils.gpu_util import get_gpu_status
36
+ from ..data.data_util import (
37
+ batch_concat_two_tensor_with_index,
38
+ batch_index_fill,
39
+ batch_index_select,
40
+ concat_two_tensor,
41
+ align_repeat_tensor_single_dim,
42
+ )
43
+ from ..utils.attention_util import generate_sparse_causcal_attn_mask
44
+ from .attention import BasicTransformerBlock
45
+ from .attention_processor import (
46
+ BaseIPAttnProcessor,
47
+ )
48
+ from . import Model_Register
49
+
50
+ # https://github.com/facebookresearch/xformers/issues/845
51
+ # 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉
52
+ # if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ @Model_Register.register
57
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
58
+ """
59
+ Transformer model for video-like data.
60
+
61
+ Parameters:
62
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
63
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
64
+ in_channels (`int`, *optional*):
65
+ Pass if the input is continuous. The number of channels in the input and output.
66
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
67
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
68
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
69
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
70
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
71
+ `ImagePositionalEmbeddings`.
72
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
75
+ double_self_attention (`bool`, *optional*):
76
+ Configure if each TransformerBlock should contain two self-attention layers
77
+ """
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ num_attention_heads: int = 16,
83
+ attention_head_dim: int = 88,
84
+ in_channels: Optional[int] = None,
85
+ out_channels: Optional[int] = None,
86
+ num_layers: int = 1,
87
+ femb_channels: Optional[int] = None,
88
+ dropout: float = 0.0,
89
+ norm_num_groups: int = 32,
90
+ cross_attention_dim: Optional[int] = None,
91
+ attention_bias: bool = False,
92
+ sample_size: Optional[int] = None,
93
+ activation_fn: str = "geglu",
94
+ norm_elementwise_affine: bool = True,
95
+ double_self_attention: bool = True,
96
+ allow_xformers: bool = False,
97
+ only_cross_attention: bool = False,
98
+ keep_content_condition: bool = False,
99
+ need_spatial_position_emb: bool = False,
100
+ need_temporal_weight: bool = True,
101
+ self_attn_mask: str = None,
102
+ # TODO: 运行参数,有待改到forward里面去
103
+ # TODO: running parameters, need to be moved to forward
104
+ image_scale: float = 1.0,
105
+ processor: AttnProcessor | None = None,
106
+ remove_femb_non_linear: bool = False,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.num_attention_heads = num_attention_heads
111
+ self.attention_head_dim = attention_head_dim
112
+
113
+ inner_dim = num_attention_heads * attention_head_dim
114
+ self.inner_dim = inner_dim
115
+ self.in_channels = in_channels
116
+
117
+ self.norm = torch.nn.GroupNorm(
118
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
119
+ )
120
+
121
+ self.proj_in = nn.Linear(in_channels, inner_dim)
122
+
123
+ # 2. Define temporal positional embedding
124
+ self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
125
+ self.remove_femb_non_linear = remove_femb_non_linear
126
+ if not remove_femb_non_linear:
127
+ self.nonlinearity = nn.SiLU()
128
+
129
+ # spatial_position_emb 使用femb_的参数配置
130
+ self.need_spatial_position_emb = need_spatial_position_emb
131
+ if need_spatial_position_emb:
132
+ self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim)
133
+ # 3. Define transformers blocks
134
+ # TODO: 该实现方式不好,待优化
135
+ # TODO: bad implementation, need to be optimized
136
+ self.need_ipadapter = False
137
+ self.cross_attn_temporal_cond = False
138
+ self.allow_xformers = allow_xformers
139
+ if processor is not None and isinstance(processor, BaseIPAttnProcessor):
140
+ self.cross_attn_temporal_cond = True
141
+ self.allow_xformers = False
142
+ if "NonParam" not in processor.__class__.__name__:
143
+ self.need_ipadapter = True
144
+
145
+ self.transformer_blocks = nn.ModuleList(
146
+ [
147
+ BasicTransformerBlock(
148
+ inner_dim,
149
+ num_attention_heads,
150
+ attention_head_dim,
151
+ dropout=dropout,
152
+ cross_attention_dim=cross_attention_dim,
153
+ activation_fn=activation_fn,
154
+ attention_bias=attention_bias,
155
+ double_self_attention=double_self_attention,
156
+ norm_elementwise_affine=norm_elementwise_affine,
157
+ allow_xformers=allow_xformers,
158
+ only_cross_attention=only_cross_attention,
159
+ cross_attn_temporal_cond=self.need_ipadapter,
160
+ image_scale=image_scale,
161
+ processor=processor,
162
+ )
163
+ for d in range(num_layers)
164
+ ]
165
+ )
166
+
167
+ self.proj_out = nn.Linear(inner_dim, in_channels)
168
+
169
+ self.need_temporal_weight = need_temporal_weight
170
+ if need_temporal_weight:
171
+ self.temporal_weight = nn.Parameter(
172
+ torch.tensor(
173
+ [
174
+ 1e-5,
175
+ ]
176
+ )
177
+ ) # initialize parameter with 0
178
+ self.skip_temporal_layers = False # Whether to skip temporal layer
179
+ self.keep_content_condition = keep_content_condition
180
+ self.self_attn_mask = self_attn_mask
181
+ self.only_cross_attention = only_cross_attention
182
+ self.double_self_attention = double_self_attention
183
+ self.cross_attention_dim = cross_attention_dim
184
+ self.image_scale = image_scale
185
+ # zero out the last layer params,so the conv block is identity
186
+ nn.init.zeros_(self.proj_out.weight)
187
+ nn.init.zeros_(self.proj_out.bias)
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ femb,
193
+ encoder_hidden_states=None,
194
+ timestep=None,
195
+ class_labels=None,
196
+ num_frames=1,
197
+ cross_attention_kwargs=None,
198
+ sample_index: torch.LongTensor = None,
199
+ vision_conditon_frames_sample_index: torch.LongTensor = None,
200
+ spatial_position_emb: torch.Tensor = None,
201
+ return_dict: bool = True,
202
+ ):
203
+ """
204
+ Args:
205
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
206
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
207
+ hidden_states
208
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
209
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
210
+ self-attention.
211
+ timestep ( `torch.long`, *optional*):
212
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
213
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
214
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
215
+ conditioning.
216
+ return_dict (`bool`, *optional*, defaults to `True`):
217
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
218
+
219
+ Returns:
220
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
221
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
222
+ When returning a tuple, the first element is the sample tensor.
223
+ """
224
+ if self.skip_temporal_layers is True:
225
+ if not return_dict:
226
+ return (hidden_states,)
227
+
228
+ return TransformerTemporalModelOutput(sample=hidden_states)
229
+
230
+ # 1. Input
231
+ batch_frames, channel, height, width = hidden_states.shape
232
+ batch_size = batch_frames // num_frames
233
+
234
+ hidden_states = rearrange(
235
+ hidden_states, "(b t) c h w -> b c t h w", b=batch_size
236
+ )
237
+ residual = hidden_states
238
+
239
+ hidden_states = self.norm(hidden_states)
240
+
241
+ hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c")
242
+
243
+ hidden_states = self.proj_in(hidden_states)
244
+
245
+ # 2 Positional embedding
246
+ # adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574
247
+ if not self.remove_femb_non_linear:
248
+ femb = self.nonlinearity(femb)
249
+ femb = self.frame_emb_proj(femb)
250
+ femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0)
251
+ hidden_states = hidden_states + femb
252
+
253
+ # 3. Blocks
254
+ if (
255
+ (self.only_cross_attention or not self.double_self_attention)
256
+ and self.cross_attention_dim is not None
257
+ and encoder_hidden_states is not None
258
+ ):
259
+ encoder_hidden_states = align_repeat_tensor_single_dim(
260
+ encoder_hidden_states,
261
+ hidden_states.shape[0],
262
+ dim=0,
263
+ n_src_base_length=batch_size,
264
+ )
265
+
266
+ for i, block in enumerate(self.transformer_blocks):
267
+ hidden_states = block(
268
+ hidden_states,
269
+ encoder_hidden_states=encoder_hidden_states,
270
+ timestep=timestep,
271
+ cross_attention_kwargs=cross_attention_kwargs,
272
+ class_labels=class_labels,
273
+ )
274
+
275
+ # 4. Output
276
+ hidden_states = self.proj_out(hidden_states)
277
+ hidden_states = rearrange(
278
+ hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width
279
+ ).contiguous()
280
+
281
+ # 保留condition对应的frames,便于保持前序内容帧,提升一致性
282
+ # keep the frames corresponding to the condition to maintain the previous content frames and improve consistency
283
+ if (
284
+ vision_conditon_frames_sample_index is not None
285
+ and self.keep_content_condition
286
+ ):
287
+ mask = torch.ones_like(hidden_states, device=hidden_states.device)
288
+ mask = batch_index_fill(
289
+ mask, dim=2, index=vision_conditon_frames_sample_index, value=0
290
+ )
291
+ if self.need_temporal_weight:
292
+ output = (
293
+ residual + torch.abs(self.temporal_weight) * mask * hidden_states
294
+ )
295
+ else:
296
+ output = residual + mask * hidden_states
297
+ else:
298
+ if self.need_temporal_weight:
299
+ output = residual + torch.abs(self.temporal_weight) * hidden_states
300
+ else:
301
+ output = residual + mask * hidden_states
302
+
303
+ # output = torch.abs(self.temporal_weight) * hidden_states + residual
304
+ output = rearrange(output, "b c t h w -> (b t) c h w")
305
+ if not return_dict:
306
+ return (output,)
307
+
308
+ return TransformerTemporalModelOutput(sample=output)
musev/models/text_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from torch import nn
3
+
4
+
5
+ class TextEmbExtractor(nn.Module):
6
+ def __init__(self, tokenizer, text_encoder) -> None:
7
+ super(TextEmbExtractor, self).__init__()
8
+ self.tokenizer = tokenizer
9
+ self.text_encoder = text_encoder
10
+
11
+ def forward(
12
+ self,
13
+ texts,
14
+ text_params: Dict = None,
15
+ ):
16
+ if text_params is None:
17
+ text_params = {}
18
+ special_prompt_input = self.tokenizer(
19
+ texts,
20
+ max_length=self.tokenizer.model_max_length,
21
+ padding="max_length",
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ if (
26
+ hasattr(self.text_encoder.config, "use_attention_mask")
27
+ and self.text_encoder.config.use_attention_mask
28
+ ):
29
+ attention_mask = special_prompt_input.attention_mask.to(
30
+ self.text_encoder.device
31
+ )
32
+ else:
33
+ attention_mask = None
34
+
35
+ embeddings = self.text_encoder(
36
+ special_prompt_input.input_ids.to(self.text_encoder.device),
37
+ attention_mask=attention_mask,
38
+ **text_params
39
+ )
40
+ return embeddings
musev/models/transformer_2d.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Literal, Optional
17
+ import logging
18
+
19
+ from einops import rearrange
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+ from diffusers.models.transformer_2d import (
26
+ Transformer2DModelOutput,
27
+ Transformer2DModel as DiffusersTransformer2DModel,
28
+ )
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
32
+ from diffusers.utils import BaseOutput, deprecate
33
+ from diffusers.models.attention import (
34
+ BasicTransformerBlock as DiffusersBasicTransformerBlock,
35
+ )
36
+ from diffusers.models.embeddings import PatchEmbed
37
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from diffusers.utils.constants import USE_PEFT_BACKEND
40
+
41
+ from .attention import BasicTransformerBlock
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # 本部分 与 diffusers/models/transformer_2d.py 几乎一样
46
+ # 更新部分
47
+ # 1. 替换自定义 BasicTransformerBlock 类
48
+ # 2. 在forward 里增加了 self_attn_block_embs 用于 提取 self_attn 中的emb
49
+
50
+ # this module is same as diffusers/models/transformer_2d.py. The update part is
51
+ # 1 redefine BasicTransformerBlock
52
+ # 2. add self_attn_block_embs in forward to extract emb from self_attn
53
+
54
+
55
+ class Transformer2DModel(DiffusersTransformer2DModel):
56
+ """
57
+ A 2D Transformer model for image-like data.
58
+
59
+ Parameters:
60
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
61
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
62
+ in_channels (`int`, *optional*):
63
+ The number of channels in the input and output (specify if the input is **continuous**).
64
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
65
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
67
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
68
+ This is fixed during training since it is used to learn a number of position embeddings.
69
+ num_vector_embeds (`int`, *optional*):
70
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
71
+ Includes the class for the masked latent pixel.
72
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
73
+ num_embeds_ada_norm ( `int`, *optional*):
74
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
75
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
76
+ added to the hidden states.
77
+
78
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
79
+ attention_bias (`bool`, *optional*):
80
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
81
+ """
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ num_attention_heads: int = 16,
87
+ attention_head_dim: int = 88,
88
+ in_channels: int | None = None,
89
+ out_channels: int | None = None,
90
+ num_layers: int = 1,
91
+ dropout: float = 0,
92
+ norm_num_groups: int = 32,
93
+ cross_attention_dim: int | None = None,
94
+ attention_bias: bool = False,
95
+ sample_size: int | None = None,
96
+ num_vector_embeds: int | None = None,
97
+ patch_size: int | None = None,
98
+ activation_fn: str = "geglu",
99
+ num_embeds_ada_norm: int | None = None,
100
+ use_linear_projection: bool = False,
101
+ only_cross_attention: bool = False,
102
+ double_self_attention: bool = False,
103
+ upcast_attention: bool = False,
104
+ norm_type: str = "layer_norm",
105
+ norm_elementwise_affine: bool = True,
106
+ attention_type: str = "default",
107
+ cross_attn_temporal_cond: bool = False,
108
+ ip_adapter_cross_attn: bool = False,
109
+ need_t2i_facein: bool = False,
110
+ need_t2i_ip_adapter_face: bool = False,
111
+ image_scale: float = 1.0,
112
+ ):
113
+ super().__init__(
114
+ num_attention_heads,
115
+ attention_head_dim,
116
+ in_channels,
117
+ out_channels,
118
+ num_layers,
119
+ dropout,
120
+ norm_num_groups,
121
+ cross_attention_dim,
122
+ attention_bias,
123
+ sample_size,
124
+ num_vector_embeds,
125
+ patch_size,
126
+ activation_fn,
127
+ num_embeds_ada_norm,
128
+ use_linear_projection,
129
+ only_cross_attention,
130
+ double_self_attention,
131
+ upcast_attention,
132
+ norm_type,
133
+ norm_elementwise_affine,
134
+ attention_type,
135
+ )
136
+ inner_dim = num_attention_heads * attention_head_dim
137
+ self.transformer_blocks = nn.ModuleList(
138
+ [
139
+ BasicTransformerBlock(
140
+ inner_dim,
141
+ num_attention_heads,
142
+ attention_head_dim,
143
+ dropout=dropout,
144
+ cross_attention_dim=cross_attention_dim,
145
+ activation_fn=activation_fn,
146
+ num_embeds_ada_norm=num_embeds_ada_norm,
147
+ attention_bias=attention_bias,
148
+ only_cross_attention=only_cross_attention,
149
+ double_self_attention=double_self_attention,
150
+ upcast_attention=upcast_attention,
151
+ norm_type=norm_type,
152
+ norm_elementwise_affine=norm_elementwise_affine,
153
+ attention_type=attention_type,
154
+ cross_attn_temporal_cond=cross_attn_temporal_cond,
155
+ ip_adapter_cross_attn=ip_adapter_cross_attn,
156
+ need_t2i_facein=need_t2i_facein,
157
+ need_t2i_ip_adapter_face=need_t2i_ip_adapter_face,
158
+ image_scale=image_scale,
159
+ )
160
+ for d in range(num_layers)
161
+ ]
162
+ )
163
+ self.num_layers = num_layers
164
+ self.cross_attn_temporal_cond = cross_attn_temporal_cond
165
+ self.ip_adapter_cross_attn = ip_adapter_cross_attn
166
+
167
+ self.need_t2i_facein = need_t2i_facein
168
+ self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
169
+ self.image_scale = image_scale
170
+ self.print_idx = 0
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states: torch.Tensor,
175
+ encoder_hidden_states: Optional[torch.Tensor] = None,
176
+ timestep: Optional[torch.LongTensor] = None,
177
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
178
+ class_labels: Optional[torch.LongTensor] = None,
179
+ cross_attention_kwargs: Dict[str, Any] = None,
180
+ attention_mask: Optional[torch.Tensor] = None,
181
+ encoder_attention_mask: Optional[torch.Tensor] = None,
182
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
183
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
184
+ return_dict: bool = True,
185
+ ):
186
+ """
187
+ The [`Transformer2DModel`] forward method.
188
+
189
+ Args:
190
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
191
+ Input `hidden_states`.
192
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
193
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
194
+ self-attention.
195
+ timestep ( `torch.LongTensor`, *optional*):
196
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
197
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
198
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
199
+ `AdaLayerZeroNorm`.
200
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
201
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
202
+ `self.processor` in
203
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
204
+ attention_mask ( `torch.Tensor`, *optional*):
205
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
206
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
207
+ negative values to the attention scores corresponding to "discard" tokens.
208
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
209
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
210
+
211
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
212
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
213
+
214
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
215
+ above. This bias will be added to the cross-attention scores.
216
+ return_dict (`bool`, *optional*, defaults to `True`):
217
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
218
+ tuple.
219
+
220
+ Returns:
221
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
222
+ `tuple` where the first element is the sample tensor.
223
+ """
224
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
225
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
226
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
227
+ # expects mask of shape:
228
+ # [batch, key_tokens]
229
+ # adds singleton query_tokens dimension:
230
+ # [batch, 1, key_tokens]
231
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
232
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
233
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
234
+ if attention_mask is not None and attention_mask.ndim == 2:
235
+ # assume that mask is expressed as:
236
+ # (1 = keep, 0 = discard)
237
+ # convert mask into a bias that can be added to attention scores:
238
+ # (keep = +0, discard = -10000.0)
239
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
240
+ attention_mask = attention_mask.unsqueeze(1)
241
+
242
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
243
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
244
+ encoder_attention_mask = (
245
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
246
+ ) * -10000.0
247
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
248
+
249
+ # Retrieve lora scale.
250
+ lora_scale = (
251
+ cross_attention_kwargs.get("scale", 1.0)
252
+ if cross_attention_kwargs is not None
253
+ else 1.0
254
+ )
255
+
256
+ # 1. Input
257
+ if self.is_input_continuous:
258
+ batch, _, height, width = hidden_states.shape
259
+ residual = hidden_states
260
+
261
+ hidden_states = self.norm(hidden_states)
262
+ if not self.use_linear_projection:
263
+ hidden_states = (
264
+ self.proj_in(hidden_states, scale=lora_scale)
265
+ if not USE_PEFT_BACKEND
266
+ else self.proj_in(hidden_states)
267
+ )
268
+ inner_dim = hidden_states.shape[1]
269
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
270
+ batch, height * width, inner_dim
271
+ )
272
+ else:
273
+ inner_dim = hidden_states.shape[1]
274
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
275
+ batch, height * width, inner_dim
276
+ )
277
+ hidden_states = (
278
+ self.proj_in(hidden_states, scale=lora_scale)
279
+ if not USE_PEFT_BACKEND
280
+ else self.proj_in(hidden_states)
281
+ )
282
+
283
+ elif self.is_input_vectorized:
284
+ hidden_states = self.latent_image_embedding(hidden_states)
285
+ elif self.is_input_patches:
286
+ height, width = (
287
+ hidden_states.shape[-2] // self.patch_size,
288
+ hidden_states.shape[-1] // self.patch_size,
289
+ )
290
+ hidden_states = self.pos_embed(hidden_states)
291
+
292
+ if self.adaln_single is not None:
293
+ if self.use_additional_conditions and added_cond_kwargs is None:
294
+ raise ValueError(
295
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
296
+ )
297
+ batch_size = hidden_states.shape[0]
298
+ timestep, embedded_timestep = self.adaln_single(
299
+ timestep,
300
+ added_cond_kwargs,
301
+ batch_size=batch_size,
302
+ hidden_dtype=hidden_states.dtype,
303
+ )
304
+
305
+ # 2. Blocks
306
+ if self.caption_projection is not None:
307
+ batch_size = hidden_states.shape[0]
308
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
309
+ encoder_hidden_states = encoder_hidden_states.view(
310
+ batch_size, -1, hidden_states.shape[-1]
311
+ )
312
+
313
+ for block in self.transformer_blocks:
314
+ if self.training and self.gradient_checkpointing:
315
+ hidden_states = torch.utils.checkpoint.checkpoint(
316
+ block,
317
+ hidden_states,
318
+ attention_mask,
319
+ encoder_hidden_states,
320
+ encoder_attention_mask,
321
+ timestep,
322
+ cross_attention_kwargs,
323
+ class_labels,
324
+ self_attn_block_embs,
325
+ self_attn_block_embs_mode,
326
+ use_reentrant=False,
327
+ )
328
+ else:
329
+ hidden_states = block(
330
+ hidden_states,
331
+ attention_mask=attention_mask,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ encoder_attention_mask=encoder_attention_mask,
334
+ timestep=timestep,
335
+ cross_attention_kwargs=cross_attention_kwargs,
336
+ class_labels=class_labels,
337
+ self_attn_block_embs=self_attn_block_embs,
338
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
339
+ )
340
+ # 将 转换 self_attn_emb的尺寸
341
+ if (
342
+ self_attn_block_embs is not None
343
+ and self_attn_block_embs_mode.lower() == "write"
344
+ ):
345
+ self_attn_idx = block.spatial_self_attn_idx
346
+ if self.print_idx == 0:
347
+ logger.debug(
348
+ f"self_attn_block_embs, num={len(self_attn_block_embs)}, before, shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
349
+ )
350
+ self_attn_block_embs[self_attn_idx] = rearrange(
351
+ self_attn_block_embs[self_attn_idx],
352
+ "bt (h w) c->bt c h w",
353
+ h=height,
354
+ w=width,
355
+ )
356
+ if self.print_idx == 0:
357
+ logger.debug(
358
+ f"self_attn_block_embs, num={len(self_attn_block_embs)}, after ,shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}"
359
+ )
360
+
361
+ if self.proj_out is None:
362
+ return hidden_states
363
+
364
+ # 3. Output
365
+ if self.is_input_continuous:
366
+ if not self.use_linear_projection:
367
+ hidden_states = (
368
+ hidden_states.reshape(batch, height, width, inner_dim)
369
+ .permute(0, 3, 1, 2)
370
+ .contiguous()
371
+ )
372
+ hidden_states = (
373
+ self.proj_out(hidden_states, scale=lora_scale)
374
+ if not USE_PEFT_BACKEND
375
+ else self.proj_out(hidden_states)
376
+ )
377
+ else:
378
+ hidden_states = (
379
+ self.proj_out(hidden_states, scale=lora_scale)
380
+ if not USE_PEFT_BACKEND
381
+ else self.proj_out(hidden_states)
382
+ )
383
+ hidden_states = (
384
+ hidden_states.reshape(batch, height, width, inner_dim)
385
+ .permute(0, 3, 1, 2)
386
+ .contiguous()
387
+ )
388
+
389
+ output = hidden_states + residual
390
+ elif self.is_input_vectorized:
391
+ hidden_states = self.norm_out(hidden_states)
392
+ logits = self.out(hidden_states)
393
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
394
+ logits = logits.permute(0, 2, 1)
395
+
396
+ # log(p(x_0))
397
+ output = F.log_softmax(logits.double(), dim=1).float()
398
+
399
+ if self.is_input_patches:
400
+ if self.config.norm_type != "ada_norm_single":
401
+ conditioning = self.transformer_blocks[0].norm1.emb(
402
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
403
+ )
404
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
405
+ hidden_states = (
406
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
407
+ )
408
+ hidden_states = self.proj_out_2(hidden_states)
409
+ elif self.config.norm_type == "ada_norm_single":
410
+ shift, scale = (
411
+ self.scale_shift_table[None] + embedded_timestep[:, None]
412
+ ).chunk(2, dim=1)
413
+ hidden_states = self.norm_out(hidden_states)
414
+ # Modulation
415
+ hidden_states = hidden_states * (1 + scale) + shift
416
+ hidden_states = self.proj_out(hidden_states)
417
+ hidden_states = hidden_states.squeeze(1)
418
+
419
+ # unpatchify
420
+ if self.adaln_single is None:
421
+ height = width = int(hidden_states.shape[1] ** 0.5)
422
+ hidden_states = hidden_states.reshape(
423
+ shape=(
424
+ -1,
425
+ height,
426
+ width,
427
+ self.patch_size,
428
+ self.patch_size,
429
+ self.out_channels,
430
+ )
431
+ )
432
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
433
+ output = hidden_states.reshape(
434
+ shape=(
435
+ -1,
436
+ self.out_channels,
437
+ height * self.patch_size,
438
+ width * self.patch_size,
439
+ )
440
+ )
441
+ self.print_idx += 1
442
+ if not return_dict:
443
+ return (output,)
444
+
445
+ return Transformer2DModelOutput(sample=output)
musev/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Literal, Optional, Tuple, Union, List
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.utils.torch_utils import apply_freeu
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import (
25
+ Attention,
26
+ AttnAddedKVProcessor,
27
+ AttnAddedKVProcessor2_0,
28
+ )
29
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
30
+ from diffusers.models.normalization import AdaGroupNorm
31
+ from diffusers.models.resnet import (
32
+ Downsample2D,
33
+ FirDownsample2D,
34
+ FirUpsample2D,
35
+ KDownsample2D,
36
+ KUpsample2D,
37
+ ResnetBlock2D,
38
+ Upsample2D,
39
+ )
40
+ from diffusers.models.unet_2d_blocks import (
41
+ AttnDownBlock2D,
42
+ AttnDownEncoderBlock2D,
43
+ AttnSkipDownBlock2D,
44
+ AttnSkipUpBlock2D,
45
+ AttnUpBlock2D,
46
+ AttnUpDecoderBlock2D,
47
+ DownEncoderBlock2D,
48
+ KCrossAttnDownBlock2D,
49
+ KCrossAttnUpBlock2D,
50
+ KDownBlock2D,
51
+ KUpBlock2D,
52
+ ResnetDownsampleBlock2D,
53
+ ResnetUpsampleBlock2D,
54
+ SimpleCrossAttnDownBlock2D,
55
+ SimpleCrossAttnUpBlock2D,
56
+ SkipDownBlock2D,
57
+ SkipUpBlock2D,
58
+ UpDecoderBlock2D,
59
+ )
60
+
61
+ from .transformer_2d import Transformer2DModel
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ def get_down_block(
68
+ down_block_type: str,
69
+ num_layers: int,
70
+ in_channels: int,
71
+ out_channels: int,
72
+ temb_channels: int,
73
+ add_downsample: bool,
74
+ resnet_eps: float,
75
+ resnet_act_fn: str,
76
+ transformer_layers_per_block: int = 1,
77
+ num_attention_heads: Optional[int] = None,
78
+ resnet_groups: Optional[int] = None,
79
+ cross_attention_dim: Optional[int] = None,
80
+ downsample_padding: Optional[int] = None,
81
+ dual_cross_attention: bool = False,
82
+ use_linear_projection: bool = False,
83
+ only_cross_attention: bool = False,
84
+ upcast_attention: bool = False,
85
+ resnet_time_scale_shift: str = "default",
86
+ attention_type: str = "default",
87
+ resnet_skip_time_act: bool = False,
88
+ resnet_out_scale_factor: float = 1.0,
89
+ cross_attention_norm: Optional[str] = None,
90
+ attention_head_dim: Optional[int] = None,
91
+ downsample_type: Optional[str] = None,
92
+ dropout: float = 0.0,
93
+ ):
94
+ # If attn head dim is not defined, we default it to the number of heads
95
+ if attention_head_dim is None:
96
+ logger.warn(
97
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
98
+ )
99
+ attention_head_dim = num_attention_heads
100
+
101
+ down_block_type = (
102
+ down_block_type[7:]
103
+ if down_block_type.startswith("UNetRes")
104
+ else down_block_type
105
+ )
106
+ if down_block_type == "DownBlock2D":
107
+ return DownBlock2D(
108
+ num_layers=num_layers,
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ temb_channels=temb_channels,
112
+ dropout=dropout,
113
+ add_downsample=add_downsample,
114
+ resnet_eps=resnet_eps,
115
+ resnet_act_fn=resnet_act_fn,
116
+ resnet_groups=resnet_groups,
117
+ downsample_padding=downsample_padding,
118
+ resnet_time_scale_shift=resnet_time_scale_shift,
119
+ )
120
+ elif down_block_type == "ResnetDownsampleBlock2D":
121
+ return ResnetDownsampleBlock2D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ temb_channels=temb_channels,
126
+ dropout=dropout,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+ skip_time_act=resnet_skip_time_act,
133
+ output_scale_factor=resnet_out_scale_factor,
134
+ )
135
+ elif down_block_type == "AttnDownBlock2D":
136
+ if add_downsample is False:
137
+ downsample_type = None
138
+ else:
139
+ downsample_type = downsample_type or "conv" # default to 'conv'
140
+ return AttnDownBlock2D(
141
+ num_layers=num_layers,
142
+ in_channels=in_channels,
143
+ out_channels=out_channels,
144
+ temb_channels=temb_channels,
145
+ dropout=dropout,
146
+ resnet_eps=resnet_eps,
147
+ resnet_act_fn=resnet_act_fn,
148
+ resnet_groups=resnet_groups,
149
+ downsample_padding=downsample_padding,
150
+ attention_head_dim=attention_head_dim,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+ downsample_type=downsample_type,
153
+ )
154
+ elif down_block_type == "CrossAttnDownBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError(
157
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
158
+ )
159
+ return CrossAttnDownBlock2D(
160
+ num_layers=num_layers,
161
+ transformer_layers_per_block=transformer_layers_per_block,
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ temb_channels=temb_channels,
165
+ dropout=dropout,
166
+ add_downsample=add_downsample,
167
+ resnet_eps=resnet_eps,
168
+ resnet_act_fn=resnet_act_fn,
169
+ resnet_groups=resnet_groups,
170
+ downsample_padding=downsample_padding,
171
+ cross_attention_dim=cross_attention_dim,
172
+ num_attention_heads=num_attention_heads,
173
+ dual_cross_attention=dual_cross_attention,
174
+ use_linear_projection=use_linear_projection,
175
+ only_cross_attention=only_cross_attention,
176
+ upcast_attention=upcast_attention,
177
+ resnet_time_scale_shift=resnet_time_scale_shift,
178
+ attention_type=attention_type,
179
+ )
180
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
181
+ if cross_attention_dim is None:
182
+ raise ValueError(
183
+ "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
184
+ )
185
+ return SimpleCrossAttnDownBlock2D(
186
+ num_layers=num_layers,
187
+ in_channels=in_channels,
188
+ out_channels=out_channels,
189
+ temb_channels=temb_channels,
190
+ dropout=dropout,
191
+ add_downsample=add_downsample,
192
+ resnet_eps=resnet_eps,
193
+ resnet_act_fn=resnet_act_fn,
194
+ resnet_groups=resnet_groups,
195
+ cross_attention_dim=cross_attention_dim,
196
+ attention_head_dim=attention_head_dim,
197
+ resnet_time_scale_shift=resnet_time_scale_shift,
198
+ skip_time_act=resnet_skip_time_act,
199
+ output_scale_factor=resnet_out_scale_factor,
200
+ only_cross_attention=only_cross_attention,
201
+ cross_attention_norm=cross_attention_norm,
202
+ )
203
+ elif down_block_type == "SkipDownBlock2D":
204
+ return SkipDownBlock2D(
205
+ num_layers=num_layers,
206
+ in_channels=in_channels,
207
+ out_channels=out_channels,
208
+ temb_channels=temb_channels,
209
+ dropout=dropout,
210
+ add_downsample=add_downsample,
211
+ resnet_eps=resnet_eps,
212
+ resnet_act_fn=resnet_act_fn,
213
+ downsample_padding=downsample_padding,
214
+ resnet_time_scale_shift=resnet_time_scale_shift,
215
+ )
216
+ elif down_block_type == "AttnSkipDownBlock2D":
217
+ return AttnSkipDownBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ temb_channels=temb_channels,
222
+ dropout=dropout,
223
+ add_downsample=add_downsample,
224
+ resnet_eps=resnet_eps,
225
+ resnet_act_fn=resnet_act_fn,
226
+ attention_head_dim=attention_head_dim,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "DownEncoderBlock2D":
230
+ return DownEncoderBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ dropout=dropout,
235
+ add_downsample=add_downsample,
236
+ resnet_eps=resnet_eps,
237
+ resnet_act_fn=resnet_act_fn,
238
+ resnet_groups=resnet_groups,
239
+ downsample_padding=downsample_padding,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ )
242
+ elif down_block_type == "AttnDownEncoderBlock2D":
243
+ return AttnDownEncoderBlock2D(
244
+ num_layers=num_layers,
245
+ in_channels=in_channels,
246
+ out_channels=out_channels,
247
+ dropout=dropout,
248
+ add_downsample=add_downsample,
249
+ resnet_eps=resnet_eps,
250
+ resnet_act_fn=resnet_act_fn,
251
+ resnet_groups=resnet_groups,
252
+ downsample_padding=downsample_padding,
253
+ attention_head_dim=attention_head_dim,
254
+ resnet_time_scale_shift=resnet_time_scale_shift,
255
+ )
256
+ elif down_block_type == "KDownBlock2D":
257
+ return KDownBlock2D(
258
+ num_layers=num_layers,
259
+ in_channels=in_channels,
260
+ out_channels=out_channels,
261
+ temb_channels=temb_channels,
262
+ dropout=dropout,
263
+ add_downsample=add_downsample,
264
+ resnet_eps=resnet_eps,
265
+ resnet_act_fn=resnet_act_fn,
266
+ )
267
+ elif down_block_type == "KCrossAttnDownBlock2D":
268
+ return KCrossAttnDownBlock2D(
269
+ num_layers=num_layers,
270
+ in_channels=in_channels,
271
+ out_channels=out_channels,
272
+ temb_channels=temb_channels,
273
+ dropout=dropout,
274
+ add_downsample=add_downsample,
275
+ resnet_eps=resnet_eps,
276
+ resnet_act_fn=resnet_act_fn,
277
+ cross_attention_dim=cross_attention_dim,
278
+ attention_head_dim=attention_head_dim,
279
+ add_self_attention=True if not add_downsample else False,
280
+ )
281
+ raise ValueError(f"{down_block_type} does not exist.")
282
+
283
+
284
+ def get_up_block(
285
+ up_block_type: str,
286
+ num_layers: int,
287
+ in_channels: int,
288
+ out_channels: int,
289
+ prev_output_channel: int,
290
+ temb_channels: int,
291
+ add_upsample: bool,
292
+ resnet_eps: float,
293
+ resnet_act_fn: str,
294
+ resolution_idx: Optional[int] = None,
295
+ transformer_layers_per_block: int = 1,
296
+ num_attention_heads: Optional[int] = None,
297
+ resnet_groups: Optional[int] = None,
298
+ cross_attention_dim: Optional[int] = None,
299
+ dual_cross_attention: bool = False,
300
+ use_linear_projection: bool = False,
301
+ only_cross_attention: bool = False,
302
+ upcast_attention: bool = False,
303
+ resnet_time_scale_shift: str = "default",
304
+ attention_type: str = "default",
305
+ resnet_skip_time_act: bool = False,
306
+ resnet_out_scale_factor: float = 1.0,
307
+ cross_attention_norm: Optional[str] = None,
308
+ attention_head_dim: Optional[int] = None,
309
+ upsample_type: Optional[str] = None,
310
+ dropout: float = 0.0,
311
+ ) -> nn.Module:
312
+ # If attn head dim is not defined, we default it to the number of heads
313
+ if attention_head_dim is None:
314
+ logger.warn(
315
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
316
+ )
317
+ attention_head_dim = num_attention_heads
318
+
319
+ up_block_type = (
320
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
321
+ )
322
+ if up_block_type == "UpBlock2D":
323
+ return UpBlock2D(
324
+ num_layers=num_layers,
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ prev_output_channel=prev_output_channel,
328
+ temb_channels=temb_channels,
329
+ resolution_idx=resolution_idx,
330
+ dropout=dropout,
331
+ add_upsample=add_upsample,
332
+ resnet_eps=resnet_eps,
333
+ resnet_act_fn=resnet_act_fn,
334
+ resnet_groups=resnet_groups,
335
+ resnet_time_scale_shift=resnet_time_scale_shift,
336
+ )
337
+ elif up_block_type == "ResnetUpsampleBlock2D":
338
+ return ResnetUpsampleBlock2D(
339
+ num_layers=num_layers,
340
+ in_channels=in_channels,
341
+ out_channels=out_channels,
342
+ prev_output_channel=prev_output_channel,
343
+ temb_channels=temb_channels,
344
+ resolution_idx=resolution_idx,
345
+ dropout=dropout,
346
+ add_upsample=add_upsample,
347
+ resnet_eps=resnet_eps,
348
+ resnet_act_fn=resnet_act_fn,
349
+ resnet_groups=resnet_groups,
350
+ resnet_time_scale_shift=resnet_time_scale_shift,
351
+ skip_time_act=resnet_skip_time_act,
352
+ output_scale_factor=resnet_out_scale_factor,
353
+ )
354
+ elif up_block_type == "CrossAttnUpBlock2D":
355
+ if cross_attention_dim is None:
356
+ raise ValueError(
357
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
358
+ )
359
+ return CrossAttnUpBlock2D(
360
+ num_layers=num_layers,
361
+ transformer_layers_per_block=transformer_layers_per_block,
362
+ in_channels=in_channels,
363
+ out_channels=out_channels,
364
+ prev_output_channel=prev_output_channel,
365
+ temb_channels=temb_channels,
366
+ resolution_idx=resolution_idx,
367
+ dropout=dropout,
368
+ add_upsample=add_upsample,
369
+ resnet_eps=resnet_eps,
370
+ resnet_act_fn=resnet_act_fn,
371
+ resnet_groups=resnet_groups,
372
+ cross_attention_dim=cross_attention_dim,
373
+ num_attention_heads=num_attention_heads,
374
+ dual_cross_attention=dual_cross_attention,
375
+ use_linear_projection=use_linear_projection,
376
+ only_cross_attention=only_cross_attention,
377
+ upcast_attention=upcast_attention,
378
+ resnet_time_scale_shift=resnet_time_scale_shift,
379
+ attention_type=attention_type,
380
+ )
381
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
382
+ if cross_attention_dim is None:
383
+ raise ValueError(
384
+ "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
385
+ )
386
+ return SimpleCrossAttnUpBlock2D(
387
+ num_layers=num_layers,
388
+ in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ prev_output_channel=prev_output_channel,
391
+ temb_channels=temb_channels,
392
+ resolution_idx=resolution_idx,
393
+ dropout=dropout,
394
+ add_upsample=add_upsample,
395
+ resnet_eps=resnet_eps,
396
+ resnet_act_fn=resnet_act_fn,
397
+ resnet_groups=resnet_groups,
398
+ cross_attention_dim=cross_attention_dim,
399
+ attention_head_dim=attention_head_dim,
400
+ resnet_time_scale_shift=resnet_time_scale_shift,
401
+ skip_time_act=resnet_skip_time_act,
402
+ output_scale_factor=resnet_out_scale_factor,
403
+ only_cross_attention=only_cross_attention,
404
+ cross_attention_norm=cross_attention_norm,
405
+ )
406
+ elif up_block_type == "AttnUpBlock2D":
407
+ if add_upsample is False:
408
+ upsample_type = None
409
+ else:
410
+ upsample_type = upsample_type or "conv" # default to 'conv'
411
+
412
+ return AttnUpBlock2D(
413
+ num_layers=num_layers,
414
+ in_channels=in_channels,
415
+ out_channels=out_channels,
416
+ prev_output_channel=prev_output_channel,
417
+ temb_channels=temb_channels,
418
+ resolution_idx=resolution_idx,
419
+ dropout=dropout,
420
+ resnet_eps=resnet_eps,
421
+ resnet_act_fn=resnet_act_fn,
422
+ resnet_groups=resnet_groups,
423
+ attention_head_dim=attention_head_dim,
424
+ resnet_time_scale_shift=resnet_time_scale_shift,
425
+ upsample_type=upsample_type,
426
+ )
427
+ elif up_block_type == "SkipUpBlock2D":
428
+ return SkipUpBlock2D(
429
+ num_layers=num_layers,
430
+ in_channels=in_channels,
431
+ out_channels=out_channels,
432
+ prev_output_channel=prev_output_channel,
433
+ temb_channels=temb_channels,
434
+ resolution_idx=resolution_idx,
435
+ dropout=dropout,
436
+ add_upsample=add_upsample,
437
+ resnet_eps=resnet_eps,
438
+ resnet_act_fn=resnet_act_fn,
439
+ resnet_time_scale_shift=resnet_time_scale_shift,
440
+ )
441
+ elif up_block_type == "AttnSkipUpBlock2D":
442
+ return AttnSkipUpBlock2D(
443
+ num_layers=num_layers,
444
+ in_channels=in_channels,
445
+ out_channels=out_channels,
446
+ prev_output_channel=prev_output_channel,
447
+ temb_channels=temb_channels,
448
+ resolution_idx=resolution_idx,
449
+ dropout=dropout,
450
+ add_upsample=add_upsample,
451
+ resnet_eps=resnet_eps,
452
+ resnet_act_fn=resnet_act_fn,
453
+ attention_head_dim=attention_head_dim,
454
+ resnet_time_scale_shift=resnet_time_scale_shift,
455
+ )
456
+ elif up_block_type == "UpDecoderBlock2D":
457
+ return UpDecoderBlock2D(
458
+ num_layers=num_layers,
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ resolution_idx=resolution_idx,
462
+ dropout=dropout,
463
+ add_upsample=add_upsample,
464
+ resnet_eps=resnet_eps,
465
+ resnet_act_fn=resnet_act_fn,
466
+ resnet_groups=resnet_groups,
467
+ resnet_time_scale_shift=resnet_time_scale_shift,
468
+ temb_channels=temb_channels,
469
+ )
470
+ elif up_block_type == "AttnUpDecoderBlock2D":
471
+ return AttnUpDecoderBlock2D(
472
+ num_layers=num_layers,
473
+ in_channels=in_channels,
474
+ out_channels=out_channels,
475
+ resolution_idx=resolution_idx,
476
+ dropout=dropout,
477
+ add_upsample=add_upsample,
478
+ resnet_eps=resnet_eps,
479
+ resnet_act_fn=resnet_act_fn,
480
+ resnet_groups=resnet_groups,
481
+ attention_head_dim=attention_head_dim,
482
+ resnet_time_scale_shift=resnet_time_scale_shift,
483
+ temb_channels=temb_channels,
484
+ )
485
+ elif up_block_type == "KUpBlock2D":
486
+ return KUpBlock2D(
487
+ num_layers=num_layers,
488
+ in_channels=in_channels,
489
+ out_channels=out_channels,
490
+ temb_channels=temb_channels,
491
+ resolution_idx=resolution_idx,
492
+ dropout=dropout,
493
+ add_upsample=add_upsample,
494
+ resnet_eps=resnet_eps,
495
+ resnet_act_fn=resnet_act_fn,
496
+ )
497
+ elif up_block_type == "KCrossAttnUpBlock2D":
498
+ return KCrossAttnUpBlock2D(
499
+ num_layers=num_layers,
500
+ in_channels=in_channels,
501
+ out_channels=out_channels,
502
+ temb_channels=temb_channels,
503
+ resolution_idx=resolution_idx,
504
+ dropout=dropout,
505
+ add_upsample=add_upsample,
506
+ resnet_eps=resnet_eps,
507
+ resnet_act_fn=resnet_act_fn,
508
+ cross_attention_dim=cross_attention_dim,
509
+ attention_head_dim=attention_head_dim,
510
+ )
511
+
512
+ raise ValueError(f"{up_block_type} does not exist.")
513
+
514
+
515
+ class UNetMidBlock2D(nn.Module):
516
+ """
517
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
518
+
519
+ Args:
520
+ in_channels (`int`): The number of input channels.
521
+ temb_channels (`int`): The number of temporal embedding channels.
522
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
523
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
524
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
525
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
526
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
527
+ model on tasks with long-range temporal dependencies.
528
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
529
+ resnet_groups (`int`, *optional*, defaults to 32):
530
+ The number of groups to use in the group normalization layers of the resnet blocks.
531
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
532
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
533
+ Whether to use pre-normalization for the resnet blocks.
534
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
535
+ attention_head_dim (`int`, *optional*, defaults to 1):
536
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
537
+ the number of input channels.
538
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
539
+
540
+ Returns:
541
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
542
+ in_channels, height, width)`.
543
+
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ in_channels: int,
549
+ temb_channels: int,
550
+ dropout: float = 0.0,
551
+ num_layers: int = 1,
552
+ resnet_eps: float = 1e-6,
553
+ resnet_time_scale_shift: str = "default", # default, spatial
554
+ resnet_act_fn: str = "swish",
555
+ resnet_groups: int = 32,
556
+ attn_groups: Optional[int] = None,
557
+ resnet_pre_norm: bool = True,
558
+ add_attention: bool = True,
559
+ attention_head_dim: int = 1,
560
+ output_scale_factor: float = 1.0,
561
+ ):
562
+ super().__init__()
563
+ resnet_groups = (
564
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
565
+ )
566
+ self.add_attention = add_attention
567
+
568
+ if attn_groups is None:
569
+ attn_groups = (
570
+ resnet_groups if resnet_time_scale_shift == "default" else None
571
+ )
572
+
573
+ # there is always at least one resnet
574
+ resnets = [
575
+ ResnetBlock2D(
576
+ in_channels=in_channels,
577
+ out_channels=in_channels,
578
+ temb_channels=temb_channels,
579
+ eps=resnet_eps,
580
+ groups=resnet_groups,
581
+ dropout=dropout,
582
+ time_embedding_norm=resnet_time_scale_shift,
583
+ non_linearity=resnet_act_fn,
584
+ output_scale_factor=output_scale_factor,
585
+ pre_norm=resnet_pre_norm,
586
+ )
587
+ ]
588
+ attentions = []
589
+
590
+ if attention_head_dim is None:
591
+ logger.warn(
592
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
593
+ )
594
+ attention_head_dim = in_channels
595
+
596
+ for _ in range(num_layers):
597
+ if self.add_attention:
598
+ attentions.append(
599
+ Attention(
600
+ in_channels,
601
+ heads=in_channels // attention_head_dim,
602
+ dim_head=attention_head_dim,
603
+ rescale_output_factor=output_scale_factor,
604
+ eps=resnet_eps,
605
+ norm_num_groups=attn_groups,
606
+ spatial_norm_dim=temb_channels
607
+ if resnet_time_scale_shift == "spatial"
608
+ else None,
609
+ residual_connection=True,
610
+ bias=True,
611
+ upcast_softmax=True,
612
+ _from_deprecated_attn_block=True,
613
+ )
614
+ )
615
+ else:
616
+ attentions.append(None)
617
+
618
+ resnets.append(
619
+ ResnetBlock2D(
620
+ in_channels=in_channels,
621
+ out_channels=in_channels,
622
+ temb_channels=temb_channels,
623
+ eps=resnet_eps,
624
+ groups=resnet_groups,
625
+ dropout=dropout,
626
+ time_embedding_norm=resnet_time_scale_shift,
627
+ non_linearity=resnet_act_fn,
628
+ output_scale_factor=output_scale_factor,
629
+ pre_norm=resnet_pre_norm,
630
+ )
631
+ )
632
+
633
+ self.attentions = nn.ModuleList(attentions)
634
+ self.resnets = nn.ModuleList(resnets)
635
+
636
+ def forward(
637
+ self,
638
+ hidden_states: torch.FloatTensor,
639
+ temb: Optional[torch.FloatTensor] = None,
640
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
641
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
642
+ ) -> torch.FloatTensor:
643
+ hidden_states = self.resnets[0](hidden_states, temb)
644
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
645
+ if attn is not None:
646
+ hidden_states = attn(
647
+ hidden_states,
648
+ temb=temb,
649
+ self_attn_block_embs=self_attn_block_embs,
650
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
651
+ )
652
+ hidden_states = resnet(hidden_states, temb)
653
+
654
+ return hidden_states
655
+
656
+
657
+ class UNetMidBlock2DCrossAttn(nn.Module):
658
+ def __init__(
659
+ self,
660
+ in_channels: int,
661
+ temb_channels: int,
662
+ dropout: float = 0.0,
663
+ num_layers: int = 1,
664
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
665
+ resnet_eps: float = 1e-6,
666
+ resnet_time_scale_shift: str = "default",
667
+ resnet_act_fn: str = "swish",
668
+ resnet_groups: int = 32,
669
+ resnet_pre_norm: bool = True,
670
+ num_attention_heads: int = 1,
671
+ output_scale_factor: float = 1.0,
672
+ cross_attention_dim: int = 1280,
673
+ dual_cross_attention: bool = False,
674
+ use_linear_projection: bool = False,
675
+ upcast_attention: bool = False,
676
+ attention_type: str = "default",
677
+ ):
678
+ super().__init__()
679
+
680
+ self.has_cross_attention = True
681
+ self.num_attention_heads = num_attention_heads
682
+ resnet_groups = (
683
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
684
+ )
685
+
686
+ # support for variable transformer layers per block
687
+ if isinstance(transformer_layers_per_block, int):
688
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
689
+
690
+ # there is always at least one resnet
691
+ resnets = [
692
+ ResnetBlock2D(
693
+ in_channels=in_channels,
694
+ out_channels=in_channels,
695
+ temb_channels=temb_channels,
696
+ eps=resnet_eps,
697
+ groups=resnet_groups,
698
+ dropout=dropout,
699
+ time_embedding_norm=resnet_time_scale_shift,
700
+ non_linearity=resnet_act_fn,
701
+ output_scale_factor=output_scale_factor,
702
+ pre_norm=resnet_pre_norm,
703
+ )
704
+ ]
705
+ attentions = []
706
+
707
+ for i in range(num_layers):
708
+ if not dual_cross_attention:
709
+ attentions.append(
710
+ Transformer2DModel(
711
+ num_attention_heads,
712
+ in_channels // num_attention_heads,
713
+ in_channels=in_channels,
714
+ num_layers=transformer_layers_per_block[i],
715
+ cross_attention_dim=cross_attention_dim,
716
+ norm_num_groups=resnet_groups,
717
+ use_linear_projection=use_linear_projection,
718
+ upcast_attention=upcast_attention,
719
+ attention_type=attention_type,
720
+ )
721
+ )
722
+ else:
723
+ attentions.append(
724
+ DualTransformer2DModel(
725
+ num_attention_heads,
726
+ in_channels // num_attention_heads,
727
+ in_channels=in_channels,
728
+ num_layers=1,
729
+ cross_attention_dim=cross_attention_dim,
730
+ norm_num_groups=resnet_groups,
731
+ )
732
+ )
733
+ resnets.append(
734
+ ResnetBlock2D(
735
+ in_channels=in_channels,
736
+ out_channels=in_channels,
737
+ temb_channels=temb_channels,
738
+ eps=resnet_eps,
739
+ groups=resnet_groups,
740
+ dropout=dropout,
741
+ time_embedding_norm=resnet_time_scale_shift,
742
+ non_linearity=resnet_act_fn,
743
+ output_scale_factor=output_scale_factor,
744
+ pre_norm=resnet_pre_norm,
745
+ )
746
+ )
747
+
748
+ self.attentions = nn.ModuleList(attentions)
749
+ self.resnets = nn.ModuleList(resnets)
750
+
751
+ self.gradient_checkpointing = False
752
+
753
+ def forward(
754
+ self,
755
+ hidden_states: torch.FloatTensor,
756
+ temb: Optional[torch.FloatTensor] = None,
757
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
758
+ attention_mask: Optional[torch.FloatTensor] = None,
759
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
760
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
761
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
762
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
763
+ ) -> torch.FloatTensor:
764
+ lora_scale = (
765
+ cross_attention_kwargs.get("scale", 1.0)
766
+ if cross_attention_kwargs is not None
767
+ else 1.0
768
+ )
769
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
770
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
771
+ if self.training and self.gradient_checkpointing:
772
+
773
+ def create_custom_forward(module, return_dict=None):
774
+ def custom_forward(*inputs):
775
+ if return_dict is not None:
776
+ return module(*inputs, return_dict=return_dict)
777
+ else:
778
+ return module(*inputs)
779
+
780
+ return custom_forward
781
+
782
+ ckpt_kwargs: Dict[str, Any] = (
783
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
784
+ )
785
+ hidden_states = attn(
786
+ hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ attention_mask=attention_mask,
790
+ encoder_attention_mask=encoder_attention_mask,
791
+ return_dict=False,
792
+ self_attn_block_embs=self_attn_block_embs,
793
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
794
+ )[0]
795
+ hidden_states = torch.utils.checkpoint.checkpoint(
796
+ create_custom_forward(resnet),
797
+ hidden_states,
798
+ temb,
799
+ **ckpt_kwargs,
800
+ )
801
+ else:
802
+ hidden_states = attn(
803
+ hidden_states,
804
+ encoder_hidden_states=encoder_hidden_states,
805
+ cross_attention_kwargs=cross_attention_kwargs,
806
+ attention_mask=attention_mask,
807
+ encoder_attention_mask=encoder_attention_mask,
808
+ return_dict=False,
809
+ self_attn_block_embs=self_attn_block_embs,
810
+ )[0]
811
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
812
+
813
+ return hidden_states
814
+
815
+
816
+ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
817
+ def __init__(
818
+ self,
819
+ in_channels: int,
820
+ temb_channels: int,
821
+ dropout: float = 0.0,
822
+ num_layers: int = 1,
823
+ resnet_eps: float = 1e-6,
824
+ resnet_time_scale_shift: str = "default",
825
+ resnet_act_fn: str = "swish",
826
+ resnet_groups: int = 32,
827
+ resnet_pre_norm: bool = True,
828
+ attention_head_dim: int = 1,
829
+ output_scale_factor: float = 1.0,
830
+ cross_attention_dim: int = 1280,
831
+ skip_time_act: bool = False,
832
+ only_cross_attention: bool = False,
833
+ cross_attention_norm: Optional[str] = None,
834
+ ):
835
+ super().__init__()
836
+
837
+ self.has_cross_attention = True
838
+
839
+ self.attention_head_dim = attention_head_dim
840
+ resnet_groups = (
841
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
842
+ )
843
+
844
+ self.num_heads = in_channels // self.attention_head_dim
845
+
846
+ # there is always at least one resnet
847
+ resnets = [
848
+ ResnetBlock2D(
849
+ in_channels=in_channels,
850
+ out_channels=in_channels,
851
+ temb_channels=temb_channels,
852
+ eps=resnet_eps,
853
+ groups=resnet_groups,
854
+ dropout=dropout,
855
+ time_embedding_norm=resnet_time_scale_shift,
856
+ non_linearity=resnet_act_fn,
857
+ output_scale_factor=output_scale_factor,
858
+ pre_norm=resnet_pre_norm,
859
+ skip_time_act=skip_time_act,
860
+ )
861
+ ]
862
+ attentions = []
863
+
864
+ for _ in range(num_layers):
865
+ processor = (
866
+ AttnAddedKVProcessor2_0()
867
+ if hasattr(F, "scaled_dot_product_attention")
868
+ else AttnAddedKVProcessor()
869
+ )
870
+
871
+ attentions.append(
872
+ Attention(
873
+ query_dim=in_channels,
874
+ cross_attention_dim=in_channels,
875
+ heads=self.num_heads,
876
+ dim_head=self.attention_head_dim,
877
+ added_kv_proj_dim=cross_attention_dim,
878
+ norm_num_groups=resnet_groups,
879
+ bias=True,
880
+ upcast_softmax=True,
881
+ only_cross_attention=only_cross_attention,
882
+ cross_attention_norm=cross_attention_norm,
883
+ processor=processor,
884
+ )
885
+ )
886
+ resnets.append(
887
+ ResnetBlock2D(
888
+ in_channels=in_channels,
889
+ out_channels=in_channels,
890
+ temb_channels=temb_channels,
891
+ eps=resnet_eps,
892
+ groups=resnet_groups,
893
+ dropout=dropout,
894
+ time_embedding_norm=resnet_time_scale_shift,
895
+ non_linearity=resnet_act_fn,
896
+ output_scale_factor=output_scale_factor,
897
+ pre_norm=resnet_pre_norm,
898
+ skip_time_act=skip_time_act,
899
+ )
900
+ )
901
+
902
+ self.attentions = nn.ModuleList(attentions)
903
+ self.resnets = nn.ModuleList(resnets)
904
+
905
+ def forward(
906
+ self,
907
+ hidden_states: torch.FloatTensor,
908
+ temb: Optional[torch.FloatTensor] = None,
909
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
910
+ attention_mask: Optional[torch.FloatTensor] = None,
911
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
912
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
913
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
914
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
915
+ ) -> torch.FloatTensor:
916
+ cross_attention_kwargs = (
917
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
918
+ )
919
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
920
+
921
+ if attention_mask is None:
922
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
923
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
924
+ else:
925
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
926
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
927
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
928
+ # then we can simplify this whole if/else block to:
929
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
930
+ mask = attention_mask
931
+
932
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
933
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
934
+ # attn
935
+ hidden_states = attn(
936
+ hidden_states,
937
+ encoder_hidden_states=encoder_hidden_states,
938
+ attention_mask=mask,
939
+ **cross_attention_kwargs,
940
+ self_attn_block_embs=self_attn_block_embs,
941
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
942
+ )
943
+
944
+ # resnet
945
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
946
+
947
+ return hidden_states
948
+
949
+
950
+ class CrossAttnDownBlock2D(nn.Module):
951
+ print_idx = 0
952
+
953
+ def __init__(
954
+ self,
955
+ in_channels: int,
956
+ out_channels: int,
957
+ temb_channels: int,
958
+ dropout: float = 0.0,
959
+ num_layers: int = 1,
960
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
961
+ resnet_eps: float = 1e-6,
962
+ resnet_time_scale_shift: str = "default",
963
+ resnet_act_fn: str = "swish",
964
+ resnet_groups: int = 32,
965
+ resnet_pre_norm: bool = True,
966
+ num_attention_heads: int = 1,
967
+ cross_attention_dim: int = 1280,
968
+ output_scale_factor: float = 1.0,
969
+ downsample_padding: int = 1,
970
+ add_downsample: bool = True,
971
+ dual_cross_attention: bool = False,
972
+ use_linear_projection: bool = False,
973
+ only_cross_attention: bool = False,
974
+ upcast_attention: bool = False,
975
+ attention_type: str = "default",
976
+ ):
977
+ super().__init__()
978
+ resnets = []
979
+ attentions = []
980
+
981
+ self.has_cross_attention = True
982
+ self.num_attention_heads = num_attention_heads
983
+ if isinstance(transformer_layers_per_block, int):
984
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
985
+
986
+ for i in range(num_layers):
987
+ in_channels = in_channels if i == 0 else out_channels
988
+ resnets.append(
989
+ ResnetBlock2D(
990
+ in_channels=in_channels,
991
+ out_channels=out_channels,
992
+ temb_channels=temb_channels,
993
+ eps=resnet_eps,
994
+ groups=resnet_groups,
995
+ dropout=dropout,
996
+ time_embedding_norm=resnet_time_scale_shift,
997
+ non_linearity=resnet_act_fn,
998
+ output_scale_factor=output_scale_factor,
999
+ pre_norm=resnet_pre_norm,
1000
+ )
1001
+ )
1002
+ if not dual_cross_attention:
1003
+ attentions.append(
1004
+ Transformer2DModel(
1005
+ num_attention_heads,
1006
+ out_channels // num_attention_heads,
1007
+ in_channels=out_channels,
1008
+ num_layers=transformer_layers_per_block[i],
1009
+ cross_attention_dim=cross_attention_dim,
1010
+ norm_num_groups=resnet_groups,
1011
+ use_linear_projection=use_linear_projection,
1012
+ only_cross_attention=only_cross_attention,
1013
+ upcast_attention=upcast_attention,
1014
+ attention_type=attention_type,
1015
+ )
1016
+ )
1017
+ else:
1018
+ attentions.append(
1019
+ DualTransformer2DModel(
1020
+ num_attention_heads,
1021
+ out_channels // num_attention_heads,
1022
+ in_channels=out_channels,
1023
+ num_layers=1,
1024
+ cross_attention_dim=cross_attention_dim,
1025
+ norm_num_groups=resnet_groups,
1026
+ )
1027
+ )
1028
+ self.attentions = nn.ModuleList(attentions)
1029
+ self.resnets = nn.ModuleList(resnets)
1030
+
1031
+ if add_downsample:
1032
+ self.downsamplers = nn.ModuleList(
1033
+ [
1034
+ Downsample2D(
1035
+ out_channels,
1036
+ use_conv=True,
1037
+ out_channels=out_channels,
1038
+ padding=downsample_padding,
1039
+ name="op",
1040
+ )
1041
+ ]
1042
+ )
1043
+ else:
1044
+ self.downsamplers = None
1045
+
1046
+ self.gradient_checkpointing = False
1047
+
1048
+ def forward(
1049
+ self,
1050
+ hidden_states: torch.FloatTensor,
1051
+ temb: Optional[torch.FloatTensor] = None,
1052
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1053
+ attention_mask: Optional[torch.FloatTensor] = None,
1054
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1055
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1056
+ additional_residuals: Optional[torch.FloatTensor] = None,
1057
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1058
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1059
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1060
+ output_states = ()
1061
+
1062
+ lora_scale = (
1063
+ cross_attention_kwargs.get("scale", 1.0)
1064
+ if cross_attention_kwargs is not None
1065
+ else 1.0
1066
+ )
1067
+
1068
+ blocks = list(zip(self.resnets, self.attentions))
1069
+
1070
+ for i, (resnet, attn) in enumerate(blocks):
1071
+ if self.training and self.gradient_checkpointing:
1072
+
1073
+ def create_custom_forward(module, return_dict=None):
1074
+ def custom_forward(*inputs):
1075
+ if return_dict is not None:
1076
+ return module(*inputs, return_dict=return_dict)
1077
+ else:
1078
+ return module(*inputs)
1079
+
1080
+ return custom_forward
1081
+
1082
+ ckpt_kwargs: Dict[str, Any] = (
1083
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1084
+ )
1085
+ hidden_states = torch.utils.checkpoint.checkpoint(
1086
+ create_custom_forward(resnet),
1087
+ hidden_states,
1088
+ temb,
1089
+ **ckpt_kwargs,
1090
+ )
1091
+ if self.print_idx == 0:
1092
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
1093
+
1094
+ hidden_states = attn(
1095
+ hidden_states,
1096
+ encoder_hidden_states=encoder_hidden_states,
1097
+ cross_attention_kwargs=cross_attention_kwargs,
1098
+ attention_mask=attention_mask,
1099
+ encoder_attention_mask=encoder_attention_mask,
1100
+ return_dict=False,
1101
+ self_attn_block_embs=self_attn_block_embs,
1102
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1103
+ )[0]
1104
+ else:
1105
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1106
+ if self.print_idx == 0:
1107
+ logger.debug(f"unet3d after resnet {hidden_states.mean()}")
1108
+ hidden_states = attn(
1109
+ hidden_states,
1110
+ encoder_hidden_states=encoder_hidden_states,
1111
+ cross_attention_kwargs=cross_attention_kwargs,
1112
+ attention_mask=attention_mask,
1113
+ encoder_attention_mask=encoder_attention_mask,
1114
+ return_dict=False,
1115
+ self_attn_block_embs=self_attn_block_embs,
1116
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1117
+ )[0]
1118
+
1119
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
1120
+ if i == len(blocks) - 1 and additional_residuals is not None:
1121
+ hidden_states = hidden_states + additional_residuals
1122
+
1123
+ output_states = output_states + (hidden_states,)
1124
+
1125
+ if self.downsamplers is not None:
1126
+ for downsampler in self.downsamplers:
1127
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
1128
+
1129
+ output_states = output_states + (hidden_states,)
1130
+
1131
+ self.print_idx += 1
1132
+ return hidden_states, output_states
1133
+
1134
+
1135
+ class DownBlock2D(nn.Module):
1136
+ def __init__(
1137
+ self,
1138
+ in_channels: int,
1139
+ out_channels: int,
1140
+ temb_channels: int,
1141
+ dropout: float = 0.0,
1142
+ num_layers: int = 1,
1143
+ resnet_eps: float = 1e-6,
1144
+ resnet_time_scale_shift: str = "default",
1145
+ resnet_act_fn: str = "swish",
1146
+ resnet_groups: int = 32,
1147
+ resnet_pre_norm: bool = True,
1148
+ output_scale_factor: float = 1.0,
1149
+ add_downsample: bool = True,
1150
+ downsample_padding: int = 1,
1151
+ ):
1152
+ super().__init__()
1153
+ resnets = []
1154
+
1155
+ for i in range(num_layers):
1156
+ in_channels = in_channels if i == 0 else out_channels
1157
+ resnets.append(
1158
+ ResnetBlock2D(
1159
+ in_channels=in_channels,
1160
+ out_channels=out_channels,
1161
+ temb_channels=temb_channels,
1162
+ eps=resnet_eps,
1163
+ groups=resnet_groups,
1164
+ dropout=dropout,
1165
+ time_embedding_norm=resnet_time_scale_shift,
1166
+ non_linearity=resnet_act_fn,
1167
+ output_scale_factor=output_scale_factor,
1168
+ pre_norm=resnet_pre_norm,
1169
+ )
1170
+ )
1171
+
1172
+ self.resnets = nn.ModuleList(resnets)
1173
+
1174
+ if add_downsample:
1175
+ self.downsamplers = nn.ModuleList(
1176
+ [
1177
+ Downsample2D(
1178
+ out_channels,
1179
+ use_conv=True,
1180
+ out_channels=out_channels,
1181
+ padding=downsample_padding,
1182
+ name="op",
1183
+ )
1184
+ ]
1185
+ )
1186
+ else:
1187
+ self.downsamplers = None
1188
+
1189
+ self.gradient_checkpointing = False
1190
+
1191
+ def forward(
1192
+ self,
1193
+ hidden_states: torch.FloatTensor,
1194
+ temb: Optional[torch.FloatTensor] = None,
1195
+ scale: float = 1.0,
1196
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1197
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1198
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1199
+ output_states = ()
1200
+
1201
+ for resnet in self.resnets:
1202
+ if self.training and self.gradient_checkpointing:
1203
+
1204
+ def create_custom_forward(module):
1205
+ def custom_forward(*inputs):
1206
+ return module(*inputs)
1207
+
1208
+ return custom_forward
1209
+
1210
+ if is_torch_version(">=", "1.11.0"):
1211
+ hidden_states = torch.utils.checkpoint.checkpoint(
1212
+ create_custom_forward(resnet),
1213
+ hidden_states,
1214
+ temb,
1215
+ use_reentrant=False,
1216
+ )
1217
+ else:
1218
+ hidden_states = torch.utils.checkpoint.checkpoint(
1219
+ create_custom_forward(resnet), hidden_states, temb
1220
+ )
1221
+ else:
1222
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1223
+
1224
+ output_states = output_states + (hidden_states,)
1225
+
1226
+ if self.downsamplers is not None:
1227
+ for downsampler in self.downsamplers:
1228
+ hidden_states = downsampler(hidden_states, scale=scale)
1229
+
1230
+ output_states = output_states + (hidden_states,)
1231
+
1232
+ return hidden_states, output_states
1233
+
1234
+
1235
+ class CrossAttnUpBlock2D(nn.Module):
1236
+ def __init__(
1237
+ self,
1238
+ in_channels: int,
1239
+ out_channels: int,
1240
+ prev_output_channel: int,
1241
+ temb_channels: int,
1242
+ resolution_idx: Optional[int] = None,
1243
+ dropout: float = 0.0,
1244
+ num_layers: int = 1,
1245
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1246
+ resnet_eps: float = 1e-6,
1247
+ resnet_time_scale_shift: str = "default",
1248
+ resnet_act_fn: str = "swish",
1249
+ resnet_groups: int = 32,
1250
+ resnet_pre_norm: bool = True,
1251
+ num_attention_heads: int = 1,
1252
+ cross_attention_dim: int = 1280,
1253
+ output_scale_factor: float = 1.0,
1254
+ add_upsample: bool = True,
1255
+ dual_cross_attention: bool = False,
1256
+ use_linear_projection: bool = False,
1257
+ only_cross_attention: bool = False,
1258
+ upcast_attention: bool = False,
1259
+ attention_type: str = "default",
1260
+ ):
1261
+ super().__init__()
1262
+ resnets = []
1263
+ attentions = []
1264
+
1265
+ self.has_cross_attention = True
1266
+ self.num_attention_heads = num_attention_heads
1267
+
1268
+ if isinstance(transformer_layers_per_block, int):
1269
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1270
+
1271
+ for i in range(num_layers):
1272
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1273
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1274
+
1275
+ resnets.append(
1276
+ ResnetBlock2D(
1277
+ in_channels=resnet_in_channels + res_skip_channels,
1278
+ out_channels=out_channels,
1279
+ temb_channels=temb_channels,
1280
+ eps=resnet_eps,
1281
+ groups=resnet_groups,
1282
+ dropout=dropout,
1283
+ time_embedding_norm=resnet_time_scale_shift,
1284
+ non_linearity=resnet_act_fn,
1285
+ output_scale_factor=output_scale_factor,
1286
+ pre_norm=resnet_pre_norm,
1287
+ )
1288
+ )
1289
+ if not dual_cross_attention:
1290
+ attentions.append(
1291
+ Transformer2DModel(
1292
+ num_attention_heads,
1293
+ out_channels // num_attention_heads,
1294
+ in_channels=out_channels,
1295
+ num_layers=transformer_layers_per_block[i],
1296
+ cross_attention_dim=cross_attention_dim,
1297
+ norm_num_groups=resnet_groups,
1298
+ use_linear_projection=use_linear_projection,
1299
+ only_cross_attention=only_cross_attention,
1300
+ upcast_attention=upcast_attention,
1301
+ attention_type=attention_type,
1302
+ )
1303
+ )
1304
+ else:
1305
+ attentions.append(
1306
+ DualTransformer2DModel(
1307
+ num_attention_heads,
1308
+ out_channels // num_attention_heads,
1309
+ in_channels=out_channels,
1310
+ num_layers=1,
1311
+ cross_attention_dim=cross_attention_dim,
1312
+ norm_num_groups=resnet_groups,
1313
+ )
1314
+ )
1315
+ self.attentions = nn.ModuleList(attentions)
1316
+ self.resnets = nn.ModuleList(resnets)
1317
+
1318
+ if add_upsample:
1319
+ self.upsamplers = nn.ModuleList(
1320
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1321
+ )
1322
+ else:
1323
+ self.upsamplers = None
1324
+
1325
+ self.gradient_checkpointing = False
1326
+ self.resolution_idx = resolution_idx
1327
+
1328
+ def forward(
1329
+ self,
1330
+ hidden_states: torch.FloatTensor,
1331
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1332
+ temb: Optional[torch.FloatTensor] = None,
1333
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1334
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1335
+ upsample_size: Optional[int] = None,
1336
+ attention_mask: Optional[torch.FloatTensor] = None,
1337
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1338
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1339
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1340
+ ) -> torch.FloatTensor:
1341
+ lora_scale = (
1342
+ cross_attention_kwargs.get("scale", 1.0)
1343
+ if cross_attention_kwargs is not None
1344
+ else 1.0
1345
+ )
1346
+ is_freeu_enabled = (
1347
+ getattr(self, "s1", None)
1348
+ and getattr(self, "s2", None)
1349
+ and getattr(self, "b1", None)
1350
+ and getattr(self, "b2", None)
1351
+ )
1352
+
1353
+ for resnet, attn in zip(self.resnets, self.attentions):
1354
+ # pop res hidden states
1355
+ res_hidden_states = res_hidden_states_tuple[-1]
1356
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1357
+
1358
+ # FreeU: Only operate on the first two stages
1359
+ if is_freeu_enabled:
1360
+ hidden_states, res_hidden_states = apply_freeu(
1361
+ self.resolution_idx,
1362
+ hidden_states,
1363
+ res_hidden_states,
1364
+ s1=self.s1,
1365
+ s2=self.s2,
1366
+ b1=self.b1,
1367
+ b2=self.b2,
1368
+ )
1369
+
1370
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1371
+
1372
+ if self.training and self.gradient_checkpointing:
1373
+
1374
+ def create_custom_forward(module, return_dict=None):
1375
+ def custom_forward(*inputs):
1376
+ if return_dict is not None:
1377
+ return module(*inputs, return_dict=return_dict)
1378
+ else:
1379
+ return module(*inputs)
1380
+
1381
+ return custom_forward
1382
+
1383
+ ckpt_kwargs: Dict[str, Any] = (
1384
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1385
+ )
1386
+ hidden_states = torch.utils.checkpoint.checkpoint(
1387
+ create_custom_forward(resnet),
1388
+ hidden_states,
1389
+ temb,
1390
+ **ckpt_kwargs,
1391
+ )
1392
+ hidden_states = attn(
1393
+ hidden_states,
1394
+ encoder_hidden_states=encoder_hidden_states,
1395
+ cross_attention_kwargs=cross_attention_kwargs,
1396
+ attention_mask=attention_mask,
1397
+ encoder_attention_mask=encoder_attention_mask,
1398
+ return_dict=False,
1399
+ self_attn_block_embs=self_attn_block_embs,
1400
+ self_attn_block_embs_mode=self_attn_block_embs_mode,
1401
+ )[0]
1402
+ else:
1403
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1404
+ hidden_states = attn(
1405
+ hidden_states,
1406
+ encoder_hidden_states=encoder_hidden_states,
1407
+ cross_attention_kwargs=cross_attention_kwargs,
1408
+ attention_mask=attention_mask,
1409
+ encoder_attention_mask=encoder_attention_mask,
1410
+ return_dict=False,
1411
+ self_attn_block_embs=self_attn_block_embs,
1412
+ )[0]
1413
+
1414
+ if self.upsamplers is not None:
1415
+ for upsampler in self.upsamplers:
1416
+ hidden_states = upsampler(
1417
+ hidden_states, upsample_size, scale=lora_scale
1418
+ )
1419
+
1420
+ return hidden_states
1421
+
1422
+
1423
+ class UpBlock2D(nn.Module):
1424
+ def __init__(
1425
+ self,
1426
+ in_channels: int,
1427
+ prev_output_channel: int,
1428
+ out_channels: int,
1429
+ temb_channels: int,
1430
+ resolution_idx: Optional[int] = None,
1431
+ dropout: float = 0.0,
1432
+ num_layers: int = 1,
1433
+ resnet_eps: float = 1e-6,
1434
+ resnet_time_scale_shift: str = "default",
1435
+ resnet_act_fn: str = "swish",
1436
+ resnet_groups: int = 32,
1437
+ resnet_pre_norm: bool = True,
1438
+ output_scale_factor: float = 1.0,
1439
+ add_upsample: bool = True,
1440
+ ):
1441
+ super().__init__()
1442
+ resnets = []
1443
+
1444
+ for i in range(num_layers):
1445
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1446
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1447
+
1448
+ resnets.append(
1449
+ ResnetBlock2D(
1450
+ in_channels=resnet_in_channels + res_skip_channels,
1451
+ out_channels=out_channels,
1452
+ temb_channels=temb_channels,
1453
+ eps=resnet_eps,
1454
+ groups=resnet_groups,
1455
+ dropout=dropout,
1456
+ time_embedding_norm=resnet_time_scale_shift,
1457
+ non_linearity=resnet_act_fn,
1458
+ output_scale_factor=output_scale_factor,
1459
+ pre_norm=resnet_pre_norm,
1460
+ )
1461
+ )
1462
+
1463
+ self.resnets = nn.ModuleList(resnets)
1464
+
1465
+ if add_upsample:
1466
+ self.upsamplers = nn.ModuleList(
1467
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1468
+ )
1469
+ else:
1470
+ self.upsamplers = None
1471
+
1472
+ self.gradient_checkpointing = False
1473
+ self.resolution_idx = resolution_idx
1474
+
1475
+ def forward(
1476
+ self,
1477
+ hidden_states: torch.FloatTensor,
1478
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1479
+ temb: Optional[torch.FloatTensor] = None,
1480
+ upsample_size: Optional[int] = None,
1481
+ scale: float = 1.0,
1482
+ self_attn_block_embs: Optional[List[torch.Tensor]] = None,
1483
+ self_attn_block_embs_mode: Literal["read", "write"] = "write",
1484
+ ) -> torch.FloatTensor:
1485
+ is_freeu_enabled = (
1486
+ getattr(self, "s1", None)
1487
+ and getattr(self, "s2", None)
1488
+ and getattr(self, "b1", None)
1489
+ and getattr(self, "b2", None)
1490
+ )
1491
+
1492
+ for resnet in self.resnets:
1493
+ # pop res hidden states
1494
+ res_hidden_states = res_hidden_states_tuple[-1]
1495
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1496
+
1497
+ # FreeU: Only operate on the first two stages
1498
+ if is_freeu_enabled:
1499
+ hidden_states, res_hidden_states = apply_freeu(
1500
+ self.resolution_idx,
1501
+ hidden_states,
1502
+ res_hidden_states,
1503
+ s1=self.s1,
1504
+ s2=self.s2,
1505
+ b1=self.b1,
1506
+ b2=self.b2,
1507
+ )
1508
+
1509
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1510
+
1511
+ if self.training and self.gradient_checkpointing:
1512
+
1513
+ def create_custom_forward(module):
1514
+ def custom_forward(*inputs):
1515
+ return module(*inputs)
1516
+
1517
+ return custom_forward
1518
+
1519
+ if is_torch_version(">=", "1.11.0"):
1520
+ hidden_states = torch.utils.checkpoint.checkpoint(
1521
+ create_custom_forward(resnet),
1522
+ hidden_states,
1523
+ temb,
1524
+ use_reentrant=False,
1525
+ )
1526
+ else:
1527
+ hidden_states = torch.utils.checkpoint.checkpoint(
1528
+ create_custom_forward(resnet), hidden_states, temb
1529
+ )
1530
+ else:
1531
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1532
+
1533
+ if self.upsamplers is not None:
1534
+ for upsampler in self.upsamplers:
1535
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1536
+
1537
+ return hidden_states