diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..fe9379779360d4743a74477ee5384edf354f203b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+__assets__/boat_light.gif filter=lfs diff=lfs merge=lfs -text
+__assets__/cat_light.gif filter=lfs diff=lfs merge=lfs -text
+__assets__/man_light.gif filter=lfs diff=lfs merge=lfs -text
+__assets__/pipeline.png filter=lfs diff=lfs merge=lfs -text
+__assets__/title.png filter=lfs diff=lfs merge=lfs -text
+__assets__/water_light.gif filter=lfs diff=lfs merge=lfs -text
+input_animatediff/bear.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/bloom.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/boat.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/car.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/cat.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/cat2.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/coin.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/cow.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/flowers.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/fox.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/girl2.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/man.mp4 filter=lfs diff=lfs merge=lfs -text
+input_animatediff/woman.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b8d980563a343c9410a9f0d05c11913cd5ea12ce
--- /dev/null
+++ b/ORIGINAL_README.md
@@ -0,0 +1,152 @@
+
+

+
+
+---
+### ⭐️ **Our team's works:** [[**MotionClone**](https://bujiazi.github.io/motionclone.github.io/)] [[**BroadWay**](https://bujiazi.github.io/BroadWay.github.io/)]
+
+## Light-A-Video: Training-free Video Relighting via Progressive Light Fusion
+This repository is the official implementation of Light-A-Video. It is a **training-free framework** that enables
+zero-shot illumination control of any given video sequences or foreground sequences.
+
+Click for the full abstract of Light-A-Video
+
+> Recent advancements in image relighting models, driven by large-scale datasets and pre-trained diffusion models,
+have enabled the imposition of consistent lighting.
+However, video relighting still lags, primarily due to the excessive training costs and the scarcity of diverse, high-quality video relighting datasets.
+A simple application of image relighting models on a frame-by-frame basis leads to several issues:
+lighting source inconsistency and relighted appearance inconsistency, resulting in flickers in the generated videos.
+In this work, we propose Light-A-Video, a training-free approach to achieve temporally smooth video relighting.
+Adapted from image relighting models, Light-A-Video introduces two key techniques to enhance lighting consistency.
+First, we design a Consistent Light Attention (CLA) module, which enhances cross-frame interactions within the self-attention layers
+to stabilize the generation of the background lighting source. Second, leveraging the physical principle of light transport independence,
+we apply linear blending between the source video’s appearance and the relighted appearance, using a Progressive Light Fusion (PLF) strategy to ensure smooth temporal transitions in illumination.
+Experiments show that Light-A-Video improves the temporal consistency of relighted video
+while maintaining the image quality, ensuring coherent lighting transitions across frames.
+
+
+**[Light-A-Video: Training-free Video Relighting via Progressive Light Fusion]()**
+
+[Yujie Zhou*](https://github.com/YujieOuO/),
+[Jiazi Bu*](https://github.com/Bujiazi/),
+[Pengyang Ling*](https://github.com/LPengYang/),
+[Pan Zhang†](https://panzhang0212.github.io/),
+[Tong Wu](https://wutong16.github.io/),
+[Qidong Huang](https://shikiw.github.io/),
+[Jinsong Li](https://li-jinsong.github.io/),
+[Xiaoyi Dong](https://scholar.google.com/citations?user=FscToE0AAAAJ&hl=en/),
+[Yuhang Zang](https://yuhangzang.github.io/),
+[Yuhang Cao](https://scholar.google.com/citations?hl=zh-CN&user=sJkqsqkAAAAJ),
+[Anyi Rao](https://anyirao.com/),
+[Jiaqi Wang](https://myownskyw7.github.io/),
+[Li Niu†](https://www.ustcnewly.com/)
+(*Equal Contribution)(†Corresponding Author)
+
+[](https://arxiv.org/abs/2502.08590)
+[](https://bujiazi.github.io/light-a-video.github.io/)
+
+## 📜 News
+
+**[2025/2/11]** Code is available now!
+
+**[2025/2/10]** The paper and project page are released!
+
+## 🏗️ Todo
+- [ ] Release a gradio demo.
+
+- [ ] Release Light-A-Video code with CogVideoX-2B pipeline.
+
+## 📚 Gallery
+We show more results in the [Project Page](https://bujiazi.github.io/light-a-video.github.io/).
+
+
+
+ ..., red and blue neon light |
+ ..., sunset over sea |
+
+
+  |
+  |
+
+
+ ..., sunlight through the blinds |
+ ..., in the forest, magic golden lit |
+
+
+  |
+  |
+
+
+
+
+## 🚀 Method Overview
+
+
+

+
+
+Light-A-Video leverages the the capabilities of image relighting models and VDM motion priors to achieve temporally consistent video relighting.
+By integrating the **Consistent Light Attention** to stabilize lighting source generation and employ the **Progressive Light Fusion** strategy
+for smooth appearance transitions.
+
+## 🔧 Installations
+
+### Setup repository and conda environment
+
+```bash
+git clone https://github.com/bcmi/Light-A-Video.git
+cd Light-A-Video
+
+conda create -n lav python=3.10
+conda activate lav
+
+pip install -r requirements.txt
+```
+
+## 🔑 Pretrained Model Preparations
+- IC-Light: [Huggingface](https://huggingface.co/lllyasviel/ic-light)
+- SD RealisticVision: [Huggingface](https://huggingface.co/stablediffusionapi/realistic-vision-v51)
+- Animatediff Motion-Adapter-V-1.5.3: [Huggingface](https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
+
+Model downloading is automatic.
+
+## 🎈 Quick Start
+
+### Perform video relighting with customized illumination control
+```bash
+# relight
+python lav_relight.py --config "configs/relight/car.yaml"
+```
+### Perform foreground sequences relighting with background generation
+A script based on [SAM2](https://github.com/facebookresearch/sam2) is provided to extract foreground sequences from videos.
+```bash
+# extract foreground sequence
+python sam2.py --video_name car --x 255 --y 255
+
+# inpaint and relight
+python lav_paint.py --config "configs/relight_inpaint/car.yaml"
+```
+
+## 📎 Citation
+
+If you find our work helpful for your research, please consider giving a star ⭐ and citation 📝
+```bibtex
+@article{zhou2025light,
+ title={Light-A-Video: Training-free Video Relighting via Progressive Light Fusion},
+ author={Zhou, Yujie and Bu, Jiazi and Ling, Pengyang and Zhang, Pan and Wu, Tong and Huang, Qidong and Li, Jinsong and Dong, Xiaoyi and Zang, Yuhang and Cao, Yuhang and others},
+ journal={arXiv preprint arXiv:2502.08590},
+ year={2025}
+}
+```
+
+## 📣 Disclaimer
+
+This is official code of Light-A-Video.
+All the copyrights of the demo images and audio are from community users.
+Feel free to contact us if you would like remove them.
+
+## 💞 Acknowledgements
+The code is built upon the below repositories, we thank all the contributors for open-sourcing.
+* [IC-Light](https://github.com/lllyasviel/IC-Light)
+* [AnimateDiff](https://github.com/guoyww/AnimateDiff)
+* [CogVideoX](https://github.com/THUDM/CogVideo)
diff --git a/__assets__/boat_light.gif b/__assets__/boat_light.gif
new file mode 100644
index 0000000000000000000000000000000000000000..f32fc6555ddc31d1727edd7a4eb6c0218cf66703
--- /dev/null
+++ b/__assets__/boat_light.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9171d365d5a9643f7509364a71b566736dbe1e9fd758f11adc89617f042899d3
+size 2154893
diff --git a/__assets__/cat_light.gif b/__assets__/cat_light.gif
new file mode 100644
index 0000000000000000000000000000000000000000..c03e3e04c3aa2100a56dfc03a03b8aa8d4403e7e
--- /dev/null
+++ b/__assets__/cat_light.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b33dd1a609e0b37dd41e63a8889d57342b60fee4fe733221f112fb8fd17632a3
+size 4337712
diff --git a/__assets__/man_light.gif b/__assets__/man_light.gif
new file mode 100644
index 0000000000000000000000000000000000000000..887770a11d01c1e3a67626bc3903404b53c71556
--- /dev/null
+++ b/__assets__/man_light.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c207d125e8a82817d56ff86e28533ed22fdeb972bfee3160b0c19323e6a832a3
+size 3737125
diff --git a/__assets__/pipeline.png b/__assets__/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..01c59e67902ab1f2123063160f78a71a96bf6a80
--- /dev/null
+++ b/__assets__/pipeline.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5249b8ab938cc283f91b9a76411e8ef7ea48c48de63cb547a5ef1178d6092832
+size 1879354
diff --git a/__assets__/title.png b/__assets__/title.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b298f9061f2d29e1877301a3cc6008680d8fb8f
--- /dev/null
+++ b/__assets__/title.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e631437a296c9c45b6418f0903c0905ddb5aa29f510d2326dfbbcaefec992b42
+size 138292
diff --git a/__assets__/water_light.gif b/__assets__/water_light.gif
new file mode 100644
index 0000000000000000000000000000000000000000..3e0a37b72106fce625b39fefe7fea6e14b86990b
--- /dev/null
+++ b/__assets__/water_light.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d4c881aff47e88b611b9c78d17b89b5ccb29819de6ac7b860934bee1e5afe0e
+size 2942895
diff --git a/configs/relight/bear.yaml b/configs/relight/bear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d20842d925c101f367c0a7ef08e6f41959e3956
--- /dev/null
+++ b/configs/relight/bear.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a bear walking on the rock, nature lighting, key light"
+video_path: "input_animatediff/bear.mp4"
+bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/boat.yaml b/configs/relight/boat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..01d87304cb0d8c0c27a7ed05cf945e4b348823f8
--- /dev/null
+++ b/configs/relight/boat.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a boat floating on the sea, sunset"
+video_path: "input_animatediff/boat.mp4"
+bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/car.yaml b/configs/relight/car.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a557aaf386e0a669a89562b6dbc87bc8db0c853b
--- /dev/null
+++ b/configs/relight/car.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a car driving on the street, neon light"
+video_path: "input_animatediff/car.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 2060
\ No newline at end of file
diff --git a/configs/relight/cat.yaml b/configs/relight/cat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0572a99135d4a72983704e87e7e4bad42dcb5913
--- /dev/null
+++ b/configs/relight/cat.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a cat, red and blue neon light"
+video_path: "input_animatediff/cat.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/cow.yaml b/configs/relight/cow.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffa82db8dba6c3fd7dc2246118fbc800e8b3856b
--- /dev/null
+++ b/configs/relight/cow.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a cow drinking water in the river, sunset"
+video_path: "input_animatediff/cow.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/flowers.yaml b/configs/relight/flowers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1b3ca42905a031494cd31dfdb21213e8568354cc
--- /dev/null
+++ b/configs/relight/flowers.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality, unclear, blurry"
+relight_prompt: "A basket of flowers, sunshine, hard light"
+video_path: "input_animatediff/flowers.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/fox.yaml b/configs/relight/fox.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8604aa4c0325439b257c86730198a026c52ed281
--- /dev/null
+++ b/configs/relight/fox.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a fox, sunlight filtering through trees, dappled light"
+video_path: "input_animatediff/fox.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/girl.yaml b/configs/relight/girl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..98429cafec8116192ce8260e5e80361a5e831694
--- /dev/null
+++ b/configs/relight/girl.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a girl, magic lit, sci-fi RGB glowing, key lighting"
+video_path: "input_animatediff/girl.mp4"
+bg_source: "BOTTOM" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/girl2.yaml b/configs/relight/girl2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4c1924724e0ff1a123f1a8d1fea93dc842a7115
--- /dev/null
+++ b/configs/relight/girl2.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "an anime girl, neon light"
+video_path: "input_animatediff/girl2.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/juice.yaml b/configs/relight/juice.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..17b4b4ee9b84fd8073abb2bf123fc366f15ef26f
--- /dev/null
+++ b/configs/relight/juice.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "Pour juice into a glass, magic golden lit"
+video_path: "input_animatediff/juice.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/man2.yaml b/configs/relight/man2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ab13dd7751e9ff34b7b1a5ccaa998dbe3ac7ad2
--- /dev/null
+++ b/configs/relight/man2.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "handsome man with glasses, shadow from window, sunshine"
+video_path: "input_animatediff/man2.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/man4.yaml b/configs/relight/man4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ff954182e9cdcfe185fec3fd8e3ca6863b0aeda7
--- /dev/null
+++ b/configs/relight/man4.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "handsome man with glasses, sunlight through the blinds"
+video_path: "input_animatediff/man4.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/plane.yaml b/configs/relight/plane.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e07d911177ae80dc3859606f6d6c26edf7fd62d
--- /dev/null
+++ b/configs/relight/plane.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a plane on the runway, bottom neon light"
+video_path: "input_animatediff/plane.mp4"
+bg_source: "BOTTOM" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/toy.yaml b/configs/relight/toy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..94d3eb6bfb2841ca8dd2590f5bcd022bcf211177
--- /dev/null
+++ b/configs/relight/toy.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a maneki-neko toy, cozy bedroom illumination"
+video_path: "input_animatediff/toy.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight/woman.yaml b/configs/relight/woman.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1729738d84286d3faec9380b170383fee7b08951
--- /dev/null
+++ b/configs/relight/woman.yaml
@@ -0,0 +1,13 @@
+n_prompt: "bad quality, worse quality"
+relight_prompt: "a woman with curly hair, natural lighting, warm atmosphere"
+video_path: "input_animatediff/woman.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 25
+text_guide_scale: 2
+seed: 42
\ No newline at end of file
diff --git a/configs/relight_inpaint/bloom.yaml b/configs/relight_inpaint/bloom.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4d82edc61a4bb5ba41a226fac5e351e0edc484d
--- /dev/null
+++ b/configs/relight_inpaint/bloom.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a red flower blooming in the river"
+relight_prompt: "a red flower blooming in the river, nature lighting"
+
+video_path: "input_animatediff/bloom.mp4"
+bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.4
+gamma: 0.5
+num_step: 50
+text_guide_scale: 4
+seed: 8776
\ No newline at end of file
diff --git a/configs/relight_inpaint/camera.yaml b/configs/relight_inpaint/camera.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca1b430cd098b4f18b7d482b0d400bd485051f7f
--- /dev/null
+++ b/configs/relight_inpaint/camera.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "A tiny camera on a tray, cyberpunk"
+relight_prompt: "A tiny camera on a tray, cyberpunk, neon light"
+
+video_path: "input_animatediff/camera.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.4
+gamma: 0.5
+num_step: 50
+text_guide_scale: 3
+seed: 1333
\ No newline at end of file
diff --git a/configs/relight_inpaint/car.yaml b/configs/relight_inpaint/car.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bae0400427aae1db1d9f85f827473e2a142018af
--- /dev/null
+++ b/configs/relight_inpaint/car.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a car driving on the street"
+relight_prompt: "a car driving on the street, neon light"
+
+video_path: "input_animatediff/car.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 2
+seed: 6561
\ No newline at end of file
diff --git a/configs/relight_inpaint/car_2.yaml b/configs/relight_inpaint/car_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bcd5cbcb9413b8e03e06804b2582fa1e656c5241
--- /dev/null
+++ b/configs/relight_inpaint/car_2.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a car driving on the beach, sunset over sea"
+relight_prompt: "a car driving on the beach, sunset over sea, left light, shadow"
+
+video_path: "input_animatediff/car.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 2
+seed: 2409
\ No newline at end of file
diff --git a/configs/relight_inpaint/cat2.yaml b/configs/relight_inpaint/cat2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b66d9b86a489a6efad6d00ac21815b64b13eb73f
--- /dev/null
+++ b/configs/relight_inpaint/cat2.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "A cat walking on a runway, red and blue neon lights on both sides"
+relight_prompt: "A cat walking on a runway, red and blue neon lights on both sides, key light"
+
+video_path: "input_animatediff/cat2.mp4"
+bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 5
+seed: 2949
\ No newline at end of file
diff --git a/configs/relight_inpaint/coin.yaml b/configs/relight_inpaint/coin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fbe01a9985bf957f86e7dc25c7268107adf7bf00
--- /dev/null
+++ b/configs/relight_inpaint/coin.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "A coin on the desk"
+relight_prompt: "A coin on the desk, natural lighting"
+
+video_path: "input_animatediff/coin.mp4"
+bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.4
+gamma: 0.5
+num_step: 80
+text_guide_scale: 2
+seed: 4013
\ No newline at end of file
diff --git a/configs/relight_inpaint/dog2.yaml b/configs/relight_inpaint/dog2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea731d1ba68e3b81b88b06cecd5c6a62494f5797
--- /dev/null
+++ b/configs/relight_inpaint/dog2.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a dog in the room, sunshine from window"
+relight_prompt: "a dog in the room, sunshine from window"
+
+video_path: "input_animatediff/dog2.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.4
+gamma: 0.5
+num_step: 50
+text_guide_scale: 2
+seed: 4550
\ No newline at end of file
diff --git a/configs/relight_inpaint/man3.yaml b/configs/relight_inpaint/man3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5d685403f389d0f0a5e051e2381cfa6a5283df65
--- /dev/null
+++ b/configs/relight_inpaint/man3.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "A man in the classroom"
+relight_prompt: "A man in the classroom, sunshine from the window"
+
+video_path: "input_animatediff/man3.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 3
+seed: 3931
\ No newline at end of file
diff --git a/configs/relight_inpaint/man3_2.yaml b/configs/relight_inpaint/man3_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..043125a62add97022215d2f86713f94aa319526b
--- /dev/null
+++ b/configs/relight_inpaint/man3_2.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "A man in a bar, left yellow and right purple neon lights"
+relight_prompt: "A man in a bar, left yellow and right purple neon lights, hard light"
+
+video_path: "input_animatediff/man3.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 4
+seed: 9528
\ No newline at end of file
diff --git a/configs/relight_inpaint/water.yaml b/configs/relight_inpaint/water.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fdfc8cf72f0dc392c1e275081dda6a7d1b3eeb60
--- /dev/null
+++ b/configs/relight_inpaint/water.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a glass of water, in the forest, magic golden lit"
+relight_prompt: "a glass of water, in the forest, magic golden lit, key light"
+
+video_path: "input_animatediff/water.mp4"
+bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.4
+gamma: 0.5
+num_step: 50
+text_guide_scale: 4
+seed: 796
\ No newline at end of file
diff --git a/configs/relight_inpaint/wolf2.yaml b/configs/relight_inpaint/wolf2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34b64b8e727fa19a2f7b39d87a5fbfa68c74ed49
--- /dev/null
+++ b/configs/relight_inpaint/wolf2.yaml
@@ -0,0 +1,15 @@
+n_prompt: "bad quality, worse quality"
+inpaint_prompt: "a wolf stands in an alley, detailed face, neon, Wong Kar-wai, warm"
+relight_prompt: "a wolf stands in an alley, detailed face, neon, Wong Kar-wai, warm, right light"
+
+video_path: "input_animatediff/wolf2.mp4"
+bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP
+save_path: "output"
+
+width: 512
+height: 512
+strength: 0.5
+gamma: 0.5
+num_step: 50
+text_guide_scale: 5
+seed: 2172
\ No newline at end of file
diff --git a/input_animatediff/bear.mp4 b/input_animatediff/bear.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e0ff52199e2a5c0d2e73ed19ee37f4fc712a7195
--- /dev/null
+++ b/input_animatediff/bear.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdcdc1757a085d6e0c43c60d650be0eb75dc65a9a0fe178a394d7b0f0131c20
+size 252475
diff --git a/input_animatediff/bloom.mp4 b/input_animatediff/bloom.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..28047d179277e26a5d44484ec871dabe16075b4a
--- /dev/null
+++ b/input_animatediff/bloom.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:227ae36b83a4ce33c844d0ee818e3c70bf0667c963886e8c79e8b21241e552c8
+size 110051
diff --git a/input_animatediff/boat.mp4 b/input_animatediff/boat.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..64a42126e9e2ede549e16de35e417d0dc6ae773c
--- /dev/null
+++ b/input_animatediff/boat.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:523849774e74df87699e450baa99fd3b7edf21176d9d3464756c3a302e4fc6a2
+size 163870
diff --git a/input_animatediff/camera.mp4 b/input_animatediff/camera.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b3801312189e3a67381563489b85545171f5d786
Binary files /dev/null and b/input_animatediff/camera.mp4 differ
diff --git a/input_animatediff/car.mp4 b/input_animatediff/car.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5dffc40af5375ce66a6ad0ee52bcf1b33737d3cc
--- /dev/null
+++ b/input_animatediff/car.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5bd723eb892d1f0005016c933c6d0eb2f46a6e77d802e14934ff917c4f49db09
+size 221663
diff --git a/input_animatediff/cat.mp4 b/input_animatediff/cat.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..72c5f0d7b9b4d2bac2c66d22e00b7085816e7207
--- /dev/null
+++ b/input_animatediff/cat.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7984b62f1a070f779b11e0ae7b113be5bd170f0f4511b75acf2c9afaf204d6f
+size 136264
diff --git a/input_animatediff/cat2.mp4 b/input_animatediff/cat2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2d3aca6535626d3f59f468368292596f7857cd27
--- /dev/null
+++ b/input_animatediff/cat2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89ab4a1945d78fafa3c47447a0e969f9b88fd0f244da9089ef84a022d28e6db2
+size 169476
diff --git a/input_animatediff/coin.mp4 b/input_animatediff/coin.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a192ce924b2bb7fde15f31cc0e2783cb40cbdada
--- /dev/null
+++ b/input_animatediff/coin.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1271b6970cdc4cc9159bdf6b46655551d0494750b30f542a263d9219e601f6a8
+size 102631
diff --git a/input_animatediff/cow.mp4 b/input_animatediff/cow.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8c6ccd548e840ac94c153cb5e6f109a851aedc9a
--- /dev/null
+++ b/input_animatediff/cow.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6a55901e04b6fba73b4d29a9df4e40adc85e9a161d91ddb5b0ea9438b7d016c
+size 455415
diff --git a/input_animatediff/dog2.mp4 b/input_animatediff/dog2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1b0544fea797141e51d388c0b0db451662b516fd
Binary files /dev/null and b/input_animatediff/dog2.mp4 differ
diff --git a/input_animatediff/flowers.mp4 b/input_animatediff/flowers.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3a4f14e6d516ff9f3f0f74b27b608cff02247b74
--- /dev/null
+++ b/input_animatediff/flowers.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dafaf1deb210e383bb9b8ae3ef81f0273a1f46e1b87008508617c6ec176fbca6
+size 163157
diff --git a/input_animatediff/fox.mp4 b/input_animatediff/fox.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..fe1de541be2dc4fa1ee87a2b67c834bfd1769bc5
--- /dev/null
+++ b/input_animatediff/fox.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20b488cbb41e4b84e63b3a4337d034b281c7dff5c7fa2d4cacd5ae4de6241beb
+size 141847
diff --git a/input_animatediff/girl.mp4 b/input_animatediff/girl.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6ba42eb9d79ec21c0c85ff86542e38cb7f914250
Binary files /dev/null and b/input_animatediff/girl.mp4 differ
diff --git a/input_animatediff/girl2.mp4 b/input_animatediff/girl2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f8ed2f0415be2711879afad118f586565b611cd1
--- /dev/null
+++ b/input_animatediff/girl2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:699c09aba9345d8563ea1448c16f75ce237bda6d828d74169ecfc4adfc0d6528
+size 293061
diff --git a/input_animatediff/juice.mp4 b/input_animatediff/juice.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9aafd14a6675599dcb2d166842ca6e59d051d5c4
Binary files /dev/null and b/input_animatediff/juice.mp4 differ
diff --git a/input_animatediff/man.mp4 b/input_animatediff/man.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0cb2faefa8f2387851d6ac59b3e9c9e32557255a
--- /dev/null
+++ b/input_animatediff/man.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c2d01c602bd34660fcc281f752ea4ccd9d100d0fa293e25093918d3f4fecb49
+size 141549
diff --git a/input_animatediff/man2.mp4 b/input_animatediff/man2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6981b42b65fd87ca94dd6cc1b3a40a2beddc283e
Binary files /dev/null and b/input_animatediff/man2.mp4 differ
diff --git a/input_animatediff/man3.mp4 b/input_animatediff/man3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7a78924885a68800e1405dc48270fcc555a7fb9c
Binary files /dev/null and b/input_animatediff/man3.mp4 differ
diff --git a/input_animatediff/man4.mp4 b/input_animatediff/man4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a8c0064db39a3422a77c9130ac382a2d3b9c965d
Binary files /dev/null and b/input_animatediff/man4.mp4 differ
diff --git a/input_animatediff/plane.mp4 b/input_animatediff/plane.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2917408aa00730d58ffd6a81e580ff132bf2e770
Binary files /dev/null and b/input_animatediff/plane.mp4 differ
diff --git a/input_animatediff/toy.mp4 b/input_animatediff/toy.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0327fec7ed68592c1f663aa9fbbc20e22cdd605a
Binary files /dev/null and b/input_animatediff/toy.mp4 differ
diff --git a/input_animatediff/wolf2.mp4 b/input_animatediff/wolf2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a32d5f3250600a9aefe2f58f6b11c41cf0cf6aab
Binary files /dev/null and b/input_animatediff/wolf2.mp4 differ
diff --git a/input_animatediff/woman.mp4 b/input_animatediff/woman.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b7e875db4e4ebec451a374147bd9ab7d9d071fe7
--- /dev/null
+++ b/input_animatediff/woman.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28b449fc38c1625354c116ba18603f815210c23d9f0862313aa75eba5dfa95af
+size 163666
diff --git a/lav_paint.py b/lav_paint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef6429d3e13e71364237d93311ec8ea7b5e9ff5e
--- /dev/null
+++ b/lav_paint.py
@@ -0,0 +1,256 @@
+import os
+import torch
+import imageio
+import argparse
+from types import MethodType
+import safetensors.torch as sf
+import torch.nn.functional as F
+from omegaconf import OmegaConf
+from transformers import CLIPTextModel, CLIPTokenizer
+from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
+from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
+from diffusers.models.attention_processor import AttnProcessor2_0
+from torch.hub import download_url_to_file
+
+from src.ic_light import BGSource
+from src.ic_light import Relighter
+from src.animatediff_inpaint_pipe import AnimateDiffVideoToVideoPipeline
+from src.ic_light_pipe import StableDiffusionImg2ImgPipeline
+from utils.tools import read_video, read_mask,set_all_seed, get_fg_video
+
+def main(args):
+
+ config = OmegaConf.load(args.config)
+ device = torch.device('cuda')
+ adopted_dtype = torch.float16
+ set_all_seed(42)
+
+ ## vdm model
+ adapter = MotionAdapter.from_pretrained(args.motion_adapter_model)
+
+ ## pipeline
+ pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(args.sd_model, motion_adapter=adapter)
+ eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
+ args.sd_model,
+ subfolder="scheduler",
+ beta_schedule="linear",
+ )
+
+ pipe.scheduler = eul_scheduler
+ pipe.enable_vae_slicing()
+ pipe = pipe.to(device=device, dtype=adopted_dtype)
+ pipe.vae.requires_grad_(False)
+ pipe.unet.requires_grad_(False)
+
+ ## ic-light model
+ tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet")
+ with torch.no_grad():
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
+ new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3])
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ new_conv_in.bias = unet.conv_in.bias
+ unet.conv_in = new_conv_in
+ unet_original_forward = unet.forward
+
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
+
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
+ new_sample = torch.cat([sample, c_concat], dim=1)
+ kwargs['cross_attention_kwargs'] = {}
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
+ unet.forward = hooked_unet_forward
+
+ ## ic-light model loader
+ if not os.path.exists(args.ic_light_model):
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors',
+ dst=args.ic_light_model)
+ sd_offset = sf.load_file(args.ic_light_model)
+ sd_origin = unet.state_dict()
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
+ unet.load_state_dict(sd_merged, strict=True)
+ del sd_offset, sd_origin, sd_merged
+ text_encoder = text_encoder.to(device=device, dtype=adopted_dtype)
+ vae = vae.to(device=device, dtype=adopted_dtype)
+ unet = unet.to(device=device, dtype=adopted_dtype)
+ unet.set_attn_processor(AttnProcessor2_0())
+ vae.set_attn_processor(AttnProcessor2_0())
+
+ # Consistent light attention
+ @torch.inference_mode()
+ def custom_forward_CLA(self,
+ hidden_states,
+ gamma=config.get("gamma", 0.5),
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None
+ ):
+
+ batch_size, sequence_length, channel = hidden_states.shape
+
+ residual = hidden_states
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = self.to_q(hidden_states)
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.heads
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
+ shape = query.shape
+
+ # addition key and value
+ mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
+ mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
+ mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
+ mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
+ add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False)
+
+ # mix
+ hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = self.to_out[0](hidden_states)
+ hidden_states = self.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / self.rescale_output_factor
+ return hidden_states
+
+ ### attention
+ @torch.inference_mode()
+ def prep_unet_self_attention(unet):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+
+ name_split_list = name.split(".")
+ cond_1 = name_split_list[0] in "up_blocks"
+ cond_2 = name_split_list[-1] in ('attn1')
+
+ if "Attention" in module_name and cond_1 and cond_2:
+ cond_3 = name_split_list[1]
+ if cond_3 not in "3":
+ module.forward = MethodType(custom_forward_CLA, module)
+
+ return unet
+
+ ## consistency light attention
+ unet = prep_unet_self_attention(unet)
+
+ ## ic-light-scheduler
+ ic_light_scheduler = DPMSolverMultistepScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ algorithm_type="sde-dpmsolver++",
+ use_karras_sigmas=True,
+ steps_offset=1
+ )
+ ic_light_pipe = StableDiffusionImg2ImgPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=ic_light_scheduler,
+ safety_checker=None,
+ requires_safety_checker=False,
+ feature_extractor=None,
+ image_encoder=None
+ )
+ ic_light_pipe = ic_light_pipe.to(device)
+
+ ############################# params ######################################
+ strength = config.get("strength", 0.5)
+ num_step = config.get("num_step", 50)
+ text_guide_scale = config.get("text_guide_scale", 4)
+ seed = config.get("seed")
+ image_width = config.get("width", 512)
+ image_height = config.get("height", 512)
+ n_prompt = config.get("n_prompt", "")
+ inpaint_prompt = config.get("inpaint_prompt", "")
+ relight_prompt = config.get("relight_prompt", "")
+ video_path = config.get("video_path", "")
+ bg_source = BGSource[config.get("bg_source")]
+ save_path = config.get("save_path")
+
+ ############################## infer #####################################
+ generator = torch.manual_seed(seed)
+ video_name = os.path.basename(video_path)
+ video_list, video_name = read_video(video_path, image_width, image_height)
+ mask_folder = os.path.join("masks_animatediff", video_name.split('.')[-2])
+ mask_list = read_mask(mask_folder)
+
+ print("################## begin ##################")
+ ## get foreground video
+ fg_video_tensor = get_fg_video(video_list, mask_list, device, adopted_dtype) ## torch.Size([16, 3, 512, 512])
+
+ with torch.no_grad():
+ relighter = Relighter(
+ pipeline=ic_light_pipe,
+ relight_prompt=relight_prompt,
+ bg_source=bg_source,
+ generator=generator,
+ )
+ vdm_init_latent = relighter(fg_video_tensor)
+
+ ## infer
+ num_inference_steps = num_step
+ output = pipe(
+ ic_light_pipe=ic_light_pipe,
+ relight_prompt=relight_prompt,
+ bg_source=bg_source,
+ mask=mask_list,
+ vdm_init_latent=vdm_init_latent,
+ video=video_list,
+ prompt=inpaint_prompt,
+ strength=strength,
+ negative_prompt=n_prompt,
+ guidance_scale=text_guide_scale,
+ num_inference_steps=num_inference_steps,
+ height=image_height,
+ width=image_width,
+ generator=generator,
+ )
+
+ frames = output.frames[0]
+ results_path = f"{save_path}/inpaint_{video_name}"
+ imageio.mimwrite(results_path, frames, fps=8)
+ print(f"relight with bg generation! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--sd_model", type=str, default="stablediffusionapi/realistic-vision-v51")
+ parser.add_argument("--motion_adapter_model", type=str, default="guoyww/animatediff-motion-adapter-v1-5-3")
+ parser.add_argument("--ic_light_model", type=str, default="./models/iclight_sd15_fc.safetensors")
+
+ parser.add_argument("--config", type=str, default="configs/relight_inpaint/car.yaml", help="the config file for each sample.")
+
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/lav_relight.py b/lav_relight.py
new file mode 100644
index 0000000000000000000000000000000000000000..803c2d117b3b4b898b093961083d95ca249c0e58
--- /dev/null
+++ b/lav_relight.py
@@ -0,0 +1,240 @@
+import os
+import torch
+import imageio
+import argparse
+from types import MethodType
+import safetensors.torch as sf
+import torch.nn.functional as F
+from omegaconf import OmegaConf
+from transformers import CLIPTextModel, CLIPTokenizer
+from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
+from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
+from diffusers.models.attention_processor import AttnProcessor2_0
+from torch.hub import download_url_to_file
+
+from src.ic_light import BGSource
+from src.animatediff_pipe import AnimateDiffVideoToVideoPipeline
+from src.ic_light_pipe import StableDiffusionImg2ImgPipeline
+from utils.tools import read_video, set_all_seed
+
+def main(args):
+
+ config = OmegaConf.load(args.config)
+ device = torch.device('cuda')
+ adopted_dtype = torch.float16
+ set_all_seed(42)
+
+ ## vdm model
+ adapter = MotionAdapter.from_pretrained(args.motion_adapter_model)
+
+ ## pipeline
+ pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(args.sd_model, motion_adapter=adapter)
+ eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
+ args.sd_model,
+ subfolder="scheduler",
+ beta_schedule="linear",
+ )
+
+ pipe.scheduler = eul_scheduler
+ pipe.enable_vae_slicing()
+ pipe = pipe.to(device=device, dtype=adopted_dtype)
+ pipe.vae.requires_grad_(False)
+ pipe.unet.requires_grad_(False)
+
+ ## ic-light model
+ tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet")
+ with torch.no_grad():
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
+ new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3])
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ new_conv_in.bias = unet.conv_in.bias
+ unet.conv_in = new_conv_in
+ unet_original_forward = unet.forward
+
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
+
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
+ new_sample = torch.cat([sample, c_concat], dim=1)
+ kwargs['cross_attention_kwargs'] = {}
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
+ unet.forward = hooked_unet_forward
+
+ ## ic-light model loader
+ if not os.path.exists(args.ic_light_model):
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors',
+ dst=args.ic_light_model)
+
+ sd_offset = sf.load_file(args.ic_light_model)
+ sd_origin = unet.state_dict()
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
+ unet.load_state_dict(sd_merged, strict=True)
+ del sd_offset, sd_origin, sd_merged
+ text_encoder = text_encoder.to(device=device, dtype=adopted_dtype)
+ vae = vae.to(device=device, dtype=adopted_dtype)
+ unet = unet.to(device=device, dtype=adopted_dtype)
+ unet.set_attn_processor(AttnProcessor2_0())
+ vae.set_attn_processor(AttnProcessor2_0())
+
+ # Consistent light attention
+ @torch.inference_mode()
+ def custom_forward_CLA(self,
+ hidden_states,
+ gamma=config.get("gamma", 0.5),
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None
+ ):
+
+ batch_size, sequence_length, channel = hidden_states.shape
+
+ residual = hidden_states
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = self.to_q(hidden_states)
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.heads
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
+ shape = query.shape
+
+ # addition key and value
+ mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
+ mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
+ mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
+ mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
+ add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False)
+
+ # mix
+ hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = self.to_out[0](hidden_states)
+ hidden_states = self.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / self.rescale_output_factor
+ return hidden_states
+
+ ### attention
+ @torch.inference_mode()
+ def prep_unet_self_attention(unet):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+
+ name_split_list = name.split(".")
+ cond_1 = name_split_list[0] in "up_blocks"
+ cond_2 = name_split_list[-1] in ('attn1')
+
+ if "Attention" in module_name and cond_1 and cond_2:
+ cond_3 = name_split_list[1]
+ if cond_3 not in "3":
+ module.forward = MethodType(custom_forward_CLA, module)
+
+ return unet
+
+ ## consistency light attention
+ unet = prep_unet_self_attention(unet)
+
+ ## ic-light-scheduler
+ ic_light_scheduler = DPMSolverMultistepScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ algorithm_type="sde-dpmsolver++",
+ use_karras_sigmas=True,
+ steps_offset=1
+ )
+ ic_light_pipe = StableDiffusionImg2ImgPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=ic_light_scheduler,
+ safety_checker=None,
+ requires_safety_checker=False,
+ feature_extractor=None,
+ image_encoder=None
+ )
+ ic_light_pipe = ic_light_pipe.to(device)
+
+ ############################# params ######################################
+ strength = config.get("strength", 0.5)
+ num_step = config.get("num_step", 25)
+ text_guide_scale = config.get("text_guide_scale", 2)
+ seed = config.get("seed")
+ image_width = config.get("width", 512)
+ image_height = config.get("height", 512)
+ n_prompt = config.get("n_prompt", "")
+ relight_prompt = config.get("relight_prompt", "")
+ video_path = config.get("video_path", "")
+ bg_source = BGSource[config.get("bg_source")]
+ save_path = config.get("save_path")
+
+ ############################## infer #####################################
+ generator = torch.manual_seed(seed)
+ video_name = os.path.basename(video_path)
+ video_list, video_name = read_video(video_path, image_width, image_height)
+
+ print("################## begin ##################")
+ with torch.no_grad():
+ num_inference_steps = int(round(num_step / strength))
+
+ output = pipe(
+ ic_light_pipe=ic_light_pipe,
+ relight_prompt=relight_prompt,
+ bg_source=bg_source,
+ video=video_list,
+ prompt=relight_prompt,
+ strength=strength,
+ negative_prompt=n_prompt,
+ guidance_scale=text_guide_scale,
+ num_inference_steps=num_inference_steps,
+ height=image_height,
+ width=image_width,
+ generator=generator,
+ )
+
+ frames = output.frames[0]
+ results_path = f"{save_path}/relight_{video_name}"
+ imageio.mimwrite(results_path, frames, fps=8)
+ print(f"relight with bg generation! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--sd_model", type=str, default="stablediffusionapi/realistic-vision-v51")
+ parser.add_argument("--motion_adapter_model", type=str, default="guoyww/animatediff-motion-adapter-v1-5-3")
+ parser.add_argument("--ic_light_model", type=str, default="./models/iclight_sd15_fc.safetensors")
+
+ parser.add_argument("--config", type=str, default="configs/relight/car.yaml", help="the config file for each sample.")
+
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/masks_animatediff/bloom/000.png b/masks_animatediff/bloom/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b532a540acf1bdb760c670cf468a49d9e6b8e89b
Binary files /dev/null and b/masks_animatediff/bloom/000.png differ
diff --git a/masks_animatediff/bloom/001.png b/masks_animatediff/bloom/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b4a7d8846b644b77475ca49c613d02f4825361c
Binary files /dev/null and b/masks_animatediff/bloom/001.png differ
diff --git a/masks_animatediff/bloom/002.png b/masks_animatediff/bloom/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1a8a59809e5d93d7a7d55569e1a0ea91af46456
Binary files /dev/null and b/masks_animatediff/bloom/002.png differ
diff --git a/masks_animatediff/bloom/003.png b/masks_animatediff/bloom/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..b7b192c715eb4735090562f3c77a060ec9b694e7
Binary files /dev/null and b/masks_animatediff/bloom/003.png differ
diff --git a/masks_animatediff/bloom/004.png b/masks_animatediff/bloom/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..3f73a6b25f831b2eeb579288859353197377b6e9
Binary files /dev/null and b/masks_animatediff/bloom/004.png differ
diff --git a/masks_animatediff/bloom/005.png b/masks_animatediff/bloom/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1fee5c9983abad319086d7e1ec8c80f4db7da09
Binary files /dev/null and b/masks_animatediff/bloom/005.png differ
diff --git a/masks_animatediff/bloom/006.png b/masks_animatediff/bloom/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc9f5947337cde384ba9deaa5f490154dc84c8a5
Binary files /dev/null and b/masks_animatediff/bloom/006.png differ
diff --git a/masks_animatediff/bloom/007.png b/masks_animatediff/bloom/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0eae407ce699dbba960a38aa9718cebcfa52c8a
Binary files /dev/null and b/masks_animatediff/bloom/007.png differ
diff --git a/masks_animatediff/bloom/008.png b/masks_animatediff/bloom/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7a7d85e8b898a675f1a70f21713219618c99e7f
Binary files /dev/null and b/masks_animatediff/bloom/008.png differ
diff --git a/masks_animatediff/bloom/009.png b/masks_animatediff/bloom/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..bfee78eaae723da5a269514aea059f873fdcea29
Binary files /dev/null and b/masks_animatediff/bloom/009.png differ
diff --git a/masks_animatediff/bloom/010.png b/masks_animatediff/bloom/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..12c438f8875d8af7d46957015f71266ca1f4fe22
Binary files /dev/null and b/masks_animatediff/bloom/010.png differ
diff --git a/masks_animatediff/bloom/011.png b/masks_animatediff/bloom/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a4ad0b371220ec8e7caeafe01302065d85bd7bf
Binary files /dev/null and b/masks_animatediff/bloom/011.png differ
diff --git a/masks_animatediff/bloom/012.png b/masks_animatediff/bloom/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..3dcadd1bf6c76777cc56d4b787cf52af23f7855f
Binary files /dev/null and b/masks_animatediff/bloom/012.png differ
diff --git a/masks_animatediff/bloom/013.png b/masks_animatediff/bloom/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc017574c69d20a860a1516a78cf9d87ec77b9b7
Binary files /dev/null and b/masks_animatediff/bloom/013.png differ
diff --git a/masks_animatediff/bloom/014.png b/masks_animatediff/bloom/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..db61ab1315932d2059b5ec1fa2df3d1f2ad43220
Binary files /dev/null and b/masks_animatediff/bloom/014.png differ
diff --git a/masks_animatediff/bloom/015.png b/masks_animatediff/bloom/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ffa569e01bb1ec46fd8c287b3cbf38175cf99a6
Binary files /dev/null and b/masks_animatediff/bloom/015.png differ
diff --git a/masks_animatediff/camera/000.png b/masks_animatediff/camera/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..38aec99437f815998200dcf598005d9e030afb33
Binary files /dev/null and b/masks_animatediff/camera/000.png differ
diff --git a/masks_animatediff/camera/001.png b/masks_animatediff/camera/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fb3fba11f8ec9640dc579ccdb820eb1fe8caccc
Binary files /dev/null and b/masks_animatediff/camera/001.png differ
diff --git a/masks_animatediff/camera/002.png b/masks_animatediff/camera/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9bf84d11f14d00b7c9eced74ac3f4a0a1cb0124
Binary files /dev/null and b/masks_animatediff/camera/002.png differ
diff --git a/masks_animatediff/camera/003.png b/masks_animatediff/camera/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4f45f75eae3b5de2ae0c3346a3f81dbd3dbbb83
Binary files /dev/null and b/masks_animatediff/camera/003.png differ
diff --git a/masks_animatediff/camera/004.png b/masks_animatediff/camera/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..8a52d8b474a3e7c73bc765122f5e5c1345247013
Binary files /dev/null and b/masks_animatediff/camera/004.png differ
diff --git a/masks_animatediff/camera/005.png b/masks_animatediff/camera/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..81b9561c65d1d3c45aceebccabffb195ffe78686
Binary files /dev/null and b/masks_animatediff/camera/005.png differ
diff --git a/masks_animatediff/camera/006.png b/masks_animatediff/camera/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..35443bd535614b149489720bf547d0bdc8ca3d3a
Binary files /dev/null and b/masks_animatediff/camera/006.png differ
diff --git a/masks_animatediff/camera/007.png b/masks_animatediff/camera/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..4eb31dde1ecce35198c01ae65d4cc5ea76628c5b
Binary files /dev/null and b/masks_animatediff/camera/007.png differ
diff --git a/masks_animatediff/camera/008.png b/masks_animatediff/camera/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..ad53dec1080520bdbd9a28ebd0bcd7f15923275a
Binary files /dev/null and b/masks_animatediff/camera/008.png differ
diff --git a/masks_animatediff/camera/009.png b/masks_animatediff/camera/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..085265013a1496407824b31ddebaaecac857669f
Binary files /dev/null and b/masks_animatediff/camera/009.png differ
diff --git a/masks_animatediff/camera/010.png b/masks_animatediff/camera/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..56762f4273eb9bdc2d70bee42e47d3bb1330b3ac
Binary files /dev/null and b/masks_animatediff/camera/010.png differ
diff --git a/masks_animatediff/camera/011.png b/masks_animatediff/camera/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce43e8942fecf608f39f2ae4faa5c06e65ca6d72
Binary files /dev/null and b/masks_animatediff/camera/011.png differ
diff --git a/masks_animatediff/camera/012.png b/masks_animatediff/camera/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f3778c1536f69e19bced93cfaeeabf4b57b5b46
Binary files /dev/null and b/masks_animatediff/camera/012.png differ
diff --git a/masks_animatediff/camera/013.png b/masks_animatediff/camera/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..c6cd18829766ffea1483ad721a973b84ae34a4bd
Binary files /dev/null and b/masks_animatediff/camera/013.png differ
diff --git a/masks_animatediff/camera/014.png b/masks_animatediff/camera/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce4fc1b2d1ce3713230a4292b097e26219ba6d68
Binary files /dev/null and b/masks_animatediff/camera/014.png differ
diff --git a/masks_animatediff/camera/015.png b/masks_animatediff/camera/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d86b9bc5f98f2284a442f13ad14752e9426dc98
Binary files /dev/null and b/masks_animatediff/camera/015.png differ
diff --git a/masks_animatediff/car/000.png b/masks_animatediff/car/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b4109107fac5e59c588ea24b996044d5cafe15c
Binary files /dev/null and b/masks_animatediff/car/000.png differ
diff --git a/masks_animatediff/car/001.png b/masks_animatediff/car/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e8aa98ce53aa32faabe59b389a5956444e9cfb3
Binary files /dev/null and b/masks_animatediff/car/001.png differ
diff --git a/masks_animatediff/car/002.png b/masks_animatediff/car/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..65c1a187544f0e99ceda72fc95e17743373752e4
Binary files /dev/null and b/masks_animatediff/car/002.png differ
diff --git a/masks_animatediff/car/003.png b/masks_animatediff/car/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..f6b33d8f106982823873e7ac9494ca326b637a72
Binary files /dev/null and b/masks_animatediff/car/003.png differ
diff --git a/masks_animatediff/car/004.png b/masks_animatediff/car/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..aff2b03fdf90221416dd39b6a1893654f2e4aa15
Binary files /dev/null and b/masks_animatediff/car/004.png differ
diff --git a/masks_animatediff/car/005.png b/masks_animatediff/car/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..9eab771a4560fd4eea16cdb8a8df4f601705e20d
Binary files /dev/null and b/masks_animatediff/car/005.png differ
diff --git a/masks_animatediff/car/006.png b/masks_animatediff/car/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..825086558a68860f82d684b20dd8399e7c2d9228
Binary files /dev/null and b/masks_animatediff/car/006.png differ
diff --git a/masks_animatediff/car/007.png b/masks_animatediff/car/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b68003b619f65dead2ffc5cb547dcf4a19d2b42
Binary files /dev/null and b/masks_animatediff/car/007.png differ
diff --git a/masks_animatediff/car/008.png b/masks_animatediff/car/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..612782d39c47ff125cb714dcbeff76d6ddcb8f48
Binary files /dev/null and b/masks_animatediff/car/008.png differ
diff --git a/masks_animatediff/car/009.png b/masks_animatediff/car/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..468c522586c2e845064e5c646519d66638e90887
Binary files /dev/null and b/masks_animatediff/car/009.png differ
diff --git a/masks_animatediff/car/010.png b/masks_animatediff/car/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..4bd1289ad42fa6bffd8d8ec175acaf5cd778daef
Binary files /dev/null and b/masks_animatediff/car/010.png differ
diff --git a/masks_animatediff/car/011.png b/masks_animatediff/car/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..a6d492b75f602da182f5c18ecfcd4924f598ae49
Binary files /dev/null and b/masks_animatediff/car/011.png differ
diff --git a/masks_animatediff/car/012.png b/masks_animatediff/car/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..c59681eafe43ecc8ea8fe63baaf6fdb1e55227fb
Binary files /dev/null and b/masks_animatediff/car/012.png differ
diff --git a/masks_animatediff/car/013.png b/masks_animatediff/car/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..967e948d8162e94a1dae775280bc2cc668ffbbb2
Binary files /dev/null and b/masks_animatediff/car/013.png differ
diff --git a/masks_animatediff/car/014.png b/masks_animatediff/car/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..419b68d66c5db3cca460ad83b7ff7cdf6f8713ce
Binary files /dev/null and b/masks_animatediff/car/014.png differ
diff --git a/masks_animatediff/car/015.png b/masks_animatediff/car/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..d7ffbb5e6fab742274fe01f9a5a5a0965919f335
Binary files /dev/null and b/masks_animatediff/car/015.png differ
diff --git a/masks_animatediff/cat2/000.png b/masks_animatediff/cat2/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..95d67ee3de758c44be2e7c81ae2c864814573d5d
Binary files /dev/null and b/masks_animatediff/cat2/000.png differ
diff --git a/masks_animatediff/cat2/001.png b/masks_animatediff/cat2/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d08db2100ba59f707beef49d555bd9df9dcdaa5
Binary files /dev/null and b/masks_animatediff/cat2/001.png differ
diff --git a/masks_animatediff/cat2/002.png b/masks_animatediff/cat2/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..c70fd1ef39c8657f983b6a8dda35f59ca5dd23ea
Binary files /dev/null and b/masks_animatediff/cat2/002.png differ
diff --git a/masks_animatediff/cat2/003.png b/masks_animatediff/cat2/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..abf2e201bd4c21a42242575ed88df226010258b8
Binary files /dev/null and b/masks_animatediff/cat2/003.png differ
diff --git a/masks_animatediff/cat2/004.png b/masks_animatediff/cat2/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..060ce94dfe74040af6cc8abfe0b5de8ccdd9e1e6
Binary files /dev/null and b/masks_animatediff/cat2/004.png differ
diff --git a/masks_animatediff/cat2/005.png b/masks_animatediff/cat2/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..63a424962b6e0b18dea69541a52405f4a8bfb2b6
Binary files /dev/null and b/masks_animatediff/cat2/005.png differ
diff --git a/masks_animatediff/cat2/006.png b/masks_animatediff/cat2/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..f826395ca8c178371ff37877bf2afc503bcd2a03
Binary files /dev/null and b/masks_animatediff/cat2/006.png differ
diff --git a/masks_animatediff/cat2/007.png b/masks_animatediff/cat2/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..d56b4983ac161030c513846a7abcf3a4d2280998
Binary files /dev/null and b/masks_animatediff/cat2/007.png differ
diff --git a/masks_animatediff/cat2/008.png b/masks_animatediff/cat2/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..e58c787c128ec3154f02a05d20607d08295e89d4
Binary files /dev/null and b/masks_animatediff/cat2/008.png differ
diff --git a/masks_animatediff/cat2/009.png b/masks_animatediff/cat2/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..36ddbdcbcc09900c489cc7d9067a421bc496bfb5
Binary files /dev/null and b/masks_animatediff/cat2/009.png differ
diff --git a/masks_animatediff/cat2/010.png b/masks_animatediff/cat2/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..f3d2c891defdfc7f731b38473e4bbe64d78d82c4
Binary files /dev/null and b/masks_animatediff/cat2/010.png differ
diff --git a/masks_animatediff/cat2/011.png b/masks_animatediff/cat2/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..1140fc7aae1b7d064dce99c888a50357d08a5dad
Binary files /dev/null and b/masks_animatediff/cat2/011.png differ
diff --git a/masks_animatediff/cat2/012.png b/masks_animatediff/cat2/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..d5a521b13a2d4c62d37e1fcaa8a2ec12e5b7b98f
Binary files /dev/null and b/masks_animatediff/cat2/012.png differ
diff --git a/masks_animatediff/cat2/013.png b/masks_animatediff/cat2/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..c63d8fc21806f95f1425e1d3f397f9e6d2502a1c
Binary files /dev/null and b/masks_animatediff/cat2/013.png differ
diff --git a/masks_animatediff/cat2/014.png b/masks_animatediff/cat2/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..52c5dace779eaec9b8865dcdee991994aa123a9d
Binary files /dev/null and b/masks_animatediff/cat2/014.png differ
diff --git a/masks_animatediff/cat2/015.png b/masks_animatediff/cat2/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..72194cc25469383d113a44da4b00ac55e039dc3a
Binary files /dev/null and b/masks_animatediff/cat2/015.png differ
diff --git a/masks_animatediff/coin/000.png b/masks_animatediff/coin/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..69193093e545b32683727cdd1115d684bd6f077d
Binary files /dev/null and b/masks_animatediff/coin/000.png differ
diff --git a/masks_animatediff/coin/001.png b/masks_animatediff/coin/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..db98ac6f82515156520ee76b750743e8cd851d84
Binary files /dev/null and b/masks_animatediff/coin/001.png differ
diff --git a/masks_animatediff/coin/002.png b/masks_animatediff/coin/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc63849557e26ff444bd14cbf875db475a4c94b4
Binary files /dev/null and b/masks_animatediff/coin/002.png differ
diff --git a/masks_animatediff/coin/003.png b/masks_animatediff/coin/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..79eee9e87fc0f8507a206b11f381d396d6fec0c7
Binary files /dev/null and b/masks_animatediff/coin/003.png differ
diff --git a/masks_animatediff/coin/004.png b/masks_animatediff/coin/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..4444694a52c9df88952f01871613ab65e2da5a5d
Binary files /dev/null and b/masks_animatediff/coin/004.png differ
diff --git a/masks_animatediff/coin/005.png b/masks_animatediff/coin/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..393c84d99de48cbac613a2bc179b6bb6f816d2ab
Binary files /dev/null and b/masks_animatediff/coin/005.png differ
diff --git a/masks_animatediff/coin/006.png b/masks_animatediff/coin/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c70da00496c288c50db53d69ba21bbf6ebe3c54
Binary files /dev/null and b/masks_animatediff/coin/006.png differ
diff --git a/masks_animatediff/coin/007.png b/masks_animatediff/coin/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7673521101e5727eac19edbca57f94734a40971
Binary files /dev/null and b/masks_animatediff/coin/007.png differ
diff --git a/masks_animatediff/coin/008.png b/masks_animatediff/coin/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..58a645bc2654e2fe1559049b6e67b9cf36103d12
Binary files /dev/null and b/masks_animatediff/coin/008.png differ
diff --git a/masks_animatediff/coin/009.png b/masks_animatediff/coin/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9b103353151a362fb100f07773eec467a9f2a03
Binary files /dev/null and b/masks_animatediff/coin/009.png differ
diff --git a/masks_animatediff/coin/010.png b/masks_animatediff/coin/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..b261ad98f532a5f5abbf9f082afbbfdb3e25d700
Binary files /dev/null and b/masks_animatediff/coin/010.png differ
diff --git a/masks_animatediff/coin/011.png b/masks_animatediff/coin/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4454820ed3122c452f18ecc0ca969806a56fcdf
Binary files /dev/null and b/masks_animatediff/coin/011.png differ
diff --git a/masks_animatediff/coin/012.png b/masks_animatediff/coin/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..932bf53a5823bd79c91543115e3b42adcb8e5a32
Binary files /dev/null and b/masks_animatediff/coin/012.png differ
diff --git a/masks_animatediff/coin/013.png b/masks_animatediff/coin/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..953c17a25da20ae149cda8b3721e257576858259
Binary files /dev/null and b/masks_animatediff/coin/013.png differ
diff --git a/masks_animatediff/coin/014.png b/masks_animatediff/coin/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..08ec3a83c525997738c9136f3179c03d2566a5f6
Binary files /dev/null and b/masks_animatediff/coin/014.png differ
diff --git a/masks_animatediff/coin/015.png b/masks_animatediff/coin/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..e1ab6f305cf9018b3046c24798fe015e5cbc734b
Binary files /dev/null and b/masks_animatediff/coin/015.png differ
diff --git a/masks_animatediff/dog2/000.png b/masks_animatediff/dog2/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..116c06d120b79bea3d3015525f64c779acc25034
Binary files /dev/null and b/masks_animatediff/dog2/000.png differ
diff --git a/masks_animatediff/dog2/001.png b/masks_animatediff/dog2/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9440c57be2c2342b9b8836c5c0310edd3963bc2
Binary files /dev/null and b/masks_animatediff/dog2/001.png differ
diff --git a/masks_animatediff/dog2/002.png b/masks_animatediff/dog2/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce4e075254260a833deb6c6992c85f555a6adecc
Binary files /dev/null and b/masks_animatediff/dog2/002.png differ
diff --git a/masks_animatediff/dog2/003.png b/masks_animatediff/dog2/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..dafb4bab6d0ce6a2ded722a92ba494b302624a56
Binary files /dev/null and b/masks_animatediff/dog2/003.png differ
diff --git a/masks_animatediff/dog2/004.png b/masks_animatediff/dog2/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..6300b543533646332c0e083070200f9661f46b85
Binary files /dev/null and b/masks_animatediff/dog2/004.png differ
diff --git a/masks_animatediff/dog2/005.png b/masks_animatediff/dog2/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..1535d20aaa43bba2cddfcb68f72005d0091ba8ae
Binary files /dev/null and b/masks_animatediff/dog2/005.png differ
diff --git a/masks_animatediff/dog2/006.png b/masks_animatediff/dog2/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..327c727ee4078a2041090bb5b7b55dd23e7d1aad
Binary files /dev/null and b/masks_animatediff/dog2/006.png differ
diff --git a/masks_animatediff/dog2/007.png b/masks_animatediff/dog2/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..ebeeb56b12f539c06f63a0bba991ddde4fc4ae5a
Binary files /dev/null and b/masks_animatediff/dog2/007.png differ
diff --git a/masks_animatediff/dog2/008.png b/masks_animatediff/dog2/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..fcc5aa674ee5b47b160ca071ac4ff3030ad52c50
Binary files /dev/null and b/masks_animatediff/dog2/008.png differ
diff --git a/masks_animatediff/dog2/009.png b/masks_animatediff/dog2/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..50182d3bf840ee707f8daee1f3ebf7cf080efe40
Binary files /dev/null and b/masks_animatediff/dog2/009.png differ
diff --git a/masks_animatediff/dog2/010.png b/masks_animatediff/dog2/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5383c2b2af56d6feed6b7b8a535de30b448adb9
Binary files /dev/null and b/masks_animatediff/dog2/010.png differ
diff --git a/masks_animatediff/dog2/011.png b/masks_animatediff/dog2/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..57307723a71118e5291873dd5a8915fc8c5d2afd
Binary files /dev/null and b/masks_animatediff/dog2/011.png differ
diff --git a/masks_animatediff/dog2/012.png b/masks_animatediff/dog2/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..93491e6324b533fd29d378abb1a89c765b9f0b44
Binary files /dev/null and b/masks_animatediff/dog2/012.png differ
diff --git a/masks_animatediff/dog2/013.png b/masks_animatediff/dog2/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..8e6cb0113ab0cbb5b5e43a7d5a9f50be2afa0d1d
Binary files /dev/null and b/masks_animatediff/dog2/013.png differ
diff --git a/masks_animatediff/dog2/014.png b/masks_animatediff/dog2/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..34812fd45ed604394a05fc0f995d5da0a80ca438
Binary files /dev/null and b/masks_animatediff/dog2/014.png differ
diff --git a/masks_animatediff/dog2/015.png b/masks_animatediff/dog2/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..338844e52152285820b3ebbf377170c0019339fe
Binary files /dev/null and b/masks_animatediff/dog2/015.png differ
diff --git a/masks_animatediff/duck/0.png b/masks_animatediff/duck/0.png
new file mode 100644
index 0000000000000000000000000000000000000000..688594611b557e1bf9c75029d82aadcde32c785a
Binary files /dev/null and b/masks_animatediff/duck/0.png differ
diff --git a/masks_animatediff/duck/1.png b/masks_animatediff/duck/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..d50ceb1de041d12540bab125b49492510e9e7b15
Binary files /dev/null and b/masks_animatediff/duck/1.png differ
diff --git a/masks_animatediff/duck/10.png b/masks_animatediff/duck/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..f45dbcd93a116f12a94d1b6cda6aa91c426f9fd1
Binary files /dev/null and b/masks_animatediff/duck/10.png differ
diff --git a/masks_animatediff/duck/11.png b/masks_animatediff/duck/11.png
new file mode 100644
index 0000000000000000000000000000000000000000..ad86abf2526d625d75a8f319b1178a6802d52188
Binary files /dev/null and b/masks_animatediff/duck/11.png differ
diff --git a/masks_animatediff/duck/12.png b/masks_animatediff/duck/12.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d6279bc1d74f7284c16932c271a1cdfb5d085ca
Binary files /dev/null and b/masks_animatediff/duck/12.png differ
diff --git a/masks_animatediff/duck/13.png b/masks_animatediff/duck/13.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d0e72bb13bf269d7655d818da5484cff344ca33
Binary files /dev/null and b/masks_animatediff/duck/13.png differ
diff --git a/masks_animatediff/duck/14.png b/masks_animatediff/duck/14.png
new file mode 100644
index 0000000000000000000000000000000000000000..7cb1e1d91c97add9ac97ab1206ecb3b01d7e6db5
Binary files /dev/null and b/masks_animatediff/duck/14.png differ
diff --git a/masks_animatediff/duck/15.png b/masks_animatediff/duck/15.png
new file mode 100644
index 0000000000000000000000000000000000000000..23a50224cc99f246216cef33cc620b780e5d31c9
Binary files /dev/null and b/masks_animatediff/duck/15.png differ
diff --git a/masks_animatediff/duck/2.png b/masks_animatediff/duck/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3a63a023ae0b0cc4a14647a35b36c0466675017
Binary files /dev/null and b/masks_animatediff/duck/2.png differ
diff --git a/masks_animatediff/duck/3.png b/masks_animatediff/duck/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..122532c2ba63eaaa673e8ea83d32914fda5405d6
Binary files /dev/null and b/masks_animatediff/duck/3.png differ
diff --git a/masks_animatediff/duck/4.png b/masks_animatediff/duck/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..62a53760baf8f1471937176626695cd5e8958967
Binary files /dev/null and b/masks_animatediff/duck/4.png differ
diff --git a/masks_animatediff/duck/5.png b/masks_animatediff/duck/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..99a3eabbbcdf59f4284060db4ddefa5eccce3f56
Binary files /dev/null and b/masks_animatediff/duck/5.png differ
diff --git a/masks_animatediff/duck/6.png b/masks_animatediff/duck/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..f7aac608455e1b10008c6f55b6f95772fb490871
Binary files /dev/null and b/masks_animatediff/duck/6.png differ
diff --git a/masks_animatediff/duck/7.png b/masks_animatediff/duck/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ce3600b554ec2c0f41fc37c2973659bbd166c1d
Binary files /dev/null and b/masks_animatediff/duck/7.png differ
diff --git a/masks_animatediff/duck/8.png b/masks_animatediff/duck/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..816a9c7b83d584ebe18cc6135421221a22f9dcfc
Binary files /dev/null and b/masks_animatediff/duck/8.png differ
diff --git a/masks_animatediff/duck/9.png b/masks_animatediff/duck/9.png
new file mode 100644
index 0000000000000000000000000000000000000000..3782ee9555fb62c38f46215d7a58a165361e9438
Binary files /dev/null and b/masks_animatediff/duck/9.png differ
diff --git a/masks_animatediff/man3/000.png b/masks_animatediff/man3/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..993d9213a49998bf9b0d455db84cb881d481f6c8
Binary files /dev/null and b/masks_animatediff/man3/000.png differ
diff --git a/masks_animatediff/man3/001.png b/masks_animatediff/man3/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..169d1c632edfb763fe8e7247d06dfd8a80a74e36
Binary files /dev/null and b/masks_animatediff/man3/001.png differ
diff --git a/masks_animatediff/man3/002.png b/masks_animatediff/man3/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c98167ade5e913a22b1246f102ffa6e5b5f0079
Binary files /dev/null and b/masks_animatediff/man3/002.png differ
diff --git a/masks_animatediff/man3/003.png b/masks_animatediff/man3/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..a338d354b48f6e2c0c66298bd19a3c0a7bdf3104
Binary files /dev/null and b/masks_animatediff/man3/003.png differ
diff --git a/masks_animatediff/man3/004.png b/masks_animatediff/man3/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b9a1a61faa22c25ba48fbea0807377fc5357553
Binary files /dev/null and b/masks_animatediff/man3/004.png differ
diff --git a/masks_animatediff/man3/005.png b/masks_animatediff/man3/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f995cb05bf928389207e4f70440ec5b090b9add
Binary files /dev/null and b/masks_animatediff/man3/005.png differ
diff --git a/masks_animatediff/man3/006.png b/masks_animatediff/man3/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..e63b608af938b67d5de36ea2c889d9707025e172
Binary files /dev/null and b/masks_animatediff/man3/006.png differ
diff --git a/masks_animatediff/man3/007.png b/masks_animatediff/man3/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..caa5ea310d1a16f8ce2b043c8de066c9ad0c1535
Binary files /dev/null and b/masks_animatediff/man3/007.png differ
diff --git a/masks_animatediff/man3/008.png b/masks_animatediff/man3/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..6e9cd67cb8d0cc694d3977178c6d0e41068bf753
Binary files /dev/null and b/masks_animatediff/man3/008.png differ
diff --git a/masks_animatediff/man3/009.png b/masks_animatediff/man3/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..82415fdd735602caca46a89dbe0034c18c952734
Binary files /dev/null and b/masks_animatediff/man3/009.png differ
diff --git a/masks_animatediff/man3/010.png b/masks_animatediff/man3/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9bd0fb652ff1eeddd6ece16bc0e3f6c497ce459
Binary files /dev/null and b/masks_animatediff/man3/010.png differ
diff --git a/masks_animatediff/man3/011.png b/masks_animatediff/man3/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..15ee886f675c580b7879aadcc10420a070ee29c8
Binary files /dev/null and b/masks_animatediff/man3/011.png differ
diff --git a/masks_animatediff/man3/012.png b/masks_animatediff/man3/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..ff98d1b20f02fec9a6b240a9e7a1a11c0b0f1dab
Binary files /dev/null and b/masks_animatediff/man3/012.png differ
diff --git a/masks_animatediff/man3/013.png b/masks_animatediff/man3/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..13a051683a98597488b95599ad66cdcf6df19800
Binary files /dev/null and b/masks_animatediff/man3/013.png differ
diff --git a/masks_animatediff/man3/014.png b/masks_animatediff/man3/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..bc78c6b845b4143e65a826b3d0a905f8637dd0db
Binary files /dev/null and b/masks_animatediff/man3/014.png differ
diff --git a/masks_animatediff/man3/015.png b/masks_animatediff/man3/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..0383fc8765193592187882080bd1eb088f5a17db
Binary files /dev/null and b/masks_animatediff/man3/015.png differ
diff --git a/masks_animatediff/water/000.png b/masks_animatediff/water/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..40ad27d7d9ecb17e36d65db40e117a57a617adcb
Binary files /dev/null and b/masks_animatediff/water/000.png differ
diff --git a/masks_animatediff/water/001.png b/masks_animatediff/water/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e021f2d241175c36b43a6187092395a944e9e8b
Binary files /dev/null and b/masks_animatediff/water/001.png differ
diff --git a/masks_animatediff/water/002.png b/masks_animatediff/water/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..64e9690aac875a62b5d3d1a65b8a047584b5a5aa
Binary files /dev/null and b/masks_animatediff/water/002.png differ
diff --git a/masks_animatediff/water/003.png b/masks_animatediff/water/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..015f20383db3e026a59f919ce5d1536bae837d8e
Binary files /dev/null and b/masks_animatediff/water/003.png differ
diff --git a/masks_animatediff/water/004.png b/masks_animatediff/water/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..06ff1b0dbcf05a18d9e75e33bfb03f6c6af16fa2
Binary files /dev/null and b/masks_animatediff/water/004.png differ
diff --git a/masks_animatediff/water/005.png b/masks_animatediff/water/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..8dcb166ea3002740ee0c7b86aab142c908cb0487
Binary files /dev/null and b/masks_animatediff/water/005.png differ
diff --git a/masks_animatediff/water/006.png b/masks_animatediff/water/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..91df7769681cb557722c98b1d2fc419cc3e574b3
Binary files /dev/null and b/masks_animatediff/water/006.png differ
diff --git a/masks_animatediff/water/007.png b/masks_animatediff/water/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..8130786eb0466784b4fcf5998437677c83496774
Binary files /dev/null and b/masks_animatediff/water/007.png differ
diff --git a/masks_animatediff/water/008.png b/masks_animatediff/water/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..f7c58444693684dd908888a4ad2b8aad5c2b1907
Binary files /dev/null and b/masks_animatediff/water/008.png differ
diff --git a/masks_animatediff/water/009.png b/masks_animatediff/water/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..e86be34a96d2dcd08c78e6b149c5304f08e1eac1
Binary files /dev/null and b/masks_animatediff/water/009.png differ
diff --git a/masks_animatediff/water/010.png b/masks_animatediff/water/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..29c868d22c98b91f9e82e4c92a02353800108308
Binary files /dev/null and b/masks_animatediff/water/010.png differ
diff --git a/masks_animatediff/water/011.png b/masks_animatediff/water/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..d738f8ac78fd015eab72127eb549d9808ee26ed8
Binary files /dev/null and b/masks_animatediff/water/011.png differ
diff --git a/masks_animatediff/water/012.png b/masks_animatediff/water/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..031f8adc9bbdf04f94273e815c578e460f67627a
Binary files /dev/null and b/masks_animatediff/water/012.png differ
diff --git a/masks_animatediff/water/013.png b/masks_animatediff/water/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..98d229ce88963f0abff190a26924455742cae9cb
Binary files /dev/null and b/masks_animatediff/water/013.png differ
diff --git a/masks_animatediff/water/014.png b/masks_animatediff/water/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3385af69d0a6837c9b79a72aa00a4b34478680e
Binary files /dev/null and b/masks_animatediff/water/014.png differ
diff --git a/masks_animatediff/water/015.png b/masks_animatediff/water/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..a094ee766b87544b3bfdd0841950bfc7b91a710d
Binary files /dev/null and b/masks_animatediff/water/015.png differ
diff --git a/masks_animatediff/wolf2/000.png b/masks_animatediff/wolf2/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..843861b00d0079bc8c5288e4d67726eb4bafc1d4
Binary files /dev/null and b/masks_animatediff/wolf2/000.png differ
diff --git a/masks_animatediff/wolf2/001.png b/masks_animatediff/wolf2/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7a2ba0ce887071e484a06d7af023f8f8302331c
Binary files /dev/null and b/masks_animatediff/wolf2/001.png differ
diff --git a/masks_animatediff/wolf2/002.png b/masks_animatediff/wolf2/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c7ef9624f711bbe2ce874283180cbaeaf701610
Binary files /dev/null and b/masks_animatediff/wolf2/002.png differ
diff --git a/masks_animatediff/wolf2/003.png b/masks_animatediff/wolf2/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..99c32b1c3a6cd0d85ce25c4251ed80c56246d3a7
Binary files /dev/null and b/masks_animatediff/wolf2/003.png differ
diff --git a/masks_animatediff/wolf2/004.png b/masks_animatediff/wolf2/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..c04e7ad6f09ddabec21b19ba6d694af6f200ba66
Binary files /dev/null and b/masks_animatediff/wolf2/004.png differ
diff --git a/masks_animatediff/wolf2/005.png b/masks_animatediff/wolf2/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..408b1f3b43d736243796216c72785d1821fe95bd
Binary files /dev/null and b/masks_animatediff/wolf2/005.png differ
diff --git a/masks_animatediff/wolf2/006.png b/masks_animatediff/wolf2/006.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb4f48a7b333138d331a27dec0af746452fb4e7b
Binary files /dev/null and b/masks_animatediff/wolf2/006.png differ
diff --git a/masks_animatediff/wolf2/007.png b/masks_animatediff/wolf2/007.png
new file mode 100644
index 0000000000000000000000000000000000000000..1243e6c83466fc97cd4284300c4c898128106d49
Binary files /dev/null and b/masks_animatediff/wolf2/007.png differ
diff --git a/masks_animatediff/wolf2/008.png b/masks_animatediff/wolf2/008.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a266a02f695df11ba6e00f33e78c4a053985dbb
Binary files /dev/null and b/masks_animatediff/wolf2/008.png differ
diff --git a/masks_animatediff/wolf2/009.png b/masks_animatediff/wolf2/009.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c809d1f24d05da4522f4170042a2af64e5f2440
Binary files /dev/null and b/masks_animatediff/wolf2/009.png differ
diff --git a/masks_animatediff/wolf2/010.png b/masks_animatediff/wolf2/010.png
new file mode 100644
index 0000000000000000000000000000000000000000..845548a1a0d0f0dd51d4014a2850073e55009ae2
Binary files /dev/null and b/masks_animatediff/wolf2/010.png differ
diff --git a/masks_animatediff/wolf2/011.png b/masks_animatediff/wolf2/011.png
new file mode 100644
index 0000000000000000000000000000000000000000..f2fb0cdb9eff47a87d84895865b4fdcf55b439b7
Binary files /dev/null and b/masks_animatediff/wolf2/011.png differ
diff --git a/masks_animatediff/wolf2/012.png b/masks_animatediff/wolf2/012.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a35f748dd53dcd256954fde030febf4b5769782
Binary files /dev/null and b/masks_animatediff/wolf2/012.png differ
diff --git a/masks_animatediff/wolf2/013.png b/masks_animatediff/wolf2/013.png
new file mode 100644
index 0000000000000000000000000000000000000000..9e97db6ae9f6db49393d8a2033ce072f5be4e625
Binary files /dev/null and b/masks_animatediff/wolf2/013.png differ
diff --git a/masks_animatediff/wolf2/014.png b/masks_animatediff/wolf2/014.png
new file mode 100644
index 0000000000000000000000000000000000000000..6fa60006114c29d65c6b5235762fb9f6031301c9
Binary files /dev/null and b/masks_animatediff/wolf2/014.png differ
diff --git a/masks_animatediff/wolf2/015.png b/masks_animatediff/wolf2/015.png
new file mode 100644
index 0000000000000000000000000000000000000000..8faf41738f983644e4bfca7105c8a9fdc344c2c7
Binary files /dev/null and b/masks_animatediff/wolf2/015.png differ
diff --git a/models/model_download_here b/models/model_download_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/output/output_mp4 b/output/output_mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..17a06a8c2bace2372ef6903cd78ae615643dd0aa
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+diffusers==0.32.1
+transformers==4.48.0
+opencv-python
+safetensors
+pillow==10.3.0
+einops
+peft
+imageio
+omegaconf
+ultralytics
+tqdm==4.67.1
+protobuf==3.20.2
+torch==2.3.0
+torchvision==0.18.0
+moviepy==1.0.3
\ No newline at end of file
diff --git a/sam2.py b/sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33429ad99f9935dc1983b632df731b89a76c719
--- /dev/null
+++ b/sam2.py
@@ -0,0 +1,35 @@
+import os
+import torch
+import argparse
+import numpy as np
+from PIL import Image
+from ultralytics.models.sam import SAM2VideoPredictor
+
+
+def main(args):
+
+ # Create SAM2VideoPredictor
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt")
+ predictor = SAM2VideoPredictor(overrides=overrides)
+
+ video_name = args.video_name
+ results = predictor(source=f"input_animatediff/{video_name}.mp4",points=[args.x, args.y],labels=[1])
+
+ for i in range(len(results)):
+ mask = (results[i].masks.data).squeeze().to(torch.float16)
+ mask = (mask * 255).cpu().numpy().astype(np.uint8)
+ mask_image = Image.fromarray(mask)
+ mask_dir = f'masks_animatediff/{video_name}'
+ if not os.path.exists(mask_dir):
+ os.makedirs(mask_dir)
+ mask_image.save(mask_dir + f'/{str(i).zfill(3)}.png')
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process a video and generate masks using SAM2VideoPredictor.")
+ parser.add_argument("--video_name", type=str, required=True, help="Name of the video file (without extension).")
+ parser.add_argument("--x", type=int, default=255, help="X coordinate of the point.")
+ parser.add_argument("--y", type=int, default=255, help="Y coordinate of the point.")
+
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/src/animatediff_eul.py b/src/animatediff_eul.py
new file mode 100644
index 0000000000000000000000000000000000000000..70dadbb9a5de5d190098c0978d41e63264d2272a
--- /dev/null
+++ b/src/animatediff_eul.py
@@ -0,0 +1,94 @@
+import torch
+from typing import List, Optional, Tuple, Union
+
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ BaseOutput,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+def eul_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ fusion_latent,
+ pipe,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon": ## True, 计算x_0
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ elif self.config.prediction_type == "sample":
+ raise NotImplementedError("prediction_type not implemented yet: sample")
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ ## fusion latent
+ pred_original_sample = fusion_latent
+
+ sigma_from = self.sigmas[self.step_index]
+ sigma_to = self.sigmas[self.step_index + 1]
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ dt = sigma_down - sigma
+
+ prev_sample = sample + derivative * dt
+
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
+
+ prev_sample = prev_sample + noise * sigma_up
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerAncestralDiscreteSchedulerOutput(
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
+ )
\ No newline at end of file
diff --git a/src/animatediff_inpaint_pipe.py b/src/animatediff_inpaint_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43d501aa5e0560fa95ce728ba818e2179a987ab
--- /dev/null
+++ b/src/animatediff_inpaint_pipe.py
@@ -0,0 +1,1077 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_motion_model import MotionAdapter
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.pipelines.free_init_utils import FreeInitMixin
+from diffusers.pipelines.free_noise_utils import AnimateDiffFreeNoiseMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+
+from src.ic_light import Relighter
+from einops import rearrange
+from diffusers.utils import export_to_gif
+from src.animatediff_eul import eul_step
+import torch.nn.functional as F
+import numpy as np
+from utils.tools import numpy2pytorch
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import imageio
+ >>> import requests
+ >>> import torch
+ >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
+ >>> from diffusers.utils import export_to_gif
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> adapter = MotionAdapter.from_pretrained(
+ ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
+ ... )
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
+ ... ).to("cuda")
+ >>> pipe.scheduler = DDIMScheduler(
+ ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
+ ... )
+
+
+ >>> def load_video(file_path: str):
+ ... images = []
+
+ ... if file_path.startswith(("http://", "https://")):
+ ... # If the file_path is a URL
+ ... response = requests.get(file_path)
+ ... response.raise_for_status()
+ ... content = BytesIO(response.content)
+ ... vid = imageio.get_reader(content)
+ ... else:
+ ... # Assuming it's a local file path
+ ... vid = imageio.get_reader(file_path)
+
+ ... for frame in vid:
+ ... pil_image = Image.fromarray(frame)
+ ... images.append(pil_image)
+
+ ... return images
+
+
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
+ ... )
+ >>> output = pipe(
+ ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
+ ... )
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnimateDiffVideoToVideoPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ FreeInitMixin,
+ AnimateDiffFreeNoiseMixin,
+):
+ r"""
+ Pipeline for video-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
+ latents = []
+ for i in range(0, len(video), decode_chunk_size):
+ batch_video = video[i : i + decode_chunk_size]
+ batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
+ latents.append(batch_video)
+ return torch.cat(latents)
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ video = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ batch_latents = latents[i : i + decode_chunk_size]
+ batch_latents = self.vae.decode(batch_latents).sample
+ video.append(batch_latents)
+
+ video = torch.cat(video)
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ # video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ video=None,
+ latents=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ video,
+ height,
+ width,
+ num_channels_latents,
+ batch_size,
+ timestep,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ decode_chunk_size: int = 16,
+ ):
+ if latents is None:
+ num_frames = video.shape[1]
+ else:
+ num_frames = latents.shape[2]
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ video = video.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
+ for i in range(batch_size)
+ ]
+ else:
+ ## torch.Size([1, 16, 3, 512, 512])
+ init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # restore vae to original dtype
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ error_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
+ )
+ raise ValueError(error_message)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
+ else:
+ if shape != latents.shape:
+ # [B, C, F, H, W]
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
+ latents = latents.to(device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ ic_light_pipe=None,
+ relight_prompt=None,
+ bg_source=None,
+ mask=None,
+ vdm_init_latent=None,
+ video: List[List[PipelineImageInput]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ decode_chunk_size: int = 16,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[PipelineImageInput]`):
+ The input video to condition the generation on. Must be a list of images/frames of the video.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ strength (`float`, *optional*, defaults to 0.8):
+ Higher strength leads to more differences between original video and generated video.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ decode_chunk_size (`int`, defaults to `16`):
+ The number of frames to decode at a time when calling `decode_latents` method.
+
+ Examples:
+
+ Returns:
+ [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ strength=strength,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ video=video,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt, ## animatediff outpaint prompt
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ original_timesteps = timesteps
+ org_latent_timestep = original_timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+
+ # 5. Prepare latent variables
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ # Move the number of frames before the number of channels.
+ video = video.permute(0, 2, 1, 3, 4)
+ video = video.to(device=device, dtype=prompt_embeds.dtype)
+
+ video_latent = self.vae.encode(video[0]).latent_dist.mode() * self.vae.config.scaling_factor
+ num_frames = video_latent.shape[0]
+ prompt_embeds_wo_negative = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
+
+ latents = randn_tensor(video_latent.permute(1, 0, 2, 3).unsqueeze(0).shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+ fg_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+
+ ## foreground mask
+ mask_tensor = torch.from_numpy(np.stack(mask, axis=0)).float() / 255
+ mask_tensor = mask_tensor.movedim(-1, 1).to(latents.device, dtype=latents.dtype)
+ mask_latent = F.interpolate(mask_tensor, size=(latents.shape[-2], latents.shape[-1]), mode='bilinear') # torch.Size([16, 3, 64, 64])
+ mask_latent[mask_latent!=1.0] = 0
+ mask_latent = mask_latent[:,:1]
+ mask_latent = mask_latent.repeat(1, 4, 1, 1).to(latents.dtype).permute(1, 0, 2, 3).unsqueeze(0) ## torch.Size([1, 4, 16, 64, 64])
+ mask_latent = mask_latent.to(device)
+
+ ## Init Relighter
+ relighter = Relighter(
+ pipeline=ic_light_pipe,
+ relight_prompt=relight_prompt,
+ bg_source=bg_source,
+ generator=generator,
+ num_frames=num_frames,
+ )
+
+ bg_latents = self.vae.encode(vdm_init_latent).latent_dist.mode() * self.vae.config.scaling_factor
+ bg_latents = self.scheduler.add_noise(bg_latents.permute(1, 0, 2, 3).unsqueeze(0), fg_noise, org_latent_timestep)
+ latents = bg_latents
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ self._num_timesteps = len(original_timesteps)
+ num_warmup_steps = len(original_timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(original_timesteps):
+
+ if t > timesteps[0]: ## outpaint
+
+ fg_latents = self.scheduler.add_noise(video_latent.permute(1, 0, 2, 3).unsqueeze(0), fg_noise, t[None, ...])
+ latents = latents * (1 - mask_latent) + mask_latent * fg_latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet( ## torch.Size([1, 4, 16, 64, 64])
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
+ else:
+ latent_model_input = latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet( ## torch.Size([1, 4, 16, 64, 64])
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_wo_negative,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ ## progressive light fusion
+ lbd = t/ timesteps[0]
+ if lbd>0.15:
+ ## lbd
+ print(f"relight lbd = {lbd}")
+
+ ## get pred_x
+ sigma = self.scheduler.sigmas[self.scheduler.step_index]
+ pred_x0_latent = latents - sigma * noise_pred ## torch.Size([1, 4, 16, 64, 64])
+
+ ## consistent target
+ consist_target = self.decode_latents(pred_x0_latent) ## torch.Size([1, 3, 16, 512, 512])
+ consist_target = rearrange(consist_target, "1 c f h w -> f c h w")
+
+ ## add diff
+ if t == timesteps[0]:
+ org_target = video[0].to(device=consist_target.device, dtype=consist_target.dtype)
+ detail_diff = org_target - consist_target
+ consist_target = consist_target + lbd * (mask_tensor * detail_diff)
+
+ ## relight target
+ relight_target = relighter(consist_target) ## torch.Size([16, 3, 512, 512])
+ fusion_target = (1 - lbd) * consist_target + lbd * relight_target ## torch.Size([16, 3, 512, 512])
+
+ ## fusion_target -> pixel level
+ fusion_latent = self.vae.encode(fusion_target).latent_dist.mode() * self.vae.config.scaling_factor
+ fusion_latent = fusion_latent.to(consist_target.dtype)
+ fusion_latent = rearrange(fusion_latent, "f c h w -> 1 c f h w")
+
+ output = eul_step(self.scheduler, noise_pred, t, latents, fusion_latent, self, **extra_step_kwargs)
+ else:
+ output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
+
+ latents = output[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post-processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
\ No newline at end of file
diff --git a/src/animatediff_pipe.py b/src/animatediff_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..41f31e7d29925d29032a8b339f561b355ebcca2c
--- /dev/null
+++ b/src/animatediff_pipe.py
@@ -0,0 +1,1063 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_motion_model import MotionAdapter
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.pipelines.free_init_utils import FreeInitMixin
+from diffusers.pipelines.free_noise_utils import AnimateDiffFreeNoiseMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+
+from src.ic_light import Relighter
+from einops import rearrange
+from diffusers.utils import export_to_gif
+from src.animatediff_eul import eul_step
+import math
+from utils.tools import vis_video
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import imageio
+ >>> import requests
+ >>> import torch
+ >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
+ >>> from diffusers.utils import export_to_gif
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> adapter = MotionAdapter.from_pretrained(
+ ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
+ ... )
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
+ ... ).to("cuda")
+ >>> pipe.scheduler = DDIMScheduler(
+ ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
+ ... )
+
+
+ >>> def load_video(file_path: str):
+ ... images = []
+
+ ... if file_path.startswith(("http://", "https://")):
+ ... # If the file_path is a URL
+ ... response = requests.get(file_path)
+ ... response.raise_for_status()
+ ... content = BytesIO(response.content)
+ ... vid = imageio.get_reader(content)
+ ... else:
+ ... # Assuming it's a local file path
+ ... vid = imageio.get_reader(file_path)
+
+ ... for frame in vid:
+ ... pil_image = Image.fromarray(frame)
+ ... images.append(pil_image)
+
+ ... return images
+
+
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
+ ... )
+ >>> output = pipe(
+ ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
+ ... )
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnimateDiffVideoToVideoPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ FreeInitMixin,
+ AnimateDiffFreeNoiseMixin,
+):
+ r"""
+ Pipeline for video-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
+ latents = []
+ for i in range(0, len(video), decode_chunk_size):
+ batch_video = video[i : i + decode_chunk_size]
+ batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
+ latents.append(batch_video)
+ return torch.cat(latents)
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ video = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ batch_latents = latents[i : i + decode_chunk_size]
+ batch_latents = self.vae.decode(batch_latents).sample
+ video.append(batch_latents)
+
+ video = torch.cat(video)
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ # video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ video=None,
+ latents=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ video,
+ height,
+ width,
+ num_channels_latents,
+ batch_size,
+ timestep,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ decode_chunk_size: int = 16,
+ ):
+ if latents is None:
+ num_frames = video.shape[1]
+ else:
+ num_frames = latents.shape[2]
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ video = video.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
+ for i in range(batch_size)
+ ]
+ else:
+ ## torch.Size([1, 16, 3, 512, 512])
+ init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # restore vae to original dtype
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ error_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
+ )
+ raise ValueError(error_message)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
+ else:
+ if shape != latents.shape:
+ # [B, C, F, H, W]
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
+ latents = latents.to(device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ ic_light_pipe=None,
+ relight_prompt=None,
+ bg_source=None,
+ video: List[List[PipelineImageInput]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ decode_chunk_size: int = 16,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[PipelineImageInput]`):
+ The input video to condition the generation on. Must be a list of images/frames of the video.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ strength (`float`, *optional*, defaults to 0.8):
+ Higher strength leads to more differences between original video and generated video.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ decode_chunk_size (`int`, defaults to `16`):
+ The number of frames to decode at a time when calling `decode_latents` method.
+
+ Examples:
+
+ Returns:
+ [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ strength=strength,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ video=video,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+
+ # 5. Prepare latent variables
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ # Move the number of frames before the number of channels.
+ video = video.permute(0, 2, 1, 3, 4)
+ video = video.to(device=device, dtype=prompt_embeds.dtype) ## torch.Size([1, 16, 3, 512, 512])
+ org_target = rearrange(video, "1 f c h w -> 1 c f h w")
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents( ## torch.Size([1, 4, 16, 64, 64])
+ video=video,
+ height=height,
+ width=width,
+ num_channels_latents=num_channels_latents,
+ batch_size=batch_size * num_videos_per_prompt,
+ timestep=latent_timestep,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ decode_chunk_size=decode_chunk_size,
+ )
+
+ num_frames = video.shape[1]
+ prompt_embeds = prompt_embeds.repeat(num_frames, 1, 1)
+
+ ## Init Relighter
+ relighter = Relighter(
+ pipeline=ic_light_pipe,
+ relight_prompt=relight_prompt,
+ bg_source=bg_source,
+ generator=generator,
+ num_frames=num_frames,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+ num_inference_steps = len(timesteps)
+ # make sure to readjust timesteps based on strength
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+
+ self._num_timesteps = len(timesteps)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8. Denoising loop
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ lbd = 1 - i/(num_inference_steps-1)
+
+ if lbd>0.15:
+ ## get pred_x
+ sigma = self.scheduler.sigmas[self.scheduler.step_index]
+ pred_x0_latent = latents - sigma * noise_pred
+
+ ## consistent target
+ consist_target = self.decode_latents(pred_x0_latent)
+
+ if i == 0:
+ detail_diff = org_target - consist_target
+
+ consist_target = consist_target + lbd * detail_diff
+ consist_target = rearrange(consist_target, "1 c f h w -> f c h w")
+
+ ## relight target
+ relight_target = relighter(consist_target)
+
+ print(f"relight lbd = {lbd}")
+ fusion_target = (1 - lbd) * consist_target + lbd * relight_target
+
+ ## fusion_target -> pixel level
+ fusion_latent = self.vae.encode(fusion_target).latent_dist.mode() * self.vae.config.scaling_factor
+ fusion_latent = fusion_latent.to(consist_target.dtype)
+ fusion_latent = rearrange(fusion_latent, "f c h w -> 1 c f h w")
+
+ output = eul_step(self.scheduler, noise_pred, t, latents, fusion_latent, self, **extra_step_kwargs)
+
+ else:
+ output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
+
+ latents = output[0]
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post-processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
\ No newline at end of file
diff --git a/src/ic_light.py b/src/ic_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..29695bcae6c6bcb154bb814366d66aa84ca75621
--- /dev/null
+++ b/src/ic_light.py
@@ -0,0 +1,152 @@
+import torch
+import numpy as np
+from enum import Enum
+import math
+
+import torch.nn.functional as F
+from utils.tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video
+
+class BGSource(Enum):
+ NONE = "None"
+ LEFT = "Left Light"
+ RIGHT = "Right Light"
+ TOP = "Top Light"
+ BOTTOM = "Bottom Light"
+
+class Relighter:
+ def __init__(self,
+ pipeline,
+ relight_prompt="",
+ num_frames=16,
+ image_width=512,
+ image_height=512,
+ num_samples=1,
+ steps=15,
+ cfg=2,
+ lowres_denoise=0.9,
+ bg_source=BGSource.RIGHT,
+ generator=None,
+ ):
+
+ self.pipeline = pipeline
+ self.image_width = image_width
+ self.image_height = image_height
+ self.num_samples = num_samples
+ self.steps = steps
+ self.cfg = cfg
+ self.lowres_denoise = lowres_denoise
+ self.bg_source = bg_source
+ self.generator = generator
+ self.device = pipeline.device
+ self.num_frames = num_frames
+ self.vae = self.pipeline.vae
+
+ self.a_prompt = "best quality"
+ self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
+ positive_prompt = relight_prompt + ', ' + self.a_prompt
+ negative_prompt = self.n_prompt
+ tokenizer = self.pipeline.tokenizer
+ device = self.pipeline.device
+ vae = self.vae
+
+ conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt)
+ input_bg = self.create_background()
+ bg = resize_and_center_crop(input_bg, self.image_width, self.image_height)
+ bg_latent = numpy2pytorch([bg], device, vae.dtype)
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
+
+ self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1) ## 固定光源
+ self.conds = conds.repeat(self.num_frames, 1, 1)
+ self.unconds = unconds.repeat(self.num_frames, 1, 1)
+
+ def encode_prompt_inner(self, tokenizer, txt):
+ max_length = tokenizer.model_max_length
+ chunk_length = tokenizer.model_max_length - 2
+ id_start = tokenizer.bos_token_id
+ id_end = tokenizer.eos_token_id
+ id_pad = id_end
+
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
+
+ token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64)
+ conds = self.pipeline.text_encoder(token_ids).last_hidden_state
+ return conds
+
+ def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt):
+ c = self.encode_prompt_inner(tokenizer, positive_prompt)
+ uc = self.encode_prompt_inner(tokenizer, negative_prompt)
+
+ c_len = float(len(c))
+ uc_len = float(len(uc))
+ max_count = max(c_len, uc_len)
+ c_repeat = int(math.ceil(max_count / c_len))
+ uc_repeat = int(math.ceil(max_count / uc_len))
+ max_chunk = max(len(c), len(uc))
+
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
+
+ c = torch.cat([p[None, ...] for p in c], dim=1)
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
+
+ return c.to(device), uc.to(device)
+
+ def create_background(self):
+
+ max_pix = 255
+ min_pix = 0
+
+ print(f"max light pix:{max_pix}, min light pix:{min_pix}")
+
+ if self.bg_source == BGSource.NONE:
+ return None
+ elif self.bg_source == BGSource.LEFT:
+ gradient = np.linspace(max_pix, min_pix, self.image_width)
+ image = np.tile(gradient, (self.image_height, 1))
+ return np.stack((image,) * 3, axis=-1).astype(np.uint8)
+ elif self.bg_source == BGSource.RIGHT:
+ gradient = np.linspace(min_pix, max_pix, self.image_width)
+ image = np.tile(gradient, (self.image_height, 1))
+ return np.stack((image,) * 3, axis=-1).astype(np.uint8)
+ elif self.bg_source == BGSource.TOP:
+ gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None]
+ image = np.tile(gradient, (1, self.image_width))
+ return np.stack((image,) * 3, axis=-1).astype(np.uint8)
+ elif self.bg_source == BGSource.BOTTOM:
+ gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None]
+ image = np.tile(gradient, (1, self.image_width))
+ return np.stack((image,) * 3, axis=-1).astype(np.uint8)
+ else:
+ raise ValueError('Wrong initial latent!')
+
+ @torch.no_grad()
+ def __call__(self, input_video, init_latent=None, input_strength=None):
+ input_latent = encode_video(self.vae, input_video)* self.vae.config.scaling_factor
+
+ if input_strength:
+ light_strength = input_strength
+ else:
+ light_strength = self.lowres_denoise
+
+ if not init_latent:
+ init_latent = self.bg_latent
+
+ latents = self.pipeline(
+ image=init_latent,
+ strength=light_strength,
+ prompt_embeds=self.conds,
+ negative_prompt_embeds=self.unconds,
+ width=self.image_width,
+ height=self.image_height,
+ num_inference_steps=int(round(self.steps / self.lowres_denoise)),
+ num_images_per_prompt=self.num_samples,
+ generator=self.generator,
+ output_type='latent',
+ guidance_scale=self.cfg,
+ cross_attention_kwargs={'concat_conds': input_latent},
+ ).images.to(self.pipeline.vae.dtype)
+
+ relight_video = decode_latents(self.vae, latents)
+ return relight_video
\ No newline at end of file
diff --git a/src/ic_light_pipe.py b/src/ic_light_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..4530bcf6de8905cfdb9a4a6ec69a10a013e54183
--- /dev/null
+++ b/src/ic_light_pipe.py
@@ -0,0 +1,1122 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import FusedAttnProcessor2_0
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionImg2ImgPipeline
+
+ >>> device = "cuda"
+ >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
+ >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ >>> response = requests.get(url)
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_image = init_image.resize((768, 512))
+
+ >>> prompt = "A fantasy landscape, trending on artstation"
+
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+ >>> images[0].save("fantasy_landscape.png")
+ ```
+"""
+
+
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def preprocess(image):
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionImg2ImgPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+
+ ## 重写了初始化图构造方式, 让每一帧的加噪均相同
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+
+ # original add noise
+ # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ ## add same noise
+ frame_shape = init_latents[:1].shape
+ noise = randn_tensor(frame_shape, device=device, dtype=dtype)
+ noise = noise.repeat(shape[0],1,1,1)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ timesteps: List[int] = None,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: int = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. set timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 7.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input, ## torch.Size([98, 4, 64, 64]) torch.Size([32, 4, 64, 64])
+ t,
+ encoder_hidden_states=prompt_embeds, ## torch.Size([98, 77, 768])
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,# torch.Size([49, 4, 64, 64])
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/utils/__pycache__/tools.cpython-310.pyc b/utils/__pycache__/tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bef5811f513be37ddde5f66369893733bf5fe37
Binary files /dev/null and b/utils/__pycache__/tools.cpython-310.pyc differ
diff --git a/utils/tools.py b/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1e0b46622ba1ab2ee61fcd68c67fac2f4701a8
--- /dev/null
+++ b/utils/tools.py
@@ -0,0 +1,148 @@
+from PIL import Image,ImageSequence
+import numpy as np
+import torch
+from moviepy.editor import VideoFileClip
+import os
+import imageio
+import random
+from diffusers.utils import export_to_video
+
+def resize_and_center_crop(image, target_width, target_height):
+ pil_image = Image.fromarray(image)
+ original_width, original_height = pil_image.size
+ scale_factor = max(target_width / original_width, target_height / original_height)
+ resized_width = int(round(original_width * scale_factor))
+ resized_height = int(round(original_height * scale_factor))
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
+
+ left = (resized_width - target_width) / 2
+ top = (resized_height - target_height) / 2
+ right = (resized_width + target_width) / 2
+ bottom = (resized_height + target_height) / 2
+ cropped_image = resized_image.crop((left, top, right, bottom))
+ return np.array(cropped_image)
+
+def numpy2pytorch(imgs, device, dtype):
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
+ h = h.movedim(-1, 1)
+ return h.to(device=device, dtype=dtype)
+
+def get_fg_video(video_list, mask_list, device, dtype):
+ video_np = np.stack(video_list, axis=0)
+ mask_np = np.stack(mask_list, axis=0)
+ mask_bool = mask_np == 255
+ video_fg = np.where(mask_bool, video_np, 127)
+
+ h = torch.from_numpy(video_fg).float() / 127.0 - 1.0
+ h = h.movedim(-1, 1)
+ return h.to(device=device, dtype=dtype)
+
+
+def pad(x, p, i):
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+def gif_to_mp4(gif_path, mp4_path):
+ clip = VideoFileClip(gif_path)
+ clip.write_videofile(mp4_path)
+
+def generate_light_sequence(light_tensor, num_frames=16, direction="r"):
+
+ if direction in "l":
+ target_tensor = torch.rot90(light_tensor, k=1, dims=(2, 3))
+ elif direction in "r":
+ target_tensor = torch.rot90(light_tensor, k=-1, dims=(2, 3))
+ else:
+ raise ValueError("direction must be either 'r' for right or 'l' for left")
+
+ # Generate the sequence
+ out_list = []
+ for frame_idx in range(num_frames):
+ t = frame_idx / (num_frames - 1)
+ interpolated_matrix = (1 - t) * light_tensor + t * target_tensor
+ out_list.append(interpolated_matrix)
+
+ out_tensor = torch.stack(out_list, dim=0).squeeze(1)
+
+ return out_tensor
+
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+
+ batch_size, channels, num_frames, height, width = video.shape ## [1, 4, 16, 512, 512]
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ return outputs
+
+def read_video(video_path:str, image_width, image_height):
+ extension = video_path.split('.')[-1].lower()
+ video_name = os.path.basename(video_path)
+ video_list = []
+
+ if extension in "gif":
+ ## input from gif
+ video = Image.open(video_path)
+ for i, frame in enumerate(ImageSequence.Iterator(video)):
+ frame = np.array(frame.convert("RGB"))
+ frame = resize_and_center_crop(frame, image_width, image_height)
+ video_list.append(frame)
+ elif extension in "mp4":
+ ## input from mp4
+ reader = imageio.get_reader(video_path)
+ for frame in reader:
+ frame = resize_and_center_crop(frame, image_width, image_height)
+ video_list.append(frame)
+ else:
+ raise ValueError('Wrong input type')
+
+ video_list = [Image.fromarray(frame) for frame in video_list]
+
+ return video_list, video_name
+
+def read_mask(mask_folder:str):
+ mask_files = os.listdir(mask_folder)
+ mask_files = sorted(mask_files)
+ mask_list = []
+ for mask_file in mask_files:
+ mask_path = os.path.join(mask_folder, mask_file)
+ mask = Image.open(mask_path).convert('RGB')
+ mask_list.append(mask)
+
+ return mask_list
+
+def decode_latents(vae, latents, decode_chunk_size: int = 16):
+
+ latents = 1 / vae.config.scaling_factor * latents
+ video = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ batch_latents = latents[i : i + decode_chunk_size]
+ batch_latents = vae.decode(batch_latents).sample
+ video.append(batch_latents)
+
+ video = torch.cat(video)
+
+ return video
+
+def encode_video(vae, video, decode_chunk_size: int = 16) -> torch.Tensor:
+ latents = []
+ for i in range(0, len(video), decode_chunk_size):
+ batch_video = video[i : i + decode_chunk_size]
+ batch_video = vae.encode(batch_video).latent_dist.mode()
+ latents.append(batch_video)
+ return torch.cat(latents)
+
+def vis_video(input_video, video_processor, save_path):
+ ## shape: 1, c, f, h, w
+ relight_video = video_processor.postprocess_video(video=input_video, output_type="pil")
+ export_to_video(relight_video[0], save_path)
+
+def set_all_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
\ No newline at end of file