Fabrice-TIERCELIN commited on
Commit
c812274
·
verified ·
1 Parent(s): 7d90cc3
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/examples/README.md +0 -72
  2. diffusers/examples/community/README.md +0 -1132
  3. diffusers/examples/community/bit_diffusion.py +0 -264
  4. diffusers/examples/community/checkpoint_merger.py +0 -286
  5. diffusers/examples/community/clip_guided_stable_diffusion.py +0 -347
  6. diffusers/examples/community/clip_guided_stable_diffusion_img2img.py +0 -496
  7. diffusers/examples/community/composable_stable_diffusion.py +0 -582
  8. diffusers/examples/community/ddim_noise_comparative_analysis.py +0 -190
  9. diffusers/examples/community/imagic_stable_diffusion.py +0 -496
  10. diffusers/examples/community/img2img_inpainting.py +0 -463
  11. diffusers/examples/community/interpolate_stable_diffusion.py +0 -524
  12. diffusers/examples/community/lpw_stable_diffusion.py +0 -1153
  13. diffusers/examples/community/lpw_stable_diffusion_onnx.py +0 -1146
  14. diffusers/examples/community/magic_mix.py +0 -152
  15. diffusers/examples/community/multilingual_stable_diffusion.py +0 -436
  16. diffusers/examples/community/one_step_unet.py +0 -24
  17. diffusers/examples/community/sd_text2img_k_diffusion.py +0 -475
  18. diffusers/examples/community/seed_resize_stable_diffusion.py +0 -366
  19. diffusers/examples/community/speech_to_image_diffusion.py +0 -261
  20. diffusers/examples/community/stable_diffusion_comparison.py +0 -405
  21. diffusers/examples/community/stable_diffusion_controlnet_img2img.py +0 -989
  22. diffusers/examples/community/stable_diffusion_controlnet_inpaint.py +0 -1076
  23. diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +0 -1119
  24. diffusers/examples/community/stable_diffusion_mega.py +0 -227
  25. diffusers/examples/community/stable_unclip.py +0 -287
  26. diffusers/examples/community/text_inpainting.py +0 -302
  27. diffusers/examples/community/tiled_upscaling.py +0 -298
  28. diffusers/examples/community/unclip_image_interpolation.py +0 -493
  29. diffusers/examples/community/unclip_text_interpolation.py +0 -573
  30. diffusers/examples/community/wildcard_stable_diffusion.py +0 -418
  31. diffusers/examples/conftest.py +0 -45
  32. diffusers/examples/controlnet/README.md +0 -392
  33. diffusers/examples/controlnet/requirements.txt +0 -6
  34. diffusers/examples/controlnet/requirements_flax.txt +0 -9
  35. diffusers/examples/controlnet/train_controlnet.py +0 -1046
  36. diffusers/examples/controlnet/train_controlnet_flax.py +0 -1015
  37. diffusers/examples/dreambooth/README.md +0 -464
  38. diffusers/examples/dreambooth/requirements.txt +0 -6
  39. diffusers/examples/dreambooth/requirements_flax.txt +0 -8
  40. diffusers/examples/dreambooth/train_dreambooth.py +0 -1039
  41. diffusers/examples/dreambooth/train_dreambooth_flax.py +0 -709
  42. diffusers/examples/dreambooth/train_dreambooth_lora.py +0 -1028
  43. diffusers/examples/inference/README.md +0 -8
  44. diffusers/examples/inference/image_to_image.py +0 -9
  45. diffusers/examples/inference/inpainting.py +0 -9
  46. diffusers/examples/instruct_pix2pix/README.md +0 -166
  47. diffusers/examples/instruct_pix2pix/requirements.txt +0 -6
  48. diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py +0 -988
  49. diffusers/examples/research_projects/README.md +0 -14
  50. diffusers/examples/research_projects/colossalai/README.md +0 -111
diffusers/examples/README.md DELETED
@@ -1,72 +0,0 @@
1
- <!---
2
- Copyright 2023 The HuggingFace Team. All rights reserved.
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- -->
15
-
16
- # 🧨 Diffusers Examples
17
-
18
- Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
19
- for a variety of use cases involving training or fine-tuning.
20
-
21
- **Note**: If you are looking for **official** examples on how to use `diffusers` for inference,
22
- please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)
23
-
24
- Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
25
- More specifically, this means:
26
-
27
- - **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
28
- - **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
29
- - **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
30
- - **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling
31
- point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
32
-
33
- We provide **official** examples that cover the most popular tasks of diffusion models.
34
- *Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
35
- If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
36
-
37
- Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
38
-
39
- | Task | 🤗 Accelerate | 🤗 Datasets | Colab
40
- |---|---|:---:|:---:|
41
- | [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
42
- | [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
43
- | [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
44
- | [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
45
- | [**ControlNet**](./controlnet) | ✅ | ✅ | -
46
- | [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | -
47
- | [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
48
-
49
- ## Community
50
-
51
- In addition, we provide **community** examples, which are examples added and maintained by our community.
52
- Community examples can consist of both *training* examples or *inference* pipelines.
53
- For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
54
- Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
55
- **Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
56
-
57
- ## Research Projects
58
-
59
- We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
60
-
61
- ## Important note
62
-
63
- To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
64
- ```bash
65
- git clone https://github.com/huggingface/diffusers
66
- cd diffusers
67
- pip install .
68
- ```
69
- Then cd in the example folder of your choice and run
70
- ```bash
71
- pip install -r requirements.txt
72
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/README.md DELETED
@@ -1,1132 +0,0 @@
1
- # Community Examples
2
-
3
- > **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
4
-
5
- **Community** examples consist of both inference and training examples that have been added by the community.
6
- Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
7
- If a community doesn't work as expected, please open an issue and ping the author on it.
8
-
9
- | Example | Description | Code Example | Colab | Author |
10
- |:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
11
- | CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
12
- | One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
13
- | Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
14
- | Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
15
- | Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
16
- | Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
17
- | Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
18
- | [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
19
- | Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
20
- | Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
21
- | Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
22
- | Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
23
- | Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
24
- | Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
25
- | K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
26
- | Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
27
- Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
28
- MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
29
- | Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - |[Ray Wang](https://wrong.wang) |
30
- | UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
31
- | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
32
- | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) |
33
- | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
34
-
35
-
36
-
37
- To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
38
- ```py
39
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder")
40
- ```
41
-
42
- ## Example usages
43
-
44
- ### CLIP Guided Stable Diffusion
45
-
46
- CLIP guided stable diffusion can help to generate more realistic images
47
- by guiding stable diffusion at every denoising step with an additional CLIP model.
48
-
49
- The following code requires roughly 12GB of GPU RAM.
50
-
51
- ```python
52
- from diffusers import DiffusionPipeline
53
- from transformers import CLIPImageProcessor, CLIPModel
54
- import torch
55
-
56
-
57
- feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
58
- clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
59
-
60
-
61
- guided_pipeline = DiffusionPipeline.from_pretrained(
62
- "runwayml/stable-diffusion-v1-5",
63
- custom_pipeline="clip_guided_stable_diffusion",
64
- clip_model=clip_model,
65
- feature_extractor=feature_extractor,
66
-
67
- torch_dtype=torch.float16,
68
- )
69
- guided_pipeline.enable_attention_slicing()
70
- guided_pipeline = guided_pipeline.to("cuda")
71
-
72
- prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
73
-
74
- generator = torch.Generator(device="cuda").manual_seed(0)
75
- images = []
76
- for i in range(4):
77
- image = guided_pipeline(
78
- prompt,
79
- num_inference_steps=50,
80
- guidance_scale=7.5,
81
- clip_guidance_scale=100,
82
- num_cutouts=4,
83
- use_cutouts=False,
84
- generator=generator,
85
- ).images[0]
86
- images.append(image)
87
-
88
- # save images locally
89
- for i, img in enumerate(images):
90
- img.save(f"./clip_guided_sd/image_{i}.png")
91
- ```
92
-
93
- The `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab.
94
- Generated images tend to be of higher qualtiy than natively using stable diffusion. E.g. the above script generates the following images:
95
-
96
- ![clip_guidance](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/clip_guidance/merged_clip_guidance.jpg).
97
-
98
- ### One Step Unet
99
-
100
- The dummy "one-step-unet" can be run as follows:
101
-
102
- ```python
103
- from diffusers import DiffusionPipeline
104
-
105
- pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
106
- pipe()
107
- ```
108
-
109
- **Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see https://github.com/huggingface/diffusers/issues/841).
110
-
111
- ### Stable Diffusion Interpolation
112
-
113
- The following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.
114
-
115
- ```python
116
- from diffusers import DiffusionPipeline
117
- import torch
118
-
119
- pipe = DiffusionPipeline.from_pretrained(
120
- "CompVis/stable-diffusion-v1-4",
121
- revision='fp16',
122
- torch_dtype=torch.float16,
123
- safety_checker=None, # Very important for videos...lots of false positives while interpolating
124
- custom_pipeline="interpolate_stable_diffusion",
125
- ).to('cuda')
126
- pipe.enable_attention_slicing()
127
-
128
- frame_filepaths = pipe.walk(
129
- prompts=['a dog', 'a cat', 'a horse'],
130
- seeds=[42, 1337, 1234],
131
- num_interpolation_steps=16,
132
- output_dir='./dreams',
133
- batch_size=4,
134
- height=512,
135
- width=512,
136
- guidance_scale=8.5,
137
- num_inference_steps=50,
138
- )
139
- ```
140
-
141
- The output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion.
142
-
143
- > **Please have a look at https://github.com/nateraw/stable-diffusion-videos for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.**
144
-
145
- ### Stable Diffusion Mega
146
-
147
- The Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class.
148
-
149
- ```python
150
- #!/usr/bin/env python3
151
- from diffusers import DiffusionPipeline
152
- import PIL
153
- import requests
154
- from io import BytesIO
155
- import torch
156
-
157
-
158
- def download_image(url):
159
- response = requests.get(url)
160
- return PIL.Image.open(BytesIO(response.content)).convert("RGB")
161
-
162
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16")
163
- pipe.to("cuda")
164
- pipe.enable_attention_slicing()
165
-
166
-
167
- ### Text-to-Image
168
-
169
- images = pipe.text2img("An astronaut riding a horse").images
170
-
171
- ### Image-to-Image
172
-
173
- init_image = download_image("https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg")
174
-
175
- prompt = "A fantasy landscape, trending on artstation"
176
-
177
- images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
178
-
179
- ### Inpainting
180
-
181
- img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
182
- mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
183
- init_image = download_image(img_url).resize((512, 512))
184
- mask_image = download_image(mask_url).resize((512, 512))
185
-
186
- prompt = "a cat sitting on a bench"
187
- images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
188
- ```
189
-
190
- As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
191
-
192
- ### Long Prompt Weighting Stable Diffusion
193
- Features of this custom pipeline:
194
- - Input a prompt without the 77 token length limit.
195
- - Includes tx2img, img2img. and inpainting pipelines.
196
- - Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)`
197
- - De-emphasize part of your prompt as so: `a [baby] deer with big eyes`
198
- - Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)`
199
-
200
- Prompt weighting equivalents:
201
- - `a baby deer with` == `(a baby deer with:1.0)`
202
- - `(big eyes)` == `(big eyes:1.1)`
203
- - `((big eyes))` == `(big eyes:1.21)`
204
- - `[big eyes]` == `(big eyes:0.91)`
205
-
206
- You can run this custom pipeline as so:
207
-
208
- #### pytorch
209
-
210
- ```python
211
- from diffusers import DiffusionPipeline
212
- import torch
213
-
214
- pipe = DiffusionPipeline.from_pretrained(
215
- 'hakurei/waifu-diffusion',
216
- custom_pipeline="lpw_stable_diffusion",
217
-
218
- torch_dtype=torch.float16
219
- )
220
- pipe=pipe.to("cuda")
221
-
222
- prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
223
- neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
224
-
225
- pipe.text2img(prompt, negative_prompt=neg_prompt, width=512,height=512,max_embeddings_multiples=3).images[0]
226
-
227
- ```
228
-
229
- #### onnxruntime
230
-
231
- ```python
232
- from diffusers import DiffusionPipeline
233
- import torch
234
-
235
- pipe = DiffusionPipeline.from_pretrained(
236
- 'CompVis/stable-diffusion-v1-4',
237
- custom_pipeline="lpw_stable_diffusion_onnx",
238
- revision="onnx",
239
- provider="CUDAExecutionProvider"
240
- )
241
-
242
- prompt = "a photo of an astronaut riding a horse on mars, best quality"
243
- neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
244
-
245
- pipe.text2img(prompt,negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
246
-
247
- ```
248
-
249
- if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.
250
-
251
- ### Speech to Image
252
-
253
- The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.
254
-
255
- ```Python
256
- import torch
257
-
258
- import matplotlib.pyplot as plt
259
- from datasets import load_dataset
260
- from diffusers import DiffusionPipeline
261
- from transformers import (
262
- WhisperForConditionalGeneration,
263
- WhisperProcessor,
264
- )
265
-
266
-
267
- device = "cuda" if torch.cuda.is_available() else "cpu"
268
-
269
- ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
270
-
271
- audio_sample = ds[3]
272
-
273
- text = audio_sample["text"].lower()
274
- speech_data = audio_sample["audio"]["array"]
275
-
276
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
277
- processor = WhisperProcessor.from_pretrained("openai/whisper-small")
278
-
279
- diffuser_pipeline = DiffusionPipeline.from_pretrained(
280
- "CompVis/stable-diffusion-v1-4",
281
- custom_pipeline="speech_to_image_diffusion",
282
- speech_model=model,
283
- speech_processor=processor,
284
-
285
- torch_dtype=torch.float16,
286
- )
287
-
288
- diffuser_pipeline.enable_attention_slicing()
289
- diffuser_pipeline = diffuser_pipeline.to(device)
290
-
291
- output = diffuser_pipeline(speech_data)
292
- plt.imshow(output.images[0])
293
- ```
294
- This example produces the following image:
295
-
296
- ![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png)
297
-
298
- ### Wildcard Stable Diffusion
299
- Following the great examples from https://github.com/jtkelm2/stable-diffusion-webui-1/blob/master/scripts/wildcards.py and https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts#wildcards, here's a minimal implementation that allows for users to add "wildcards", denoted by `__wildcard__` to prompts that are used as placeholders for randomly sampled values given by either a dictionary or a `.txt` file. For example:
300
-
301
- Say we have a prompt:
302
-
303
- ```
304
- prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
305
- ```
306
-
307
- We can then define possible values to be sampled for `animal`, `object`, and `clothing`. These can either be from a `.txt` with the same name as the category.
308
-
309
- The possible values can also be defined / combined by using a dictionary like: `{"animal":["dog", "cat", mouse"]}`.
310
-
311
- The actual pipeline works just like `StableDiffusionPipeline`, except the `__call__` method takes in:
312
-
313
- `wildcard_files`: list of file paths for wild card replacement
314
- `wildcard_option_dict`: dict with key as `wildcard` and values as a list of possible replacements
315
- `num_prompt_samples`: number of prompts to sample, uniformly sampling wildcards
316
-
317
- A full example:
318
-
319
- create `animal.txt`, with contents like:
320
-
321
- ```
322
- dog
323
- cat
324
- mouse
325
- ```
326
-
327
- create `object.txt`, with contents like:
328
-
329
- ```
330
- chair
331
- sofa
332
- bench
333
- ```
334
-
335
- ```python
336
- from diffusers import DiffusionPipeline
337
- import torch
338
-
339
- pipe = DiffusionPipeline.from_pretrained(
340
- "CompVis/stable-diffusion-v1-4",
341
- custom_pipeline="wildcard_stable_diffusion",
342
-
343
- torch_dtype=torch.float16,
344
- )
345
- prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
346
- out = pipe(
347
- prompt,
348
- wildcard_option_dict={
349
- "clothing":["hat", "shirt", "scarf", "beret"]
350
- },
351
- wildcard_files=["object.txt", "animal.txt"],
352
- num_prompt_samples=1
353
- )
354
- ```
355
-
356
- ### Composable Stable diffusion
357
-
358
- [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
359
-
360
- ```python
361
- import torch as th
362
- import numpy as np
363
- import torchvision.utils as tvu
364
-
365
- from diffusers import DiffusionPipeline
366
-
367
- import argparse
368
-
369
- parser = argparse.ArgumentParser()
370
- parser.add_argument("--prompt", type=str, default="mystical trees | A magical pond | dark",
371
- help="use '|' as the delimiter to compose separate sentences.")
372
- parser.add_argument("--steps", type=int, default=50)
373
- parser.add_argument("--scale", type=float, default=7.5)
374
- parser.add_argument("--weights", type=str, default="7.5 | 7.5 | -7.5")
375
- parser.add_argument("--seed", type=int, default=2)
376
- parser.add_argument("--model_path", type=str, default="CompVis/stable-diffusion-v1-4")
377
- parser.add_argument("--num_images", type=int, default=1)
378
- args = parser.parse_args()
379
-
380
- has_cuda = th.cuda.is_available()
381
- device = th.device('cpu' if not has_cuda else 'cuda')
382
-
383
- prompt = args.prompt
384
- scale = args.scale
385
- steps = args.steps
386
-
387
- pipe = DiffusionPipeline.from_pretrained(
388
- args.model_path,
389
- custom_pipeline="composable_stable_diffusion",
390
- ).to(device)
391
-
392
- pipe.safety_checker = None
393
-
394
- images = []
395
- generator = th.Generator("cuda").manual_seed(args.seed)
396
- for i in range(args.num_images):
397
- image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps,
398
- weights=args.weights, generator=generator).images[0]
399
- images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
400
- grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
401
- tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
402
-
403
- ```
404
-
405
- ### Imagic Stable Diffusion
406
- Allows you to edit an image using stable diffusion.
407
-
408
- ```python
409
- import requests
410
- from PIL import Image
411
- from io import BytesIO
412
- import torch
413
- import os
414
- from diffusers import DiffusionPipeline, DDIMScheduler
415
- has_cuda = torch.cuda.is_available()
416
- device = torch.device('cpu' if not has_cuda else 'cuda')
417
- pipe = DiffusionPipeline.from_pretrained(
418
- "CompVis/stable-diffusion-v1-4",
419
- safety_checker=None,
420
- use_auth_token=True,
421
- custom_pipeline="imagic_stable_diffusion",
422
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
423
- ).to(device)
424
- generator = torch.Generator("cuda").manual_seed(0)
425
- seed = 0
426
- prompt = "A photo of Barack Obama smiling with a big grin"
427
- url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
428
- response = requests.get(url)
429
- init_image = Image.open(BytesIO(response.content)).convert("RGB")
430
- init_image = init_image.resize((512, 512))
431
- res = pipe.train(
432
- prompt,
433
- image=init_image,
434
- generator=generator)
435
- res = pipe(alpha=1, guidance_scale=7.5, num_inference_steps=50)
436
- os.makedirs("imagic", exist_ok=True)
437
- image = res.images[0]
438
- image.save('./imagic/imagic_image_alpha_1.png')
439
- res = pipe(alpha=1.5, guidance_scale=7.5, num_inference_steps=50)
440
- image = res.images[0]
441
- image.save('./imagic/imagic_image_alpha_1_5.png')
442
- res = pipe(alpha=2, guidance_scale=7.5, num_inference_steps=50)
443
- image = res.images[0]
444
- image.save('./imagic/imagic_image_alpha_2.png')
445
- ```
446
-
447
- ### Seed Resizing
448
- Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
449
-
450
- ```python
451
- import torch as th
452
- import numpy as np
453
- from diffusers import DiffusionPipeline
454
-
455
- has_cuda = th.cuda.is_available()
456
- device = th.device('cpu' if not has_cuda else 'cuda')
457
-
458
- pipe = DiffusionPipeline.from_pretrained(
459
- "CompVis/stable-diffusion-v1-4",
460
- use_auth_token=True,
461
- custom_pipeline="seed_resize_stable_diffusion"
462
- ).to(device)
463
-
464
- def dummy(images, **kwargs):
465
- return images, False
466
-
467
- pipe.safety_checker = dummy
468
-
469
-
470
- images = []
471
- th.manual_seed(0)
472
- generator = th.Generator("cuda").manual_seed(0)
473
-
474
- seed = 0
475
- prompt = "A painting of a futuristic cop"
476
-
477
- width = 512
478
- height = 512
479
-
480
- res = pipe(
481
- prompt,
482
- guidance_scale=7.5,
483
- num_inference_steps=50,
484
- height=height,
485
- width=width,
486
- generator=generator)
487
- image = res.images[0]
488
- image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
489
-
490
-
491
- th.manual_seed(0)
492
- generator = th.Generator("cuda").manual_seed(0)
493
-
494
- pipe = DiffusionPipeline.from_pretrained(
495
- "CompVis/stable-diffusion-v1-4",
496
- use_auth_token=True,
497
- custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
498
- ).to(device)
499
-
500
- width = 512
501
- height = 592
502
-
503
- res = pipe(
504
- prompt,
505
- guidance_scale=7.5,
506
- num_inference_steps=50,
507
- height=height,
508
- width=width,
509
- generator=generator)
510
- image = res.images[0]
511
- image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
512
-
513
- pipe_compare = DiffusionPipeline.from_pretrained(
514
- "CompVis/stable-diffusion-v1-4",
515
- use_auth_token=True,
516
- custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
517
- ).to(device)
518
-
519
- res = pipe_compare(
520
- prompt,
521
- guidance_scale=7.5,
522
- num_inference_steps=50,
523
- height=height,
524
- width=width,
525
- generator=generator
526
- )
527
-
528
- image = res.images[0]
529
- image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
530
- ```
531
-
532
- ### Multilingual Stable Diffusion Pipeline
533
-
534
- The following code can generate an images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion.
535
-
536
- ```python
537
- from PIL import Image
538
-
539
- import torch
540
-
541
- from diffusers import DiffusionPipeline
542
- from transformers import (
543
- pipeline,
544
- MBart50TokenizerFast,
545
- MBartForConditionalGeneration,
546
- )
547
- device = "cuda" if torch.cuda.is_available() else "cpu"
548
- device_dict = {"cuda": 0, "cpu": -1}
549
-
550
- # helper function taken from: https://huggingface.co/blog/stable_diffusion
551
- def image_grid(imgs, rows, cols):
552
- assert len(imgs) == rows*cols
553
-
554
- w, h = imgs[0].size
555
- grid = Image.new('RGB', size=(cols*w, rows*h))
556
- grid_w, grid_h = grid.size
557
-
558
- for i, img in enumerate(imgs):
559
- grid.paste(img, box=(i%cols*w, i//cols*h))
560
- return grid
561
-
562
- # Add language detection pipeline
563
- language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
564
- language_detection_pipeline = pipeline("text-classification",
565
- model=language_detection_model_ckpt,
566
- device=device_dict[device])
567
-
568
- # Add model for language translation
569
- trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
570
- trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
571
-
572
- diffuser_pipeline = DiffusionPipeline.from_pretrained(
573
- "CompVis/stable-diffusion-v1-4",
574
- custom_pipeline="multilingual_stable_diffusion",
575
- detection_pipeline=language_detection_pipeline,
576
- translation_model=trans_model,
577
- translation_tokenizer=trans_tokenizer,
578
-
579
- torch_dtype=torch.float16,
580
- )
581
-
582
- diffuser_pipeline.enable_attention_slicing()
583
- diffuser_pipeline = diffuser_pipeline.to(device)
584
-
585
- prompt = ["a photograph of an astronaut riding a horse",
586
- "Una casa en la playa",
587
- "Ein Hund, der Orange isst",
588
- "Un restaurant parisien"]
589
-
590
- output = diffuser_pipeline(prompt)
591
-
592
- images = output.images
593
-
594
- grid = image_grid(images, rows=2, cols=2)
595
- ```
596
-
597
- This example produces the following images:
598
- ![image](https://user-images.githubusercontent.com/4313860/198328706-295824a4-9856-4ce5-8e66-278ceb42fd29.png)
599
-
600
- ### Image to Image Inpainting Stable Diffusion
601
-
602
- Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.
603
-
604
- `image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel.
605
-
606
- The aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless.
607
- For example, this could be used to place a logo on a shirt and make it blend seamlessly.
608
-
609
- ```python
610
- import PIL
611
- import torch
612
-
613
- from diffusers import DiffusionPipeline
614
-
615
- image_path = "./path-to-image.png"
616
- inner_image_path = "./path-to-inner-image.png"
617
- mask_path = "./path-to-mask.png"
618
-
619
- init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
620
- inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
621
- mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
622
-
623
- pipe = DiffusionPipeline.from_pretrained(
624
- "runwayml/stable-diffusion-inpainting",
625
- custom_pipeline="img2img_inpainting",
626
-
627
- torch_dtype=torch.float16
628
- )
629
- pipe = pipe.to("cuda")
630
-
631
- prompt = "Your prompt here!"
632
- image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
633
- ```
634
-
635
- ![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
636
-
637
- ### Text Based Inpainting Stable Diffusion
638
-
639
- Use a text prompt to generate the mask for the area to be inpainted.
640
- Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting.
641
-
642
- ```python
643
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
644
- from diffusers import DiffusionPipeline
645
-
646
- from PIL import Image
647
- import requests
648
-
649
- processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
650
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
651
-
652
- pipe = DiffusionPipeline.from_pretrained(
653
- "runwayml/stable-diffusion-inpainting",
654
- custom_pipeline="text_inpainting",
655
- segmentation_model=model,
656
- segmentation_processor=processor
657
- )
658
- pipe = pipe.to("cuda")
659
-
660
-
661
- url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true"
662
- image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
663
- text = "a glass" # will mask out this text
664
- prompt = "a cup" # the masked out region will be replaced with this
665
-
666
- image = pipe(image=image, text=text, prompt=prompt).images[0]
667
- ```
668
-
669
- ### Bit Diffusion
670
- Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this:
671
-
672
- ```python
673
- from diffusers import DiffusionPipeline
674
- pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
675
- image = pipe().images[0]
676
-
677
- ```
678
-
679
- ### Stable Diffusion with K Diffusion
680
-
681
- Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed:
682
-
683
- ```
684
- pip install k-diffusion
685
- ```
686
-
687
- You can use the community pipeline as follows:
688
-
689
- ```python
690
- from diffusers import DiffusionPipeline
691
-
692
- pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
693
- pipe = pipe.to("cuda")
694
-
695
- prompt = "an astronaut riding a horse on mars"
696
- pipe.set_scheduler("sample_heun")
697
- generator = torch.Generator(device="cuda").manual_seed(seed)
698
- image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
699
-
700
- image.save("./astronaut_heun_k_diffusion.png")
701
- ```
702
-
703
- To make sure that K Diffusion and `diffusers` yield the same results:
704
-
705
- **Diffusers**:
706
- ```python
707
- from diffusers import DiffusionPipeline, EulerDiscreteScheduler
708
-
709
- seed = 33
710
-
711
- pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
712
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
713
- pipe = pipe.to("cuda")
714
-
715
- generator = torch.Generator(device="cuda").manual_seed(seed)
716
- image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
717
- ```
718
-
719
- ![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png)
720
-
721
- **K Diffusion**:
722
- ```python
723
- from diffusers import DiffusionPipeline, EulerDiscreteScheduler
724
-
725
- seed = 33
726
-
727
- pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
728
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
729
- pipe = pipe.to("cuda")
730
-
731
- pipe.set_scheduler("sample_euler")
732
- generator = torch.Generator(device="cuda").manual_seed(seed)
733
- image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
734
- ```
735
-
736
- ![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)
737
-
738
- ### Checkpoint Merger Pipeline
739
- Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
740
-
741
- The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and
742
- on colab you might run out of the 12GB memory even while merging two checkpoints.
743
-
744
- Usage:-
745
- ```python
746
- from diffusers import DiffusionPipeline
747
-
748
- #Return a CheckpointMergerPipeline class that allows you to merge checkpoints.
749
- #The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to
750
- #merge for convenience
751
- pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger")
752
-
753
- #There are multiple possible scenarios:
754
- #The pipeline with the merged checkpoints is returned in all the scenarios
755
-
756
- #Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparision.( attrs with _ as prefix )
757
- merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
758
-
759
- #Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
760
- merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4)
761
-
762
- #Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
763
- merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4)
764
-
765
- prompt = "An astronaut riding a horse on Mars"
766
-
767
- image = merged_pipe(prompt).images[0]
768
-
769
- ```
770
- Some examples along with the merge details:
771
-
772
- 1. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" ; Sigmoid interpolation; alpha = 0.8
773
-
774
- ![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stability_v1_4_waifu_sig_0.8.png)
775
-
776
- 2. "hakurei/waifu-diffusion" + "prompthero/openjourney" ; Inverse Sigmoid interpolation; alpha = 0.8
777
-
778
- ![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/waifu_openjourney_inv_sig_0.8.png)
779
-
780
-
781
- 3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
782
-
783
- ![Stable plus Waifu plus openjourney add_diff 0.5](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stable_waifu_openjourney_add_diff_0.5.png)
784
-
785
-
786
- ### Stable Diffusion Comparisons
787
-
788
- This Community Pipeline enables the comparison between the 4 checkpoints that exist for Stable Diffusion. They can be found through the following links:
789
- 1. [Stable Diffusion v1.1](https://huggingface.co/CompVis/stable-diffusion-v1-1)
790
- 2. [Stable Diffusion v1.2](https://huggingface.co/CompVis/stable-diffusion-v1-2)
791
- 3. [Stable Diffusion v1.3](https://huggingface.co/CompVis/stable-diffusion-v1-3)
792
- 4. [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
793
-
794
- ```python
795
- from diffusers import DiffusionPipeline
796
- import matplotlib.pyplot as plt
797
-
798
- pipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', custom_pipeline='suvadityamuk/StableDiffusionComparison')
799
- pipe.enable_attention_slicing()
800
- pipe = pipe.to('cuda')
801
- prompt = "an astronaut riding a horse on mars"
802
- output = pipe(prompt)
803
-
804
- plt.subplots(2,2,1)
805
- plt.imshow(output.images[0])
806
- plt.title('Stable Diffusion v1.1')
807
- plt.axis('off')
808
- plt.subplots(2,2,2)
809
- plt.imshow(output.images[1])
810
- plt.title('Stable Diffusion v1.2')
811
- plt.axis('off')
812
- plt.subplots(2,2,3)
813
- plt.imshow(output.images[2])
814
- plt.title('Stable Diffusion v1.3')
815
- plt.axis('off')
816
- plt.subplots(2,2,4)
817
- plt.imshow(output.images[3])
818
- plt.title('Stable Diffusion v1.4')
819
- plt.axis('off')
820
-
821
- plt.show()
822
- ```
823
-
824
- As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
825
-
826
- ### Magic Mix
827
-
828
- Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.
829
-
830
- There are 3 parameters for the method-
831
- - `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
832
- - `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.
833
-
834
- Here is an example usage-
835
-
836
- ```python
837
- from diffusers import DiffusionPipeline, DDIMScheduler
838
- from PIL import Image
839
-
840
- pipe = DiffusionPipeline.from_pretrained(
841
- "CompVis/stable-diffusion-v1-4",
842
- custom_pipeline="magic_mix",
843
- scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
844
- ).to('cuda')
845
-
846
- img = Image.open('phone.jpg')
847
- mix_img = pipe(
848
- img,
849
- prompt = 'bed',
850
- kmin = 0.3,
851
- kmax = 0.5,
852
- mix_factor = 0.5,
853
- )
854
- mix_img.save('phone_bed_mix.jpg')
855
- ```
856
- The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.
857
-
858
- E.g. the above script generates the following image:
859
-
860
- `phone.jpg`
861
-
862
- ![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg)
863
-
864
- `phone_bed_mix.jpg`
865
-
866
- ![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg)
867
-
868
- For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).
869
-
870
-
871
- ### Stable UnCLIP
872
-
873
- UnCLIPPipeline("kakaobrain/karlo-v1-alpha") provide a prior model that can generate clip image embedding from text.
874
- StableDiffusionImageVariationPipeline("lambdalabs/sd-image-variations-diffusers") provide a decoder model than can generate images from clip image embedding.
875
-
876
- ```python
877
- import torch
878
- from diffusers import DiffusionPipeline
879
-
880
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
881
-
882
- pipeline = DiffusionPipeline.from_pretrained(
883
- "kakaobrain/karlo-v1-alpha",
884
- torch_dtype=torch.float16,
885
- custom_pipeline="stable_unclip",
886
- decoder_pipe_kwargs=dict(
887
- image_encoder=None,
888
- ),
889
- )
890
- pipeline.to(device)
891
-
892
- prompt = "a shiba inu wearing a beret and black turtleneck"
893
- random_generator = torch.Generator(device=device).manual_seed(1000)
894
- output = pipeline(
895
- prompt=prompt,
896
- width=512,
897
- height=512,
898
- generator=random_generator,
899
- prior_guidance_scale=4,
900
- prior_num_inference_steps=25,
901
- decoder_guidance_scale=8,
902
- decoder_num_inference_steps=50,
903
- )
904
-
905
- image = output.images[0]
906
- image.save("./shiba-inu.jpg")
907
-
908
- # debug
909
-
910
- # `pipeline.decoder_pipe` is a regular StableDiffusionImageVariationPipeline instance.
911
- # It is used to convert clip image embedding to latents, then fed into VAE decoder.
912
- print(pipeline.decoder_pipe.__class__)
913
- # <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline'>
914
-
915
- # this pipeline only use prior module in "kakaobrain/karlo-v1-alpha"
916
- # It is used to convert clip text embedding to clip image embedding.
917
- print(pipeline)
918
- # StableUnCLIPPipeline {
919
- # "_class_name": "StableUnCLIPPipeline",
920
- # "_diffusers_version": "0.12.0.dev0",
921
- # "prior": [
922
- # "diffusers",
923
- # "PriorTransformer"
924
- # ],
925
- # "prior_scheduler": [
926
- # "diffusers",
927
- # "UnCLIPScheduler"
928
- # ],
929
- # "text_encoder": [
930
- # "transformers",
931
- # "CLIPTextModelWithProjection"
932
- # ],
933
- # "tokenizer": [
934
- # "transformers",
935
- # "CLIPTokenizer"
936
- # ]
937
- # }
938
-
939
- # pipeline.prior_scheduler is the scheduler used for prior in UnCLIP.
940
- print(pipeline.prior_scheduler)
941
- # UnCLIPScheduler {
942
- # "_class_name": "UnCLIPScheduler",
943
- # "_diffusers_version": "0.12.0.dev0",
944
- # "clip_sample": true,
945
- # "clip_sample_range": 5.0,
946
- # "num_train_timesteps": 1000,
947
- # "prediction_type": "sample",
948
- # "variance_type": "fixed_small_log"
949
- # }
950
- ```
951
-
952
-
953
- `shiba-inu.jpg`
954
-
955
-
956
- ![shiba-inu](https://user-images.githubusercontent.com/16448529/209185639-6e5ec794-ce9d-4883-aa29-bd6852a2abad.jpg)
957
-
958
- ### UnCLIP Text Interpolation Pipeline
959
-
960
- This Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps.
961
-
962
- ```python
963
- import torch
964
- from diffusers import DiffusionPipeline
965
-
966
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
967
-
968
- pipe = DiffusionPipeline.from_pretrained(
969
- "kakaobrain/karlo-v1-alpha",
970
- torch_dtype=torch.float16,
971
- custom_pipeline="unclip_text_interpolation"
972
- )
973
- pipe.to(device)
974
-
975
- start_prompt = "A photograph of an adult lion"
976
- end_prompt = "A photograph of a lion cub"
977
- #For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
978
- generator = torch.Generator(device=device).manual_seed(42)
979
-
980
- output = pipe(start_prompt, end_prompt, steps = 6, generator = generator, enable_sequential_cpu_offload=False)
981
-
982
- for i,image in enumerate(output.images):
983
- img.save('result%s.jpg' % i)
984
- ```
985
-
986
- The resulting images in order:-
987
-
988
- ![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_0.png)
989
- ![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_1.png)
990
- ![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_2.png)
991
- ![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_3.png)
992
- ![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_4.png)
993
- ![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_5.png)
994
-
995
- ### UnCLIP Image Interpolation Pipeline
996
-
997
- This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2 and interpolates between their embeddings using spherical interpolation ( slerp ). The input images/image_embeddings are converted to image embeddings by the pipeline's image_encoder and the interpolation is done on the resulting image_embeddings over the number of steps specified. Defaults to 5 steps.
998
-
999
- ```python
1000
- import torch
1001
- from diffusers import DiffusionPipeline
1002
- from PIL import Image
1003
-
1004
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
1005
- dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
1006
-
1007
- pipe = DiffusionPipeline.from_pretrained(
1008
- "kakaobrain/karlo-v1-alpha-image-variations",
1009
- torch_dtype=dtype,
1010
- custom_pipeline="unclip_image_interpolation"
1011
- )
1012
- pipe.to(device)
1013
-
1014
- images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
1015
- #For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
1016
- generator = torch.Generator(device=device).manual_seed(42)
1017
-
1018
- output = pipe(image = images ,steps = 6, generator = generator)
1019
-
1020
- for i,image in enumerate(output.images):
1021
- image.save('starry_to_flowers_%s.jpg' % i)
1022
- ```
1023
- The original images:-
1024
-
1025
- ![starry](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_night.jpg)
1026
- ![flowers](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/flowers.jpg)
1027
-
1028
- The resulting images in order:-
1029
-
1030
- ![result0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_0.png)
1031
- ![result1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_1.png)
1032
- ![result2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_2.png)
1033
- ![result3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_3.png)
1034
- ![result4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_4.png)
1035
- ![result5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_5.png)
1036
-
1037
- ### DDIM Noise Comparative Analysis Pipeline
1038
- #### **Research question: What visual concepts do the diffusion models learn from each noise level during training?**
1039
- The [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227) paper proposed an approach to answer the above question, which is their second contribution.
1040
- The approach consists of the following steps:
1041
-
1042
- 1. The input is an image x0.
1043
- 2. Perturb it to xt using a diffusion process q(xt|x0).
1044
- - `strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
1045
- 3. Reconstruct the image with the learned denoising process pθ(ˆx0|xt).
1046
- 4. Compare x0 and ˆx0 among various t to show how each step contributes to the sample.
1047
- The authors used [openai/guided-diffusion](https://github.com/openai/guided-diffusion) model to denoise images in FFHQ dataset. This pipeline extends their second contribution by investigating DDIM on any input image.
1048
-
1049
- ```python
1050
- import torch
1051
- from PIL import Image
1052
- import numpy as np
1053
-
1054
- image_path = "path/to/your/image" # images from CelebA-HQ might be better
1055
- image_pil = Image.open(image_path)
1056
- image_name = image_path.split("/")[-1].split(".")[0]
1057
-
1058
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
1059
- pipe = DiffusionPipeline.from_pretrained(
1060
- "google/ddpm-ema-celebahq-256",
1061
- custom_pipeline="ddim_noise_comparative_analysis",
1062
- )
1063
- pipe = pipe.to(device)
1064
-
1065
- for strength in np.linspace(0.1, 1, 25):
1066
- denoised_image, latent_timestep = pipe(
1067
- image_pil, strength=strength, return_dict=False
1068
- )
1069
- denoised_image = denoised_image[0]
1070
- denoised_image.save(
1071
- f"noise_comparative_analysis_{image_name}_{latent_timestep}.png"
1072
- )
1073
- ```
1074
-
1075
- Here is the result of this pipeline (which is DDIM) on CelebA-HQ dataset.
1076
-
1077
- ![noise-comparative-analysis](https://user-images.githubusercontent.com/67547213/224677066-4474b2ed-56ab-4c27-87c6-de3c0255eb9c.jpeg)
1078
-
1079
- ### CLIP Guided Img2Img Stable Diffusion
1080
-
1081
- CLIP guided Img2Img stable diffusion can help to generate more realistic images with an initial image
1082
- by guiding stable diffusion at every denoising step with an additional CLIP model.
1083
-
1084
- The following code requires roughly 12GB of GPU RAM.
1085
-
1086
- ```python
1087
- from io import BytesIO
1088
- import requests
1089
- import torch
1090
- from diffusers import DiffusionPipeline
1091
- from PIL import Image
1092
- from transformers import CLIPFeatureExtractor, CLIPModel
1093
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
1094
- "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
1095
- )
1096
- clip_model = CLIPModel.from_pretrained(
1097
- "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
1098
- )
1099
- guided_pipeline = DiffusionPipeline.from_pretrained(
1100
- "CompVis/stable-diffusion-v1-4",
1101
- # custom_pipeline="clip_guided_stable_diffusion",
1102
- custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
1103
- clip_model=clip_model,
1104
- feature_extractor=feature_extractor,
1105
- torch_dtype=torch.float16,
1106
- )
1107
- guided_pipeline.enable_attention_slicing()
1108
- guided_pipeline = guided_pipeline.to("cuda")
1109
- prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
1110
- url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
1111
- response = requests.get(url)
1112
- init_image = Image.open(BytesIO(response.content)).convert("RGB")
1113
- image = guided_pipeline(
1114
- prompt=prompt,
1115
- num_inference_steps=30,
1116
- image=init_image,
1117
- strength=0.75,
1118
- guidance_scale=7.5,
1119
- clip_guidance_scale=100,
1120
- num_cutouts=4,
1121
- use_cutouts=False,
1122
- ).images[0]
1123
- display(image)
1124
- ```
1125
-
1126
- Init Image
1127
-
1128
- ![img2img_init_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img_init.jpg)
1129
-
1130
- Output Image
1131
-
1132
- ![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/bit_diffusion.py DELETED
@@ -1,264 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- from einops import rearrange, reduce
5
-
6
- from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
7
- from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
8
- from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
9
-
10
-
11
- BITS = 8
12
-
13
-
14
- # convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
15
- def decimal_to_bits(x, bits=BITS):
16
- """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
17
- device = x.device
18
-
19
- x = (x * 255).int().clamp(0, 255)
20
-
21
- mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
22
- mask = rearrange(mask, "d -> d 1 1")
23
- x = rearrange(x, "b c h w -> b c 1 h w")
24
-
25
- bits = ((x & mask) != 0).float()
26
- bits = rearrange(bits, "b c d h w -> b (c d) h w")
27
- bits = bits * 2 - 1
28
- return bits
29
-
30
-
31
- def bits_to_decimal(x, bits=BITS):
32
- """expects bits from -1 to 1, outputs image tensor from 0 to 1"""
33
- device = x.device
34
-
35
- x = (x > 0).int()
36
- mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
37
-
38
- mask = rearrange(mask, "d -> d 1 1")
39
- x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
40
- dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
41
- return (dec / 255).clamp(0.0, 1.0)
42
-
43
-
44
- # modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
45
- def ddim_bit_scheduler_step(
46
- self,
47
- model_output: torch.FloatTensor,
48
- timestep: int,
49
- sample: torch.FloatTensor,
50
- eta: float = 0.0,
51
- use_clipped_model_output: bool = True,
52
- generator=None,
53
- return_dict: bool = True,
54
- ) -> Union[DDIMSchedulerOutput, Tuple]:
55
- """
56
- Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
57
- process from the learned model outputs (most often the predicted noise).
58
- Args:
59
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
60
- timestep (`int`): current discrete timestep in the diffusion chain.
61
- sample (`torch.FloatTensor`):
62
- current instance of sample being created by diffusion process.
63
- eta (`float`): weight of noise for added noise in diffusion step.
64
- use_clipped_model_output (`bool`): TODO
65
- generator: random number generator.
66
- return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
67
- Returns:
68
- [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
69
- [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
70
- returning a tuple, the first element is the sample tensor.
71
- """
72
- if self.num_inference_steps is None:
73
- raise ValueError(
74
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
75
- )
76
-
77
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
78
- # Ideally, read DDIM paper in-detail understanding
79
-
80
- # Notation (<variable name> -> <name in paper>
81
- # - pred_noise_t -> e_theta(x_t, t)
82
- # - pred_original_sample -> f_theta(x_t, t) or x_0
83
- # - std_dev_t -> sigma_t
84
- # - eta -> η
85
- # - pred_sample_direction -> "direction pointing to x_t"
86
- # - pred_prev_sample -> "x_t-1"
87
-
88
- # 1. get previous step value (=t-1)
89
- prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
90
-
91
- # 2. compute alphas, betas
92
- alpha_prod_t = self.alphas_cumprod[timestep]
93
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
94
-
95
- beta_prod_t = 1 - alpha_prod_t
96
-
97
- # 3. compute predicted original sample from predicted noise also called
98
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
99
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
100
-
101
- # 4. Clip "predicted x_0"
102
- scale = self.bit_scale
103
- if self.config.clip_sample:
104
- pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
105
-
106
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
107
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
108
- variance = self._get_variance(timestep, prev_timestep)
109
- std_dev_t = eta * variance ** (0.5)
110
-
111
- if use_clipped_model_output:
112
- # the model_output is always re-derived from the clipped x_0 in Glide
113
- model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
114
-
115
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
116
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
117
-
118
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
119
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
120
-
121
- if eta > 0:
122
- # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
123
- device = model_output.device if torch.is_tensor(model_output) else "cpu"
124
- noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
125
- variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
126
-
127
- prev_sample = prev_sample + variance
128
-
129
- if not return_dict:
130
- return (prev_sample,)
131
-
132
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
133
-
134
-
135
- def ddpm_bit_scheduler_step(
136
- self,
137
- model_output: torch.FloatTensor,
138
- timestep: int,
139
- sample: torch.FloatTensor,
140
- prediction_type="epsilon",
141
- generator=None,
142
- return_dict: bool = True,
143
- ) -> Union[DDPMSchedulerOutput, Tuple]:
144
- """
145
- Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
146
- process from the learned model outputs (most often the predicted noise).
147
- Args:
148
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
149
- timestep (`int`): current discrete timestep in the diffusion chain.
150
- sample (`torch.FloatTensor`):
151
- current instance of sample being created by diffusion process.
152
- prediction_type (`str`, default `epsilon`):
153
- indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
154
- generator: random number generator.
155
- return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
156
- Returns:
157
- [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
158
- [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
159
- returning a tuple, the first element is the sample tensor.
160
- """
161
- t = timestep
162
-
163
- if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
164
- model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
165
- else:
166
- predicted_variance = None
167
-
168
- # 1. compute alphas, betas
169
- alpha_prod_t = self.alphas_cumprod[t]
170
- alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
171
- beta_prod_t = 1 - alpha_prod_t
172
- beta_prod_t_prev = 1 - alpha_prod_t_prev
173
-
174
- # 2. compute predicted original sample from predicted noise also called
175
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
176
- if prediction_type == "epsilon":
177
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
178
- elif prediction_type == "sample":
179
- pred_original_sample = model_output
180
- else:
181
- raise ValueError(f"Unsupported prediction_type {prediction_type}.")
182
-
183
- # 3. Clip "predicted x_0"
184
- scale = self.bit_scale
185
- if self.config.clip_sample:
186
- pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
187
-
188
- # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
189
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
190
- pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
191
- current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
192
-
193
- # 5. Compute predicted previous sample µ_t
194
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
195
- pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
196
-
197
- # 6. Add noise
198
- variance = 0
199
- if t > 0:
200
- noise = torch.randn(
201
- model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
202
- ).to(model_output.device)
203
- variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
204
-
205
- pred_prev_sample = pred_prev_sample + variance
206
-
207
- if not return_dict:
208
- return (pred_prev_sample,)
209
-
210
- return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
211
-
212
-
213
- class BitDiffusion(DiffusionPipeline):
214
- def __init__(
215
- self,
216
- unet: UNet2DConditionModel,
217
- scheduler: Union[DDIMScheduler, DDPMScheduler],
218
- bit_scale: Optional[float] = 1.0,
219
- ):
220
- super().__init__()
221
- self.bit_scale = bit_scale
222
- self.scheduler.step = (
223
- ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
224
- )
225
-
226
- self.register_modules(unet=unet, scheduler=scheduler)
227
-
228
- @torch.no_grad()
229
- def __call__(
230
- self,
231
- height: Optional[int] = 256,
232
- width: Optional[int] = 256,
233
- num_inference_steps: Optional[int] = 50,
234
- generator: Optional[torch.Generator] = None,
235
- batch_size: Optional[int] = 1,
236
- output_type: Optional[str] = "pil",
237
- return_dict: bool = True,
238
- **kwargs,
239
- ) -> Union[Tuple, ImagePipelineOutput]:
240
- latents = torch.randn(
241
- (batch_size, self.unet.in_channels, height, width),
242
- generator=generator,
243
- )
244
- latents = decimal_to_bits(latents) * self.bit_scale
245
- latents = latents.to(self.device)
246
-
247
- self.scheduler.set_timesteps(num_inference_steps)
248
-
249
- for t in self.progress_bar(self.scheduler.timesteps):
250
- # predict the noise residual
251
- noise_pred = self.unet(latents, t).sample
252
-
253
- # compute the previous noisy sample x_t -> x_t-1
254
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
255
-
256
- image = bits_to_decimal(latents)
257
-
258
- if output_type == "pil":
259
- image = self.numpy_to_pil(image)
260
-
261
- if not return_dict:
262
- return (image,)
263
-
264
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/checkpoint_merger.py DELETED
@@ -1,286 +0,0 @@
1
- import glob
2
- import os
3
- from typing import Dict, List, Union
4
-
5
- import torch
6
-
7
- from diffusers.utils import is_safetensors_available
8
-
9
-
10
- if is_safetensors_available():
11
- import safetensors.torch
12
-
13
- from huggingface_hub import snapshot_download
14
-
15
- from diffusers import DiffusionPipeline, __version__
16
- from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
17
- from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
18
-
19
-
20
- class CheckpointMergerPipeline(DiffusionPipeline):
21
- """
22
- A class that that supports merging diffusion models based on the discussion here:
23
- https://github.com/huggingface/diffusers/issues/877
24
-
25
- Example usage:-
26
-
27
- pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
28
-
29
- merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
30
-
31
- merged_pipe.to('cuda')
32
-
33
- prompt = "An astronaut riding a unicycle on Mars"
34
-
35
- results = merged_pipe(prompt)
36
-
37
- ## For more details, see the docstring for the merge method.
38
-
39
- """
40
-
41
- def __init__(self):
42
- self.register_to_config()
43
- super().__init__()
44
-
45
- def _compare_model_configs(self, dict0, dict1):
46
- if dict0 == dict1:
47
- return True
48
- else:
49
- config0, meta_keys0 = self._remove_meta_keys(dict0)
50
- config1, meta_keys1 = self._remove_meta_keys(dict1)
51
- if config0 == config1:
52
- print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
53
- return True
54
- return False
55
-
56
- def _remove_meta_keys(self, config_dict: Dict):
57
- meta_keys = []
58
- temp_dict = config_dict.copy()
59
- for key in config_dict.keys():
60
- if key.startswith("_"):
61
- temp_dict.pop(key)
62
- meta_keys.append(key)
63
- return (temp_dict, meta_keys)
64
-
65
- @torch.no_grad()
66
- def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
67
- """
68
- Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
69
- in the argument 'pretrained_model_name_or_path_list' as a list.
70
-
71
- Parameters:
72
- -----------
73
- pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
74
-
75
- **kwargs:
76
- Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
77
-
78
- cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
79
-
80
- alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
81
- would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
82
-
83
- interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
84
- Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
85
-
86
- force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
87
-
88
- """
89
- # Default kwargs from DiffusionPipeline
90
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
91
- resume_download = kwargs.pop("resume_download", False)
92
- force_download = kwargs.pop("force_download", False)
93
- proxies = kwargs.pop("proxies", None)
94
- local_files_only = kwargs.pop("local_files_only", False)
95
- use_auth_token = kwargs.pop("use_auth_token", None)
96
- revision = kwargs.pop("revision", None)
97
- torch_dtype = kwargs.pop("torch_dtype", None)
98
- device_map = kwargs.pop("device_map", None)
99
-
100
- alpha = kwargs.pop("alpha", 0.5)
101
- interp = kwargs.pop("interp", None)
102
-
103
- print("Received list", pretrained_model_name_or_path_list)
104
- print(f"Combining with alpha={alpha}, interpolation mode={interp}")
105
-
106
- checkpoint_count = len(pretrained_model_name_or_path_list)
107
- # Ignore result from model_index_json comparision of the two checkpoints
108
- force = kwargs.pop("force", False)
109
-
110
- # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
111
- if checkpoint_count > 3 or checkpoint_count < 2:
112
- raise ValueError(
113
- "Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
114
- " passed."
115
- )
116
-
117
- print("Received the right number of checkpoints")
118
- # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
119
- # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
120
-
121
- # Validate that the checkpoints can be merged
122
- # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
123
- config_dicts = []
124
- for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
125
- config_dict = DiffusionPipeline.load_config(
126
- pretrained_model_name_or_path,
127
- cache_dir=cache_dir,
128
- resume_download=resume_download,
129
- force_download=force_download,
130
- proxies=proxies,
131
- local_files_only=local_files_only,
132
- use_auth_token=use_auth_token,
133
- revision=revision,
134
- )
135
- config_dicts.append(config_dict)
136
-
137
- comparison_result = True
138
- for idx in range(1, len(config_dicts)):
139
- comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
140
- if not force and comparison_result is False:
141
- raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
142
- print(config_dicts[0], config_dicts[1])
143
- print("Compatible model_index.json files found")
144
- # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
145
- cached_folders = []
146
- for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
147
- folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
148
- allow_patterns = [os.path.join(k, "*") for k in folder_names]
149
- allow_patterns += [
150
- WEIGHTS_NAME,
151
- SCHEDULER_CONFIG_NAME,
152
- CONFIG_NAME,
153
- ONNX_WEIGHTS_NAME,
154
- DiffusionPipeline.config_name,
155
- ]
156
- requested_pipeline_class = config_dict.get("_class_name")
157
- user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
158
-
159
- cached_folder = (
160
- pretrained_model_name_or_path
161
- if os.path.isdir(pretrained_model_name_or_path)
162
- else snapshot_download(
163
- pretrained_model_name_or_path,
164
- cache_dir=cache_dir,
165
- resume_download=resume_download,
166
- proxies=proxies,
167
- local_files_only=local_files_only,
168
- use_auth_token=use_auth_token,
169
- revision=revision,
170
- allow_patterns=allow_patterns,
171
- user_agent=user_agent,
172
- )
173
- )
174
- print("Cached Folder", cached_folder)
175
- cached_folders.append(cached_folder)
176
-
177
- # Step 3:-
178
- # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
179
- final_pipe = DiffusionPipeline.from_pretrained(
180
- cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
181
- )
182
- final_pipe.to(self.device)
183
-
184
- checkpoint_path_2 = None
185
- if len(cached_folders) > 2:
186
- checkpoint_path_2 = os.path.join(cached_folders[2])
187
-
188
- if interp == "sigmoid":
189
- theta_func = CheckpointMergerPipeline.sigmoid
190
- elif interp == "inv_sigmoid":
191
- theta_func = CheckpointMergerPipeline.inv_sigmoid
192
- elif interp == "add_diff":
193
- theta_func = CheckpointMergerPipeline.add_difference
194
- else:
195
- theta_func = CheckpointMergerPipeline.weighted_sum
196
-
197
- # Find each module's state dict.
198
- for attr in final_pipe.config.keys():
199
- if not attr.startswith("_"):
200
- checkpoint_path_1 = os.path.join(cached_folders[1], attr)
201
- if os.path.exists(checkpoint_path_1):
202
- files = [
203
- *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
204
- *glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
205
- ]
206
- checkpoint_path_1 = files[0] if len(files) > 0 else None
207
- if len(cached_folders) < 3:
208
- checkpoint_path_2 = None
209
- else:
210
- checkpoint_path_2 = os.path.join(cached_folders[2], attr)
211
- if os.path.exists(checkpoint_path_2):
212
- files = [
213
- *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
214
- *glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
215
- ]
216
- checkpoint_path_2 = files[0] if len(files) > 0 else None
217
- # For an attr if both checkpoint_path_1 and 2 are None, ignore.
218
- # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
219
- if checkpoint_path_1 is None and checkpoint_path_2 is None:
220
- print(f"Skipping {attr}: not present in 2nd or 3d model")
221
- continue
222
- try:
223
- module = getattr(final_pipe, attr)
224
- if isinstance(module, bool): # ignore requires_safety_checker boolean
225
- continue
226
- theta_0 = getattr(module, "state_dict")
227
- theta_0 = theta_0()
228
-
229
- update_theta_0 = getattr(module, "load_state_dict")
230
- theta_1 = (
231
- safetensors.torch.load_file(checkpoint_path_1)
232
- if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))
233
- else torch.load(checkpoint_path_1, map_location="cpu")
234
- )
235
- theta_2 = None
236
- if checkpoint_path_2:
237
- theta_2 = (
238
- safetensors.torch.load_file(checkpoint_path_2)
239
- if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
240
- else torch.load(checkpoint_path_2, map_location="cpu")
241
- )
242
-
243
- if not theta_0.keys() == theta_1.keys():
244
- print(f"Skipping {attr}: key mismatch")
245
- continue
246
- if theta_2 and not theta_1.keys() == theta_2.keys():
247
- print(f"Skipping {attr}:y mismatch")
248
- except Exception as e:
249
- print(f"Skipping {attr} do to an unexpected error: {str(e)}")
250
- continue
251
- print(f"MERGING {attr}")
252
-
253
- for key in theta_0.keys():
254
- if theta_2:
255
- theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
256
- else:
257
- theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
258
-
259
- del theta_1
260
- del theta_2
261
- update_theta_0(theta_0)
262
-
263
- del theta_0
264
- return final_pipe
265
-
266
- @staticmethod
267
- def weighted_sum(theta0, theta1, theta2, alpha):
268
- return ((1 - alpha) * theta0) + (alpha * theta1)
269
-
270
- # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
271
- @staticmethod
272
- def sigmoid(theta0, theta1, theta2, alpha):
273
- alpha = alpha * alpha * (3 - (2 * alpha))
274
- return theta0 + ((theta1 - theta0) * alpha)
275
-
276
- # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
277
- @staticmethod
278
- def inv_sigmoid(theta0, theta1, theta2, alpha):
279
- import math
280
-
281
- alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
282
- return theta0 + ((theta1 - theta0) * alpha)
283
-
284
- @staticmethod
285
- def add_difference(theta0, theta1, theta2, alpha):
286
- return theta0 + (theta1 - theta2) * (1.0 - alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/clip_guided_stable_diffusion.py DELETED
@@ -1,347 +0,0 @@
1
- import inspect
2
- from typing import List, Optional, Union
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
- from torchvision import transforms
8
- from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
9
-
10
- from diffusers import (
11
- AutoencoderKL,
12
- DDIMScheduler,
13
- DiffusionPipeline,
14
- DPMSolverMultistepScheduler,
15
- LMSDiscreteScheduler,
16
- PNDMScheduler,
17
- UNet2DConditionModel,
18
- )
19
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
20
-
21
-
22
- class MakeCutouts(nn.Module):
23
- def __init__(self, cut_size, cut_power=1.0):
24
- super().__init__()
25
-
26
- self.cut_size = cut_size
27
- self.cut_power = cut_power
28
-
29
- def forward(self, pixel_values, num_cutouts):
30
- sideY, sideX = pixel_values.shape[2:4]
31
- max_size = min(sideX, sideY)
32
- min_size = min(sideX, sideY, self.cut_size)
33
- cutouts = []
34
- for _ in range(num_cutouts):
35
- size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
36
- offsetx = torch.randint(0, sideX - size + 1, ())
37
- offsety = torch.randint(0, sideY - size + 1, ())
38
- cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
39
- cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
40
- return torch.cat(cutouts)
41
-
42
-
43
- def spherical_dist_loss(x, y):
44
- x = F.normalize(x, dim=-1)
45
- y = F.normalize(y, dim=-1)
46
- return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
47
-
48
-
49
- def set_requires_grad(model, value):
50
- for param in model.parameters():
51
- param.requires_grad = value
52
-
53
-
54
- class CLIPGuidedStableDiffusion(DiffusionPipeline):
55
- """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
56
- - https://github.com/Jack000/glid-3-xl
57
- - https://github.dev/crowsonkb/k-diffusion
58
- """
59
-
60
- def __init__(
61
- self,
62
- vae: AutoencoderKL,
63
- text_encoder: CLIPTextModel,
64
- clip_model: CLIPModel,
65
- tokenizer: CLIPTokenizer,
66
- unet: UNet2DConditionModel,
67
- scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
68
- feature_extractor: CLIPImageProcessor,
69
- ):
70
- super().__init__()
71
- self.register_modules(
72
- vae=vae,
73
- text_encoder=text_encoder,
74
- clip_model=clip_model,
75
- tokenizer=tokenizer,
76
- unet=unet,
77
- scheduler=scheduler,
78
- feature_extractor=feature_extractor,
79
- )
80
-
81
- self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
82
- self.cut_out_size = (
83
- feature_extractor.size
84
- if isinstance(feature_extractor.size, int)
85
- else feature_extractor.size["shortest_edge"]
86
- )
87
- self.make_cutouts = MakeCutouts(self.cut_out_size)
88
-
89
- set_requires_grad(self.text_encoder, False)
90
- set_requires_grad(self.clip_model, False)
91
-
92
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
93
- if slice_size == "auto":
94
- # half the attention head size is usually a good trade-off between
95
- # speed and memory
96
- slice_size = self.unet.config.attention_head_dim // 2
97
- self.unet.set_attention_slice(slice_size)
98
-
99
- def disable_attention_slicing(self):
100
- self.enable_attention_slicing(None)
101
-
102
- def freeze_vae(self):
103
- set_requires_grad(self.vae, False)
104
-
105
- def unfreeze_vae(self):
106
- set_requires_grad(self.vae, True)
107
-
108
- def freeze_unet(self):
109
- set_requires_grad(self.unet, False)
110
-
111
- def unfreeze_unet(self):
112
- set_requires_grad(self.unet, True)
113
-
114
- @torch.enable_grad()
115
- def cond_fn(
116
- self,
117
- latents,
118
- timestep,
119
- index,
120
- text_embeddings,
121
- noise_pred_original,
122
- text_embeddings_clip,
123
- clip_guidance_scale,
124
- num_cutouts,
125
- use_cutouts=True,
126
- ):
127
- latents = latents.detach().requires_grad_()
128
-
129
- latent_model_input = self.scheduler.scale_model_input(latents, timestep)
130
-
131
- # predict the noise residual
132
- noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
133
-
134
- if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
135
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
136
- beta_prod_t = 1 - alpha_prod_t
137
- # compute predicted original sample from predicted noise also called
138
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
139
- pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
140
-
141
- fac = torch.sqrt(beta_prod_t)
142
- sample = pred_original_sample * (fac) + latents * (1 - fac)
143
- elif isinstance(self.scheduler, LMSDiscreteScheduler):
144
- sigma = self.scheduler.sigmas[index]
145
- sample = latents - sigma * noise_pred
146
- else:
147
- raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
148
-
149
- sample = 1 / self.vae.config.scaling_factor * sample
150
- image = self.vae.decode(sample).sample
151
- image = (image / 2 + 0.5).clamp(0, 1)
152
-
153
- if use_cutouts:
154
- image = self.make_cutouts(image, num_cutouts)
155
- else:
156
- image = transforms.Resize(self.cut_out_size)(image)
157
- image = self.normalize(image).to(latents.dtype)
158
-
159
- image_embeddings_clip = self.clip_model.get_image_features(image)
160
- image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
161
-
162
- if use_cutouts:
163
- dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
164
- dists = dists.view([num_cutouts, sample.shape[0], -1])
165
- loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
166
- else:
167
- loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
168
-
169
- grads = -torch.autograd.grad(loss, latents)[0]
170
-
171
- if isinstance(self.scheduler, LMSDiscreteScheduler):
172
- latents = latents.detach() + grads * (sigma**2)
173
- noise_pred = noise_pred_original
174
- else:
175
- noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
176
- return noise_pred, latents
177
-
178
- @torch.no_grad()
179
- def __call__(
180
- self,
181
- prompt: Union[str, List[str]],
182
- height: Optional[int] = 512,
183
- width: Optional[int] = 512,
184
- num_inference_steps: Optional[int] = 50,
185
- guidance_scale: Optional[float] = 7.5,
186
- num_images_per_prompt: Optional[int] = 1,
187
- eta: float = 0.0,
188
- clip_guidance_scale: Optional[float] = 100,
189
- clip_prompt: Optional[Union[str, List[str]]] = None,
190
- num_cutouts: Optional[int] = 4,
191
- use_cutouts: Optional[bool] = True,
192
- generator: Optional[torch.Generator] = None,
193
- latents: Optional[torch.FloatTensor] = None,
194
- output_type: Optional[str] = "pil",
195
- return_dict: bool = True,
196
- ):
197
- if isinstance(prompt, str):
198
- batch_size = 1
199
- elif isinstance(prompt, list):
200
- batch_size = len(prompt)
201
- else:
202
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
203
-
204
- if height % 8 != 0 or width % 8 != 0:
205
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
206
-
207
- # get prompt text embeddings
208
- text_input = self.tokenizer(
209
- prompt,
210
- padding="max_length",
211
- max_length=self.tokenizer.model_max_length,
212
- truncation=True,
213
- return_tensors="pt",
214
- )
215
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
216
- # duplicate text embeddings for each generation per prompt
217
- text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
218
-
219
- if clip_guidance_scale > 0:
220
- if clip_prompt is not None:
221
- clip_text_input = self.tokenizer(
222
- clip_prompt,
223
- padding="max_length",
224
- max_length=self.tokenizer.model_max_length,
225
- truncation=True,
226
- return_tensors="pt",
227
- ).input_ids.to(self.device)
228
- else:
229
- clip_text_input = text_input.input_ids.to(self.device)
230
- text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
231
- text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
232
- # duplicate text embeddings clip for each generation per prompt
233
- text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
234
-
235
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
236
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
237
- # corresponds to doing no classifier free guidance.
238
- do_classifier_free_guidance = guidance_scale > 1.0
239
- # get unconditional embeddings for classifier free guidance
240
- if do_classifier_free_guidance:
241
- max_length = text_input.input_ids.shape[-1]
242
- uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
243
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
244
- # duplicate unconditional embeddings for each generation per prompt
245
- uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
246
-
247
- # For classifier free guidance, we need to do two forward passes.
248
- # Here we concatenate the unconditional and text embeddings into a single batch
249
- # to avoid doing two forward passes
250
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
251
-
252
- # get the initial random noise unless the user supplied it
253
-
254
- # Unlike in other pipelines, latents need to be generated in the target device
255
- # for 1-to-1 results reproducibility with the CompVis implementation.
256
- # However this currently doesn't work in `mps`.
257
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
258
- latents_dtype = text_embeddings.dtype
259
- if latents is None:
260
- if self.device.type == "mps":
261
- # randn does not work reproducibly on mps
262
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
263
- self.device
264
- )
265
- else:
266
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
267
- else:
268
- if latents.shape != latents_shape:
269
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
270
- latents = latents.to(self.device)
271
-
272
- # set timesteps
273
- accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
274
- extra_set_kwargs = {}
275
- if accepts_offset:
276
- extra_set_kwargs["offset"] = 1
277
-
278
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
279
-
280
- # Some schedulers like PNDM have timesteps as arrays
281
- # It's more optimized to move all timesteps to correct device beforehand
282
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
283
-
284
- # scale the initial noise by the standard deviation required by the scheduler
285
- latents = latents * self.scheduler.init_noise_sigma
286
-
287
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
288
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
289
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
290
- # and should be between [0, 1]
291
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
292
- extra_step_kwargs = {}
293
- if accepts_eta:
294
- extra_step_kwargs["eta"] = eta
295
-
296
- # check if the scheduler accepts generator
297
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
298
- if accepts_generator:
299
- extra_step_kwargs["generator"] = generator
300
-
301
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
302
- # expand the latents if we are doing classifier free guidance
303
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
304
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
305
-
306
- # predict the noise residual
307
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
308
-
309
- # perform classifier free guidance
310
- if do_classifier_free_guidance:
311
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
312
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
313
-
314
- # perform clip guidance
315
- if clip_guidance_scale > 0:
316
- text_embeddings_for_guidance = (
317
- text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
318
- )
319
- noise_pred, latents = self.cond_fn(
320
- latents,
321
- t,
322
- i,
323
- text_embeddings_for_guidance,
324
- noise_pred,
325
- text_embeddings_clip,
326
- clip_guidance_scale,
327
- num_cutouts,
328
- use_cutouts,
329
- )
330
-
331
- # compute the previous noisy sample x_t -> x_t-1
332
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
333
-
334
- # scale and decode the image latents with vae
335
- latents = 1 / self.vae.config.scaling_factor * latents
336
- image = self.vae.decode(latents).sample
337
-
338
- image = (image / 2 + 0.5).clamp(0, 1)
339
- image = image.cpu().permute(0, 2, 3, 1).numpy()
340
-
341
- if output_type == "pil":
342
- image = self.numpy_to_pil(image)
343
-
344
- if not return_dict:
345
- return (image, None)
346
-
347
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/clip_guided_stable_diffusion_img2img.py DELETED
@@ -1,496 +0,0 @@
1
- import inspect
2
- from typing import List, Optional, Union
3
-
4
- import numpy as np
5
- import PIL
6
- import torch
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from torchvision import transforms
10
- from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
11
-
12
- from diffusers import (
13
- AutoencoderKL,
14
- DDIMScheduler,
15
- DiffusionPipeline,
16
- DPMSolverMultistepScheduler,
17
- LMSDiscreteScheduler,
18
- PNDMScheduler,
19
- UNet2DConditionModel,
20
- )
21
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
22
- from diffusers.utils import (
23
- PIL_INTERPOLATION,
24
- deprecate,
25
- randn_tensor,
26
- )
27
-
28
-
29
- EXAMPLE_DOC_STRING = """
30
- Examples:
31
- ```
32
- from io import BytesIO
33
-
34
- import requests
35
- import torch
36
- from diffusers import DiffusionPipeline
37
- from PIL import Image
38
- from transformers import CLIPFeatureExtractor, CLIPModel
39
-
40
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
41
- "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
42
- )
43
- clip_model = CLIPModel.from_pretrained(
44
- "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
45
- )
46
-
47
-
48
- guided_pipeline = DiffusionPipeline.from_pretrained(
49
- "CompVis/stable-diffusion-v1-4",
50
- # custom_pipeline="clip_guided_stable_diffusion",
51
- custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
52
- clip_model=clip_model,
53
- feature_extractor=feature_extractor,
54
- torch_dtype=torch.float16,
55
- )
56
- guided_pipeline.enable_attention_slicing()
57
- guided_pipeline = guided_pipeline.to("cuda")
58
-
59
- prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
60
-
61
- url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
62
-
63
- response = requests.get(url)
64
- init_image = Image.open(BytesIO(response.content)).convert("RGB")
65
-
66
- image = guided_pipeline(
67
- prompt=prompt,
68
- num_inference_steps=30,
69
- image=init_image,
70
- strength=0.75,
71
- guidance_scale=7.5,
72
- clip_guidance_scale=100,
73
- num_cutouts=4,
74
- use_cutouts=False,
75
- ).images[0]
76
- display(image)
77
- ```
78
- """
79
-
80
-
81
- def preprocess(image, w, h):
82
- if isinstance(image, torch.Tensor):
83
- return image
84
- elif isinstance(image, PIL.Image.Image):
85
- image = [image]
86
-
87
- if isinstance(image[0], PIL.Image.Image):
88
- image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
89
- image = np.concatenate(image, axis=0)
90
- image = np.array(image).astype(np.float32) / 255.0
91
- image = image.transpose(0, 3, 1, 2)
92
- image = 2.0 * image - 1.0
93
- image = torch.from_numpy(image)
94
- elif isinstance(image[0], torch.Tensor):
95
- image = torch.cat(image, dim=0)
96
- return image
97
-
98
-
99
- class MakeCutouts(nn.Module):
100
- def __init__(self, cut_size, cut_power=1.0):
101
- super().__init__()
102
-
103
- self.cut_size = cut_size
104
- self.cut_power = cut_power
105
-
106
- def forward(self, pixel_values, num_cutouts):
107
- sideY, sideX = pixel_values.shape[2:4]
108
- max_size = min(sideX, sideY)
109
- min_size = min(sideX, sideY, self.cut_size)
110
- cutouts = []
111
- for _ in range(num_cutouts):
112
- size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
113
- offsetx = torch.randint(0, sideX - size + 1, ())
114
- offsety = torch.randint(0, sideY - size + 1, ())
115
- cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
116
- cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
117
- return torch.cat(cutouts)
118
-
119
-
120
- def spherical_dist_loss(x, y):
121
- x = F.normalize(x, dim=-1)
122
- y = F.normalize(y, dim=-1)
123
- return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
124
-
125
-
126
- def set_requires_grad(model, value):
127
- for param in model.parameters():
128
- param.requires_grad = value
129
-
130
-
131
- class CLIPGuidedStableDiffusion(DiffusionPipeline):
132
- """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
133
- - https://github.com/Jack000/glid-3-xl
134
- - https://github.dev/crowsonkb/k-diffusion
135
- """
136
-
137
- def __init__(
138
- self,
139
- vae: AutoencoderKL,
140
- text_encoder: CLIPTextModel,
141
- clip_model: CLIPModel,
142
- tokenizer: CLIPTokenizer,
143
- unet: UNet2DConditionModel,
144
- scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
145
- feature_extractor: CLIPFeatureExtractor,
146
- ):
147
- super().__init__()
148
- self.register_modules(
149
- vae=vae,
150
- text_encoder=text_encoder,
151
- clip_model=clip_model,
152
- tokenizer=tokenizer,
153
- unet=unet,
154
- scheduler=scheduler,
155
- feature_extractor=feature_extractor,
156
- )
157
-
158
- self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
159
- self.cut_out_size = (
160
- feature_extractor.size
161
- if isinstance(feature_extractor.size, int)
162
- else feature_extractor.size["shortest_edge"]
163
- )
164
- self.make_cutouts = MakeCutouts(self.cut_out_size)
165
-
166
- set_requires_grad(self.text_encoder, False)
167
- set_requires_grad(self.clip_model, False)
168
-
169
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
170
- if slice_size == "auto":
171
- # half the attention head size is usually a good trade-off between
172
- # speed and memory
173
- slice_size = self.unet.config.attention_head_dim // 2
174
- self.unet.set_attention_slice(slice_size)
175
-
176
- def disable_attention_slicing(self):
177
- self.enable_attention_slicing(None)
178
-
179
- def freeze_vae(self):
180
- set_requires_grad(self.vae, False)
181
-
182
- def unfreeze_vae(self):
183
- set_requires_grad(self.vae, True)
184
-
185
- def freeze_unet(self):
186
- set_requires_grad(self.unet, False)
187
-
188
- def unfreeze_unet(self):
189
- set_requires_grad(self.unet, True)
190
-
191
- def get_timesteps(self, num_inference_steps, strength, device):
192
- # get the original timestep using init_timestep
193
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
194
-
195
- t_start = max(num_inference_steps - init_timestep, 0)
196
- timesteps = self.scheduler.timesteps[t_start:]
197
-
198
- return timesteps, num_inference_steps - t_start
199
-
200
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
201
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
202
- raise ValueError(
203
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
204
- )
205
-
206
- image = image.to(device=device, dtype=dtype)
207
-
208
- batch_size = batch_size * num_images_per_prompt
209
- if isinstance(generator, list) and len(generator) != batch_size:
210
- raise ValueError(
211
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
212
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
213
- )
214
-
215
- if isinstance(generator, list):
216
- init_latents = [
217
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
218
- ]
219
- init_latents = torch.cat(init_latents, dim=0)
220
- else:
221
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
222
-
223
- init_latents = self.vae.config.scaling_factor * init_latents
224
-
225
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
226
- # expand init_latents for batch_size
227
- deprecation_message = (
228
- f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
229
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
230
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
231
- " your script to pass as many initial images as text prompts to suppress this warning."
232
- )
233
- deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
234
- additional_image_per_prompt = batch_size // init_latents.shape[0]
235
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
236
- elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
237
- raise ValueError(
238
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
239
- )
240
- else:
241
- init_latents = torch.cat([init_latents], dim=0)
242
-
243
- shape = init_latents.shape
244
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
245
-
246
- # get latents
247
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
248
- latents = init_latents
249
-
250
- return latents
251
-
252
- @torch.enable_grad()
253
- def cond_fn(
254
- self,
255
- latents,
256
- timestep,
257
- index,
258
- text_embeddings,
259
- noise_pred_original,
260
- text_embeddings_clip,
261
- clip_guidance_scale,
262
- num_cutouts,
263
- use_cutouts=True,
264
- ):
265
- latents = latents.detach().requires_grad_()
266
-
267
- latent_model_input = self.scheduler.scale_model_input(latents, timestep)
268
-
269
- # predict the noise residual
270
- noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
271
-
272
- if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
273
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
274
- beta_prod_t = 1 - alpha_prod_t
275
- # compute predicted original sample from predicted noise also called
276
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
277
- pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
278
-
279
- fac = torch.sqrt(beta_prod_t)
280
- sample = pred_original_sample * (fac) + latents * (1 - fac)
281
- elif isinstance(self.scheduler, LMSDiscreteScheduler):
282
- sigma = self.scheduler.sigmas[index]
283
- sample = latents - sigma * noise_pred
284
- else:
285
- raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
286
-
287
- sample = 1 / self.vae.config.scaling_factor * sample
288
- image = self.vae.decode(sample).sample
289
- image = (image / 2 + 0.5).clamp(0, 1)
290
-
291
- if use_cutouts:
292
- image = self.make_cutouts(image, num_cutouts)
293
- else:
294
- image = transforms.Resize(self.cut_out_size)(image)
295
- image = self.normalize(image).to(latents.dtype)
296
-
297
- image_embeddings_clip = self.clip_model.get_image_features(image)
298
- image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
299
-
300
- if use_cutouts:
301
- dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
302
- dists = dists.view([num_cutouts, sample.shape[0], -1])
303
- loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
304
- else:
305
- loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
306
-
307
- grads = -torch.autograd.grad(loss, latents)[0]
308
-
309
- if isinstance(self.scheduler, LMSDiscreteScheduler):
310
- latents = latents.detach() + grads * (sigma**2)
311
- noise_pred = noise_pred_original
312
- else:
313
- noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
314
- return noise_pred, latents
315
-
316
- @torch.no_grad()
317
- def __call__(
318
- self,
319
- prompt: Union[str, List[str]],
320
- height: Optional[int] = 512,
321
- width: Optional[int] = 512,
322
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
323
- strength: float = 0.8,
324
- num_inference_steps: Optional[int] = 50,
325
- guidance_scale: Optional[float] = 7.5,
326
- num_images_per_prompt: Optional[int] = 1,
327
- eta: float = 0.0,
328
- clip_guidance_scale: Optional[float] = 100,
329
- clip_prompt: Optional[Union[str, List[str]]] = None,
330
- num_cutouts: Optional[int] = 4,
331
- use_cutouts: Optional[bool] = True,
332
- generator: Optional[torch.Generator] = None,
333
- latents: Optional[torch.FloatTensor] = None,
334
- output_type: Optional[str] = "pil",
335
- return_dict: bool = True,
336
- ):
337
- if isinstance(prompt, str):
338
- batch_size = 1
339
- elif isinstance(prompt, list):
340
- batch_size = len(prompt)
341
- else:
342
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
343
-
344
- if height % 8 != 0 or width % 8 != 0:
345
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
346
-
347
- # get prompt text embeddings
348
- text_input = self.tokenizer(
349
- prompt,
350
- padding="max_length",
351
- max_length=self.tokenizer.model_max_length,
352
- truncation=True,
353
- return_tensors="pt",
354
- )
355
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
356
- # duplicate text embeddings for each generation per prompt
357
- text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
358
-
359
- # set timesteps
360
- accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
361
- extra_set_kwargs = {}
362
- if accepts_offset:
363
- extra_set_kwargs["offset"] = 1
364
-
365
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
366
- # Some schedulers like PNDM have timesteps as arrays
367
- # It's more optimized to move all timesteps to correct device beforehand
368
- self.scheduler.timesteps.to(self.device)
369
-
370
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
371
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
372
-
373
- # Preprocess image
374
- image = preprocess(image, width, height)
375
- latents = self.prepare_latents(
376
- image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator
377
- )
378
-
379
- if clip_guidance_scale > 0:
380
- if clip_prompt is not None:
381
- clip_text_input = self.tokenizer(
382
- clip_prompt,
383
- padding="max_length",
384
- max_length=self.tokenizer.model_max_length,
385
- truncation=True,
386
- return_tensors="pt",
387
- ).input_ids.to(self.device)
388
- else:
389
- clip_text_input = text_input.input_ids.to(self.device)
390
- text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
391
- text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
392
- # duplicate text embeddings clip for each generation per prompt
393
- text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
394
-
395
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
396
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
397
- # corresponds to doing no classifier free guidance.
398
- do_classifier_free_guidance = guidance_scale > 1.0
399
- # get unconditional embeddings for classifier free guidance
400
- if do_classifier_free_guidance:
401
- max_length = text_input.input_ids.shape[-1]
402
- uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
403
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
404
- # duplicate unconditional embeddings for each generation per prompt
405
- uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
406
-
407
- # For classifier free guidance, we need to do two forward passes.
408
- # Here we concatenate the unconditional and text embeddings into a single batch
409
- # to avoid doing two forward passes
410
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
411
-
412
- # get the initial random noise unless the user supplied it
413
-
414
- # Unlike in other pipelines, latents need to be generated in the target device
415
- # for 1-to-1 results reproducibility with the CompVis implementation.
416
- # However this currently doesn't work in `mps`.
417
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
418
- latents_dtype = text_embeddings.dtype
419
- if latents is None:
420
- if self.device.type == "mps":
421
- # randn does not work reproducibly on mps
422
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
423
- self.device
424
- )
425
- else:
426
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
427
- else:
428
- if latents.shape != latents_shape:
429
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
430
- latents = latents.to(self.device)
431
-
432
- # scale the initial noise by the standard deviation required by the scheduler
433
- latents = latents * self.scheduler.init_noise_sigma
434
-
435
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
436
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
437
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
438
- # and should be between [0, 1]
439
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
440
- extra_step_kwargs = {}
441
- if accepts_eta:
442
- extra_step_kwargs["eta"] = eta
443
-
444
- # check if the scheduler accepts generator
445
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
446
- if accepts_generator:
447
- extra_step_kwargs["generator"] = generator
448
-
449
- with self.progress_bar(total=num_inference_steps):
450
- for i, t in enumerate(timesteps):
451
- # expand the latents if we are doing classifier free guidance
452
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
453
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
454
-
455
- # predict the noise residual
456
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
457
-
458
- # perform classifier free guidance
459
- if do_classifier_free_guidance:
460
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
461
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
462
-
463
- # perform clip guidance
464
- if clip_guidance_scale > 0:
465
- text_embeddings_for_guidance = (
466
- text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
467
- )
468
- noise_pred, latents = self.cond_fn(
469
- latents,
470
- t,
471
- i,
472
- text_embeddings_for_guidance,
473
- noise_pred,
474
- text_embeddings_clip,
475
- clip_guidance_scale,
476
- num_cutouts,
477
- use_cutouts,
478
- )
479
-
480
- # compute the previous noisy sample x_t -> x_t-1
481
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
482
-
483
- # scale and decode the image latents with vae
484
- latents = 1 / self.vae.config.scaling_factor * latents
485
- image = self.vae.decode(latents).sample
486
-
487
- image = (image / 2 + 0.5).clamp(0, 1)
488
- image = image.cpu().permute(0, 2, 3, 1).numpy()
489
-
490
- if output_type == "pil":
491
- image = self.numpy_to_pil(image)
492
-
493
- if not return_dict:
494
- return (image, None)
495
-
496
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/composable_stable_diffusion.py DELETED
@@ -1,582 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Callable, List, Optional, Union
17
-
18
- import torch
19
- from packaging import version
20
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21
-
22
- from diffusers import DiffusionPipeline
23
- from diffusers.configuration_utils import FrozenDict
24
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
- from diffusers.schedulers import (
26
- DDIMScheduler,
27
- DPMSolverMultistepScheduler,
28
- EulerAncestralDiscreteScheduler,
29
- EulerDiscreteScheduler,
30
- LMSDiscreteScheduler,
31
- PNDMScheduler,
32
- )
33
- from diffusers.utils import is_accelerate_available
34
-
35
- from ...utils import deprecate, logging
36
- from . import StableDiffusionPipelineOutput
37
- from .safety_checker import StableDiffusionSafetyChecker
38
-
39
-
40
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
-
42
-
43
- class ComposableStableDiffusionPipeline(DiffusionPipeline):
44
- r"""
45
- Pipeline for text-to-image generation using Stable Diffusion.
46
-
47
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
48
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
49
-
50
- Args:
51
- vae ([`AutoencoderKL`]):
52
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
53
- text_encoder ([`CLIPTextModel`]):
54
- Frozen text-encoder. Stable Diffusion uses the text portion of
55
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
56
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
57
- tokenizer (`CLIPTokenizer`):
58
- Tokenizer of class
59
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
60
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
61
- scheduler ([`SchedulerMixin`]):
62
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
63
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
64
- safety_checker ([`StableDiffusionSafetyChecker`]):
65
- Classification module that estimates whether generated images could be considered offensive or harmful.
66
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
67
- feature_extractor ([`CLIPImageProcessor`]):
68
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
69
- """
70
- _optional_components = ["safety_checker", "feature_extractor"]
71
-
72
- def __init__(
73
- self,
74
- vae: AutoencoderKL,
75
- text_encoder: CLIPTextModel,
76
- tokenizer: CLIPTokenizer,
77
- unet: UNet2DConditionModel,
78
- scheduler: Union[
79
- DDIMScheduler,
80
- PNDMScheduler,
81
- LMSDiscreteScheduler,
82
- EulerDiscreteScheduler,
83
- EulerAncestralDiscreteScheduler,
84
- DPMSolverMultistepScheduler,
85
- ],
86
- safety_checker: StableDiffusionSafetyChecker,
87
- feature_extractor: CLIPImageProcessor,
88
- requires_safety_checker: bool = True,
89
- ):
90
- super().__init__()
91
-
92
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
93
- deprecation_message = (
94
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
95
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
96
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
97
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
98
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
99
- " file"
100
- )
101
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
102
- new_config = dict(scheduler.config)
103
- new_config["steps_offset"] = 1
104
- scheduler._internal_dict = FrozenDict(new_config)
105
-
106
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
107
- deprecation_message = (
108
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
109
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
110
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
111
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
112
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
113
- )
114
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
115
- new_config = dict(scheduler.config)
116
- new_config["clip_sample"] = False
117
- scheduler._internal_dict = FrozenDict(new_config)
118
-
119
- if safety_checker is None and requires_safety_checker:
120
- logger.warning(
121
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
122
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
123
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
124
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
125
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
126
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
127
- )
128
-
129
- if safety_checker is not None and feature_extractor is None:
130
- raise ValueError(
131
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
132
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133
- )
134
-
135
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
136
- version.parse(unet.config._diffusers_version).base_version
137
- ) < version.parse("0.9.0.dev0")
138
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
139
- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
140
- deprecation_message = (
141
- "The configuration file of the unet has set the default `sample_size` to smaller than"
142
- " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
143
- " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
144
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
145
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
146
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
147
- " in the config might lead to incorrect results in future versions. If you have downloaded this"
148
- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
149
- " the `unet/config.json` file"
150
- )
151
- deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
152
- new_config = dict(unet.config)
153
- new_config["sample_size"] = 64
154
- unet._internal_dict = FrozenDict(new_config)
155
-
156
- self.register_modules(
157
- vae=vae,
158
- text_encoder=text_encoder,
159
- tokenizer=tokenizer,
160
- unet=unet,
161
- scheduler=scheduler,
162
- safety_checker=safety_checker,
163
- feature_extractor=feature_extractor,
164
- )
165
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
166
- self.register_to_config(requires_safety_checker=requires_safety_checker)
167
-
168
- def enable_vae_slicing(self):
169
- r"""
170
- Enable sliced VAE decoding.
171
-
172
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
173
- steps. This is useful to save some memory and allow larger batch sizes.
174
- """
175
- self.vae.enable_slicing()
176
-
177
- def disable_vae_slicing(self):
178
- r"""
179
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
180
- computing decoding in one step.
181
- """
182
- self.vae.disable_slicing()
183
-
184
- def enable_sequential_cpu_offload(self, gpu_id=0):
185
- r"""
186
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
187
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
188
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
189
- """
190
- if is_accelerate_available():
191
- from accelerate import cpu_offload
192
- else:
193
- raise ImportError("Please install accelerate via `pip install accelerate`")
194
-
195
- device = torch.device(f"cuda:{gpu_id}")
196
-
197
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
- if cpu_offloaded_model is not None:
199
- cpu_offload(cpu_offloaded_model, device)
200
-
201
- if self.safety_checker is not None:
202
- # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
203
- # fix by only offloading self.safety_checker for now
204
- cpu_offload(self.safety_checker.vision_model, device)
205
-
206
- @property
207
- def _execution_device(self):
208
- r"""
209
- Returns the device on which the pipeline's models will be executed. After calling
210
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
211
- hooks.
212
- """
213
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
214
- return self.device
215
- for module in self.unet.modules():
216
- if (
217
- hasattr(module, "_hf_hook")
218
- and hasattr(module._hf_hook, "execution_device")
219
- and module._hf_hook.execution_device is not None
220
- ):
221
- return torch.device(module._hf_hook.execution_device)
222
- return self.device
223
-
224
- def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
225
- r"""
226
- Encodes the prompt into text encoder hidden states.
227
-
228
- Args:
229
- prompt (`str` or `list(int)`):
230
- prompt to be encoded
231
- device: (`torch.device`):
232
- torch device
233
- num_images_per_prompt (`int`):
234
- number of images that should be generated per prompt
235
- do_classifier_free_guidance (`bool`):
236
- whether to use classifier free guidance or not
237
- negative_prompt (`str` or `List[str]`):
238
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
239
- if `guidance_scale` is less than `1`).
240
- """
241
- batch_size = len(prompt) if isinstance(prompt, list) else 1
242
-
243
- text_inputs = self.tokenizer(
244
- prompt,
245
- padding="max_length",
246
- max_length=self.tokenizer.model_max_length,
247
- truncation=True,
248
- return_tensors="pt",
249
- )
250
- text_input_ids = text_inputs.input_ids
251
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252
-
253
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
254
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
255
- logger.warning(
256
- "The following part of your input was truncated because CLIP can only handle sequences up to"
257
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
258
- )
259
-
260
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
261
- attention_mask = text_inputs.attention_mask.to(device)
262
- else:
263
- attention_mask = None
264
-
265
- text_embeddings = self.text_encoder(
266
- text_input_ids.to(device),
267
- attention_mask=attention_mask,
268
- )
269
- text_embeddings = text_embeddings[0]
270
-
271
- # duplicate text embeddings for each generation per prompt, using mps friendly method
272
- bs_embed, seq_len, _ = text_embeddings.shape
273
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
274
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
275
-
276
- # get unconditional embeddings for classifier free guidance
277
- if do_classifier_free_guidance:
278
- uncond_tokens: List[str]
279
- if negative_prompt is None:
280
- uncond_tokens = [""] * batch_size
281
- elif type(prompt) is not type(negative_prompt):
282
- raise TypeError(
283
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
- f" {type(prompt)}."
285
- )
286
- elif isinstance(negative_prompt, str):
287
- uncond_tokens = [negative_prompt]
288
- elif batch_size != len(negative_prompt):
289
- raise ValueError(
290
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292
- " the batch size of `prompt`."
293
- )
294
- else:
295
- uncond_tokens = negative_prompt
296
-
297
- max_length = text_input_ids.shape[-1]
298
- uncond_input = self.tokenizer(
299
- uncond_tokens,
300
- padding="max_length",
301
- max_length=max_length,
302
- truncation=True,
303
- return_tensors="pt",
304
- )
305
-
306
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
307
- attention_mask = uncond_input.attention_mask.to(device)
308
- else:
309
- attention_mask = None
310
-
311
- uncond_embeddings = self.text_encoder(
312
- uncond_input.input_ids.to(device),
313
- attention_mask=attention_mask,
314
- )
315
- uncond_embeddings = uncond_embeddings[0]
316
-
317
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
318
- seq_len = uncond_embeddings.shape[1]
319
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
320
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
321
-
322
- # For classifier free guidance, we need to do two forward passes.
323
- # Here we concatenate the unconditional and text embeddings into a single batch
324
- # to avoid doing two forward passes
325
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
326
-
327
- return text_embeddings
328
-
329
- def run_safety_checker(self, image, device, dtype):
330
- if self.safety_checker is not None:
331
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
332
- image, has_nsfw_concept = self.safety_checker(
333
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
334
- )
335
- else:
336
- has_nsfw_concept = None
337
- return image, has_nsfw_concept
338
-
339
- def decode_latents(self, latents):
340
- latents = 1 / 0.18215 * latents
341
- image = self.vae.decode(latents).sample
342
- image = (image / 2 + 0.5).clamp(0, 1)
343
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
- return image
346
-
347
- def prepare_extra_step_kwargs(self, generator, eta):
348
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
- # and should be between [0, 1]
352
-
353
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
354
- extra_step_kwargs = {}
355
- if accepts_eta:
356
- extra_step_kwargs["eta"] = eta
357
-
358
- # check if the scheduler accepts generator
359
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
360
- if accepts_generator:
361
- extra_step_kwargs["generator"] = generator
362
- return extra_step_kwargs
363
-
364
- def check_inputs(self, prompt, height, width, callback_steps):
365
- if not isinstance(prompt, str) and not isinstance(prompt, list):
366
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
367
-
368
- if height % 8 != 0 or width % 8 != 0:
369
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
370
-
371
- if (callback_steps is None) or (
372
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
373
- ):
374
- raise ValueError(
375
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
376
- f" {type(callback_steps)}."
377
- )
378
-
379
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
380
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
381
- if latents is None:
382
- if device.type == "mps":
383
- # randn does not work reproducibly on mps
384
- latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
385
- else:
386
- latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
387
- else:
388
- if latents.shape != shape:
389
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
390
- latents = latents.to(device)
391
-
392
- # scale the initial noise by the standard deviation required by the scheduler
393
- latents = latents * self.scheduler.init_noise_sigma
394
- return latents
395
-
396
- @torch.no_grad()
397
- def __call__(
398
- self,
399
- prompt: Union[str, List[str]],
400
- height: Optional[int] = None,
401
- width: Optional[int] = None,
402
- num_inference_steps: int = 50,
403
- guidance_scale: float = 7.5,
404
- negative_prompt: Optional[Union[str, List[str]]] = None,
405
- num_images_per_prompt: Optional[int] = 1,
406
- eta: float = 0.0,
407
- generator: Optional[torch.Generator] = None,
408
- latents: Optional[torch.FloatTensor] = None,
409
- output_type: Optional[str] = "pil",
410
- return_dict: bool = True,
411
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
412
- callback_steps: int = 1,
413
- weights: Optional[str] = "",
414
- ):
415
- r"""
416
- Function invoked when calling the pipeline for generation.
417
-
418
- Args:
419
- prompt (`str` or `List[str]`):
420
- The prompt or prompts to guide the image generation.
421
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
422
- The height in pixels of the generated image.
423
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
424
- The width in pixels of the generated image.
425
- num_inference_steps (`int`, *optional*, defaults to 50):
426
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
427
- expense of slower inference.
428
- guidance_scale (`float`, *optional*, defaults to 7.5):
429
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
430
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
431
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
432
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
433
- usually at the expense of lower image quality.
434
- negative_prompt (`str` or `List[str]`, *optional*):
435
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
436
- if `guidance_scale` is less than `1`).
437
- num_images_per_prompt (`int`, *optional*, defaults to 1):
438
- The number of images to generate per prompt.
439
- eta (`float`, *optional*, defaults to 0.0):
440
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441
- [`schedulers.DDIMScheduler`], will be ignored for others.
442
- generator (`torch.Generator`, *optional*):
443
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
444
- deterministic.
445
- latents (`torch.FloatTensor`, *optional*):
446
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
447
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
448
- tensor will ge generated by sampling using the supplied random `generator`.
449
- output_type (`str`, *optional*, defaults to `"pil"`):
450
- The output format of the generate image. Choose between
451
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
452
- return_dict (`bool`, *optional*, defaults to `True`):
453
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
454
- plain tuple.
455
- callback (`Callable`, *optional*):
456
- A function that will be called every `callback_steps` steps during inference. The function will be
457
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
458
- callback_steps (`int`, *optional*, defaults to 1):
459
- The frequency at which the `callback` function will be called. If not specified, the callback will be
460
- called at every step.
461
-
462
- Returns:
463
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
464
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
465
- When returning a tuple, the first element is a list with the generated images, and the second element is a
466
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
467
- (nsfw) content, according to the `safety_checker`.
468
- """
469
- # 0. Default height and width to unet
470
- height = height or self.unet.config.sample_size * self.vae_scale_factor
471
- width = width or self.unet.config.sample_size * self.vae_scale_factor
472
-
473
- # 1. Check inputs. Raise error if not correct
474
- self.check_inputs(prompt, height, width, callback_steps)
475
-
476
- # 2. Define call parameters
477
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
478
- device = self._execution_device
479
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
480
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
481
- # corresponds to doing no classifier free guidance.
482
- do_classifier_free_guidance = guidance_scale > 1.0
483
-
484
- if "|" in prompt:
485
- prompt = [x.strip() for x in prompt.split("|")]
486
- print(f"composing {prompt}...")
487
-
488
- if not weights:
489
- # specify weights for prompts (excluding the unconditional score)
490
- print("using equal positive weights (conjunction) for all prompts...")
491
- weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
492
- else:
493
- # set prompt weight for each
494
- num_prompts = len(prompt) if isinstance(prompt, list) else 1
495
- weights = [float(w.strip()) for w in weights.split("|")]
496
- # guidance scale as the default
497
- if len(weights) < num_prompts:
498
- weights.append(guidance_scale)
499
- else:
500
- weights = weights[:num_prompts]
501
- assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"
502
- weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)
503
- else:
504
- weights = guidance_scale
505
-
506
- # 3. Encode input prompt
507
- text_embeddings = self._encode_prompt(
508
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
509
- )
510
-
511
- # 4. Prepare timesteps
512
- self.scheduler.set_timesteps(num_inference_steps, device=device)
513
- timesteps = self.scheduler.timesteps
514
-
515
- # 5. Prepare latent variables
516
- num_channels_latents = self.unet.in_channels
517
- latents = self.prepare_latents(
518
- batch_size * num_images_per_prompt,
519
- num_channels_latents,
520
- height,
521
- width,
522
- text_embeddings.dtype,
523
- device,
524
- generator,
525
- latents,
526
- )
527
-
528
- # composable diffusion
529
- if isinstance(prompt, list) and batch_size == 1:
530
- # remove extra unconditional embedding
531
- # N = one unconditional embed + conditional embeds
532
- text_embeddings = text_embeddings[len(prompt) - 1 :]
533
-
534
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
535
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
536
-
537
- # 7. Denoising loop
538
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
539
- with self.progress_bar(total=num_inference_steps) as progress_bar:
540
- for i, t in enumerate(timesteps):
541
- # expand the latents if we are doing classifier free guidance
542
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
543
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
544
-
545
- # predict the noise residual
546
- noise_pred = []
547
- for j in range(text_embeddings.shape[0]):
548
- noise_pred.append(
549
- self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample
550
- )
551
- noise_pred = torch.cat(noise_pred, dim=0)
552
-
553
- # perform guidance
554
- if do_classifier_free_guidance:
555
- noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]
556
- noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
557
- dim=0, keepdims=True
558
- )
559
-
560
- # compute the previous noisy sample x_t -> x_t-1
561
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
562
-
563
- # call the callback, if provided
564
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565
- progress_bar.update()
566
- if callback is not None and i % callback_steps == 0:
567
- callback(i, t, latents)
568
-
569
- # 8. Post-processing
570
- image = self.decode_latents(latents)
571
-
572
- # 9. Run safety checker
573
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
574
-
575
- # 10. Convert to PIL
576
- if output_type == "pil":
577
- image = self.numpy_to_pil(image)
578
-
579
- if not return_dict:
580
- return (image, has_nsfw_concept)
581
-
582
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/ddim_noise_comparative_analysis.py DELETED
@@ -1,190 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import List, Optional, Tuple, Union
16
-
17
- import PIL
18
- import torch
19
- from torchvision import transforms
20
-
21
- from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
22
- from diffusers.schedulers import DDIMScheduler
23
- from diffusers.utils import randn_tensor
24
-
25
-
26
- trans = transforms.Compose(
27
- [
28
- transforms.Resize((256, 256)),
29
- transforms.ToTensor(),
30
- transforms.Normalize([0.5], [0.5]),
31
- ]
32
- )
33
-
34
-
35
- def preprocess(image):
36
- if isinstance(image, torch.Tensor):
37
- return image
38
- elif isinstance(image, PIL.Image.Image):
39
- image = [image]
40
-
41
- image = [trans(img.convert("RGB")) for img in image]
42
- image = torch.stack(image)
43
- return image
44
-
45
-
46
- class DDIMNoiseComparativeAnalysisPipeline(DiffusionPipeline):
47
- r"""
48
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
49
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
50
-
51
- Parameters:
52
- unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
53
- scheduler ([`SchedulerMixin`]):
54
- A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
55
- [`DDPMScheduler`], or [`DDIMScheduler`].
56
- """
57
-
58
- def __init__(self, unet, scheduler):
59
- super().__init__()
60
-
61
- # make sure scheduler can always be converted to DDIM
62
- scheduler = DDIMScheduler.from_config(scheduler.config)
63
-
64
- self.register_modules(unet=unet, scheduler=scheduler)
65
-
66
- def check_inputs(self, strength):
67
- if strength < 0 or strength > 1:
68
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
69
-
70
- def get_timesteps(self, num_inference_steps, strength, device):
71
- # get the original timestep using init_timestep
72
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
73
-
74
- t_start = max(num_inference_steps - init_timestep, 0)
75
- timesteps = self.scheduler.timesteps[t_start:]
76
-
77
- return timesteps, num_inference_steps - t_start
78
-
79
- def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):
80
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
81
- raise ValueError(
82
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
83
- )
84
-
85
- init_latents = image.to(device=device, dtype=dtype)
86
-
87
- if isinstance(generator, list) and len(generator) != batch_size:
88
- raise ValueError(
89
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
90
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
91
- )
92
-
93
- shape = init_latents.shape
94
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
95
-
96
- # get latents
97
- print("add noise to latents at timestep", timestep)
98
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
99
- latents = init_latents
100
-
101
- return latents
102
-
103
- @torch.no_grad()
104
- def __call__(
105
- self,
106
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
107
- strength: float = 0.8,
108
- batch_size: int = 1,
109
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
110
- eta: float = 0.0,
111
- num_inference_steps: int = 50,
112
- use_clipped_model_output: Optional[bool] = None,
113
- output_type: Optional[str] = "pil",
114
- return_dict: bool = True,
115
- ) -> Union[ImagePipelineOutput, Tuple]:
116
- r"""
117
- Args:
118
- image (`torch.FloatTensor` or `PIL.Image.Image`):
119
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
120
- process.
121
- strength (`float`, *optional*, defaults to 0.8):
122
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
123
- will be used as a starting point, adding more noise to it the larger the `strength`. The number of
124
- denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
125
- be maximum and the denoising process will run for the full number of iterations specified in
126
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
127
- batch_size (`int`, *optional*, defaults to 1):
128
- The number of images to generate.
129
- generator (`torch.Generator`, *optional*):
130
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
131
- to make generation deterministic.
132
- eta (`float`, *optional*, defaults to 0.0):
133
- The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
134
- num_inference_steps (`int`, *optional*, defaults to 50):
135
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
136
- expense of slower inference.
137
- use_clipped_model_output (`bool`, *optional*, defaults to `None`):
138
- if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
139
- downstream to the scheduler. So use `None` for schedulers which don't support this argument.
140
- output_type (`str`, *optional*, defaults to `"pil"`):
141
- The output format of the generate image. Choose between
142
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
143
- return_dict (`bool`, *optional*, defaults to `True`):
144
- Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
145
-
146
- Returns:
147
- [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
148
- True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
149
- """
150
- # 1. Check inputs. Raise error if not correct
151
- self.check_inputs(strength)
152
-
153
- # 2. Preprocess image
154
- image = preprocess(image)
155
-
156
- # 3. set timesteps
157
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
158
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
159
- latent_timestep = timesteps[:1].repeat(batch_size)
160
-
161
- # 4. Prepare latent variables
162
- latents = self.prepare_latents(image, latent_timestep, batch_size, self.unet.dtype, self.device, generator)
163
- image = latents
164
-
165
- # 5. Denoising loop
166
- for t in self.progress_bar(timesteps):
167
- # 1. predict noise model_output
168
- model_output = self.unet(image, t).sample
169
-
170
- # 2. predict previous mean of image x_t-1 and add variance depending on eta
171
- # eta corresponds to η in paper and should be between [0, 1]
172
- # do x_t -> x_t-1
173
- image = self.scheduler.step(
174
- model_output,
175
- t,
176
- image,
177
- eta=eta,
178
- use_clipped_model_output=use_clipped_model_output,
179
- generator=generator,
180
- ).prev_sample
181
-
182
- image = (image / 2 + 0.5).clamp(0, 1)
183
- image = image.cpu().permute(0, 2, 3, 1).numpy()
184
- if output_type == "pil":
185
- image = self.numpy_to_pil(image)
186
-
187
- if not return_dict:
188
- return (image, latent_timestep.item())
189
-
190
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/imagic_stable_diffusion.py DELETED
@@ -1,496 +0,0 @@
1
- """
2
- modeled after the textual_inversion.py / train_dreambooth.py and the work
3
- of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
4
- """
5
- import inspect
6
- import warnings
7
- from typing import List, Optional, Union
8
-
9
- import numpy as np
10
- import PIL
11
- import torch
12
- import torch.nn.functional as F
13
- from accelerate import Accelerator
14
-
15
- # TODO: remove and import from diffusers.utils when the new version of diffusers is released
16
- from packaging import version
17
- from tqdm.auto import tqdm
18
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
19
-
20
- from diffusers import DiffusionPipeline
21
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
22
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
23
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
24
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
25
- from diffusers.utils import logging
26
-
27
-
28
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
29
- PIL_INTERPOLATION = {
30
- "linear": PIL.Image.Resampling.BILINEAR,
31
- "bilinear": PIL.Image.Resampling.BILINEAR,
32
- "bicubic": PIL.Image.Resampling.BICUBIC,
33
- "lanczos": PIL.Image.Resampling.LANCZOS,
34
- "nearest": PIL.Image.Resampling.NEAREST,
35
- }
36
- else:
37
- PIL_INTERPOLATION = {
38
- "linear": PIL.Image.LINEAR,
39
- "bilinear": PIL.Image.BILINEAR,
40
- "bicubic": PIL.Image.BICUBIC,
41
- "lanczos": PIL.Image.LANCZOS,
42
- "nearest": PIL.Image.NEAREST,
43
- }
44
- # ------------------------------------------------------------------------------
45
-
46
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
-
48
-
49
- def preprocess(image):
50
- w, h = image.size
51
- w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
52
- image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
53
- image = np.array(image).astype(np.float32) / 255.0
54
- image = image[None].transpose(0, 3, 1, 2)
55
- image = torch.from_numpy(image)
56
- return 2.0 * image - 1.0
57
-
58
-
59
- class ImagicStableDiffusionPipeline(DiffusionPipeline):
60
- r"""
61
- Pipeline for imagic image editing.
62
- See paper here: https://arxiv.org/pdf/2210.09276.pdf
63
-
64
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
65
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
66
- Args:
67
- vae ([`AutoencoderKL`]):
68
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
69
- text_encoder ([`CLIPTextModel`]):
70
- Frozen text-encoder. Stable Diffusion uses the text portion of
71
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
72
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
73
- tokenizer (`CLIPTokenizer`):
74
- Tokenizer of class
75
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
76
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
77
- scheduler ([`SchedulerMixin`]):
78
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
79
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
80
- safety_checker ([`StableDiffusionSafetyChecker`]):
81
- Classification module that estimates whether generated images could be considered offsensive or harmful.
82
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
83
- feature_extractor ([`CLIPImageProcessor`]):
84
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
85
- """
86
-
87
- def __init__(
88
- self,
89
- vae: AutoencoderKL,
90
- text_encoder: CLIPTextModel,
91
- tokenizer: CLIPTokenizer,
92
- unet: UNet2DConditionModel,
93
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
94
- safety_checker: StableDiffusionSafetyChecker,
95
- feature_extractor: CLIPImageProcessor,
96
- ):
97
- super().__init__()
98
- self.register_modules(
99
- vae=vae,
100
- text_encoder=text_encoder,
101
- tokenizer=tokenizer,
102
- unet=unet,
103
- scheduler=scheduler,
104
- safety_checker=safety_checker,
105
- feature_extractor=feature_extractor,
106
- )
107
-
108
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
109
- r"""
110
- Enable sliced attention computation.
111
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
112
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
113
- Args:
114
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
115
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
116
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
117
- `attention_head_dim` must be a multiple of `slice_size`.
118
- """
119
- if slice_size == "auto":
120
- # half the attention head size is usually a good trade-off between
121
- # speed and memory
122
- slice_size = self.unet.config.attention_head_dim // 2
123
- self.unet.set_attention_slice(slice_size)
124
-
125
- def disable_attention_slicing(self):
126
- r"""
127
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
128
- back to computing attention in one step.
129
- """
130
- # set slice_size = `None` to disable `attention slicing`
131
- self.enable_attention_slicing(None)
132
-
133
- def train(
134
- self,
135
- prompt: Union[str, List[str]],
136
- image: Union[torch.FloatTensor, PIL.Image.Image],
137
- height: Optional[int] = 512,
138
- width: Optional[int] = 512,
139
- generator: Optional[torch.Generator] = None,
140
- embedding_learning_rate: float = 0.001,
141
- diffusion_model_learning_rate: float = 2e-6,
142
- text_embedding_optimization_steps: int = 500,
143
- model_fine_tuning_optimization_steps: int = 1000,
144
- **kwargs,
145
- ):
146
- r"""
147
- Function invoked when calling the pipeline for generation.
148
- Args:
149
- prompt (`str` or `List[str]`):
150
- The prompt or prompts to guide the image generation.
151
- height (`int`, *optional*, defaults to 512):
152
- The height in pixels of the generated image.
153
- width (`int`, *optional*, defaults to 512):
154
- The width in pixels of the generated image.
155
- num_inference_steps (`int`, *optional*, defaults to 50):
156
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
157
- expense of slower inference.
158
- guidance_scale (`float`, *optional*, defaults to 7.5):
159
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
160
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
161
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
162
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
163
- usually at the expense of lower image quality.
164
- eta (`float`, *optional*, defaults to 0.0):
165
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
166
- [`schedulers.DDIMScheduler`], will be ignored for others.
167
- generator (`torch.Generator`, *optional*):
168
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
169
- deterministic.
170
- latents (`torch.FloatTensor`, *optional*):
171
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
172
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
173
- tensor will ge generated by sampling using the supplied random `generator`.
174
- output_type (`str`, *optional*, defaults to `"pil"`):
175
- The output format of the generate image. Choose between
176
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
177
- return_dict (`bool`, *optional*, defaults to `True`):
178
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
179
- plain tuple.
180
- Returns:
181
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
182
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
183
- When returning a tuple, the first element is a list with the generated images, and the second element is a
184
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
185
- (nsfw) content, according to the `safety_checker`.
186
- """
187
- accelerator = Accelerator(
188
- gradient_accumulation_steps=1,
189
- mixed_precision="fp16",
190
- )
191
-
192
- if "torch_device" in kwargs:
193
- device = kwargs.pop("torch_device")
194
- warnings.warn(
195
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
196
- " Consider using `pipe.to(torch_device)` instead."
197
- )
198
-
199
- if device is None:
200
- device = "cuda" if torch.cuda.is_available() else "cpu"
201
- self.to(device)
202
-
203
- if height % 8 != 0 or width % 8 != 0:
204
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
205
-
206
- # Freeze vae and unet
207
- self.vae.requires_grad_(False)
208
- self.unet.requires_grad_(False)
209
- self.text_encoder.requires_grad_(False)
210
- self.unet.eval()
211
- self.vae.eval()
212
- self.text_encoder.eval()
213
-
214
- if accelerator.is_main_process:
215
- accelerator.init_trackers(
216
- "imagic",
217
- config={
218
- "embedding_learning_rate": embedding_learning_rate,
219
- "text_embedding_optimization_steps": text_embedding_optimization_steps,
220
- },
221
- )
222
-
223
- # get text embeddings for prompt
224
- text_input = self.tokenizer(
225
- prompt,
226
- padding="max_length",
227
- max_length=self.tokenizer.model_max_length,
228
- truncation=True,
229
- return_tensors="pt",
230
- )
231
- text_embeddings = torch.nn.Parameter(
232
- self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
233
- )
234
- text_embeddings = text_embeddings.detach()
235
- text_embeddings.requires_grad_()
236
- text_embeddings_orig = text_embeddings.clone()
237
-
238
- # Initialize the optimizer
239
- optimizer = torch.optim.Adam(
240
- [text_embeddings], # only optimize the embeddings
241
- lr=embedding_learning_rate,
242
- )
243
-
244
- if isinstance(image, PIL.Image.Image):
245
- image = preprocess(image)
246
-
247
- latents_dtype = text_embeddings.dtype
248
- image = image.to(device=self.device, dtype=latents_dtype)
249
- init_latent_image_dist = self.vae.encode(image).latent_dist
250
- image_latents = init_latent_image_dist.sample(generator=generator)
251
- image_latents = 0.18215 * image_latents
252
-
253
- progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
254
- progress_bar.set_description("Steps")
255
-
256
- global_step = 0
257
-
258
- logger.info("First optimizing the text embedding to better reconstruct the init image")
259
- for _ in range(text_embedding_optimization_steps):
260
- with accelerator.accumulate(text_embeddings):
261
- # Sample noise that we'll add to the latents
262
- noise = torch.randn(image_latents.shape).to(image_latents.device)
263
- timesteps = torch.randint(1000, (1,), device=image_latents.device)
264
-
265
- # Add noise to the latents according to the noise magnitude at each timestep
266
- # (this is the forward diffusion process)
267
- noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
268
-
269
- # Predict the noise residual
270
- noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
271
-
272
- loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
273
- accelerator.backward(loss)
274
-
275
- optimizer.step()
276
- optimizer.zero_grad()
277
-
278
- # Checks if the accelerator has performed an optimization step behind the scenes
279
- if accelerator.sync_gradients:
280
- progress_bar.update(1)
281
- global_step += 1
282
-
283
- logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
284
- progress_bar.set_postfix(**logs)
285
- accelerator.log(logs, step=global_step)
286
-
287
- accelerator.wait_for_everyone()
288
-
289
- text_embeddings.requires_grad_(False)
290
-
291
- # Now we fine tune the unet to better reconstruct the image
292
- self.unet.requires_grad_(True)
293
- self.unet.train()
294
- optimizer = torch.optim.Adam(
295
- self.unet.parameters(), # only optimize unet
296
- lr=diffusion_model_learning_rate,
297
- )
298
- progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
299
-
300
- logger.info("Next fine tuning the entire model to better reconstruct the init image")
301
- for _ in range(model_fine_tuning_optimization_steps):
302
- with accelerator.accumulate(self.unet.parameters()):
303
- # Sample noise that we'll add to the latents
304
- noise = torch.randn(image_latents.shape).to(image_latents.device)
305
- timesteps = torch.randint(1000, (1,), device=image_latents.device)
306
-
307
- # Add noise to the latents according to the noise magnitude at each timestep
308
- # (this is the forward diffusion process)
309
- noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
310
-
311
- # Predict the noise residual
312
- noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
313
-
314
- loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
315
- accelerator.backward(loss)
316
-
317
- optimizer.step()
318
- optimizer.zero_grad()
319
-
320
- # Checks if the accelerator has performed an optimization step behind the scenes
321
- if accelerator.sync_gradients:
322
- progress_bar.update(1)
323
- global_step += 1
324
-
325
- logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
326
- progress_bar.set_postfix(**logs)
327
- accelerator.log(logs, step=global_step)
328
-
329
- accelerator.wait_for_everyone()
330
- self.text_embeddings_orig = text_embeddings_orig
331
- self.text_embeddings = text_embeddings
332
-
333
- @torch.no_grad()
334
- def __call__(
335
- self,
336
- alpha: float = 1.2,
337
- height: Optional[int] = 512,
338
- width: Optional[int] = 512,
339
- num_inference_steps: Optional[int] = 50,
340
- generator: Optional[torch.Generator] = None,
341
- output_type: Optional[str] = "pil",
342
- return_dict: bool = True,
343
- guidance_scale: float = 7.5,
344
- eta: float = 0.0,
345
- ):
346
- r"""
347
- Function invoked when calling the pipeline for generation.
348
- Args:
349
- prompt (`str` or `List[str]`):
350
- The prompt or prompts to guide the image generation.
351
- height (`int`, *optional*, defaults to 512):
352
- The height in pixels of the generated image.
353
- width (`int`, *optional*, defaults to 512):
354
- The width in pixels of the generated image.
355
- num_inference_steps (`int`, *optional*, defaults to 50):
356
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
357
- expense of slower inference.
358
- guidance_scale (`float`, *optional*, defaults to 7.5):
359
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
360
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
361
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
362
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
363
- usually at the expense of lower image quality.
364
- eta (`float`, *optional*, defaults to 0.0):
365
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
366
- [`schedulers.DDIMScheduler`], will be ignored for others.
367
- generator (`torch.Generator`, *optional*):
368
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
369
- deterministic.
370
- latents (`torch.FloatTensor`, *optional*):
371
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
372
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
373
- tensor will ge generated by sampling using the supplied random `generator`.
374
- output_type (`str`, *optional*, defaults to `"pil"`):
375
- The output format of the generate image. Choose between
376
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
377
- return_dict (`bool`, *optional*, defaults to `True`):
378
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
379
- plain tuple.
380
- Returns:
381
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
382
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
383
- When returning a tuple, the first element is a list with the generated images, and the second element is a
384
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
385
- (nsfw) content, according to the `safety_checker`.
386
- """
387
- if height % 8 != 0 or width % 8 != 0:
388
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
389
- if self.text_embeddings is None:
390
- raise ValueError("Please run the pipe.train() before trying to generate an image.")
391
- if self.text_embeddings_orig is None:
392
- raise ValueError("Please run the pipe.train() before trying to generate an image.")
393
-
394
- text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
395
-
396
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
397
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
398
- # corresponds to doing no classifier free guidance.
399
- do_classifier_free_guidance = guidance_scale > 1.0
400
- # get unconditional embeddings for classifier free guidance
401
- if do_classifier_free_guidance:
402
- uncond_tokens = [""]
403
- max_length = self.tokenizer.model_max_length
404
- uncond_input = self.tokenizer(
405
- uncond_tokens,
406
- padding="max_length",
407
- max_length=max_length,
408
- truncation=True,
409
- return_tensors="pt",
410
- )
411
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
412
-
413
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
414
- seq_len = uncond_embeddings.shape[1]
415
- uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
416
-
417
- # For classifier free guidance, we need to do two forward passes.
418
- # Here we concatenate the unconditional and text embeddings into a single batch
419
- # to avoid doing two forward passes
420
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
421
-
422
- # get the initial random noise unless the user supplied it
423
-
424
- # Unlike in other pipelines, latents need to be generated in the target device
425
- # for 1-to-1 results reproducibility with the CompVis implementation.
426
- # However this currently doesn't work in `mps`.
427
- latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
428
- latents_dtype = text_embeddings.dtype
429
- if self.device.type == "mps":
430
- # randn does not exist on mps
431
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
432
- self.device
433
- )
434
- else:
435
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
436
-
437
- # set timesteps
438
- self.scheduler.set_timesteps(num_inference_steps)
439
-
440
- # Some schedulers like PNDM have timesteps as arrays
441
- # It's more optimized to move all timesteps to correct device beforehand
442
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
443
-
444
- # scale the initial noise by the standard deviation required by the scheduler
445
- latents = latents * self.scheduler.init_noise_sigma
446
-
447
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
448
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
449
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
450
- # and should be between [0, 1]
451
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
452
- extra_step_kwargs = {}
453
- if accepts_eta:
454
- extra_step_kwargs["eta"] = eta
455
-
456
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
457
- # expand the latents if we are doing classifier free guidance
458
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
459
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
460
-
461
- # predict the noise residual
462
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
463
-
464
- # perform guidance
465
- if do_classifier_free_guidance:
466
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
467
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
468
-
469
- # compute the previous noisy sample x_t -> x_t-1
470
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
471
-
472
- latents = 1 / 0.18215 * latents
473
- image = self.vae.decode(latents).sample
474
-
475
- image = (image / 2 + 0.5).clamp(0, 1)
476
-
477
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
478
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
479
-
480
- if self.safety_checker is not None:
481
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
482
- self.device
483
- )
484
- image, has_nsfw_concept = self.safety_checker(
485
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
486
- )
487
- else:
488
- has_nsfw_concept = None
489
-
490
- if output_type == "pil":
491
- image = self.numpy_to_pil(image)
492
-
493
- if not return_dict:
494
- return (image, has_nsfw_concept)
495
-
496
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/img2img_inpainting.py DELETED
@@ -1,463 +0,0 @@
1
- import inspect
2
- from typing import Callable, List, Optional, Tuple, Union
3
-
4
- import numpy as np
5
- import PIL
6
- import torch
7
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
8
-
9
- from diffusers import DiffusionPipeline
10
- from diffusers.configuration_utils import FrozenDict
11
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
- from diffusers.utils import deprecate, logging
16
-
17
-
18
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
-
20
-
21
- def prepare_mask_and_masked_image(image, mask):
22
- image = np.array(image.convert("RGB"))
23
- image = image[None].transpose(0, 3, 1, 2)
24
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
25
-
26
- mask = np.array(mask.convert("L"))
27
- mask = mask.astype(np.float32) / 255.0
28
- mask = mask[None, None]
29
- mask[mask < 0.5] = 0
30
- mask[mask >= 0.5] = 1
31
- mask = torch.from_numpy(mask)
32
-
33
- masked_image = image * (mask < 0.5)
34
-
35
- return mask, masked_image
36
-
37
-
38
- def check_size(image, height, width):
39
- if isinstance(image, PIL.Image.Image):
40
- w, h = image.size
41
- elif isinstance(image, torch.Tensor):
42
- *_, h, w = image.shape
43
-
44
- if h != height or w != width:
45
- raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
46
-
47
-
48
- def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
49
- inner_image = inner_image.convert("RGBA")
50
- image = image.convert("RGB")
51
-
52
- image.paste(inner_image, paste_offset, inner_image)
53
- image = image.convert("RGB")
54
-
55
- return image
56
-
57
-
58
- class ImageToImageInpaintingPipeline(DiffusionPipeline):
59
- r"""
60
- Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.
61
-
62
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
63
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
64
-
65
- Args:
66
- vae ([`AutoencoderKL`]):
67
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68
- text_encoder ([`CLIPTextModel`]):
69
- Frozen text-encoder. Stable Diffusion uses the text portion of
70
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
71
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
72
- tokenizer (`CLIPTokenizer`):
73
- Tokenizer of class
74
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
75
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
76
- scheduler ([`SchedulerMixin`]):
77
- A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
78
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
79
- safety_checker ([`StableDiffusionSafetyChecker`]):
80
- Classification module that estimates whether generated images could be considered offensive or harmful.
81
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
82
- feature_extractor ([`CLIPImageProcessor`]):
83
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
84
- """
85
-
86
- def __init__(
87
- self,
88
- vae: AutoencoderKL,
89
- text_encoder: CLIPTextModel,
90
- tokenizer: CLIPTokenizer,
91
- unet: UNet2DConditionModel,
92
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
93
- safety_checker: StableDiffusionSafetyChecker,
94
- feature_extractor: CLIPImageProcessor,
95
- ):
96
- super().__init__()
97
-
98
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
99
- deprecation_message = (
100
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
101
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
102
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
103
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
104
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
105
- " file"
106
- )
107
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
108
- new_config = dict(scheduler.config)
109
- new_config["steps_offset"] = 1
110
- scheduler._internal_dict = FrozenDict(new_config)
111
-
112
- if safety_checker is None:
113
- logger.warning(
114
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
115
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
116
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
117
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
118
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
119
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
120
- )
121
-
122
- self.register_modules(
123
- vae=vae,
124
- text_encoder=text_encoder,
125
- tokenizer=tokenizer,
126
- unet=unet,
127
- scheduler=scheduler,
128
- safety_checker=safety_checker,
129
- feature_extractor=feature_extractor,
130
- )
131
-
132
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
133
- r"""
134
- Enable sliced attention computation.
135
-
136
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
137
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
138
-
139
- Args:
140
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
141
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
142
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
143
- `attention_head_dim` must be a multiple of `slice_size`.
144
- """
145
- if slice_size == "auto":
146
- # half the attention head size is usually a good trade-off between
147
- # speed and memory
148
- slice_size = self.unet.config.attention_head_dim // 2
149
- self.unet.set_attention_slice(slice_size)
150
-
151
- def disable_attention_slicing(self):
152
- r"""
153
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
154
- back to computing attention in one step.
155
- """
156
- # set slice_size = `None` to disable `attention slicing`
157
- self.enable_attention_slicing(None)
158
-
159
- @torch.no_grad()
160
- def __call__(
161
- self,
162
- prompt: Union[str, List[str]],
163
- image: Union[torch.FloatTensor, PIL.Image.Image],
164
- inner_image: Union[torch.FloatTensor, PIL.Image.Image],
165
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
166
- height: int = 512,
167
- width: int = 512,
168
- num_inference_steps: int = 50,
169
- guidance_scale: float = 7.5,
170
- negative_prompt: Optional[Union[str, List[str]]] = None,
171
- num_images_per_prompt: Optional[int] = 1,
172
- eta: float = 0.0,
173
- generator: Optional[torch.Generator] = None,
174
- latents: Optional[torch.FloatTensor] = None,
175
- output_type: Optional[str] = "pil",
176
- return_dict: bool = True,
177
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
178
- callback_steps: int = 1,
179
- **kwargs,
180
- ):
181
- r"""
182
- Function invoked when calling the pipeline for generation.
183
-
184
- Args:
185
- prompt (`str` or `List[str]`):
186
- The prompt or prompts to guide the image generation.
187
- image (`torch.Tensor` or `PIL.Image.Image`):
188
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
189
- be masked out with `mask_image` and repainted according to `prompt`.
190
- inner_image (`torch.Tensor` or `PIL.Image.Image`):
191
- `Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent
192
- regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with
193
- the last channel representing the alpha channel, which will be used to blend `inner_image` with
194
- `image`. If not provided, it will be forcibly cast to RGBA.
195
- mask_image (`PIL.Image.Image`):
196
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
197
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
198
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
199
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
200
- height (`int`, *optional*, defaults to 512):
201
- The height in pixels of the generated image.
202
- width (`int`, *optional*, defaults to 512):
203
- The width in pixels of the generated image.
204
- num_inference_steps (`int`, *optional*, defaults to 50):
205
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
206
- expense of slower inference.
207
- guidance_scale (`float`, *optional*, defaults to 7.5):
208
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
209
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
210
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
211
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
212
- usually at the expense of lower image quality.
213
- negative_prompt (`str` or `List[str]`, *optional*):
214
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
215
- if `guidance_scale` is less than `1`).
216
- num_images_per_prompt (`int`, *optional*, defaults to 1):
217
- The number of images to generate per prompt.
218
- eta (`float`, *optional*, defaults to 0.0):
219
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
220
- [`schedulers.DDIMScheduler`], will be ignored for others.
221
- generator (`torch.Generator`, *optional*):
222
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
223
- deterministic.
224
- latents (`torch.FloatTensor`, *optional*):
225
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
226
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
227
- tensor will ge generated by sampling using the supplied random `generator`.
228
- output_type (`str`, *optional*, defaults to `"pil"`):
229
- The output format of the generate image. Choose between
230
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
231
- return_dict (`bool`, *optional*, defaults to `True`):
232
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
233
- plain tuple.
234
- callback (`Callable`, *optional*):
235
- A function that will be called every `callback_steps` steps during inference. The function will be
236
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
237
- callback_steps (`int`, *optional*, defaults to 1):
238
- The frequency at which the `callback` function will be called. If not specified, the callback will be
239
- called at every step.
240
-
241
- Returns:
242
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
243
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
244
- When returning a tuple, the first element is a list with the generated images, and the second element is a
245
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
246
- (nsfw) content, according to the `safety_checker`.
247
- """
248
-
249
- if isinstance(prompt, str):
250
- batch_size = 1
251
- elif isinstance(prompt, list):
252
- batch_size = len(prompt)
253
- else:
254
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
255
-
256
- if height % 8 != 0 or width % 8 != 0:
257
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
258
-
259
- if (callback_steps is None) or (
260
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
261
- ):
262
- raise ValueError(
263
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
264
- f" {type(callback_steps)}."
265
- )
266
-
267
- # check if input sizes are correct
268
- check_size(image, height, width)
269
- check_size(inner_image, height, width)
270
- check_size(mask_image, height, width)
271
-
272
- # get prompt text embeddings
273
- text_inputs = self.tokenizer(
274
- prompt,
275
- padding="max_length",
276
- max_length=self.tokenizer.model_max_length,
277
- return_tensors="pt",
278
- )
279
- text_input_ids = text_inputs.input_ids
280
-
281
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
282
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
283
- logger.warning(
284
- "The following part of your input was truncated because CLIP can only handle sequences up to"
285
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
286
- )
287
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
288
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
289
-
290
- # duplicate text embeddings for each generation per prompt, using mps friendly method
291
- bs_embed, seq_len, _ = text_embeddings.shape
292
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
293
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
294
-
295
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
296
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
297
- # corresponds to doing no classifier free guidance.
298
- do_classifier_free_guidance = guidance_scale > 1.0
299
- # get unconditional embeddings for classifier free guidance
300
- if do_classifier_free_guidance:
301
- uncond_tokens: List[str]
302
- if negative_prompt is None:
303
- uncond_tokens = [""]
304
- elif type(prompt) is not type(negative_prompt):
305
- raise TypeError(
306
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
307
- f" {type(prompt)}."
308
- )
309
- elif isinstance(negative_prompt, str):
310
- uncond_tokens = [negative_prompt]
311
- elif batch_size != len(negative_prompt):
312
- raise ValueError(
313
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
314
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
315
- " the batch size of `prompt`."
316
- )
317
- else:
318
- uncond_tokens = negative_prompt
319
-
320
- max_length = text_input_ids.shape[-1]
321
- uncond_input = self.tokenizer(
322
- uncond_tokens,
323
- padding="max_length",
324
- max_length=max_length,
325
- truncation=True,
326
- return_tensors="pt",
327
- )
328
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
329
-
330
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
- seq_len = uncond_embeddings.shape[1]
332
- uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
333
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
334
-
335
- # For classifier free guidance, we need to do two forward passes.
336
- # Here we concatenate the unconditional and text embeddings into a single batch
337
- # to avoid doing two forward passes
338
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
339
-
340
- # get the initial random noise unless the user supplied it
341
- # Unlike in other pipelines, latents need to be generated in the target device
342
- # for 1-to-1 results reproducibility with the CompVis implementation.
343
- # However this currently doesn't work in `mps`.
344
- num_channels_latents = self.vae.config.latent_channels
345
- latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
346
- latents_dtype = text_embeddings.dtype
347
- if latents is None:
348
- if self.device.type == "mps":
349
- # randn does not exist on mps
350
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
351
- self.device
352
- )
353
- else:
354
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
355
- else:
356
- if latents.shape != latents_shape:
357
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
358
- latents = latents.to(self.device)
359
-
360
- # overlay the inner image
361
- image = overlay_inner_image(image, inner_image)
362
-
363
- # prepare mask and masked_image
364
- mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
365
- mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
366
- masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
367
-
368
- # resize the mask to latents shape as we concatenate the mask to the latents
369
- mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
370
-
371
- # encode the mask image into latents space so we can concatenate it to the latents
372
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
373
- masked_image_latents = 0.18215 * masked_image_latents
374
-
375
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
376
- mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
377
- masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
378
-
379
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
380
- masked_image_latents = (
381
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
382
- )
383
-
384
- num_channels_mask = mask.shape[1]
385
- num_channels_masked_image = masked_image_latents.shape[1]
386
-
387
- if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
388
- raise ValueError(
389
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
390
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
391
- f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
392
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
393
- " `pipeline.unet` or your `mask_image` or `image` input."
394
- )
395
-
396
- # set timesteps
397
- self.scheduler.set_timesteps(num_inference_steps)
398
-
399
- # Some schedulers like PNDM have timesteps as arrays
400
- # It's more optimized to move all timesteps to correct device beforehand
401
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
402
-
403
- # scale the initial noise by the standard deviation required by the scheduler
404
- latents = latents * self.scheduler.init_noise_sigma
405
-
406
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
407
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
408
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
409
- # and should be between [0, 1]
410
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
411
- extra_step_kwargs = {}
412
- if accepts_eta:
413
- extra_step_kwargs["eta"] = eta
414
-
415
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
416
- # expand the latents if we are doing classifier free guidance
417
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
418
-
419
- # concat latents, mask, masked_image_latents in the channel dimension
420
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
421
-
422
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
423
-
424
- # predict the noise residual
425
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
426
-
427
- # perform guidance
428
- if do_classifier_free_guidance:
429
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
430
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
431
-
432
- # compute the previous noisy sample x_t -> x_t-1
433
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
434
-
435
- # call the callback, if provided
436
- if callback is not None and i % callback_steps == 0:
437
- callback(i, t, latents)
438
-
439
- latents = 1 / 0.18215 * latents
440
- image = self.vae.decode(latents).sample
441
-
442
- image = (image / 2 + 0.5).clamp(0, 1)
443
-
444
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
445
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
446
-
447
- if self.safety_checker is not None:
448
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
449
- self.device
450
- )
451
- image, has_nsfw_concept = self.safety_checker(
452
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
453
- )
454
- else:
455
- has_nsfw_concept = None
456
-
457
- if output_type == "pil":
458
- image = self.numpy_to_pil(image)
459
-
460
- if not return_dict:
461
- return (image, has_nsfw_concept)
462
-
463
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/interpolate_stable_diffusion.py DELETED
@@ -1,524 +0,0 @@
1
- import inspect
2
- import time
3
- from pathlib import Path
4
- from typing import Callable, List, Optional, Union
5
-
6
- import numpy as np
7
- import torch
8
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
-
10
- from diffusers import DiffusionPipeline
11
- from diffusers.configuration_utils import FrozenDict
12
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
13
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
15
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
16
- from diffusers.utils import deprecate, logging
17
-
18
-
19
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
-
21
-
22
- def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
23
- """helper function to spherically interpolate two arrays v1 v2"""
24
-
25
- if not isinstance(v0, np.ndarray):
26
- inputs_are_torch = True
27
- input_device = v0.device
28
- v0 = v0.cpu().numpy()
29
- v1 = v1.cpu().numpy()
30
-
31
- dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
32
- if np.abs(dot) > DOT_THRESHOLD:
33
- v2 = (1 - t) * v0 + t * v1
34
- else:
35
- theta_0 = np.arccos(dot)
36
- sin_theta_0 = np.sin(theta_0)
37
- theta_t = theta_0 * t
38
- sin_theta_t = np.sin(theta_t)
39
- s0 = np.sin(theta_0 - theta_t) / sin_theta_0
40
- s1 = sin_theta_t / sin_theta_0
41
- v2 = s0 * v0 + s1 * v1
42
-
43
- if inputs_are_torch:
44
- v2 = torch.from_numpy(v2).to(input_device)
45
-
46
- return v2
47
-
48
-
49
- class StableDiffusionWalkPipeline(DiffusionPipeline):
50
- r"""
51
- Pipeline for text-to-image generation using Stable Diffusion.
52
-
53
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
54
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
55
-
56
- Args:
57
- vae ([`AutoencoderKL`]):
58
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
59
- text_encoder ([`CLIPTextModel`]):
60
- Frozen text-encoder. Stable Diffusion uses the text portion of
61
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
62
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
63
- tokenizer (`CLIPTokenizer`):
64
- Tokenizer of class
65
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
66
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
67
- scheduler ([`SchedulerMixin`]):
68
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
69
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
70
- safety_checker ([`StableDiffusionSafetyChecker`]):
71
- Classification module that estimates whether generated images could be considered offensive or harmful.
72
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
73
- feature_extractor ([`CLIPImageProcessor`]):
74
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
75
- """
76
-
77
- def __init__(
78
- self,
79
- vae: AutoencoderKL,
80
- text_encoder: CLIPTextModel,
81
- tokenizer: CLIPTokenizer,
82
- unet: UNet2DConditionModel,
83
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
84
- safety_checker: StableDiffusionSafetyChecker,
85
- feature_extractor: CLIPImageProcessor,
86
- ):
87
- super().__init__()
88
-
89
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
90
- deprecation_message = (
91
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
92
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
93
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
94
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
95
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
96
- " file"
97
- )
98
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
99
- new_config = dict(scheduler.config)
100
- new_config["steps_offset"] = 1
101
- scheduler._internal_dict = FrozenDict(new_config)
102
-
103
- if safety_checker is None:
104
- logger.warning(
105
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
106
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
107
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
108
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
109
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
110
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
111
- )
112
-
113
- self.register_modules(
114
- vae=vae,
115
- text_encoder=text_encoder,
116
- tokenizer=tokenizer,
117
- unet=unet,
118
- scheduler=scheduler,
119
- safety_checker=safety_checker,
120
- feature_extractor=feature_extractor,
121
- )
122
-
123
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
124
- r"""
125
- Enable sliced attention computation.
126
-
127
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
128
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
129
-
130
- Args:
131
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
132
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
133
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
134
- `attention_head_dim` must be a multiple of `slice_size`.
135
- """
136
- if slice_size == "auto":
137
- # half the attention head size is usually a good trade-off between
138
- # speed and memory
139
- slice_size = self.unet.config.attention_head_dim // 2
140
- self.unet.set_attention_slice(slice_size)
141
-
142
- def disable_attention_slicing(self):
143
- r"""
144
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
145
- back to computing attention in one step.
146
- """
147
- # set slice_size = `None` to disable `attention slicing`
148
- self.enable_attention_slicing(None)
149
-
150
- @torch.no_grad()
151
- def __call__(
152
- self,
153
- prompt: Optional[Union[str, List[str]]] = None,
154
- height: int = 512,
155
- width: int = 512,
156
- num_inference_steps: int = 50,
157
- guidance_scale: float = 7.5,
158
- negative_prompt: Optional[Union[str, List[str]]] = None,
159
- num_images_per_prompt: Optional[int] = 1,
160
- eta: float = 0.0,
161
- generator: Optional[torch.Generator] = None,
162
- latents: Optional[torch.FloatTensor] = None,
163
- output_type: Optional[str] = "pil",
164
- return_dict: bool = True,
165
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
166
- callback_steps: int = 1,
167
- text_embeddings: Optional[torch.FloatTensor] = None,
168
- **kwargs,
169
- ):
170
- r"""
171
- Function invoked when calling the pipeline for generation.
172
-
173
- Args:
174
- prompt (`str` or `List[str]`, *optional*, defaults to `None`):
175
- The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
176
- height (`int`, *optional*, defaults to 512):
177
- The height in pixels of the generated image.
178
- width (`int`, *optional*, defaults to 512):
179
- The width in pixels of the generated image.
180
- num_inference_steps (`int`, *optional*, defaults to 50):
181
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
182
- expense of slower inference.
183
- guidance_scale (`float`, *optional*, defaults to 7.5):
184
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
185
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
186
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
187
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
188
- usually at the expense of lower image quality.
189
- negative_prompt (`str` or `List[str]`, *optional*):
190
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
191
- if `guidance_scale` is less than `1`).
192
- num_images_per_prompt (`int`, *optional*, defaults to 1):
193
- The number of images to generate per prompt.
194
- eta (`float`, *optional*, defaults to 0.0):
195
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
196
- [`schedulers.DDIMScheduler`], will be ignored for others.
197
- generator (`torch.Generator`, *optional*):
198
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
199
- deterministic.
200
- latents (`torch.FloatTensor`, *optional*):
201
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
202
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
203
- tensor will ge generated by sampling using the supplied random `generator`.
204
- output_type (`str`, *optional*, defaults to `"pil"`):
205
- The output format of the generate image. Choose between
206
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
207
- return_dict (`bool`, *optional*, defaults to `True`):
208
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
209
- plain tuple.
210
- callback (`Callable`, *optional*):
211
- A function that will be called every `callback_steps` steps during inference. The function will be
212
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
213
- callback_steps (`int`, *optional*, defaults to 1):
214
- The frequency at which the `callback` function will be called. If not specified, the callback will be
215
- called at every step.
216
- text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
217
- Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
218
- `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
219
- the supplied `prompt`.
220
-
221
- Returns:
222
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
223
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
224
- When returning a tuple, the first element is a list with the generated images, and the second element is a
225
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
226
- (nsfw) content, according to the `safety_checker`.
227
- """
228
-
229
- if height % 8 != 0 or width % 8 != 0:
230
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
231
-
232
- if (callback_steps is None) or (
233
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
234
- ):
235
- raise ValueError(
236
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
237
- f" {type(callback_steps)}."
238
- )
239
-
240
- if text_embeddings is None:
241
- if isinstance(prompt, str):
242
- batch_size = 1
243
- elif isinstance(prompt, list):
244
- batch_size = len(prompt)
245
- else:
246
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
247
-
248
- # get prompt text embeddings
249
- text_inputs = self.tokenizer(
250
- prompt,
251
- padding="max_length",
252
- max_length=self.tokenizer.model_max_length,
253
- return_tensors="pt",
254
- )
255
- text_input_ids = text_inputs.input_ids
256
-
257
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
258
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
259
- print(
260
- "The following part of your input was truncated because CLIP can only handle sequences up to"
261
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
262
- )
263
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
264
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
265
- else:
266
- batch_size = text_embeddings.shape[0]
267
-
268
- # duplicate text embeddings for each generation per prompt, using mps friendly method
269
- bs_embed, seq_len, _ = text_embeddings.shape
270
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
271
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
272
-
273
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
274
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
275
- # corresponds to doing no classifier free guidance.
276
- do_classifier_free_guidance = guidance_scale > 1.0
277
- # get unconditional embeddings for classifier free guidance
278
- if do_classifier_free_guidance:
279
- uncond_tokens: List[str]
280
- if negative_prompt is None:
281
- uncond_tokens = [""] * batch_size
282
- elif type(prompt) is not type(negative_prompt):
283
- raise TypeError(
284
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
285
- f" {type(prompt)}."
286
- )
287
- elif isinstance(negative_prompt, str):
288
- uncond_tokens = [negative_prompt]
289
- elif batch_size != len(negative_prompt):
290
- raise ValueError(
291
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
292
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
293
- " the batch size of `prompt`."
294
- )
295
- else:
296
- uncond_tokens = negative_prompt
297
-
298
- max_length = self.tokenizer.model_max_length
299
- uncond_input = self.tokenizer(
300
- uncond_tokens,
301
- padding="max_length",
302
- max_length=max_length,
303
- truncation=True,
304
- return_tensors="pt",
305
- )
306
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
307
-
308
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
309
- seq_len = uncond_embeddings.shape[1]
310
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
311
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
312
-
313
- # For classifier free guidance, we need to do two forward passes.
314
- # Here we concatenate the unconditional and text embeddings into a single batch
315
- # to avoid doing two forward passes
316
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
317
-
318
- # get the initial random noise unless the user supplied it
319
-
320
- # Unlike in other pipelines, latents need to be generated in the target device
321
- # for 1-to-1 results reproducibility with the CompVis implementation.
322
- # However this currently doesn't work in `mps`.
323
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
324
- latents_dtype = text_embeddings.dtype
325
- if latents is None:
326
- if self.device.type == "mps":
327
- # randn does not work reproducibly on mps
328
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
329
- self.device
330
- )
331
- else:
332
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
333
- else:
334
- if latents.shape != latents_shape:
335
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
336
- latents = latents.to(self.device)
337
-
338
- # set timesteps
339
- self.scheduler.set_timesteps(num_inference_steps)
340
-
341
- # Some schedulers like PNDM have timesteps as arrays
342
- # It's more optimized to move all timesteps to correct device beforehand
343
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
344
-
345
- # scale the initial noise by the standard deviation required by the scheduler
346
- latents = latents * self.scheduler.init_noise_sigma
347
-
348
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
- # and should be between [0, 1]
352
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
353
- extra_step_kwargs = {}
354
- if accepts_eta:
355
- extra_step_kwargs["eta"] = eta
356
-
357
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
358
- # expand the latents if we are doing classifier free guidance
359
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
360
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
361
-
362
- # predict the noise residual
363
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
364
-
365
- # perform guidance
366
- if do_classifier_free_guidance:
367
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
368
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
369
-
370
- # compute the previous noisy sample x_t -> x_t-1
371
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
372
-
373
- # call the callback, if provided
374
- if callback is not None and i % callback_steps == 0:
375
- callback(i, t, latents)
376
-
377
- latents = 1 / 0.18215 * latents
378
- image = self.vae.decode(latents).sample
379
-
380
- image = (image / 2 + 0.5).clamp(0, 1)
381
-
382
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
383
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
384
-
385
- if self.safety_checker is not None:
386
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
387
- self.device
388
- )
389
- image, has_nsfw_concept = self.safety_checker(
390
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
391
- )
392
- else:
393
- has_nsfw_concept = None
394
-
395
- if output_type == "pil":
396
- image = self.numpy_to_pil(image)
397
-
398
- if not return_dict:
399
- return (image, has_nsfw_concept)
400
-
401
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
402
-
403
- def embed_text(self, text):
404
- """takes in text and turns it into text embeddings"""
405
- text_input = self.tokenizer(
406
- text,
407
- padding="max_length",
408
- max_length=self.tokenizer.model_max_length,
409
- truncation=True,
410
- return_tensors="pt",
411
- )
412
- with torch.no_grad():
413
- embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
414
- return embed
415
-
416
- def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
417
- """Takes in random seed and returns corresponding noise vector"""
418
- return torch.randn(
419
- (1, self.unet.in_channels, height // 8, width // 8),
420
- generator=torch.Generator(device=self.device).manual_seed(seed),
421
- device=self.device,
422
- dtype=dtype,
423
- )
424
-
425
- def walk(
426
- self,
427
- prompts: List[str],
428
- seeds: List[int],
429
- num_interpolation_steps: Optional[int] = 6,
430
- output_dir: Optional[str] = "./dreams",
431
- name: Optional[str] = None,
432
- batch_size: Optional[int] = 1,
433
- height: Optional[int] = 512,
434
- width: Optional[int] = 512,
435
- guidance_scale: Optional[float] = 7.5,
436
- num_inference_steps: Optional[int] = 50,
437
- eta: Optional[float] = 0.0,
438
- ) -> List[str]:
439
- """
440
- Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.
441
-
442
- Args:
443
- prompts (`List[str]`):
444
- List of prompts to generate images for.
445
- seeds (`List[int]`):
446
- List of seeds corresponding to provided prompts. Must be the same length as prompts.
447
- num_interpolation_steps (`int`, *optional*, defaults to 6):
448
- Number of interpolation steps to take between prompts.
449
- output_dir (`str`, *optional*, defaults to `./dreams`):
450
- Directory to save the generated images to.
451
- name (`str`, *optional*, defaults to `None`):
452
- Subdirectory of `output_dir` to save the generated images to. If `None`, the name will
453
- be the current time.
454
- batch_size (`int`, *optional*, defaults to 1):
455
- Number of images to generate at once.
456
- height (`int`, *optional*, defaults to 512):
457
- Height of the generated images.
458
- width (`int`, *optional*, defaults to 512):
459
- Width of the generated images.
460
- guidance_scale (`float`, *optional*, defaults to 7.5):
461
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
462
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
463
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
464
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
465
- usually at the expense of lower image quality.
466
- num_inference_steps (`int`, *optional*, defaults to 50):
467
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
468
- expense of slower inference.
469
- eta (`float`, *optional*, defaults to 0.0):
470
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
471
- [`schedulers.DDIMScheduler`], will be ignored for others.
472
-
473
- Returns:
474
- `List[str]`: List of paths to the generated images.
475
- """
476
- if not len(prompts) == len(seeds):
477
- raise ValueError(
478
- f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"
479
- )
480
-
481
- name = name or time.strftime("%Y%m%d-%H%M%S")
482
- save_path = Path(output_dir) / name
483
- save_path.mkdir(exist_ok=True, parents=True)
484
-
485
- frame_idx = 0
486
- frame_filepaths = []
487
- for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):
488
- # Embed Text
489
- embed_a = self.embed_text(prompt_a)
490
- embed_b = self.embed_text(prompt_b)
491
-
492
- # Get Noise
493
- noise_dtype = embed_a.dtype
494
- noise_a = self.get_noise(seed_a, noise_dtype, height, width)
495
- noise_b = self.get_noise(seed_b, noise_dtype, height, width)
496
-
497
- noise_batch, embeds_batch = None, None
498
- T = np.linspace(0.0, 1.0, num_interpolation_steps)
499
- for i, t in enumerate(T):
500
- noise = slerp(float(t), noise_a, noise_b)
501
- embed = torch.lerp(embed_a, embed_b, t)
502
-
503
- noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)
504
- embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)
505
-
506
- batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
507
- if batch_is_ready:
508
- outputs = self(
509
- latents=noise_batch,
510
- text_embeddings=embeds_batch,
511
- height=height,
512
- width=width,
513
- guidance_scale=guidance_scale,
514
- eta=eta,
515
- num_inference_steps=num_inference_steps,
516
- )
517
- noise_batch, embeds_batch = None, None
518
-
519
- for image in outputs["images"]:
520
- frame_filepath = str(save_path / f"frame_{frame_idx:06d}.png")
521
- image.save(frame_filepath)
522
- frame_filepaths.append(frame_filepath)
523
- frame_idx += 1
524
- return frame_filepaths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/lpw_stable_diffusion.py DELETED
@@ -1,1153 +0,0 @@
1
- import inspect
2
- import re
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import PIL
7
- import torch
8
- from packaging import version
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
-
11
- import diffusers
12
- from diffusers import SchedulerMixin, StableDiffusionPipeline
13
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
15
- from diffusers.utils import logging
16
-
17
-
18
- try:
19
- from diffusers.utils import PIL_INTERPOLATION
20
- except ImportError:
21
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
22
- PIL_INTERPOLATION = {
23
- "linear": PIL.Image.Resampling.BILINEAR,
24
- "bilinear": PIL.Image.Resampling.BILINEAR,
25
- "bicubic": PIL.Image.Resampling.BICUBIC,
26
- "lanczos": PIL.Image.Resampling.LANCZOS,
27
- "nearest": PIL.Image.Resampling.NEAREST,
28
- }
29
- else:
30
- PIL_INTERPOLATION = {
31
- "linear": PIL.Image.LINEAR,
32
- "bilinear": PIL.Image.BILINEAR,
33
- "bicubic": PIL.Image.BICUBIC,
34
- "lanczos": PIL.Image.LANCZOS,
35
- "nearest": PIL.Image.NEAREST,
36
- }
37
- # ------------------------------------------------------------------------------
38
-
39
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
-
41
- re_attention = re.compile(
42
- r"""
43
- \\\(|
44
- \\\)|
45
- \\\[|
46
- \\]|
47
- \\\\|
48
- \\|
49
- \(|
50
- \[|
51
- :([+-]?[.\d]+)\)|
52
- \)|
53
- ]|
54
- [^\\()\[\]:]+|
55
- :
56
- """,
57
- re.X,
58
- )
59
-
60
-
61
- def parse_prompt_attention(text):
62
- """
63
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
64
- Accepted tokens are:
65
- (abc) - increases attention to abc by a multiplier of 1.1
66
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
67
- [abc] - decreases attention to abc by a multiplier of 1.1
68
- \( - literal character '('
69
- \[ - literal character '['
70
- \) - literal character ')'
71
- \] - literal character ']'
72
- \\ - literal character '\'
73
- anything else - just text
74
- >>> parse_prompt_attention('normal text')
75
- [['normal text', 1.0]]
76
- >>> parse_prompt_attention('an (important) word')
77
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
78
- >>> parse_prompt_attention('(unbalanced')
79
- [['unbalanced', 1.1]]
80
- >>> parse_prompt_attention('\(literal\]')
81
- [['(literal]', 1.0]]
82
- >>> parse_prompt_attention('(unnecessary)(parens)')
83
- [['unnecessaryparens', 1.1]]
84
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
85
- [['a ', 1.0],
86
- ['house', 1.5730000000000004],
87
- [' ', 1.1],
88
- ['on', 1.0],
89
- [' a ', 1.1],
90
- ['hill', 0.55],
91
- [', sun, ', 1.1],
92
- ['sky', 1.4641000000000006],
93
- ['.', 1.1]]
94
- """
95
-
96
- res = []
97
- round_brackets = []
98
- square_brackets = []
99
-
100
- round_bracket_multiplier = 1.1
101
- square_bracket_multiplier = 1 / 1.1
102
-
103
- def multiply_range(start_position, multiplier):
104
- for p in range(start_position, len(res)):
105
- res[p][1] *= multiplier
106
-
107
- for m in re_attention.finditer(text):
108
- text = m.group(0)
109
- weight = m.group(1)
110
-
111
- if text.startswith("\\"):
112
- res.append([text[1:], 1.0])
113
- elif text == "(":
114
- round_brackets.append(len(res))
115
- elif text == "[":
116
- square_brackets.append(len(res))
117
- elif weight is not None and len(round_brackets) > 0:
118
- multiply_range(round_brackets.pop(), float(weight))
119
- elif text == ")" and len(round_brackets) > 0:
120
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
121
- elif text == "]" and len(square_brackets) > 0:
122
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
123
- else:
124
- res.append([text, 1.0])
125
-
126
- for pos in round_brackets:
127
- multiply_range(pos, round_bracket_multiplier)
128
-
129
- for pos in square_brackets:
130
- multiply_range(pos, square_bracket_multiplier)
131
-
132
- if len(res) == 0:
133
- res = [["", 1.0]]
134
-
135
- # merge runs of identical weights
136
- i = 0
137
- while i + 1 < len(res):
138
- if res[i][1] == res[i + 1][1]:
139
- res[i][0] += res[i + 1][0]
140
- res.pop(i + 1)
141
- else:
142
- i += 1
143
-
144
- return res
145
-
146
-
147
- def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
148
- r"""
149
- Tokenize a list of prompts and return its tokens with weights of each token.
150
-
151
- No padding, starting or ending token is included.
152
- """
153
- tokens = []
154
- weights = []
155
- truncated = False
156
- for text in prompt:
157
- texts_and_weights = parse_prompt_attention(text)
158
- text_token = []
159
- text_weight = []
160
- for word, weight in texts_and_weights:
161
- # tokenize and discard the starting and the ending token
162
- token = pipe.tokenizer(word).input_ids[1:-1]
163
- text_token += token
164
- # copy the weight by length of token
165
- text_weight += [weight] * len(token)
166
- # stop if the text is too long (longer than truncation limit)
167
- if len(text_token) > max_length:
168
- truncated = True
169
- break
170
- # truncate
171
- if len(text_token) > max_length:
172
- truncated = True
173
- text_token = text_token[:max_length]
174
- text_weight = text_weight[:max_length]
175
- tokens.append(text_token)
176
- weights.append(text_weight)
177
- if truncated:
178
- logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
179
- return tokens, weights
180
-
181
-
182
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
183
- r"""
184
- Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
185
- """
186
- max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
187
- weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
188
- for i in range(len(tokens)):
189
- tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
190
- if no_boseos_middle:
191
- weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
192
- else:
193
- w = []
194
- if len(weights[i]) == 0:
195
- w = [1.0] * weights_length
196
- else:
197
- for j in range(max_embeddings_multiples):
198
- w.append(1.0) # weight for starting token in this chunk
199
- w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
200
- w.append(1.0) # weight for ending token in this chunk
201
- w += [1.0] * (weights_length - len(w))
202
- weights[i] = w[:]
203
-
204
- return tokens, weights
205
-
206
-
207
- def get_unweighted_text_embeddings(
208
- pipe: StableDiffusionPipeline,
209
- text_input: torch.Tensor,
210
- chunk_length: int,
211
- no_boseos_middle: Optional[bool] = True,
212
- ):
213
- """
214
- When the length of tokens is a multiple of the capacity of the text encoder,
215
- it should be split into chunks and sent to the text encoder individually.
216
- """
217
- max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
218
- if max_embeddings_multiples > 1:
219
- text_embeddings = []
220
- for i in range(max_embeddings_multiples):
221
- # extract the i-th chunk
222
- text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
223
-
224
- # cover the head and the tail by the starting and the ending tokens
225
- text_input_chunk[:, 0] = text_input[0, 0]
226
- text_input_chunk[:, -1] = text_input[0, -1]
227
- text_embedding = pipe.text_encoder(text_input_chunk)[0]
228
-
229
- if no_boseos_middle:
230
- if i == 0:
231
- # discard the ending token
232
- text_embedding = text_embedding[:, :-1]
233
- elif i == max_embeddings_multiples - 1:
234
- # discard the starting token
235
- text_embedding = text_embedding[:, 1:]
236
- else:
237
- # discard both starting and ending tokens
238
- text_embedding = text_embedding[:, 1:-1]
239
-
240
- text_embeddings.append(text_embedding)
241
- text_embeddings = torch.concat(text_embeddings, axis=1)
242
- else:
243
- text_embeddings = pipe.text_encoder(text_input)[0]
244
- return text_embeddings
245
-
246
-
247
- def get_weighted_text_embeddings(
248
- pipe: StableDiffusionPipeline,
249
- prompt: Union[str, List[str]],
250
- uncond_prompt: Optional[Union[str, List[str]]] = None,
251
- max_embeddings_multiples: Optional[int] = 3,
252
- no_boseos_middle: Optional[bool] = False,
253
- skip_parsing: Optional[bool] = False,
254
- skip_weighting: Optional[bool] = False,
255
- ):
256
- r"""
257
- Prompts can be assigned with local weights using brackets. For example,
258
- prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
259
- and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
260
-
261
- Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
262
-
263
- Args:
264
- pipe (`StableDiffusionPipeline`):
265
- Pipe to provide access to the tokenizer and the text encoder.
266
- prompt (`str` or `List[str]`):
267
- The prompt or prompts to guide the image generation.
268
- uncond_prompt (`str` or `List[str]`):
269
- The unconditional prompt or prompts for guide the image generation. If unconditional prompt
270
- is provided, the embeddings of prompt and uncond_prompt are concatenated.
271
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
272
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
273
- no_boseos_middle (`bool`, *optional*, defaults to `False`):
274
- If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
275
- ending token in each of the chunk in the middle.
276
- skip_parsing (`bool`, *optional*, defaults to `False`):
277
- Skip the parsing of brackets.
278
- skip_weighting (`bool`, *optional*, defaults to `False`):
279
- Skip the weighting. When the parsing is skipped, it is forced True.
280
- """
281
- max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
282
- if isinstance(prompt, str):
283
- prompt = [prompt]
284
-
285
- if not skip_parsing:
286
- prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
287
- if uncond_prompt is not None:
288
- if isinstance(uncond_prompt, str):
289
- uncond_prompt = [uncond_prompt]
290
- uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
291
- else:
292
- prompt_tokens = [
293
- token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
294
- ]
295
- prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
296
- if uncond_prompt is not None:
297
- if isinstance(uncond_prompt, str):
298
- uncond_prompt = [uncond_prompt]
299
- uncond_tokens = [
300
- token[1:-1]
301
- for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
302
- ]
303
- uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
304
-
305
- # round up the longest length of tokens to a multiple of (model_max_length - 2)
306
- max_length = max([len(token) for token in prompt_tokens])
307
- if uncond_prompt is not None:
308
- max_length = max(max_length, max([len(token) for token in uncond_tokens]))
309
-
310
- max_embeddings_multiples = min(
311
- max_embeddings_multiples,
312
- (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
313
- )
314
- max_embeddings_multiples = max(1, max_embeddings_multiples)
315
- max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
316
-
317
- # pad the length of tokens and weights
318
- bos = pipe.tokenizer.bos_token_id
319
- eos = pipe.tokenizer.eos_token_id
320
- pad = getattr(pipe.tokenizer, "pad_token_id", eos)
321
- prompt_tokens, prompt_weights = pad_tokens_and_weights(
322
- prompt_tokens,
323
- prompt_weights,
324
- max_length,
325
- bos,
326
- eos,
327
- pad,
328
- no_boseos_middle=no_boseos_middle,
329
- chunk_length=pipe.tokenizer.model_max_length,
330
- )
331
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
332
- if uncond_prompt is not None:
333
- uncond_tokens, uncond_weights = pad_tokens_and_weights(
334
- uncond_tokens,
335
- uncond_weights,
336
- max_length,
337
- bos,
338
- eos,
339
- pad,
340
- no_boseos_middle=no_boseos_middle,
341
- chunk_length=pipe.tokenizer.model_max_length,
342
- )
343
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
344
-
345
- # get the embeddings
346
- text_embeddings = get_unweighted_text_embeddings(
347
- pipe,
348
- prompt_tokens,
349
- pipe.tokenizer.model_max_length,
350
- no_boseos_middle=no_boseos_middle,
351
- )
352
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
353
- if uncond_prompt is not None:
354
- uncond_embeddings = get_unweighted_text_embeddings(
355
- pipe,
356
- uncond_tokens,
357
- pipe.tokenizer.model_max_length,
358
- no_boseos_middle=no_boseos_middle,
359
- )
360
- uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
361
-
362
- # assign weights to the prompts and normalize in the sense of mean
363
- # TODO: should we normalize by chunk or in a whole (current implementation)?
364
- if (not skip_parsing) and (not skip_weighting):
365
- previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
366
- text_embeddings *= prompt_weights.unsqueeze(-1)
367
- current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
368
- text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
369
- if uncond_prompt is not None:
370
- previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
371
- uncond_embeddings *= uncond_weights.unsqueeze(-1)
372
- current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
373
- uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
374
-
375
- if uncond_prompt is not None:
376
- return text_embeddings, uncond_embeddings
377
- return text_embeddings, None
378
-
379
-
380
- def preprocess_image(image):
381
- w, h = image.size
382
- w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
383
- image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
384
- image = np.array(image).astype(np.float32) / 255.0
385
- image = image[None].transpose(0, 3, 1, 2)
386
- image = torch.from_numpy(image)
387
- return 2.0 * image - 1.0
388
-
389
-
390
- def preprocess_mask(mask, scale_factor=8):
391
- mask = mask.convert("L")
392
- w, h = mask.size
393
- w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
394
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
395
- mask = np.array(mask).astype(np.float32) / 255.0
396
- mask = np.tile(mask, (4, 1, 1))
397
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
398
- mask = 1 - mask # repaint white, keep black
399
- mask = torch.from_numpy(mask)
400
- return mask
401
-
402
-
403
- class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
404
- r"""
405
- Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
406
- weighting in prompt.
407
-
408
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
409
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
410
-
411
- Args:
412
- vae ([`AutoencoderKL`]):
413
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
414
- text_encoder ([`CLIPTextModel`]):
415
- Frozen text-encoder. Stable Diffusion uses the text portion of
416
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
417
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
418
- tokenizer (`CLIPTokenizer`):
419
- Tokenizer of class
420
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
421
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
422
- scheduler ([`SchedulerMixin`]):
423
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
424
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
425
- safety_checker ([`StableDiffusionSafetyChecker`]):
426
- Classification module that estimates whether generated images could be considered offensive or harmful.
427
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
428
- feature_extractor ([`CLIPImageProcessor`]):
429
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
430
- """
431
-
432
- if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
433
-
434
- def __init__(
435
- self,
436
- vae: AutoencoderKL,
437
- text_encoder: CLIPTextModel,
438
- tokenizer: CLIPTokenizer,
439
- unet: UNet2DConditionModel,
440
- scheduler: SchedulerMixin,
441
- safety_checker: StableDiffusionSafetyChecker,
442
- feature_extractor: CLIPImageProcessor,
443
- requires_safety_checker: bool = True,
444
- ):
445
- super().__init__(
446
- vae=vae,
447
- text_encoder=text_encoder,
448
- tokenizer=tokenizer,
449
- unet=unet,
450
- scheduler=scheduler,
451
- safety_checker=safety_checker,
452
- feature_extractor=feature_extractor,
453
- requires_safety_checker=requires_safety_checker,
454
- )
455
- self.__init__additional__()
456
-
457
- else:
458
-
459
- def __init__(
460
- self,
461
- vae: AutoencoderKL,
462
- text_encoder: CLIPTextModel,
463
- tokenizer: CLIPTokenizer,
464
- unet: UNet2DConditionModel,
465
- scheduler: SchedulerMixin,
466
- safety_checker: StableDiffusionSafetyChecker,
467
- feature_extractor: CLIPImageProcessor,
468
- ):
469
- super().__init__(
470
- vae=vae,
471
- text_encoder=text_encoder,
472
- tokenizer=tokenizer,
473
- unet=unet,
474
- scheduler=scheduler,
475
- safety_checker=safety_checker,
476
- feature_extractor=feature_extractor,
477
- )
478
- self.__init__additional__()
479
-
480
- def __init__additional__(self):
481
- if not hasattr(self, "vae_scale_factor"):
482
- setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
483
-
484
- @property
485
- def _execution_device(self):
486
- r"""
487
- Returns the device on which the pipeline's models will be executed. After calling
488
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
489
- hooks.
490
- """
491
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
492
- return self.device
493
- for module in self.unet.modules():
494
- if (
495
- hasattr(module, "_hf_hook")
496
- and hasattr(module._hf_hook, "execution_device")
497
- and module._hf_hook.execution_device is not None
498
- ):
499
- return torch.device(module._hf_hook.execution_device)
500
- return self.device
501
-
502
- def _encode_prompt(
503
- self,
504
- prompt,
505
- device,
506
- num_images_per_prompt,
507
- do_classifier_free_guidance,
508
- negative_prompt,
509
- max_embeddings_multiples,
510
- ):
511
- r"""
512
- Encodes the prompt into text encoder hidden states.
513
-
514
- Args:
515
- prompt (`str` or `list(int)`):
516
- prompt to be encoded
517
- device: (`torch.device`):
518
- torch device
519
- num_images_per_prompt (`int`):
520
- number of images that should be generated per prompt
521
- do_classifier_free_guidance (`bool`):
522
- whether to use classifier free guidance or not
523
- negative_prompt (`str` or `List[str]`):
524
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
525
- if `guidance_scale` is less than `1`).
526
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
527
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
528
- """
529
- batch_size = len(prompt) if isinstance(prompt, list) else 1
530
-
531
- if negative_prompt is None:
532
- negative_prompt = [""] * batch_size
533
- elif isinstance(negative_prompt, str):
534
- negative_prompt = [negative_prompt] * batch_size
535
- if batch_size != len(negative_prompt):
536
- raise ValueError(
537
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
538
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
539
- " the batch size of `prompt`."
540
- )
541
-
542
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
543
- pipe=self,
544
- prompt=prompt,
545
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
546
- max_embeddings_multiples=max_embeddings_multiples,
547
- )
548
- bs_embed, seq_len, _ = text_embeddings.shape
549
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
550
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
551
-
552
- if do_classifier_free_guidance:
553
- bs_embed, seq_len, _ = uncond_embeddings.shape
554
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
555
- uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
556
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
557
-
558
- return text_embeddings
559
-
560
- def check_inputs(self, prompt, height, width, strength, callback_steps):
561
- if not isinstance(prompt, str) and not isinstance(prompt, list):
562
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
563
-
564
- if strength < 0 or strength > 1:
565
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
566
-
567
- if height % 8 != 0 or width % 8 != 0:
568
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
569
-
570
- if (callback_steps is None) or (
571
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
572
- ):
573
- raise ValueError(
574
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
575
- f" {type(callback_steps)}."
576
- )
577
-
578
- def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
579
- if is_text2img:
580
- return self.scheduler.timesteps.to(device), num_inference_steps
581
- else:
582
- # get the original timestep using init_timestep
583
- offset = self.scheduler.config.get("steps_offset", 0)
584
- init_timestep = int(num_inference_steps * strength) + offset
585
- init_timestep = min(init_timestep, num_inference_steps)
586
-
587
- t_start = max(num_inference_steps - init_timestep + offset, 0)
588
- timesteps = self.scheduler.timesteps[t_start:].to(device)
589
- return timesteps, num_inference_steps - t_start
590
-
591
- def run_safety_checker(self, image, device, dtype):
592
- if self.safety_checker is not None:
593
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
594
- image, has_nsfw_concept = self.safety_checker(
595
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
596
- )
597
- else:
598
- has_nsfw_concept = None
599
- return image, has_nsfw_concept
600
-
601
- def decode_latents(self, latents):
602
- latents = 1 / 0.18215 * latents
603
- image = self.vae.decode(latents).sample
604
- image = (image / 2 + 0.5).clamp(0, 1)
605
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
606
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
607
- return image
608
-
609
- def prepare_extra_step_kwargs(self, generator, eta):
610
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
611
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
612
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
613
- # and should be between [0, 1]
614
-
615
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
616
- extra_step_kwargs = {}
617
- if accepts_eta:
618
- extra_step_kwargs["eta"] = eta
619
-
620
- # check if the scheduler accepts generator
621
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
622
- if accepts_generator:
623
- extra_step_kwargs["generator"] = generator
624
- return extra_step_kwargs
625
-
626
- def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
627
- if image is None:
628
- shape = (
629
- batch_size,
630
- self.unet.in_channels,
631
- height // self.vae_scale_factor,
632
- width // self.vae_scale_factor,
633
- )
634
-
635
- if latents is None:
636
- if device.type == "mps":
637
- # randn does not work reproducibly on mps
638
- latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
639
- else:
640
- latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
641
- else:
642
- if latents.shape != shape:
643
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
644
- latents = latents.to(device)
645
-
646
- # scale the initial noise by the standard deviation required by the scheduler
647
- latents = latents * self.scheduler.init_noise_sigma
648
- return latents, None, None
649
- else:
650
- init_latent_dist = self.vae.encode(image).latent_dist
651
- init_latents = init_latent_dist.sample(generator=generator)
652
- init_latents = 0.18215 * init_latents
653
- init_latents = torch.cat([init_latents] * batch_size, dim=0)
654
- init_latents_orig = init_latents
655
- shape = init_latents.shape
656
-
657
- # add noise to latents using the timesteps
658
- if device.type == "mps":
659
- noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
660
- else:
661
- noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
662
- latents = self.scheduler.add_noise(init_latents, noise, timestep)
663
- return latents, init_latents_orig, noise
664
-
665
- @torch.no_grad()
666
- def __call__(
667
- self,
668
- prompt: Union[str, List[str]],
669
- negative_prompt: Optional[Union[str, List[str]]] = None,
670
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
671
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
672
- height: int = 512,
673
- width: int = 512,
674
- num_inference_steps: int = 50,
675
- guidance_scale: float = 7.5,
676
- strength: float = 0.8,
677
- num_images_per_prompt: Optional[int] = 1,
678
- eta: float = 0.0,
679
- generator: Optional[torch.Generator] = None,
680
- latents: Optional[torch.FloatTensor] = None,
681
- max_embeddings_multiples: Optional[int] = 3,
682
- output_type: Optional[str] = "pil",
683
- return_dict: bool = True,
684
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
685
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
686
- callback_steps: int = 1,
687
- ):
688
- r"""
689
- Function invoked when calling the pipeline for generation.
690
-
691
- Args:
692
- prompt (`str` or `List[str]`):
693
- The prompt or prompts to guide the image generation.
694
- negative_prompt (`str` or `List[str]`, *optional*):
695
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
696
- if `guidance_scale` is less than `1`).
697
- image (`torch.FloatTensor` or `PIL.Image.Image`):
698
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
699
- process.
700
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
701
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
702
- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
703
- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
704
- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
705
- height (`int`, *optional*, defaults to 512):
706
- The height in pixels of the generated image.
707
- width (`int`, *optional*, defaults to 512):
708
- The width in pixels of the generated image.
709
- num_inference_steps (`int`, *optional*, defaults to 50):
710
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
711
- expense of slower inference.
712
- guidance_scale (`float`, *optional*, defaults to 7.5):
713
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
714
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
715
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
716
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
717
- usually at the expense of lower image quality.
718
- strength (`float`, *optional*, defaults to 0.8):
719
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
720
- `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
721
- number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
722
- noise will be maximum and the denoising process will run for the full number of iterations specified in
723
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
724
- num_images_per_prompt (`int`, *optional*, defaults to 1):
725
- The number of images to generate per prompt.
726
- eta (`float`, *optional*, defaults to 0.0):
727
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
728
- [`schedulers.DDIMScheduler`], will be ignored for others.
729
- generator (`torch.Generator`, *optional*):
730
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
731
- deterministic.
732
- latents (`torch.FloatTensor`, *optional*):
733
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
734
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
735
- tensor will ge generated by sampling using the supplied random `generator`.
736
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
737
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
738
- output_type (`str`, *optional*, defaults to `"pil"`):
739
- The output format of the generate image. Choose between
740
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
741
- return_dict (`bool`, *optional*, defaults to `True`):
742
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
743
- plain tuple.
744
- callback (`Callable`, *optional*):
745
- A function that will be called every `callback_steps` steps during inference. The function will be
746
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
747
- is_cancelled_callback (`Callable`, *optional*):
748
- A function that will be called every `callback_steps` steps during inference. If the function returns
749
- `True`, the inference will be cancelled.
750
- callback_steps (`int`, *optional*, defaults to 1):
751
- The frequency at which the `callback` function will be called. If not specified, the callback will be
752
- called at every step.
753
-
754
- Returns:
755
- `None` if cancelled by `is_cancelled_callback`,
756
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
757
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
758
- When returning a tuple, the first element is a list with the generated images, and the second element is a
759
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
760
- (nsfw) content, according to the `safety_checker`.
761
- """
762
- # 0. Default height and width to unet
763
- height = height or self.unet.config.sample_size * self.vae_scale_factor
764
- width = width or self.unet.config.sample_size * self.vae_scale_factor
765
-
766
- # 1. Check inputs. Raise error if not correct
767
- self.check_inputs(prompt, height, width, strength, callback_steps)
768
-
769
- # 2. Define call parameters
770
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
771
- device = self._execution_device
772
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
773
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
774
- # corresponds to doing no classifier free guidance.
775
- do_classifier_free_guidance = guidance_scale > 1.0
776
-
777
- # 3. Encode input prompt
778
- text_embeddings = self._encode_prompt(
779
- prompt,
780
- device,
781
- num_images_per_prompt,
782
- do_classifier_free_guidance,
783
- negative_prompt,
784
- max_embeddings_multiples,
785
- )
786
- dtype = text_embeddings.dtype
787
-
788
- # 4. Preprocess image and mask
789
- if isinstance(image, PIL.Image.Image):
790
- image = preprocess_image(image)
791
- if image is not None:
792
- image = image.to(device=self.device, dtype=dtype)
793
- if isinstance(mask_image, PIL.Image.Image):
794
- mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
795
- if mask_image is not None:
796
- mask = mask_image.to(device=self.device, dtype=dtype)
797
- mask = torch.cat([mask] * batch_size * num_images_per_prompt)
798
- else:
799
- mask = None
800
-
801
- # 5. set timesteps
802
- self.scheduler.set_timesteps(num_inference_steps, device=device)
803
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
804
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
805
-
806
- # 6. Prepare latent variables
807
- latents, init_latents_orig, noise = self.prepare_latents(
808
- image,
809
- latent_timestep,
810
- batch_size * num_images_per_prompt,
811
- height,
812
- width,
813
- dtype,
814
- device,
815
- generator,
816
- latents,
817
- )
818
-
819
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
820
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
821
-
822
- # 8. Denoising loop
823
- for i, t in enumerate(self.progress_bar(timesteps)):
824
- # expand the latents if we are doing classifier free guidance
825
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
826
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
827
-
828
- # predict the noise residual
829
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
830
-
831
- # perform guidance
832
- if do_classifier_free_guidance:
833
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
834
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
835
-
836
- # compute the previous noisy sample x_t -> x_t-1
837
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
838
-
839
- if mask is not None:
840
- # masking
841
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
842
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
843
-
844
- # call the callback, if provided
845
- if i % callback_steps == 0:
846
- if callback is not None:
847
- callback(i, t, latents)
848
- if is_cancelled_callback is not None and is_cancelled_callback():
849
- return None
850
-
851
- # 9. Post-processing
852
- image = self.decode_latents(latents)
853
-
854
- # 10. Run safety checker
855
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
856
-
857
- # 11. Convert to PIL
858
- if output_type == "pil":
859
- image = self.numpy_to_pil(image)
860
-
861
- if not return_dict:
862
- return image, has_nsfw_concept
863
-
864
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
865
-
866
- def text2img(
867
- self,
868
- prompt: Union[str, List[str]],
869
- negative_prompt: Optional[Union[str, List[str]]] = None,
870
- height: int = 512,
871
- width: int = 512,
872
- num_inference_steps: int = 50,
873
- guidance_scale: float = 7.5,
874
- num_images_per_prompt: Optional[int] = 1,
875
- eta: float = 0.0,
876
- generator: Optional[torch.Generator] = None,
877
- latents: Optional[torch.FloatTensor] = None,
878
- max_embeddings_multiples: Optional[int] = 3,
879
- output_type: Optional[str] = "pil",
880
- return_dict: bool = True,
881
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
882
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
883
- callback_steps: int = 1,
884
- ):
885
- r"""
886
- Function for text-to-image generation.
887
- Args:
888
- prompt (`str` or `List[str]`):
889
- The prompt or prompts to guide the image generation.
890
- negative_prompt (`str` or `List[str]`, *optional*):
891
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
892
- if `guidance_scale` is less than `1`).
893
- height (`int`, *optional*, defaults to 512):
894
- The height in pixels of the generated image.
895
- width (`int`, *optional*, defaults to 512):
896
- The width in pixels of the generated image.
897
- num_inference_steps (`int`, *optional*, defaults to 50):
898
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
899
- expense of slower inference.
900
- guidance_scale (`float`, *optional*, defaults to 7.5):
901
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
902
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
903
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
904
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
905
- usually at the expense of lower image quality.
906
- num_images_per_prompt (`int`, *optional*, defaults to 1):
907
- The number of images to generate per prompt.
908
- eta (`float`, *optional*, defaults to 0.0):
909
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
910
- [`schedulers.DDIMScheduler`], will be ignored for others.
911
- generator (`torch.Generator`, *optional*):
912
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
913
- deterministic.
914
- latents (`torch.FloatTensor`, *optional*):
915
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
916
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
917
- tensor will ge generated by sampling using the supplied random `generator`.
918
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
919
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
920
- output_type (`str`, *optional*, defaults to `"pil"`):
921
- The output format of the generate image. Choose between
922
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
923
- return_dict (`bool`, *optional*, defaults to `True`):
924
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
925
- plain tuple.
926
- callback (`Callable`, *optional*):
927
- A function that will be called every `callback_steps` steps during inference. The function will be
928
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
929
- is_cancelled_callback (`Callable`, *optional*):
930
- A function that will be called every `callback_steps` steps during inference. If the function returns
931
- `True`, the inference will be cancelled.
932
- callback_steps (`int`, *optional*, defaults to 1):
933
- The frequency at which the `callback` function will be called. If not specified, the callback will be
934
- called at every step.
935
- Returns:
936
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
937
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
938
- When returning a tuple, the first element is a list with the generated images, and the second element is a
939
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
940
- (nsfw) content, according to the `safety_checker`.
941
- """
942
- return self.__call__(
943
- prompt=prompt,
944
- negative_prompt=negative_prompt,
945
- height=height,
946
- width=width,
947
- num_inference_steps=num_inference_steps,
948
- guidance_scale=guidance_scale,
949
- num_images_per_prompt=num_images_per_prompt,
950
- eta=eta,
951
- generator=generator,
952
- latents=latents,
953
- max_embeddings_multiples=max_embeddings_multiples,
954
- output_type=output_type,
955
- return_dict=return_dict,
956
- callback=callback,
957
- is_cancelled_callback=is_cancelled_callback,
958
- callback_steps=callback_steps,
959
- )
960
-
961
- def img2img(
962
- self,
963
- image: Union[torch.FloatTensor, PIL.Image.Image],
964
- prompt: Union[str, List[str]],
965
- negative_prompt: Optional[Union[str, List[str]]] = None,
966
- strength: float = 0.8,
967
- num_inference_steps: Optional[int] = 50,
968
- guidance_scale: Optional[float] = 7.5,
969
- num_images_per_prompt: Optional[int] = 1,
970
- eta: Optional[float] = 0.0,
971
- generator: Optional[torch.Generator] = None,
972
- max_embeddings_multiples: Optional[int] = 3,
973
- output_type: Optional[str] = "pil",
974
- return_dict: bool = True,
975
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
976
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
977
- callback_steps: int = 1,
978
- ):
979
- r"""
980
- Function for image-to-image generation.
981
- Args:
982
- image (`torch.FloatTensor` or `PIL.Image.Image`):
983
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
984
- process.
985
- prompt (`str` or `List[str]`):
986
- The prompt or prompts to guide the image generation.
987
- negative_prompt (`str` or `List[str]`, *optional*):
988
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
989
- if `guidance_scale` is less than `1`).
990
- strength (`float`, *optional*, defaults to 0.8):
991
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
992
- `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
993
- number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
994
- noise will be maximum and the denoising process will run for the full number of iterations specified in
995
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
996
- num_inference_steps (`int`, *optional*, defaults to 50):
997
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
998
- expense of slower inference. This parameter will be modulated by `strength`.
999
- guidance_scale (`float`, *optional*, defaults to 7.5):
1000
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1001
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1002
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1003
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1004
- usually at the expense of lower image quality.
1005
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1006
- The number of images to generate per prompt.
1007
- eta (`float`, *optional*, defaults to 0.0):
1008
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1009
- [`schedulers.DDIMScheduler`], will be ignored for others.
1010
- generator (`torch.Generator`, *optional*):
1011
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1012
- deterministic.
1013
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1014
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
1015
- output_type (`str`, *optional*, defaults to `"pil"`):
1016
- The output format of the generate image. Choose between
1017
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1018
- return_dict (`bool`, *optional*, defaults to `True`):
1019
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1020
- plain tuple.
1021
- callback (`Callable`, *optional*):
1022
- A function that will be called every `callback_steps` steps during inference. The function will be
1023
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1024
- is_cancelled_callback (`Callable`, *optional*):
1025
- A function that will be called every `callback_steps` steps during inference. If the function returns
1026
- `True`, the inference will be cancelled.
1027
- callback_steps (`int`, *optional*, defaults to 1):
1028
- The frequency at which the `callback` function will be called. If not specified, the callback will be
1029
- called at every step.
1030
- Returns:
1031
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1032
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1033
- When returning a tuple, the first element is a list with the generated images, and the second element is a
1034
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1035
- (nsfw) content, according to the `safety_checker`.
1036
- """
1037
- return self.__call__(
1038
- prompt=prompt,
1039
- negative_prompt=negative_prompt,
1040
- image=image,
1041
- num_inference_steps=num_inference_steps,
1042
- guidance_scale=guidance_scale,
1043
- strength=strength,
1044
- num_images_per_prompt=num_images_per_prompt,
1045
- eta=eta,
1046
- generator=generator,
1047
- max_embeddings_multiples=max_embeddings_multiples,
1048
- output_type=output_type,
1049
- return_dict=return_dict,
1050
- callback=callback,
1051
- is_cancelled_callback=is_cancelled_callback,
1052
- callback_steps=callback_steps,
1053
- )
1054
-
1055
- def inpaint(
1056
- self,
1057
- image: Union[torch.FloatTensor, PIL.Image.Image],
1058
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1059
- prompt: Union[str, List[str]],
1060
- negative_prompt: Optional[Union[str, List[str]]] = None,
1061
- strength: float = 0.8,
1062
- num_inference_steps: Optional[int] = 50,
1063
- guidance_scale: Optional[float] = 7.5,
1064
- num_images_per_prompt: Optional[int] = 1,
1065
- eta: Optional[float] = 0.0,
1066
- generator: Optional[torch.Generator] = None,
1067
- max_embeddings_multiples: Optional[int] = 3,
1068
- output_type: Optional[str] = "pil",
1069
- return_dict: bool = True,
1070
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1071
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1072
- callback_steps: int = 1,
1073
- ):
1074
- r"""
1075
- Function for inpaint.
1076
- Args:
1077
- image (`torch.FloatTensor` or `PIL.Image.Image`):
1078
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
1079
- process. This is the image whose masked region will be inpainted.
1080
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1081
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1082
- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1083
- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1084
- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1085
- prompt (`str` or `List[str]`):
1086
- The prompt or prompts to guide the image generation.
1087
- negative_prompt (`str` or `List[str]`, *optional*):
1088
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1089
- if `guidance_scale` is less than `1`).
1090
- strength (`float`, *optional*, defaults to 0.8):
1091
- Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1092
- is 1, the denoising process will be run on the masked area for the full number of iterations specified
1093
- in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1094
- noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1095
- num_inference_steps (`int`, *optional*, defaults to 50):
1096
- The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1097
- the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1098
- guidance_scale (`float`, *optional*, defaults to 7.5):
1099
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1100
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1101
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1102
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1103
- usually at the expense of lower image quality.
1104
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1105
- The number of images to generate per prompt.
1106
- eta (`float`, *optional*, defaults to 0.0):
1107
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1108
- [`schedulers.DDIMScheduler`], will be ignored for others.
1109
- generator (`torch.Generator`, *optional*):
1110
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1111
- deterministic.
1112
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1113
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
1114
- output_type (`str`, *optional*, defaults to `"pil"`):
1115
- The output format of the generate image. Choose between
1116
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1117
- return_dict (`bool`, *optional*, defaults to `True`):
1118
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1119
- plain tuple.
1120
- callback (`Callable`, *optional*):
1121
- A function that will be called every `callback_steps` steps during inference. The function will be
1122
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1123
- is_cancelled_callback (`Callable`, *optional*):
1124
- A function that will be called every `callback_steps` steps during inference. If the function returns
1125
- `True`, the inference will be cancelled.
1126
- callback_steps (`int`, *optional*, defaults to 1):
1127
- The frequency at which the `callback` function will be called. If not specified, the callback will be
1128
- called at every step.
1129
- Returns:
1130
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1131
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1132
- When returning a tuple, the first element is a list with the generated images, and the second element is a
1133
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1134
- (nsfw) content, according to the `safety_checker`.
1135
- """
1136
- return self.__call__(
1137
- prompt=prompt,
1138
- negative_prompt=negative_prompt,
1139
- image=image,
1140
- mask_image=mask_image,
1141
- num_inference_steps=num_inference_steps,
1142
- guidance_scale=guidance_scale,
1143
- strength=strength,
1144
- num_images_per_prompt=num_images_per_prompt,
1145
- eta=eta,
1146
- generator=generator,
1147
- max_embeddings_multiples=max_embeddings_multiples,
1148
- output_type=output_type,
1149
- return_dict=return_dict,
1150
- callback=callback,
1151
- is_cancelled_callback=is_cancelled_callback,
1152
- callback_steps=callback_steps,
1153
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/lpw_stable_diffusion_onnx.py DELETED
@@ -1,1146 +0,0 @@
1
- import inspect
2
- import re
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import PIL
7
- import torch
8
- from packaging import version
9
- from transformers import CLIPImageProcessor, CLIPTokenizer
10
-
11
- import diffusers
12
- from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
13
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
- from diffusers.utils import logging
15
-
16
-
17
- try:
18
- from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
19
- except ImportError:
20
- ORT_TO_NP_TYPE = {
21
- "tensor(bool)": np.bool_,
22
- "tensor(int8)": np.int8,
23
- "tensor(uint8)": np.uint8,
24
- "tensor(int16)": np.int16,
25
- "tensor(uint16)": np.uint16,
26
- "tensor(int32)": np.int32,
27
- "tensor(uint32)": np.uint32,
28
- "tensor(int64)": np.int64,
29
- "tensor(uint64)": np.uint64,
30
- "tensor(float16)": np.float16,
31
- "tensor(float)": np.float32,
32
- "tensor(double)": np.float64,
33
- }
34
-
35
- try:
36
- from diffusers.utils import PIL_INTERPOLATION
37
- except ImportError:
38
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
39
- PIL_INTERPOLATION = {
40
- "linear": PIL.Image.Resampling.BILINEAR,
41
- "bilinear": PIL.Image.Resampling.BILINEAR,
42
- "bicubic": PIL.Image.Resampling.BICUBIC,
43
- "lanczos": PIL.Image.Resampling.LANCZOS,
44
- "nearest": PIL.Image.Resampling.NEAREST,
45
- }
46
- else:
47
- PIL_INTERPOLATION = {
48
- "linear": PIL.Image.LINEAR,
49
- "bilinear": PIL.Image.BILINEAR,
50
- "bicubic": PIL.Image.BICUBIC,
51
- "lanczos": PIL.Image.LANCZOS,
52
- "nearest": PIL.Image.NEAREST,
53
- }
54
- # ------------------------------------------------------------------------------
55
-
56
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
-
58
- re_attention = re.compile(
59
- r"""
60
- \\\(|
61
- \\\)|
62
- \\\[|
63
- \\]|
64
- \\\\|
65
- \\|
66
- \(|
67
- \[|
68
- :([+-]?[.\d]+)\)|
69
- \)|
70
- ]|
71
- [^\\()\[\]:]+|
72
- :
73
- """,
74
- re.X,
75
- )
76
-
77
-
78
- def parse_prompt_attention(text):
79
- """
80
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
81
- Accepted tokens are:
82
- (abc) - increases attention to abc by a multiplier of 1.1
83
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
84
- [abc] - decreases attention to abc by a multiplier of 1.1
85
- \( - literal character '('
86
- \[ - literal character '['
87
- \) - literal character ')'
88
- \] - literal character ']'
89
- \\ - literal character '\'
90
- anything else - just text
91
- >>> parse_prompt_attention('normal text')
92
- [['normal text', 1.0]]
93
- >>> parse_prompt_attention('an (important) word')
94
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
95
- >>> parse_prompt_attention('(unbalanced')
96
- [['unbalanced', 1.1]]
97
- >>> parse_prompt_attention('\(literal\]')
98
- [['(literal]', 1.0]]
99
- >>> parse_prompt_attention('(unnecessary)(parens)')
100
- [['unnecessaryparens', 1.1]]
101
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
102
- [['a ', 1.0],
103
- ['house', 1.5730000000000004],
104
- [' ', 1.1],
105
- ['on', 1.0],
106
- [' a ', 1.1],
107
- ['hill', 0.55],
108
- [', sun, ', 1.1],
109
- ['sky', 1.4641000000000006],
110
- ['.', 1.1]]
111
- """
112
-
113
- res = []
114
- round_brackets = []
115
- square_brackets = []
116
-
117
- round_bracket_multiplier = 1.1
118
- square_bracket_multiplier = 1 / 1.1
119
-
120
- def multiply_range(start_position, multiplier):
121
- for p in range(start_position, len(res)):
122
- res[p][1] *= multiplier
123
-
124
- for m in re_attention.finditer(text):
125
- text = m.group(0)
126
- weight = m.group(1)
127
-
128
- if text.startswith("\\"):
129
- res.append([text[1:], 1.0])
130
- elif text == "(":
131
- round_brackets.append(len(res))
132
- elif text == "[":
133
- square_brackets.append(len(res))
134
- elif weight is not None and len(round_brackets) > 0:
135
- multiply_range(round_brackets.pop(), float(weight))
136
- elif text == ")" and len(round_brackets) > 0:
137
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
138
- elif text == "]" and len(square_brackets) > 0:
139
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
140
- else:
141
- res.append([text, 1.0])
142
-
143
- for pos in round_brackets:
144
- multiply_range(pos, round_bracket_multiplier)
145
-
146
- for pos in square_brackets:
147
- multiply_range(pos, square_bracket_multiplier)
148
-
149
- if len(res) == 0:
150
- res = [["", 1.0]]
151
-
152
- # merge runs of identical weights
153
- i = 0
154
- while i + 1 < len(res):
155
- if res[i][1] == res[i + 1][1]:
156
- res[i][0] += res[i + 1][0]
157
- res.pop(i + 1)
158
- else:
159
- i += 1
160
-
161
- return res
162
-
163
-
164
- def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
165
- r"""
166
- Tokenize a list of prompts and return its tokens with weights of each token.
167
-
168
- No padding, starting or ending token is included.
169
- """
170
- tokens = []
171
- weights = []
172
- truncated = False
173
- for text in prompt:
174
- texts_and_weights = parse_prompt_attention(text)
175
- text_token = []
176
- text_weight = []
177
- for word, weight in texts_and_weights:
178
- # tokenize and discard the starting and the ending token
179
- token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
180
- text_token += list(token)
181
- # copy the weight by length of token
182
- text_weight += [weight] * len(token)
183
- # stop if the text is too long (longer than truncation limit)
184
- if len(text_token) > max_length:
185
- truncated = True
186
- break
187
- # truncate
188
- if len(text_token) > max_length:
189
- truncated = True
190
- text_token = text_token[:max_length]
191
- text_weight = text_weight[:max_length]
192
- tokens.append(text_token)
193
- weights.append(text_weight)
194
- if truncated:
195
- logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
196
- return tokens, weights
197
-
198
-
199
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
200
- r"""
201
- Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
202
- """
203
- max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
204
- weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
205
- for i in range(len(tokens)):
206
- tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
207
- if no_boseos_middle:
208
- weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
209
- else:
210
- w = []
211
- if len(weights[i]) == 0:
212
- w = [1.0] * weights_length
213
- else:
214
- for j in range(max_embeddings_multiples):
215
- w.append(1.0) # weight for starting token in this chunk
216
- w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
217
- w.append(1.0) # weight for ending token in this chunk
218
- w += [1.0] * (weights_length - len(w))
219
- weights[i] = w[:]
220
-
221
- return tokens, weights
222
-
223
-
224
- def get_unweighted_text_embeddings(
225
- pipe,
226
- text_input: np.array,
227
- chunk_length: int,
228
- no_boseos_middle: Optional[bool] = True,
229
- ):
230
- """
231
- When the length of tokens is a multiple of the capacity of the text encoder,
232
- it should be split into chunks and sent to the text encoder individually.
233
- """
234
- max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
235
- if max_embeddings_multiples > 1:
236
- text_embeddings = []
237
- for i in range(max_embeddings_multiples):
238
- # extract the i-th chunk
239
- text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
240
-
241
- # cover the head and the tail by the starting and the ending tokens
242
- text_input_chunk[:, 0] = text_input[0, 0]
243
- text_input_chunk[:, -1] = text_input[0, -1]
244
-
245
- text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
246
-
247
- if no_boseos_middle:
248
- if i == 0:
249
- # discard the ending token
250
- text_embedding = text_embedding[:, :-1]
251
- elif i == max_embeddings_multiples - 1:
252
- # discard the starting token
253
- text_embedding = text_embedding[:, 1:]
254
- else:
255
- # discard both starting and ending tokens
256
- text_embedding = text_embedding[:, 1:-1]
257
-
258
- text_embeddings.append(text_embedding)
259
- text_embeddings = np.concatenate(text_embeddings, axis=1)
260
- else:
261
- text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
262
- return text_embeddings
263
-
264
-
265
- def get_weighted_text_embeddings(
266
- pipe,
267
- prompt: Union[str, List[str]],
268
- uncond_prompt: Optional[Union[str, List[str]]] = None,
269
- max_embeddings_multiples: Optional[int] = 4,
270
- no_boseos_middle: Optional[bool] = False,
271
- skip_parsing: Optional[bool] = False,
272
- skip_weighting: Optional[bool] = False,
273
- **kwargs,
274
- ):
275
- r"""
276
- Prompts can be assigned with local weights using brackets. For example,
277
- prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
278
- and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
279
-
280
- Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
281
-
282
- Args:
283
- pipe (`OnnxStableDiffusionPipeline`):
284
- Pipe to provide access to the tokenizer and the text encoder.
285
- prompt (`str` or `List[str]`):
286
- The prompt or prompts to guide the image generation.
287
- uncond_prompt (`str` or `List[str]`):
288
- The unconditional prompt or prompts for guide the image generation. If unconditional prompt
289
- is provided, the embeddings of prompt and uncond_prompt are concatenated.
290
- max_embeddings_multiples (`int`, *optional*, defaults to `1`):
291
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
292
- no_boseos_middle (`bool`, *optional*, defaults to `False`):
293
- If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
294
- ending token in each of the chunk in the middle.
295
- skip_parsing (`bool`, *optional*, defaults to `False`):
296
- Skip the parsing of brackets.
297
- skip_weighting (`bool`, *optional*, defaults to `False`):
298
- Skip the weighting. When the parsing is skipped, it is forced True.
299
- """
300
- max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
301
- if isinstance(prompt, str):
302
- prompt = [prompt]
303
-
304
- if not skip_parsing:
305
- prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
306
- if uncond_prompt is not None:
307
- if isinstance(uncond_prompt, str):
308
- uncond_prompt = [uncond_prompt]
309
- uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
310
- else:
311
- prompt_tokens = [
312
- token[1:-1]
313
- for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
314
- ]
315
- prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
316
- if uncond_prompt is not None:
317
- if isinstance(uncond_prompt, str):
318
- uncond_prompt = [uncond_prompt]
319
- uncond_tokens = [
320
- token[1:-1]
321
- for token in pipe.tokenizer(
322
- uncond_prompt,
323
- max_length=max_length,
324
- truncation=True,
325
- return_tensors="np",
326
- ).input_ids
327
- ]
328
- uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
329
-
330
- # round up the longest length of tokens to a multiple of (model_max_length - 2)
331
- max_length = max([len(token) for token in prompt_tokens])
332
- if uncond_prompt is not None:
333
- max_length = max(max_length, max([len(token) for token in uncond_tokens]))
334
-
335
- max_embeddings_multiples = min(
336
- max_embeddings_multiples,
337
- (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
338
- )
339
- max_embeddings_multiples = max(1, max_embeddings_multiples)
340
- max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
341
-
342
- # pad the length of tokens and weights
343
- bos = pipe.tokenizer.bos_token_id
344
- eos = pipe.tokenizer.eos_token_id
345
- pad = getattr(pipe.tokenizer, "pad_token_id", eos)
346
- prompt_tokens, prompt_weights = pad_tokens_and_weights(
347
- prompt_tokens,
348
- prompt_weights,
349
- max_length,
350
- bos,
351
- eos,
352
- pad,
353
- no_boseos_middle=no_boseos_middle,
354
- chunk_length=pipe.tokenizer.model_max_length,
355
- )
356
- prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
357
- if uncond_prompt is not None:
358
- uncond_tokens, uncond_weights = pad_tokens_and_weights(
359
- uncond_tokens,
360
- uncond_weights,
361
- max_length,
362
- bos,
363
- eos,
364
- pad,
365
- no_boseos_middle=no_boseos_middle,
366
- chunk_length=pipe.tokenizer.model_max_length,
367
- )
368
- uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
369
-
370
- # get the embeddings
371
- text_embeddings = get_unweighted_text_embeddings(
372
- pipe,
373
- prompt_tokens,
374
- pipe.tokenizer.model_max_length,
375
- no_boseos_middle=no_boseos_middle,
376
- )
377
- prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
378
- if uncond_prompt is not None:
379
- uncond_embeddings = get_unweighted_text_embeddings(
380
- pipe,
381
- uncond_tokens,
382
- pipe.tokenizer.model_max_length,
383
- no_boseos_middle=no_boseos_middle,
384
- )
385
- uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
386
-
387
- # assign weights to the prompts and normalize in the sense of mean
388
- # TODO: should we normalize by chunk or in a whole (current implementation)?
389
- if (not skip_parsing) and (not skip_weighting):
390
- previous_mean = text_embeddings.mean(axis=(-2, -1))
391
- text_embeddings *= prompt_weights[:, :, None]
392
- text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
393
- if uncond_prompt is not None:
394
- previous_mean = uncond_embeddings.mean(axis=(-2, -1))
395
- uncond_embeddings *= uncond_weights[:, :, None]
396
- uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
397
-
398
- # For classifier free guidance, we need to do two forward passes.
399
- # Here we concatenate the unconditional and text embeddings into a single batch
400
- # to avoid doing two forward passes
401
- if uncond_prompt is not None:
402
- return text_embeddings, uncond_embeddings
403
-
404
- return text_embeddings
405
-
406
-
407
- def preprocess_image(image):
408
- w, h = image.size
409
- w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
410
- image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
411
- image = np.array(image).astype(np.float32) / 255.0
412
- image = image[None].transpose(0, 3, 1, 2)
413
- return 2.0 * image - 1.0
414
-
415
-
416
- def preprocess_mask(mask, scale_factor=8):
417
- mask = mask.convert("L")
418
- w, h = mask.size
419
- w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
420
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
421
- mask = np.array(mask).astype(np.float32) / 255.0
422
- mask = np.tile(mask, (4, 1, 1))
423
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
424
- mask = 1 - mask # repaint white, keep black
425
- return mask
426
-
427
-
428
- class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
429
- r"""
430
- Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
431
- weighting in prompt.
432
-
433
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
434
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
435
- """
436
- if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
437
-
438
- def __init__(
439
- self,
440
- vae_encoder: OnnxRuntimeModel,
441
- vae_decoder: OnnxRuntimeModel,
442
- text_encoder: OnnxRuntimeModel,
443
- tokenizer: CLIPTokenizer,
444
- unet: OnnxRuntimeModel,
445
- scheduler: SchedulerMixin,
446
- safety_checker: OnnxRuntimeModel,
447
- feature_extractor: CLIPImageProcessor,
448
- requires_safety_checker: bool = True,
449
- ):
450
- super().__init__(
451
- vae_encoder=vae_encoder,
452
- vae_decoder=vae_decoder,
453
- text_encoder=text_encoder,
454
- tokenizer=tokenizer,
455
- unet=unet,
456
- scheduler=scheduler,
457
- safety_checker=safety_checker,
458
- feature_extractor=feature_extractor,
459
- requires_safety_checker=requires_safety_checker,
460
- )
461
- self.__init__additional__()
462
-
463
- else:
464
-
465
- def __init__(
466
- self,
467
- vae_encoder: OnnxRuntimeModel,
468
- vae_decoder: OnnxRuntimeModel,
469
- text_encoder: OnnxRuntimeModel,
470
- tokenizer: CLIPTokenizer,
471
- unet: OnnxRuntimeModel,
472
- scheduler: SchedulerMixin,
473
- safety_checker: OnnxRuntimeModel,
474
- feature_extractor: CLIPImageProcessor,
475
- ):
476
- super().__init__(
477
- vae_encoder=vae_encoder,
478
- vae_decoder=vae_decoder,
479
- text_encoder=text_encoder,
480
- tokenizer=tokenizer,
481
- unet=unet,
482
- scheduler=scheduler,
483
- safety_checker=safety_checker,
484
- feature_extractor=feature_extractor,
485
- )
486
- self.__init__additional__()
487
-
488
- def __init__additional__(self):
489
- self.unet_in_channels = 4
490
- self.vae_scale_factor = 8
491
-
492
- def _encode_prompt(
493
- self,
494
- prompt,
495
- num_images_per_prompt,
496
- do_classifier_free_guidance,
497
- negative_prompt,
498
- max_embeddings_multiples,
499
- ):
500
- r"""
501
- Encodes the prompt into text encoder hidden states.
502
-
503
- Args:
504
- prompt (`str` or `list(int)`):
505
- prompt to be encoded
506
- num_images_per_prompt (`int`):
507
- number of images that should be generated per prompt
508
- do_classifier_free_guidance (`bool`):
509
- whether to use classifier free guidance or not
510
- negative_prompt (`str` or `List[str]`):
511
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
512
- if `guidance_scale` is less than `1`).
513
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
514
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
515
- """
516
- batch_size = len(prompt) if isinstance(prompt, list) else 1
517
-
518
- if negative_prompt is None:
519
- negative_prompt = [""] * batch_size
520
- elif isinstance(negative_prompt, str):
521
- negative_prompt = [negative_prompt] * batch_size
522
- if batch_size != len(negative_prompt):
523
- raise ValueError(
524
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
525
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
526
- " the batch size of `prompt`."
527
- )
528
-
529
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
530
- pipe=self,
531
- prompt=prompt,
532
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
533
- max_embeddings_multiples=max_embeddings_multiples,
534
- )
535
-
536
- text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
537
- if do_classifier_free_guidance:
538
- uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
539
- text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
540
-
541
- return text_embeddings
542
-
543
- def check_inputs(self, prompt, height, width, strength, callback_steps):
544
- if not isinstance(prompt, str) and not isinstance(prompt, list):
545
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
546
-
547
- if strength < 0 or strength > 1:
548
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
549
-
550
- if height % 8 != 0 or width % 8 != 0:
551
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
552
-
553
- if (callback_steps is None) or (
554
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
555
- ):
556
- raise ValueError(
557
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
558
- f" {type(callback_steps)}."
559
- )
560
-
561
- def get_timesteps(self, num_inference_steps, strength, is_text2img):
562
- if is_text2img:
563
- return self.scheduler.timesteps, num_inference_steps
564
- else:
565
- # get the original timestep using init_timestep
566
- offset = self.scheduler.config.get("steps_offset", 0)
567
- init_timestep = int(num_inference_steps * strength) + offset
568
- init_timestep = min(init_timestep, num_inference_steps)
569
-
570
- t_start = max(num_inference_steps - init_timestep + offset, 0)
571
- timesteps = self.scheduler.timesteps[t_start:]
572
- return timesteps, num_inference_steps - t_start
573
-
574
- def run_safety_checker(self, image):
575
- if self.safety_checker is not None:
576
- safety_checker_input = self.feature_extractor(
577
- self.numpy_to_pil(image), return_tensors="np"
578
- ).pixel_values.astype(image.dtype)
579
- # There will throw an error if use safety_checker directly and batchsize>1
580
- images, has_nsfw_concept = [], []
581
- for i in range(image.shape[0]):
582
- image_i, has_nsfw_concept_i = self.safety_checker(
583
- clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
584
- )
585
- images.append(image_i)
586
- has_nsfw_concept.append(has_nsfw_concept_i[0])
587
- image = np.concatenate(images)
588
- else:
589
- has_nsfw_concept = None
590
- return image, has_nsfw_concept
591
-
592
- def decode_latents(self, latents):
593
- latents = 1 / 0.18215 * latents
594
- # image = self.vae_decoder(latent_sample=latents)[0]
595
- # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
596
- image = np.concatenate(
597
- [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
598
- )
599
- image = np.clip(image / 2 + 0.5, 0, 1)
600
- image = image.transpose((0, 2, 3, 1))
601
- return image
602
-
603
- def prepare_extra_step_kwargs(self, generator, eta):
604
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
605
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
606
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
607
- # and should be between [0, 1]
608
-
609
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
610
- extra_step_kwargs = {}
611
- if accepts_eta:
612
- extra_step_kwargs["eta"] = eta
613
-
614
- # check if the scheduler accepts generator
615
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
616
- if accepts_generator:
617
- extra_step_kwargs["generator"] = generator
618
- return extra_step_kwargs
619
-
620
- def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
621
- if image is None:
622
- shape = (
623
- batch_size,
624
- self.unet_in_channels,
625
- height // self.vae_scale_factor,
626
- width // self.vae_scale_factor,
627
- )
628
-
629
- if latents is None:
630
- latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
631
- else:
632
- if latents.shape != shape:
633
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
634
-
635
- # scale the initial noise by the standard deviation required by the scheduler
636
- latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
637
- return latents, None, None
638
- else:
639
- init_latents = self.vae_encoder(sample=image)[0]
640
- init_latents = 0.18215 * init_latents
641
- init_latents = np.concatenate([init_latents] * batch_size, axis=0)
642
- init_latents_orig = init_latents
643
- shape = init_latents.shape
644
-
645
- # add noise to latents using the timesteps
646
- noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
647
- latents = self.scheduler.add_noise(
648
- torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
649
- ).numpy()
650
- return latents, init_latents_orig, noise
651
-
652
- @torch.no_grad()
653
- def __call__(
654
- self,
655
- prompt: Union[str, List[str]],
656
- negative_prompt: Optional[Union[str, List[str]]] = None,
657
- image: Union[np.ndarray, PIL.Image.Image] = None,
658
- mask_image: Union[np.ndarray, PIL.Image.Image] = None,
659
- height: int = 512,
660
- width: int = 512,
661
- num_inference_steps: int = 50,
662
- guidance_scale: float = 7.5,
663
- strength: float = 0.8,
664
- num_images_per_prompt: Optional[int] = 1,
665
- eta: float = 0.0,
666
- generator: Optional[torch.Generator] = None,
667
- latents: Optional[np.ndarray] = None,
668
- max_embeddings_multiples: Optional[int] = 3,
669
- output_type: Optional[str] = "pil",
670
- return_dict: bool = True,
671
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
672
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
673
- callback_steps: int = 1,
674
- **kwargs,
675
- ):
676
- r"""
677
- Function invoked when calling the pipeline for generation.
678
-
679
- Args:
680
- prompt (`str` or `List[str]`):
681
- The prompt or prompts to guide the image generation.
682
- negative_prompt (`str` or `List[str]`, *optional*):
683
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
684
- if `guidance_scale` is less than `1`).
685
- image (`np.ndarray` or `PIL.Image.Image`):
686
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
687
- process.
688
- mask_image (`np.ndarray` or `PIL.Image.Image`):
689
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
690
- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
691
- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
692
- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
693
- height (`int`, *optional*, defaults to 512):
694
- The height in pixels of the generated image.
695
- width (`int`, *optional*, defaults to 512):
696
- The width in pixels of the generated image.
697
- num_inference_steps (`int`, *optional*, defaults to 50):
698
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
699
- expense of slower inference.
700
- guidance_scale (`float`, *optional*, defaults to 7.5):
701
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
702
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
703
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
704
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
705
- usually at the expense of lower image quality.
706
- strength (`float`, *optional*, defaults to 0.8):
707
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
708
- `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
709
- number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
710
- noise will be maximum and the denoising process will run for the full number of iterations specified in
711
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
712
- num_images_per_prompt (`int`, *optional*, defaults to 1):
713
- The number of images to generate per prompt.
714
- eta (`float`, *optional*, defaults to 0.0):
715
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
716
- [`schedulers.DDIMScheduler`], will be ignored for others.
717
- generator (`torch.Generator`, *optional*):
718
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
719
- deterministic.
720
- latents (`np.ndarray`, *optional*):
721
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
722
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
723
- tensor will ge generated by sampling using the supplied random `generator`.
724
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
725
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
726
- output_type (`str`, *optional*, defaults to `"pil"`):
727
- The output format of the generate image. Choose between
728
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
729
- return_dict (`bool`, *optional*, defaults to `True`):
730
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
731
- plain tuple.
732
- callback (`Callable`, *optional*):
733
- A function that will be called every `callback_steps` steps during inference. The function will be
734
- called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
735
- is_cancelled_callback (`Callable`, *optional*):
736
- A function that will be called every `callback_steps` steps during inference. If the function returns
737
- `True`, the inference will be cancelled.
738
- callback_steps (`int`, *optional*, defaults to 1):
739
- The frequency at which the `callback` function will be called. If not specified, the callback will be
740
- called at every step.
741
-
742
- Returns:
743
- `None` if cancelled by `is_cancelled_callback`,
744
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
745
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
746
- When returning a tuple, the first element is a list with the generated images, and the second element is a
747
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
748
- (nsfw) content, according to the `safety_checker`.
749
- """
750
- # 0. Default height and width to unet
751
- height = height or self.unet.config.sample_size * self.vae_scale_factor
752
- width = width or self.unet.config.sample_size * self.vae_scale_factor
753
-
754
- # 1. Check inputs. Raise error if not correct
755
- self.check_inputs(prompt, height, width, strength, callback_steps)
756
-
757
- # 2. Define call parameters
758
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
759
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
760
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
761
- # corresponds to doing no classifier free guidance.
762
- do_classifier_free_guidance = guidance_scale > 1.0
763
-
764
- # 3. Encode input prompt
765
- text_embeddings = self._encode_prompt(
766
- prompt,
767
- num_images_per_prompt,
768
- do_classifier_free_guidance,
769
- negative_prompt,
770
- max_embeddings_multiples,
771
- )
772
- dtype = text_embeddings.dtype
773
-
774
- # 4. Preprocess image and mask
775
- if isinstance(image, PIL.Image.Image):
776
- image = preprocess_image(image)
777
- if image is not None:
778
- image = image.astype(dtype)
779
- if isinstance(mask_image, PIL.Image.Image):
780
- mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
781
- if mask_image is not None:
782
- mask = mask_image.astype(dtype)
783
- mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
784
- else:
785
- mask = None
786
-
787
- # 5. set timesteps
788
- self.scheduler.set_timesteps(num_inference_steps)
789
- timestep_dtype = next(
790
- (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
791
- )
792
- timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
793
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
794
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
795
-
796
- # 6. Prepare latent variables
797
- latents, init_latents_orig, noise = self.prepare_latents(
798
- image,
799
- latent_timestep,
800
- batch_size * num_images_per_prompt,
801
- height,
802
- width,
803
- dtype,
804
- generator,
805
- latents,
806
- )
807
-
808
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
809
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
810
-
811
- # 8. Denoising loop
812
- for i, t in enumerate(self.progress_bar(timesteps)):
813
- # expand the latents if we are doing classifier free guidance
814
- latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
815
- latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
816
- latent_model_input = latent_model_input.numpy()
817
-
818
- # predict the noise residual
819
- noise_pred = self.unet(
820
- sample=latent_model_input,
821
- timestep=np.array([t], dtype=timestep_dtype),
822
- encoder_hidden_states=text_embeddings,
823
- )
824
- noise_pred = noise_pred[0]
825
-
826
- # perform guidance
827
- if do_classifier_free_guidance:
828
- noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
829
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
830
-
831
- # compute the previous noisy sample x_t -> x_t-1
832
- scheduler_output = self.scheduler.step(
833
- torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
834
- )
835
- latents = scheduler_output.prev_sample.numpy()
836
-
837
- if mask is not None:
838
- # masking
839
- init_latents_proper = self.scheduler.add_noise(
840
- torch.from_numpy(init_latents_orig),
841
- torch.from_numpy(noise),
842
- t,
843
- ).numpy()
844
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
845
-
846
- # call the callback, if provided
847
- if i % callback_steps == 0:
848
- if callback is not None:
849
- callback(i, t, latents)
850
- if is_cancelled_callback is not None and is_cancelled_callback():
851
- return None
852
-
853
- # 9. Post-processing
854
- image = self.decode_latents(latents)
855
-
856
- # 10. Run safety checker
857
- image, has_nsfw_concept = self.run_safety_checker(image)
858
-
859
- # 11. Convert to PIL
860
- if output_type == "pil":
861
- image = self.numpy_to_pil(image)
862
-
863
- if not return_dict:
864
- return image, has_nsfw_concept
865
-
866
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
867
-
868
- def text2img(
869
- self,
870
- prompt: Union[str, List[str]],
871
- negative_prompt: Optional[Union[str, List[str]]] = None,
872
- height: int = 512,
873
- width: int = 512,
874
- num_inference_steps: int = 50,
875
- guidance_scale: float = 7.5,
876
- num_images_per_prompt: Optional[int] = 1,
877
- eta: float = 0.0,
878
- generator: Optional[torch.Generator] = None,
879
- latents: Optional[np.ndarray] = None,
880
- max_embeddings_multiples: Optional[int] = 3,
881
- output_type: Optional[str] = "pil",
882
- return_dict: bool = True,
883
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
884
- callback_steps: int = 1,
885
- **kwargs,
886
- ):
887
- r"""
888
- Function for text-to-image generation.
889
- Args:
890
- prompt (`str` or `List[str]`):
891
- The prompt or prompts to guide the image generation.
892
- negative_prompt (`str` or `List[str]`, *optional*):
893
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
894
- if `guidance_scale` is less than `1`).
895
- height (`int`, *optional*, defaults to 512):
896
- The height in pixels of the generated image.
897
- width (`int`, *optional*, defaults to 512):
898
- The width in pixels of the generated image.
899
- num_inference_steps (`int`, *optional*, defaults to 50):
900
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
901
- expense of slower inference.
902
- guidance_scale (`float`, *optional*, defaults to 7.5):
903
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
904
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
905
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
906
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
907
- usually at the expense of lower image quality.
908
- num_images_per_prompt (`int`, *optional*, defaults to 1):
909
- The number of images to generate per prompt.
910
- eta (`float`, *optional*, defaults to 0.0):
911
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
912
- [`schedulers.DDIMScheduler`], will be ignored for others.
913
- generator (`torch.Generator`, *optional*):
914
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
915
- deterministic.
916
- latents (`np.ndarray`, *optional*):
917
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
918
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
919
- tensor will ge generated by sampling using the supplied random `generator`.
920
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
921
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
922
- output_type (`str`, *optional*, defaults to `"pil"`):
923
- The output format of the generate image. Choose between
924
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
925
- return_dict (`bool`, *optional*, defaults to `True`):
926
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
927
- plain tuple.
928
- callback (`Callable`, *optional*):
929
- A function that will be called every `callback_steps` steps during inference. The function will be
930
- called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
931
- callback_steps (`int`, *optional*, defaults to 1):
932
- The frequency at which the `callback` function will be called. If not specified, the callback will be
933
- called at every step.
934
- Returns:
935
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
936
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
937
- When returning a tuple, the first element is a list with the generated images, and the second element is a
938
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
939
- (nsfw) content, according to the `safety_checker`.
940
- """
941
- return self.__call__(
942
- prompt=prompt,
943
- negative_prompt=negative_prompt,
944
- height=height,
945
- width=width,
946
- num_inference_steps=num_inference_steps,
947
- guidance_scale=guidance_scale,
948
- num_images_per_prompt=num_images_per_prompt,
949
- eta=eta,
950
- generator=generator,
951
- latents=latents,
952
- max_embeddings_multiples=max_embeddings_multiples,
953
- output_type=output_type,
954
- return_dict=return_dict,
955
- callback=callback,
956
- callback_steps=callback_steps,
957
- **kwargs,
958
- )
959
-
960
- def img2img(
961
- self,
962
- image: Union[np.ndarray, PIL.Image.Image],
963
- prompt: Union[str, List[str]],
964
- negative_prompt: Optional[Union[str, List[str]]] = None,
965
- strength: float = 0.8,
966
- num_inference_steps: Optional[int] = 50,
967
- guidance_scale: Optional[float] = 7.5,
968
- num_images_per_prompt: Optional[int] = 1,
969
- eta: Optional[float] = 0.0,
970
- generator: Optional[torch.Generator] = None,
971
- max_embeddings_multiples: Optional[int] = 3,
972
- output_type: Optional[str] = "pil",
973
- return_dict: bool = True,
974
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
975
- callback_steps: int = 1,
976
- **kwargs,
977
- ):
978
- r"""
979
- Function for image-to-image generation.
980
- Args:
981
- image (`np.ndarray` or `PIL.Image.Image`):
982
- `Image`, or ndarray representing an image batch, that will be used as the starting point for the
983
- process.
984
- prompt (`str` or `List[str]`):
985
- The prompt or prompts to guide the image generation.
986
- negative_prompt (`str` or `List[str]`, *optional*):
987
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
988
- if `guidance_scale` is less than `1`).
989
- strength (`float`, *optional*, defaults to 0.8):
990
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
991
- `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
992
- number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
993
- noise will be maximum and the denoising process will run for the full number of iterations specified in
994
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
995
- num_inference_steps (`int`, *optional*, defaults to 50):
996
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
997
- expense of slower inference. This parameter will be modulated by `strength`.
998
- guidance_scale (`float`, *optional*, defaults to 7.5):
999
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1000
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1001
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1002
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1003
- usually at the expense of lower image quality.
1004
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1005
- The number of images to generate per prompt.
1006
- eta (`float`, *optional*, defaults to 0.0):
1007
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1008
- [`schedulers.DDIMScheduler`], will be ignored for others.
1009
- generator (`torch.Generator`, *optional*):
1010
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1011
- deterministic.
1012
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1013
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
1014
- output_type (`str`, *optional*, defaults to `"pil"`):
1015
- The output format of the generate image. Choose between
1016
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1017
- return_dict (`bool`, *optional*, defaults to `True`):
1018
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1019
- plain tuple.
1020
- callback (`Callable`, *optional*):
1021
- A function that will be called every `callback_steps` steps during inference. The function will be
1022
- called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
1023
- callback_steps (`int`, *optional*, defaults to 1):
1024
- The frequency at which the `callback` function will be called. If not specified, the callback will be
1025
- called at every step.
1026
- Returns:
1027
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1028
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1029
- When returning a tuple, the first element is a list with the generated images, and the second element is a
1030
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1031
- (nsfw) content, according to the `safety_checker`.
1032
- """
1033
- return self.__call__(
1034
- prompt=prompt,
1035
- negative_prompt=negative_prompt,
1036
- image=image,
1037
- num_inference_steps=num_inference_steps,
1038
- guidance_scale=guidance_scale,
1039
- strength=strength,
1040
- num_images_per_prompt=num_images_per_prompt,
1041
- eta=eta,
1042
- generator=generator,
1043
- max_embeddings_multiples=max_embeddings_multiples,
1044
- output_type=output_type,
1045
- return_dict=return_dict,
1046
- callback=callback,
1047
- callback_steps=callback_steps,
1048
- **kwargs,
1049
- )
1050
-
1051
- def inpaint(
1052
- self,
1053
- image: Union[np.ndarray, PIL.Image.Image],
1054
- mask_image: Union[np.ndarray, PIL.Image.Image],
1055
- prompt: Union[str, List[str]],
1056
- negative_prompt: Optional[Union[str, List[str]]] = None,
1057
- strength: float = 0.8,
1058
- num_inference_steps: Optional[int] = 50,
1059
- guidance_scale: Optional[float] = 7.5,
1060
- num_images_per_prompt: Optional[int] = 1,
1061
- eta: Optional[float] = 0.0,
1062
- generator: Optional[torch.Generator] = None,
1063
- max_embeddings_multiples: Optional[int] = 3,
1064
- output_type: Optional[str] = "pil",
1065
- return_dict: bool = True,
1066
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
1067
- callback_steps: int = 1,
1068
- **kwargs,
1069
- ):
1070
- r"""
1071
- Function for inpaint.
1072
- Args:
1073
- image (`np.ndarray` or `PIL.Image.Image`):
1074
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
1075
- process. This is the image whose masked region will be inpainted.
1076
- mask_image (`np.ndarray` or `PIL.Image.Image`):
1077
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1078
- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1079
- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1080
- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1081
- prompt (`str` or `List[str]`):
1082
- The prompt or prompts to guide the image generation.
1083
- negative_prompt (`str` or `List[str]`, *optional*):
1084
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1085
- if `guidance_scale` is less than `1`).
1086
- strength (`float`, *optional*, defaults to 0.8):
1087
- Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1088
- is 1, the denoising process will be run on the masked area for the full number of iterations specified
1089
- in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1090
- noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1091
- num_inference_steps (`int`, *optional*, defaults to 50):
1092
- The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1093
- the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1094
- guidance_scale (`float`, *optional*, defaults to 7.5):
1095
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1096
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1097
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1098
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1099
- usually at the expense of lower image quality.
1100
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1101
- The number of images to generate per prompt.
1102
- eta (`float`, *optional*, defaults to 0.0):
1103
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1104
- [`schedulers.DDIMScheduler`], will be ignored for others.
1105
- generator (`torch.Generator`, *optional*):
1106
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1107
- deterministic.
1108
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1109
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
1110
- output_type (`str`, *optional*, defaults to `"pil"`):
1111
- The output format of the generate image. Choose between
1112
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1113
- return_dict (`bool`, *optional*, defaults to `True`):
1114
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1115
- plain tuple.
1116
- callback (`Callable`, *optional*):
1117
- A function that will be called every `callback_steps` steps during inference. The function will be
1118
- called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
1119
- callback_steps (`int`, *optional*, defaults to 1):
1120
- The frequency at which the `callback` function will be called. If not specified, the callback will be
1121
- called at every step.
1122
- Returns:
1123
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1124
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1125
- When returning a tuple, the first element is a list with the generated images, and the second element is a
1126
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1127
- (nsfw) content, according to the `safety_checker`.
1128
- """
1129
- return self.__call__(
1130
- prompt=prompt,
1131
- negative_prompt=negative_prompt,
1132
- image=image,
1133
- mask_image=mask_image,
1134
- num_inference_steps=num_inference_steps,
1135
- guidance_scale=guidance_scale,
1136
- strength=strength,
1137
- num_images_per_prompt=num_images_per_prompt,
1138
- eta=eta,
1139
- generator=generator,
1140
- max_embeddings_multiples=max_embeddings_multiples,
1141
- output_type=output_type,
1142
- return_dict=return_dict,
1143
- callback=callback,
1144
- callback_steps=callback_steps,
1145
- **kwargs,
1146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/magic_mix.py DELETED
@@ -1,152 +0,0 @@
1
- from typing import Union
2
-
3
- import torch
4
- from PIL import Image
5
- from torchvision import transforms as tfms
6
- from tqdm.auto import tqdm
7
- from transformers import CLIPTextModel, CLIPTokenizer
8
-
9
- from diffusers import (
10
- AutoencoderKL,
11
- DDIMScheduler,
12
- DiffusionPipeline,
13
- LMSDiscreteScheduler,
14
- PNDMScheduler,
15
- UNet2DConditionModel,
16
- )
17
-
18
-
19
- class MagicMixPipeline(DiffusionPipeline):
20
- def __init__(
21
- self,
22
- vae: AutoencoderKL,
23
- text_encoder: CLIPTextModel,
24
- tokenizer: CLIPTokenizer,
25
- unet: UNet2DConditionModel,
26
- scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
27
- ):
28
- super().__init__()
29
-
30
- self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
31
-
32
- # convert PIL image to latents
33
- def encode(self, img):
34
- with torch.no_grad():
35
- latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
36
- latent = 0.18215 * latent.latent_dist.sample()
37
- return latent
38
-
39
- # convert latents to PIL image
40
- def decode(self, latent):
41
- latent = (1 / 0.18215) * latent
42
- with torch.no_grad():
43
- img = self.vae.decode(latent).sample
44
- img = (img / 2 + 0.5).clamp(0, 1)
45
- img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
46
- img = (img * 255).round().astype("uint8")
47
- return Image.fromarray(img[0])
48
-
49
- # convert prompt into text embeddings, also unconditional embeddings
50
- def prep_text(self, prompt):
51
- text_input = self.tokenizer(
52
- prompt,
53
- padding="max_length",
54
- max_length=self.tokenizer.model_max_length,
55
- truncation=True,
56
- return_tensors="pt",
57
- )
58
-
59
- text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
60
-
61
- uncond_input = self.tokenizer(
62
- "",
63
- padding="max_length",
64
- max_length=self.tokenizer.model_max_length,
65
- truncation=True,
66
- return_tensors="pt",
67
- )
68
-
69
- uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
70
-
71
- return torch.cat([uncond_embedding, text_embedding])
72
-
73
- def __call__(
74
- self,
75
- img: Image.Image,
76
- prompt: str,
77
- kmin: float = 0.3,
78
- kmax: float = 0.6,
79
- mix_factor: float = 0.5,
80
- seed: int = 42,
81
- steps: int = 50,
82
- guidance_scale: float = 7.5,
83
- ) -> Image.Image:
84
- tmin = steps - int(kmin * steps)
85
- tmax = steps - int(kmax * steps)
86
-
87
- text_embeddings = self.prep_text(prompt)
88
-
89
- self.scheduler.set_timesteps(steps)
90
-
91
- width, height = img.size
92
- encoded = self.encode(img)
93
-
94
- torch.manual_seed(seed)
95
- noise = torch.randn(
96
- (1, self.unet.in_channels, height // 8, width // 8),
97
- ).to(self.device)
98
-
99
- latents = self.scheduler.add_noise(
100
- encoded,
101
- noise,
102
- timesteps=self.scheduler.timesteps[tmax],
103
- )
104
-
105
- input = torch.cat([latents] * 2)
106
-
107
- input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
108
-
109
- with torch.no_grad():
110
- pred = self.unet(
111
- input,
112
- self.scheduler.timesteps[tmax],
113
- encoder_hidden_states=text_embeddings,
114
- ).sample
115
-
116
- pred_uncond, pred_text = pred.chunk(2)
117
- pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
118
-
119
- latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
120
-
121
- for i, t in enumerate(tqdm(self.scheduler.timesteps)):
122
- if i > tmax:
123
- if i < tmin: # layout generation phase
124
- orig_latents = self.scheduler.add_noise(
125
- encoded,
126
- noise,
127
- timesteps=t,
128
- )
129
-
130
- input = (mix_factor * latents) + (
131
- 1 - mix_factor
132
- ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
133
- input = torch.cat([input] * 2)
134
-
135
- else: # content generation phase
136
- input = torch.cat([latents] * 2)
137
-
138
- input = self.scheduler.scale_model_input(input, t)
139
-
140
- with torch.no_grad():
141
- pred = self.unet(
142
- input,
143
- t,
144
- encoder_hidden_states=text_embeddings,
145
- ).sample
146
-
147
- pred_uncond, pred_text = pred.chunk(2)
148
- pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
149
-
150
- latents = self.scheduler.step(pred, t, latents).prev_sample
151
-
152
- return self.decode(latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/multilingual_stable_diffusion.py DELETED
@@ -1,436 +0,0 @@
1
- import inspect
2
- from typing import Callable, List, Optional, Union
3
-
4
- import torch
5
- from transformers import (
6
- CLIPImageProcessor,
7
- CLIPTextModel,
8
- CLIPTokenizer,
9
- MBart50TokenizerFast,
10
- MBartForConditionalGeneration,
11
- pipeline,
12
- )
13
-
14
- from diffusers import DiffusionPipeline
15
- from diffusers.configuration_utils import FrozenDict
16
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
18
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
19
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
20
- from diffusers.utils import deprecate, logging
21
-
22
-
23
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
-
25
-
26
- def detect_language(pipe, prompt, batch_size):
27
- """helper function to detect language(s) of prompt"""
28
-
29
- if batch_size == 1:
30
- preds = pipe(prompt, top_k=1, truncation=True, max_length=128)
31
- return preds[0]["label"]
32
- else:
33
- detected_languages = []
34
- for p in prompt:
35
- preds = pipe(p, top_k=1, truncation=True, max_length=128)
36
- detected_languages.append(preds[0]["label"])
37
-
38
- return detected_languages
39
-
40
-
41
- def translate_prompt(prompt, translation_tokenizer, translation_model, device):
42
- """helper function to translate prompt to English"""
43
-
44
- encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)
45
- generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)
46
- en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
47
-
48
- return en_trans[0]
49
-
50
-
51
- class MultilingualStableDiffusion(DiffusionPipeline):
52
- r"""
53
- Pipeline for text-to-image generation using Stable Diffusion in different languages.
54
-
55
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
56
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
57
-
58
- Args:
59
- detection_pipeline ([`pipeline`]):
60
- Transformers pipeline to detect prompt's language.
61
- translation_model ([`MBartForConditionalGeneration`]):
62
- Model to translate prompt to English, if necessary. Please refer to the
63
- [model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.
64
- translation_tokenizer ([`MBart50TokenizerFast`]):
65
- Tokenizer of the translation model.
66
- vae ([`AutoencoderKL`]):
67
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68
- text_encoder ([`CLIPTextModel`]):
69
- Frozen text-encoder. Stable Diffusion uses the text portion of
70
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
71
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
72
- tokenizer (`CLIPTokenizer`):
73
- Tokenizer of class
74
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
75
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
76
- scheduler ([`SchedulerMixin`]):
77
- A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
78
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
79
- safety_checker ([`StableDiffusionSafetyChecker`]):
80
- Classification module that estimates whether generated images could be considered offensive or harmful.
81
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
82
- feature_extractor ([`CLIPImageProcessor`]):
83
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
84
- """
85
-
86
- def __init__(
87
- self,
88
- detection_pipeline: pipeline,
89
- translation_model: MBartForConditionalGeneration,
90
- translation_tokenizer: MBart50TokenizerFast,
91
- vae: AutoencoderKL,
92
- text_encoder: CLIPTextModel,
93
- tokenizer: CLIPTokenizer,
94
- unet: UNet2DConditionModel,
95
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
96
- safety_checker: StableDiffusionSafetyChecker,
97
- feature_extractor: CLIPImageProcessor,
98
- ):
99
- super().__init__()
100
-
101
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
102
- deprecation_message = (
103
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
104
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
105
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
106
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
107
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
108
- " file"
109
- )
110
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
111
- new_config = dict(scheduler.config)
112
- new_config["steps_offset"] = 1
113
- scheduler._internal_dict = FrozenDict(new_config)
114
-
115
- if safety_checker is None:
116
- logger.warning(
117
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
118
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
119
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
120
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
121
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
122
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
123
- )
124
-
125
- self.register_modules(
126
- detection_pipeline=detection_pipeline,
127
- translation_model=translation_model,
128
- translation_tokenizer=translation_tokenizer,
129
- vae=vae,
130
- text_encoder=text_encoder,
131
- tokenizer=tokenizer,
132
- unet=unet,
133
- scheduler=scheduler,
134
- safety_checker=safety_checker,
135
- feature_extractor=feature_extractor,
136
- )
137
-
138
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
139
- r"""
140
- Enable sliced attention computation.
141
-
142
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
143
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
144
-
145
- Args:
146
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
147
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
148
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
149
- `attention_head_dim` must be a multiple of `slice_size`.
150
- """
151
- if slice_size == "auto":
152
- # half the attention head size is usually a good trade-off between
153
- # speed and memory
154
- slice_size = self.unet.config.attention_head_dim // 2
155
- self.unet.set_attention_slice(slice_size)
156
-
157
- def disable_attention_slicing(self):
158
- r"""
159
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
160
- back to computing attention in one step.
161
- """
162
- # set slice_size = `None` to disable `attention slicing`
163
- self.enable_attention_slicing(None)
164
-
165
- @torch.no_grad()
166
- def __call__(
167
- self,
168
- prompt: Union[str, List[str]],
169
- height: int = 512,
170
- width: int = 512,
171
- num_inference_steps: int = 50,
172
- guidance_scale: float = 7.5,
173
- negative_prompt: Optional[Union[str, List[str]]] = None,
174
- num_images_per_prompt: Optional[int] = 1,
175
- eta: float = 0.0,
176
- generator: Optional[torch.Generator] = None,
177
- latents: Optional[torch.FloatTensor] = None,
178
- output_type: Optional[str] = "pil",
179
- return_dict: bool = True,
180
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
181
- callback_steps: int = 1,
182
- **kwargs,
183
- ):
184
- r"""
185
- Function invoked when calling the pipeline for generation.
186
-
187
- Args:
188
- prompt (`str` or `List[str]`):
189
- The prompt or prompts to guide the image generation. Can be in different languages.
190
- height (`int`, *optional*, defaults to 512):
191
- The height in pixels of the generated image.
192
- width (`int`, *optional*, defaults to 512):
193
- The width in pixels of the generated image.
194
- num_inference_steps (`int`, *optional*, defaults to 50):
195
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
196
- expense of slower inference.
197
- guidance_scale (`float`, *optional*, defaults to 7.5):
198
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
199
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
200
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
201
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
202
- usually at the expense of lower image quality.
203
- negative_prompt (`str` or `List[str]`, *optional*):
204
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
205
- if `guidance_scale` is less than `1`).
206
- num_images_per_prompt (`int`, *optional*, defaults to 1):
207
- The number of images to generate per prompt.
208
- eta (`float`, *optional*, defaults to 0.0):
209
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
210
- [`schedulers.DDIMScheduler`], will be ignored for others.
211
- generator (`torch.Generator`, *optional*):
212
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
213
- deterministic.
214
- latents (`torch.FloatTensor`, *optional*):
215
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
216
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
217
- tensor will ge generated by sampling using the supplied random `generator`.
218
- output_type (`str`, *optional*, defaults to `"pil"`):
219
- The output format of the generate image. Choose between
220
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
221
- return_dict (`bool`, *optional*, defaults to `True`):
222
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
223
- plain tuple.
224
- callback (`Callable`, *optional*):
225
- A function that will be called every `callback_steps` steps during inference. The function will be
226
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
227
- callback_steps (`int`, *optional*, defaults to 1):
228
- The frequency at which the `callback` function will be called. If not specified, the callback will be
229
- called at every step.
230
-
231
- Returns:
232
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
233
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
234
- When returning a tuple, the first element is a list with the generated images, and the second element is a
235
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
236
- (nsfw) content, according to the `safety_checker`.
237
- """
238
- if isinstance(prompt, str):
239
- batch_size = 1
240
- elif isinstance(prompt, list):
241
- batch_size = len(prompt)
242
- else:
243
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
244
-
245
- if height % 8 != 0 or width % 8 != 0:
246
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
247
-
248
- if (callback_steps is None) or (
249
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
250
- ):
251
- raise ValueError(
252
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
253
- f" {type(callback_steps)}."
254
- )
255
-
256
- # detect language and translate if necessary
257
- prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)
258
- if batch_size == 1 and prompt_language != "en":
259
- prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)
260
-
261
- if isinstance(prompt, list):
262
- for index in range(batch_size):
263
- if prompt_language[index] != "en":
264
- p = translate_prompt(
265
- prompt[index], self.translation_tokenizer, self.translation_model, self.device
266
- )
267
- prompt[index] = p
268
-
269
- # get prompt text embeddings
270
- text_inputs = self.tokenizer(
271
- prompt,
272
- padding="max_length",
273
- max_length=self.tokenizer.model_max_length,
274
- return_tensors="pt",
275
- )
276
- text_input_ids = text_inputs.input_ids
277
-
278
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
279
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
280
- logger.warning(
281
- "The following part of your input was truncated because CLIP can only handle sequences up to"
282
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
283
- )
284
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
285
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
286
-
287
- # duplicate text embeddings for each generation per prompt, using mps friendly method
288
- bs_embed, seq_len, _ = text_embeddings.shape
289
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
290
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
291
-
292
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
293
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
294
- # corresponds to doing no classifier free guidance.
295
- do_classifier_free_guidance = guidance_scale > 1.0
296
- # get unconditional embeddings for classifier free guidance
297
- if do_classifier_free_guidance:
298
- uncond_tokens: List[str]
299
- if negative_prompt is None:
300
- uncond_tokens = [""] * batch_size
301
- elif type(prompt) is not type(negative_prompt):
302
- raise TypeError(
303
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
304
- f" {type(prompt)}."
305
- )
306
- elif isinstance(negative_prompt, str):
307
- # detect language and translate it if necessary
308
- negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)
309
- if negative_prompt_language != "en":
310
- negative_prompt = translate_prompt(
311
- negative_prompt, self.translation_tokenizer, self.translation_model, self.device
312
- )
313
- if isinstance(negative_prompt, str):
314
- uncond_tokens = [negative_prompt]
315
- elif batch_size != len(negative_prompt):
316
- raise ValueError(
317
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
318
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
319
- " the batch size of `prompt`."
320
- )
321
- else:
322
- # detect language and translate it if necessary
323
- if isinstance(negative_prompt, list):
324
- negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)
325
- for index in range(batch_size):
326
- if negative_prompt_languages[index] != "en":
327
- p = translate_prompt(
328
- negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device
329
- )
330
- negative_prompt[index] = p
331
- uncond_tokens = negative_prompt
332
-
333
- max_length = text_input_ids.shape[-1]
334
- uncond_input = self.tokenizer(
335
- uncond_tokens,
336
- padding="max_length",
337
- max_length=max_length,
338
- truncation=True,
339
- return_tensors="pt",
340
- )
341
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
342
-
343
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
344
- seq_len = uncond_embeddings.shape[1]
345
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
346
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
347
-
348
- # For classifier free guidance, we need to do two forward passes.
349
- # Here we concatenate the unconditional and text embeddings into a single batch
350
- # to avoid doing two forward passes
351
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
352
-
353
- # get the initial random noise unless the user supplied it
354
-
355
- # Unlike in other pipelines, latents need to be generated in the target device
356
- # for 1-to-1 results reproducibility with the CompVis implementation.
357
- # However this currently doesn't work in `mps`.
358
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
359
- latents_dtype = text_embeddings.dtype
360
- if latents is None:
361
- if self.device.type == "mps":
362
- # randn does not work reproducibly on mps
363
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
364
- self.device
365
- )
366
- else:
367
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
368
- else:
369
- if latents.shape != latents_shape:
370
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
371
- latents = latents.to(self.device)
372
-
373
- # set timesteps
374
- self.scheduler.set_timesteps(num_inference_steps)
375
-
376
- # Some schedulers like PNDM have timesteps as arrays
377
- # It's more optimized to move all timesteps to correct device beforehand
378
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
379
-
380
- # scale the initial noise by the standard deviation required by the scheduler
381
- latents = latents * self.scheduler.init_noise_sigma
382
-
383
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
384
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
385
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
386
- # and should be between [0, 1]
387
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
388
- extra_step_kwargs = {}
389
- if accepts_eta:
390
- extra_step_kwargs["eta"] = eta
391
-
392
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
393
- # expand the latents if we are doing classifier free guidance
394
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
395
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
396
-
397
- # predict the noise residual
398
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
399
-
400
- # perform guidance
401
- if do_classifier_free_guidance:
402
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
403
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
404
-
405
- # compute the previous noisy sample x_t -> x_t-1
406
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
407
-
408
- # call the callback, if provided
409
- if callback is not None and i % callback_steps == 0:
410
- callback(i, t, latents)
411
-
412
- latents = 1 / 0.18215 * latents
413
- image = self.vae.decode(latents).sample
414
-
415
- image = (image / 2 + 0.5).clamp(0, 1)
416
-
417
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
418
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
419
-
420
- if self.safety_checker is not None:
421
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
422
- self.device
423
- )
424
- image, has_nsfw_concept = self.safety_checker(
425
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
426
- )
427
- else:
428
- has_nsfw_concept = None
429
-
430
- if output_type == "pil":
431
- image = self.numpy_to_pil(image)
432
-
433
- if not return_dict:
434
- return (image, has_nsfw_concept)
435
-
436
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/one_step_unet.py DELETED
@@ -1,24 +0,0 @@
1
- #!/usr/bin/env python3
2
- import torch
3
-
4
- from diffusers import DiffusionPipeline
5
-
6
-
7
- class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
8
- def __init__(self, unet, scheduler):
9
- super().__init__()
10
-
11
- self.register_modules(unet=unet, scheduler=scheduler)
12
-
13
- def __call__(self):
14
- image = torch.randn(
15
- (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
16
- )
17
- timestep = 1
18
-
19
- model_output = self.unet(image, timestep).sample
20
- scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
21
-
22
- result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
23
-
24
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/sd_text2img_k_diffusion.py DELETED
@@ -1,475 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
- import warnings
17
- from typing import Callable, List, Optional, Union
18
-
19
- import torch
20
- from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
21
-
22
- from diffusers import DiffusionPipeline, LMSDiscreteScheduler
23
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
24
- from diffusers.utils import is_accelerate_available, logging
25
-
26
-
27
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
-
29
-
30
- class ModelWrapper:
31
- def __init__(self, model, alphas_cumprod):
32
- self.model = model
33
- self.alphas_cumprod = alphas_cumprod
34
-
35
- def apply_model(self, *args, **kwargs):
36
- if len(args) == 3:
37
- encoder_hidden_states = args[-1]
38
- args = args[:2]
39
- if kwargs.get("cond", None) is not None:
40
- encoder_hidden_states = kwargs.pop("cond")
41
- return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
42
-
43
-
44
- class StableDiffusionPipeline(DiffusionPipeline):
45
- r"""
46
- Pipeline for text-to-image generation using Stable Diffusion.
47
-
48
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
49
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
50
-
51
- Args:
52
- vae ([`AutoencoderKL`]):
53
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
54
- text_encoder ([`CLIPTextModel`]):
55
- Frozen text-encoder. Stable Diffusion uses the text portion of
56
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
57
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
58
- tokenizer (`CLIPTokenizer`):
59
- Tokenizer of class
60
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
61
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
62
- scheduler ([`SchedulerMixin`]):
63
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
64
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
65
- safety_checker ([`StableDiffusionSafetyChecker`]):
66
- Classification module that estimates whether generated images could be considered offensive or harmful.
67
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
68
- feature_extractor ([`CLIPImageProcessor`]):
69
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
70
- """
71
- _optional_components = ["safety_checker", "feature_extractor"]
72
-
73
- def __init__(
74
- self,
75
- vae,
76
- text_encoder,
77
- tokenizer,
78
- unet,
79
- scheduler,
80
- safety_checker,
81
- feature_extractor,
82
- ):
83
- super().__init__()
84
-
85
- if safety_checker is None:
86
- logger.warning(
87
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
88
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
89
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
90
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
91
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
92
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
93
- )
94
-
95
- # get correct sigmas from LMS
96
- scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
97
- self.register_modules(
98
- vae=vae,
99
- text_encoder=text_encoder,
100
- tokenizer=tokenizer,
101
- unet=unet,
102
- scheduler=scheduler,
103
- safety_checker=safety_checker,
104
- feature_extractor=feature_extractor,
105
- )
106
-
107
- model = ModelWrapper(unet, scheduler.alphas_cumprod)
108
- if scheduler.prediction_type == "v_prediction":
109
- self.k_diffusion_model = CompVisVDenoiser(model)
110
- else:
111
- self.k_diffusion_model = CompVisDenoiser(model)
112
-
113
- def set_sampler(self, scheduler_type: str):
114
- warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
115
- return self.set_scheduler(scheduler_type)
116
-
117
- def set_scheduler(self, scheduler_type: str):
118
- library = importlib.import_module("k_diffusion")
119
- sampling = getattr(library, "sampling")
120
- self.sampler = getattr(sampling, scheduler_type)
121
-
122
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
123
- r"""
124
- Enable sliced attention computation.
125
-
126
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
127
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
128
-
129
- Args:
130
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
131
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
132
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
133
- `attention_head_dim` must be a multiple of `slice_size`.
134
- """
135
- if slice_size == "auto":
136
- # half the attention head size is usually a good trade-off between
137
- # speed and memory
138
- slice_size = self.unet.config.attention_head_dim // 2
139
- self.unet.set_attention_slice(slice_size)
140
-
141
- def disable_attention_slicing(self):
142
- r"""
143
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
144
- back to computing attention in one step.
145
- """
146
- # set slice_size = `None` to disable `attention slicing`
147
- self.enable_attention_slicing(None)
148
-
149
- def enable_sequential_cpu_offload(self, gpu_id=0):
150
- r"""
151
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
152
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
153
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
154
- """
155
- if is_accelerate_available():
156
- from accelerate import cpu_offload
157
- else:
158
- raise ImportError("Please install accelerate via `pip install accelerate`")
159
-
160
- device = torch.device(f"cuda:{gpu_id}")
161
-
162
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
163
- if cpu_offloaded_model is not None:
164
- cpu_offload(cpu_offloaded_model, device)
165
-
166
- @property
167
- def _execution_device(self):
168
- r"""
169
- Returns the device on which the pipeline's models will be executed. After calling
170
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
171
- hooks.
172
- """
173
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
174
- return self.device
175
- for module in self.unet.modules():
176
- if (
177
- hasattr(module, "_hf_hook")
178
- and hasattr(module._hf_hook, "execution_device")
179
- and module._hf_hook.execution_device is not None
180
- ):
181
- return torch.device(module._hf_hook.execution_device)
182
- return self.device
183
-
184
- def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
185
- r"""
186
- Encodes the prompt into text encoder hidden states.
187
-
188
- Args:
189
- prompt (`str` or `list(int)`):
190
- prompt to be encoded
191
- device: (`torch.device`):
192
- torch device
193
- num_images_per_prompt (`int`):
194
- number of images that should be generated per prompt
195
- do_classifier_free_guidance (`bool`):
196
- whether to use classifier free guidance or not
197
- negative_prompt (`str` or `List[str]`):
198
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
199
- if `guidance_scale` is less than `1`).
200
- """
201
- batch_size = len(prompt) if isinstance(prompt, list) else 1
202
-
203
- text_inputs = self.tokenizer(
204
- prompt,
205
- padding="max_length",
206
- max_length=self.tokenizer.model_max_length,
207
- truncation=True,
208
- return_tensors="pt",
209
- )
210
- text_input_ids = text_inputs.input_ids
211
- untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
212
-
213
- if not torch.equal(text_input_ids, untruncated_ids):
214
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
215
- logger.warning(
216
- "The following part of your input was truncated because CLIP can only handle sequences up to"
217
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
218
- )
219
-
220
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
221
- attention_mask = text_inputs.attention_mask.to(device)
222
- else:
223
- attention_mask = None
224
-
225
- text_embeddings = self.text_encoder(
226
- text_input_ids.to(device),
227
- attention_mask=attention_mask,
228
- )
229
- text_embeddings = text_embeddings[0]
230
-
231
- # duplicate text embeddings for each generation per prompt, using mps friendly method
232
- bs_embed, seq_len, _ = text_embeddings.shape
233
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
234
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
235
-
236
- # get unconditional embeddings for classifier free guidance
237
- if do_classifier_free_guidance:
238
- uncond_tokens: List[str]
239
- if negative_prompt is None:
240
- uncond_tokens = [""] * batch_size
241
- elif type(prompt) is not type(negative_prompt):
242
- raise TypeError(
243
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
244
- f" {type(prompt)}."
245
- )
246
- elif isinstance(negative_prompt, str):
247
- uncond_tokens = [negative_prompt]
248
- elif batch_size != len(negative_prompt):
249
- raise ValueError(
250
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
251
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
252
- " the batch size of `prompt`."
253
- )
254
- else:
255
- uncond_tokens = negative_prompt
256
-
257
- max_length = text_input_ids.shape[-1]
258
- uncond_input = self.tokenizer(
259
- uncond_tokens,
260
- padding="max_length",
261
- max_length=max_length,
262
- truncation=True,
263
- return_tensors="pt",
264
- )
265
-
266
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
267
- attention_mask = uncond_input.attention_mask.to(device)
268
- else:
269
- attention_mask = None
270
-
271
- uncond_embeddings = self.text_encoder(
272
- uncond_input.input_ids.to(device),
273
- attention_mask=attention_mask,
274
- )
275
- uncond_embeddings = uncond_embeddings[0]
276
-
277
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
278
- seq_len = uncond_embeddings.shape[1]
279
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
280
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
281
-
282
- # For classifier free guidance, we need to do two forward passes.
283
- # Here we concatenate the unconditional and text embeddings into a single batch
284
- # to avoid doing two forward passes
285
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
286
-
287
- return text_embeddings
288
-
289
- def run_safety_checker(self, image, device, dtype):
290
- if self.safety_checker is not None:
291
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
292
- image, has_nsfw_concept = self.safety_checker(
293
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
294
- )
295
- else:
296
- has_nsfw_concept = None
297
- return image, has_nsfw_concept
298
-
299
- def decode_latents(self, latents):
300
- latents = 1 / 0.18215 * latents
301
- image = self.vae.decode(latents).sample
302
- image = (image / 2 + 0.5).clamp(0, 1)
303
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
304
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
305
- return image
306
-
307
- def check_inputs(self, prompt, height, width, callback_steps):
308
- if not isinstance(prompt, str) and not isinstance(prompt, list):
309
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
310
-
311
- if height % 8 != 0 or width % 8 != 0:
312
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
313
-
314
- if (callback_steps is None) or (
315
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
316
- ):
317
- raise ValueError(
318
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
319
- f" {type(callback_steps)}."
320
- )
321
-
322
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
323
- shape = (batch_size, num_channels_latents, height // 8, width // 8)
324
- if latents is None:
325
- if device.type == "mps":
326
- # randn does not work reproducibly on mps
327
- latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
328
- else:
329
- latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
330
- else:
331
- if latents.shape != shape:
332
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
333
- latents = latents.to(device)
334
-
335
- # scale the initial noise by the standard deviation required by the scheduler
336
- return latents
337
-
338
- @torch.no_grad()
339
- def __call__(
340
- self,
341
- prompt: Union[str, List[str]],
342
- height: int = 512,
343
- width: int = 512,
344
- num_inference_steps: int = 50,
345
- guidance_scale: float = 7.5,
346
- negative_prompt: Optional[Union[str, List[str]]] = None,
347
- num_images_per_prompt: Optional[int] = 1,
348
- eta: float = 0.0,
349
- generator: Optional[torch.Generator] = None,
350
- latents: Optional[torch.FloatTensor] = None,
351
- output_type: Optional[str] = "pil",
352
- return_dict: bool = True,
353
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
354
- callback_steps: int = 1,
355
- **kwargs,
356
- ):
357
- r"""
358
- Function invoked when calling the pipeline for generation.
359
-
360
- Args:
361
- prompt (`str` or `List[str]`):
362
- The prompt or prompts to guide the image generation.
363
- height (`int`, *optional*, defaults to 512):
364
- The height in pixels of the generated image.
365
- width (`int`, *optional*, defaults to 512):
366
- The width in pixels of the generated image.
367
- num_inference_steps (`int`, *optional*, defaults to 50):
368
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
369
- expense of slower inference.
370
- guidance_scale (`float`, *optional*, defaults to 7.5):
371
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
372
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
373
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
374
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
375
- usually at the expense of lower image quality.
376
- negative_prompt (`str` or `List[str]`, *optional*):
377
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
378
- if `guidance_scale` is less than `1`).
379
- num_images_per_prompt (`int`, *optional*, defaults to 1):
380
- The number of images to generate per prompt.
381
- eta (`float`, *optional*, defaults to 0.0):
382
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
383
- [`schedulers.DDIMScheduler`], will be ignored for others.
384
- generator (`torch.Generator`, *optional*):
385
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
386
- deterministic.
387
- latents (`torch.FloatTensor`, *optional*):
388
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
389
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
390
- tensor will ge generated by sampling using the supplied random `generator`.
391
- output_type (`str`, *optional*, defaults to `"pil"`):
392
- The output format of the generate image. Choose between
393
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
394
- return_dict (`bool`, *optional*, defaults to `True`):
395
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
396
- plain tuple.
397
- callback (`Callable`, *optional*):
398
- A function that will be called every `callback_steps` steps during inference. The function will be
399
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
400
- callback_steps (`int`, *optional*, defaults to 1):
401
- The frequency at which the `callback` function will be called. If not specified, the callback will be
402
- called at every step.
403
-
404
- Returns:
405
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
406
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
407
- When returning a tuple, the first element is a list with the generated images, and the second element is a
408
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
409
- (nsfw) content, according to the `safety_checker`.
410
- """
411
-
412
- # 1. Check inputs. Raise error if not correct
413
- self.check_inputs(prompt, height, width, callback_steps)
414
-
415
- # 2. Define call parameters
416
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
417
- device = self._execution_device
418
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
419
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
420
- # corresponds to doing no classifier free guidance.
421
- do_classifier_free_guidance = True
422
- if guidance_scale <= 1.0:
423
- raise ValueError("has to use guidance_scale")
424
-
425
- # 3. Encode input prompt
426
- text_embeddings = self._encode_prompt(
427
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
428
- )
429
-
430
- # 4. Prepare timesteps
431
- self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
432
- sigmas = self.scheduler.sigmas
433
- sigmas = sigmas.to(text_embeddings.dtype)
434
-
435
- # 5. Prepare latent variables
436
- num_channels_latents = self.unet.in_channels
437
- latents = self.prepare_latents(
438
- batch_size * num_images_per_prompt,
439
- num_channels_latents,
440
- height,
441
- width,
442
- text_embeddings.dtype,
443
- device,
444
- generator,
445
- latents,
446
- )
447
- latents = latents * sigmas[0]
448
- self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
449
- self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
450
-
451
- def model_fn(x, t):
452
- latent_model_input = torch.cat([x] * 2)
453
-
454
- noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
455
-
456
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
457
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
458
- return noise_pred
459
-
460
- latents = self.sampler(model_fn, latents, sigmas)
461
-
462
- # 8. Post-processing
463
- image = self.decode_latents(latents)
464
-
465
- # 9. Run safety checker
466
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
467
-
468
- # 10. Convert to PIL
469
- if output_type == "pil":
470
- image = self.numpy_to_pil(image)
471
-
472
- if not return_dict:
473
- return (image, has_nsfw_concept)
474
-
475
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/seed_resize_stable_diffusion.py DELETED
@@ -1,366 +0,0 @@
1
- """
2
- modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
3
- """
4
- import inspect
5
- from typing import Callable, List, Optional, Union
6
-
7
- import torch
8
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
-
10
- from diffusers import DiffusionPipeline
11
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
- from diffusers.utils import logging
16
-
17
-
18
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
-
20
-
21
- class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
22
- r"""
23
- Pipeline for text-to-image generation using Stable Diffusion.
24
-
25
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
26
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
27
-
28
- Args:
29
- vae ([`AutoencoderKL`]):
30
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
31
- text_encoder ([`CLIPTextModel`]):
32
- Frozen text-encoder. Stable Diffusion uses the text portion of
33
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
34
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
35
- tokenizer (`CLIPTokenizer`):
36
- Tokenizer of class
37
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
38
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
39
- scheduler ([`SchedulerMixin`]):
40
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
41
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
42
- safety_checker ([`StableDiffusionSafetyChecker`]):
43
- Classification module that estimates whether generated images could be considered offensive or harmful.
44
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
45
- feature_extractor ([`CLIPImageProcessor`]):
46
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
47
- """
48
-
49
- def __init__(
50
- self,
51
- vae: AutoencoderKL,
52
- text_encoder: CLIPTextModel,
53
- tokenizer: CLIPTokenizer,
54
- unet: UNet2DConditionModel,
55
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
56
- safety_checker: StableDiffusionSafetyChecker,
57
- feature_extractor: CLIPImageProcessor,
58
- ):
59
- super().__init__()
60
- self.register_modules(
61
- vae=vae,
62
- text_encoder=text_encoder,
63
- tokenizer=tokenizer,
64
- unet=unet,
65
- scheduler=scheduler,
66
- safety_checker=safety_checker,
67
- feature_extractor=feature_extractor,
68
- )
69
-
70
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
71
- r"""
72
- Enable sliced attention computation.
73
-
74
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
75
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
76
-
77
- Args:
78
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
79
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
80
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
81
- `attention_head_dim` must be a multiple of `slice_size`.
82
- """
83
- if slice_size == "auto":
84
- # half the attention head size is usually a good trade-off between
85
- # speed and memory
86
- slice_size = self.unet.config.attention_head_dim // 2
87
- self.unet.set_attention_slice(slice_size)
88
-
89
- def disable_attention_slicing(self):
90
- r"""
91
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
92
- back to computing attention in one step.
93
- """
94
- # set slice_size = `None` to disable `attention slicing`
95
- self.enable_attention_slicing(None)
96
-
97
- @torch.no_grad()
98
- def __call__(
99
- self,
100
- prompt: Union[str, List[str]],
101
- height: int = 512,
102
- width: int = 512,
103
- num_inference_steps: int = 50,
104
- guidance_scale: float = 7.5,
105
- negative_prompt: Optional[Union[str, List[str]]] = None,
106
- num_images_per_prompt: Optional[int] = 1,
107
- eta: float = 0.0,
108
- generator: Optional[torch.Generator] = None,
109
- latents: Optional[torch.FloatTensor] = None,
110
- output_type: Optional[str] = "pil",
111
- return_dict: bool = True,
112
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
113
- callback_steps: int = 1,
114
- text_embeddings: Optional[torch.FloatTensor] = None,
115
- **kwargs,
116
- ):
117
- r"""
118
- Function invoked when calling the pipeline for generation.
119
-
120
- Args:
121
- prompt (`str` or `List[str]`):
122
- The prompt or prompts to guide the image generation.
123
- height (`int`, *optional*, defaults to 512):
124
- The height in pixels of the generated image.
125
- width (`int`, *optional*, defaults to 512):
126
- The width in pixels of the generated image.
127
- num_inference_steps (`int`, *optional*, defaults to 50):
128
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
129
- expense of slower inference.
130
- guidance_scale (`float`, *optional*, defaults to 7.5):
131
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
132
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
133
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
134
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
135
- usually at the expense of lower image quality.
136
- negative_prompt (`str` or `List[str]`, *optional*):
137
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
138
- if `guidance_scale` is less than `1`).
139
- num_images_per_prompt (`int`, *optional*, defaults to 1):
140
- The number of images to generate per prompt.
141
- eta (`float`, *optional*, defaults to 0.0):
142
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
143
- [`schedulers.DDIMScheduler`], will be ignored for others.
144
- generator (`torch.Generator`, *optional*):
145
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
146
- deterministic.
147
- latents (`torch.FloatTensor`, *optional*):
148
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
149
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
150
- tensor will ge generated by sampling using the supplied random `generator`.
151
- output_type (`str`, *optional*, defaults to `"pil"`):
152
- The output format of the generate image. Choose between
153
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
154
- return_dict (`bool`, *optional*, defaults to `True`):
155
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
156
- plain tuple.
157
- callback (`Callable`, *optional*):
158
- A function that will be called every `callback_steps` steps during inference. The function will be
159
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
160
- callback_steps (`int`, *optional*, defaults to 1):
161
- The frequency at which the `callback` function will be called. If not specified, the callback will be
162
- called at every step.
163
-
164
- Returns:
165
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
166
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
167
- When returning a tuple, the first element is a list with the generated images, and the second element is a
168
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
169
- (nsfw) content, according to the `safety_checker`.
170
- """
171
-
172
- if isinstance(prompt, str):
173
- batch_size = 1
174
- elif isinstance(prompt, list):
175
- batch_size = len(prompt)
176
- else:
177
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
178
-
179
- if height % 8 != 0 or width % 8 != 0:
180
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
181
-
182
- if (callback_steps is None) or (
183
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
184
- ):
185
- raise ValueError(
186
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
187
- f" {type(callback_steps)}."
188
- )
189
-
190
- # get prompt text embeddings
191
- text_inputs = self.tokenizer(
192
- prompt,
193
- padding="max_length",
194
- max_length=self.tokenizer.model_max_length,
195
- return_tensors="pt",
196
- )
197
- text_input_ids = text_inputs.input_ids
198
-
199
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
200
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
201
- logger.warning(
202
- "The following part of your input was truncated because CLIP can only handle sequences up to"
203
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
204
- )
205
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
206
-
207
- if text_embeddings is None:
208
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
209
-
210
- # duplicate text embeddings for each generation per prompt, using mps friendly method
211
- bs_embed, seq_len, _ = text_embeddings.shape
212
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
213
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
214
-
215
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
216
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
217
- # corresponds to doing no classifier free guidance.
218
- do_classifier_free_guidance = guidance_scale > 1.0
219
- # get unconditional embeddings for classifier free guidance
220
- if do_classifier_free_guidance:
221
- uncond_tokens: List[str]
222
- if negative_prompt is None:
223
- uncond_tokens = [""]
224
- elif type(prompt) is not type(negative_prompt):
225
- raise TypeError(
226
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
227
- f" {type(prompt)}."
228
- )
229
- elif isinstance(negative_prompt, str):
230
- uncond_tokens = [negative_prompt]
231
- elif batch_size != len(negative_prompt):
232
- raise ValueError(
233
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
234
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
235
- " the batch size of `prompt`."
236
- )
237
- else:
238
- uncond_tokens = negative_prompt
239
-
240
- max_length = text_input_ids.shape[-1]
241
- uncond_input = self.tokenizer(
242
- uncond_tokens,
243
- padding="max_length",
244
- max_length=max_length,
245
- truncation=True,
246
- return_tensors="pt",
247
- )
248
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
249
-
250
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
251
- seq_len = uncond_embeddings.shape[1]
252
- uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
253
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
254
-
255
- # For classifier free guidance, we need to do two forward passes.
256
- # Here we concatenate the unconditional and text embeddings into a single batch
257
- # to avoid doing two forward passes
258
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
259
-
260
- # get the initial random noise unless the user supplied it
261
-
262
- # Unlike in other pipelines, latents need to be generated in the target device
263
- # for 1-to-1 results reproducibility with the CompVis implementation.
264
- # However this currently doesn't work in `mps`.
265
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
266
- latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)
267
- latents_dtype = text_embeddings.dtype
268
- if latents is None:
269
- if self.device.type == "mps":
270
- # randn does not exist on mps
271
- latents_reference = torch.randn(
272
- latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype
273
- ).to(self.device)
274
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
275
- self.device
276
- )
277
- else:
278
- latents_reference = torch.randn(
279
- latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype
280
- )
281
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
282
- else:
283
- if latents_reference.shape != latents_shape:
284
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
285
- latents_reference = latents_reference.to(self.device)
286
- latents = latents.to(self.device)
287
-
288
- # This is the key part of the pipeline where we
289
- # try to ensure that the generated images w/ the same seed
290
- # but different sizes actually result in similar images
291
- dx = (latents_shape[3] - latents_shape_reference[3]) // 2
292
- dy = (latents_shape[2] - latents_shape_reference[2]) // 2
293
- w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx
294
- h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy
295
- tx = 0 if dx < 0 else dx
296
- ty = 0 if dy < 0 else dy
297
- dx = max(-dx, 0)
298
- dy = max(-dy, 0)
299
- # import pdb
300
- # pdb.set_trace()
301
- latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]
302
-
303
- # set timesteps
304
- self.scheduler.set_timesteps(num_inference_steps)
305
-
306
- # Some schedulers like PNDM have timesteps as arrays
307
- # It's more optimized to move all timesteps to correct device beforehand
308
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
309
-
310
- # scale the initial noise by the standard deviation required by the scheduler
311
- latents = latents * self.scheduler.init_noise_sigma
312
-
313
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
314
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
315
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
316
- # and should be between [0, 1]
317
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
318
- extra_step_kwargs = {}
319
- if accepts_eta:
320
- extra_step_kwargs["eta"] = eta
321
-
322
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
323
- # expand the latents if we are doing classifier free guidance
324
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
325
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
326
-
327
- # predict the noise residual
328
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
329
-
330
- # perform guidance
331
- if do_classifier_free_guidance:
332
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
333
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
334
-
335
- # compute the previous noisy sample x_t -> x_t-1
336
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
337
-
338
- # call the callback, if provided
339
- if callback is not None and i % callback_steps == 0:
340
- callback(i, t, latents)
341
-
342
- latents = 1 / 0.18215 * latents
343
- image = self.vae.decode(latents).sample
344
-
345
- image = (image / 2 + 0.5).clamp(0, 1)
346
-
347
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
348
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
349
-
350
- if self.safety_checker is not None:
351
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
352
- self.device
353
- )
354
- image, has_nsfw_concept = self.safety_checker(
355
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
356
- )
357
- else:
358
- has_nsfw_concept = None
359
-
360
- if output_type == "pil":
361
- image = self.numpy_to_pil(image)
362
-
363
- if not return_dict:
364
- return (image, has_nsfw_concept)
365
-
366
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/speech_to_image_diffusion.py DELETED
@@ -1,261 +0,0 @@
1
- import inspect
2
- from typing import Callable, List, Optional, Union
3
-
4
- import torch
5
- from transformers import (
6
- CLIPImageProcessor,
7
- CLIPTextModel,
8
- CLIPTokenizer,
9
- WhisperForConditionalGeneration,
10
- WhisperProcessor,
11
- )
12
-
13
- from diffusers import (
14
- AutoencoderKL,
15
- DDIMScheduler,
16
- DiffusionPipeline,
17
- LMSDiscreteScheduler,
18
- PNDMScheduler,
19
- UNet2DConditionModel,
20
- )
21
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
22
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
23
- from diffusers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
-
28
-
29
- class SpeechToImagePipeline(DiffusionPipeline):
30
- def __init__(
31
- self,
32
- speech_model: WhisperForConditionalGeneration,
33
- speech_processor: WhisperProcessor,
34
- vae: AutoencoderKL,
35
- text_encoder: CLIPTextModel,
36
- tokenizer: CLIPTokenizer,
37
- unet: UNet2DConditionModel,
38
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
39
- safety_checker: StableDiffusionSafetyChecker,
40
- feature_extractor: CLIPImageProcessor,
41
- ):
42
- super().__init__()
43
-
44
- if safety_checker is None:
45
- logger.warning(
46
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
47
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
48
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
49
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
50
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
51
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
52
- )
53
-
54
- self.register_modules(
55
- speech_model=speech_model,
56
- speech_processor=speech_processor,
57
- vae=vae,
58
- text_encoder=text_encoder,
59
- tokenizer=tokenizer,
60
- unet=unet,
61
- scheduler=scheduler,
62
- feature_extractor=feature_extractor,
63
- )
64
-
65
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66
- if slice_size == "auto":
67
- slice_size = self.unet.config.attention_head_dim // 2
68
- self.unet.set_attention_slice(slice_size)
69
-
70
- def disable_attention_slicing(self):
71
- self.enable_attention_slicing(None)
72
-
73
- @torch.no_grad()
74
- def __call__(
75
- self,
76
- audio,
77
- sampling_rate=16_000,
78
- height: int = 512,
79
- width: int = 512,
80
- num_inference_steps: int = 50,
81
- guidance_scale: float = 7.5,
82
- negative_prompt: Optional[Union[str, List[str]]] = None,
83
- num_images_per_prompt: Optional[int] = 1,
84
- eta: float = 0.0,
85
- generator: Optional[torch.Generator] = None,
86
- latents: Optional[torch.FloatTensor] = None,
87
- output_type: Optional[str] = "pil",
88
- return_dict: bool = True,
89
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
90
- callback_steps: int = 1,
91
- **kwargs,
92
- ):
93
- inputs = self.speech_processor.feature_extractor(
94
- audio, return_tensors="pt", sampling_rate=sampling_rate
95
- ).input_features.to(self.device)
96
- predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
97
-
98
- prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
99
- 0
100
- ]
101
-
102
- if isinstance(prompt, str):
103
- batch_size = 1
104
- elif isinstance(prompt, list):
105
- batch_size = len(prompt)
106
- else:
107
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
108
-
109
- if height % 8 != 0 or width % 8 != 0:
110
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
111
-
112
- if (callback_steps is None) or (
113
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
114
- ):
115
- raise ValueError(
116
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
117
- f" {type(callback_steps)}."
118
- )
119
-
120
- # get prompt text embeddings
121
- text_inputs = self.tokenizer(
122
- prompt,
123
- padding="max_length",
124
- max_length=self.tokenizer.model_max_length,
125
- return_tensors="pt",
126
- )
127
- text_input_ids = text_inputs.input_ids
128
-
129
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
130
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
131
- logger.warning(
132
- "The following part of your input was truncated because CLIP can only handle sequences up to"
133
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
134
- )
135
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
136
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
137
-
138
- # duplicate text embeddings for each generation per prompt, using mps friendly method
139
- bs_embed, seq_len, _ = text_embeddings.shape
140
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
141
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
142
-
143
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
144
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
145
- # corresponds to doing no classifier free guidance.
146
- do_classifier_free_guidance = guidance_scale > 1.0
147
- # get unconditional embeddings for classifier free guidance
148
- if do_classifier_free_guidance:
149
- uncond_tokens: List[str]
150
- if negative_prompt is None:
151
- uncond_tokens = [""] * batch_size
152
- elif type(prompt) is not type(negative_prompt):
153
- raise TypeError(
154
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
155
- f" {type(prompt)}."
156
- )
157
- elif isinstance(negative_prompt, str):
158
- uncond_tokens = [negative_prompt]
159
- elif batch_size != len(negative_prompt):
160
- raise ValueError(
161
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
162
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
163
- " the batch size of `prompt`."
164
- )
165
- else:
166
- uncond_tokens = negative_prompt
167
-
168
- max_length = text_input_ids.shape[-1]
169
- uncond_input = self.tokenizer(
170
- uncond_tokens,
171
- padding="max_length",
172
- max_length=max_length,
173
- truncation=True,
174
- return_tensors="pt",
175
- )
176
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
177
-
178
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
179
- seq_len = uncond_embeddings.shape[1]
180
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
181
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
182
-
183
- # For classifier free guidance, we need to do two forward passes.
184
- # Here we concatenate the unconditional and text embeddings into a single batch
185
- # to avoid doing two forward passes
186
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
187
-
188
- # get the initial random noise unless the user supplied it
189
-
190
- # Unlike in other pipelines, latents need to be generated in the target device
191
- # for 1-to-1 results reproducibility with the CompVis implementation.
192
- # However this currently doesn't work in `mps`.
193
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
194
- latents_dtype = text_embeddings.dtype
195
- if latents is None:
196
- if self.device.type == "mps":
197
- # randn does not exist on mps
198
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
199
- self.device
200
- )
201
- else:
202
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
203
- else:
204
- if latents.shape != latents_shape:
205
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
206
- latents = latents.to(self.device)
207
-
208
- # set timesteps
209
- self.scheduler.set_timesteps(num_inference_steps)
210
-
211
- # Some schedulers like PNDM have timesteps as arrays
212
- # It's more optimized to move all timesteps to correct device beforehand
213
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
214
-
215
- # scale the initial noise by the standard deviation required by the scheduler
216
- latents = latents * self.scheduler.init_noise_sigma
217
-
218
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
219
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
220
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
221
- # and should be between [0, 1]
222
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
223
- extra_step_kwargs = {}
224
- if accepts_eta:
225
- extra_step_kwargs["eta"] = eta
226
-
227
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
228
- # expand the latents if we are doing classifier free guidance
229
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
230
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
231
-
232
- # predict the noise residual
233
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
234
-
235
- # perform guidance
236
- if do_classifier_free_guidance:
237
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
238
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
239
-
240
- # compute the previous noisy sample x_t -> x_t-1
241
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
242
-
243
- # call the callback, if provided
244
- if callback is not None and i % callback_steps == 0:
245
- callback(i, t, latents)
246
-
247
- latents = 1 / 0.18215 * latents
248
- image = self.vae.decode(latents).sample
249
-
250
- image = (image / 2 + 0.5).clamp(0, 1)
251
-
252
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
253
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
254
-
255
- if output_type == "pil":
256
- image = self.numpy_to_pil(image)
257
-
258
- if not return_dict:
259
- return image
260
-
261
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_diffusion_comparison.py DELETED
@@ -1,405 +0,0 @@
1
- from typing import Any, Callable, Dict, List, Optional, Union
2
-
3
- import torch
4
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
5
-
6
- from diffusers import (
7
- AutoencoderKL,
8
- DDIMScheduler,
9
- DiffusionPipeline,
10
- LMSDiscreteScheduler,
11
- PNDMScheduler,
12
- StableDiffusionPipeline,
13
- UNet2DConditionModel,
14
- )
15
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
-
18
-
19
- pipe1_model_id = "CompVis/stable-diffusion-v1-1"
20
- pipe2_model_id = "CompVis/stable-diffusion-v1-2"
21
- pipe3_model_id = "CompVis/stable-diffusion-v1-3"
22
- pipe4_model_id = "CompVis/stable-diffusion-v1-4"
23
-
24
-
25
- class StableDiffusionComparisonPipeline(DiffusionPipeline):
26
- r"""
27
- Pipeline for parallel comparison of Stable Diffusion v1-v4
28
- This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
29
- downloading pre-trained checkpoints from Hugging Face Hub.
30
- If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
31
- automatically.
32
- Args:
33
- vae ([`AutoencoderKL`]):
34
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
35
- text_encoder ([`CLIPTextModel`]):
36
- Frozen text-encoder. Stable Diffusion uses the text portion of
37
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
38
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
39
- tokenizer (`CLIPTokenizer`):
40
- Tokenizer of class
41
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
42
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
43
- scheduler ([`SchedulerMixin`]):
44
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
45
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
46
- safety_checker ([`StableDiffusionMegaSafetyChecker`]):
47
- Classification module that estimates whether generated images could be considered offensive or harmful.
48
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
49
- feature_extractor ([`CLIPImageProcessor`]):
50
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
51
- """
52
-
53
- def __init__(
54
- self,
55
- vae: AutoencoderKL,
56
- text_encoder: CLIPTextModel,
57
- tokenizer: CLIPTokenizer,
58
- unet: UNet2DConditionModel,
59
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
60
- safety_checker: StableDiffusionSafetyChecker,
61
- feature_extractor: CLIPImageProcessor,
62
- requires_safety_checker: bool = True,
63
- ):
64
- super()._init_()
65
-
66
- self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
67
- self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
68
- self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
69
- self.pipe4 = StableDiffusionPipeline(
70
- vae=vae,
71
- text_encoder=text_encoder,
72
- tokenizer=tokenizer,
73
- unet=unet,
74
- scheduler=scheduler,
75
- safety_checker=safety_checker,
76
- feature_extractor=feature_extractor,
77
- requires_safety_checker=requires_safety_checker,
78
- )
79
-
80
- self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
81
-
82
- @property
83
- def layers(self) -> Dict[str, Any]:
84
- return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
85
-
86
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
87
- r"""
88
- Enable sliced attention computation.
89
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
90
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
91
- Args:
92
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
93
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
94
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
95
- `attention_head_dim` must be a multiple of `slice_size`.
96
- """
97
- if slice_size == "auto":
98
- # half the attention head size is usually a good trade-off between
99
- # speed and memory
100
- slice_size = self.unet.config.attention_head_dim // 2
101
- self.unet.set_attention_slice(slice_size)
102
-
103
- def disable_attention_slicing(self):
104
- r"""
105
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
106
- back to computing attention in one step.
107
- """
108
- # set slice_size = `None` to disable `attention slicing`
109
- self.enable_attention_slicing(None)
110
-
111
- @torch.no_grad()
112
- def text2img_sd1_1(
113
- self,
114
- prompt: Union[str, List[str]],
115
- height: int = 512,
116
- width: int = 512,
117
- num_inference_steps: int = 50,
118
- guidance_scale: float = 7.5,
119
- negative_prompt: Optional[Union[str, List[str]]] = None,
120
- num_images_per_prompt: Optional[int] = 1,
121
- eta: float = 0.0,
122
- generator: Optional[torch.Generator] = None,
123
- latents: Optional[torch.FloatTensor] = None,
124
- output_type: Optional[str] = "pil",
125
- return_dict: bool = True,
126
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
127
- callback_steps: int = 1,
128
- **kwargs,
129
- ):
130
- return self.pipe1(
131
- prompt=prompt,
132
- height=height,
133
- width=width,
134
- num_inference_steps=num_inference_steps,
135
- guidance_scale=guidance_scale,
136
- negative_prompt=negative_prompt,
137
- num_images_per_prompt=num_images_per_prompt,
138
- eta=eta,
139
- generator=generator,
140
- latents=latents,
141
- output_type=output_type,
142
- return_dict=return_dict,
143
- callback=callback,
144
- callback_steps=callback_steps,
145
- **kwargs,
146
- )
147
-
148
- @torch.no_grad()
149
- def text2img_sd1_2(
150
- self,
151
- prompt: Union[str, List[str]],
152
- height: int = 512,
153
- width: int = 512,
154
- num_inference_steps: int = 50,
155
- guidance_scale: float = 7.5,
156
- negative_prompt: Optional[Union[str, List[str]]] = None,
157
- num_images_per_prompt: Optional[int] = 1,
158
- eta: float = 0.0,
159
- generator: Optional[torch.Generator] = None,
160
- latents: Optional[torch.FloatTensor] = None,
161
- output_type: Optional[str] = "pil",
162
- return_dict: bool = True,
163
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
164
- callback_steps: int = 1,
165
- **kwargs,
166
- ):
167
- return self.pipe2(
168
- prompt=prompt,
169
- height=height,
170
- width=width,
171
- num_inference_steps=num_inference_steps,
172
- guidance_scale=guidance_scale,
173
- negative_prompt=negative_prompt,
174
- num_images_per_prompt=num_images_per_prompt,
175
- eta=eta,
176
- generator=generator,
177
- latents=latents,
178
- output_type=output_type,
179
- return_dict=return_dict,
180
- callback=callback,
181
- callback_steps=callback_steps,
182
- **kwargs,
183
- )
184
-
185
- @torch.no_grad()
186
- def text2img_sd1_3(
187
- self,
188
- prompt: Union[str, List[str]],
189
- height: int = 512,
190
- width: int = 512,
191
- num_inference_steps: int = 50,
192
- guidance_scale: float = 7.5,
193
- negative_prompt: Optional[Union[str, List[str]]] = None,
194
- num_images_per_prompt: Optional[int] = 1,
195
- eta: float = 0.0,
196
- generator: Optional[torch.Generator] = None,
197
- latents: Optional[torch.FloatTensor] = None,
198
- output_type: Optional[str] = "pil",
199
- return_dict: bool = True,
200
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
201
- callback_steps: int = 1,
202
- **kwargs,
203
- ):
204
- return self.pipe3(
205
- prompt=prompt,
206
- height=height,
207
- width=width,
208
- num_inference_steps=num_inference_steps,
209
- guidance_scale=guidance_scale,
210
- negative_prompt=negative_prompt,
211
- num_images_per_prompt=num_images_per_prompt,
212
- eta=eta,
213
- generator=generator,
214
- latents=latents,
215
- output_type=output_type,
216
- return_dict=return_dict,
217
- callback=callback,
218
- callback_steps=callback_steps,
219
- **kwargs,
220
- )
221
-
222
- @torch.no_grad()
223
- def text2img_sd1_4(
224
- self,
225
- prompt: Union[str, List[str]],
226
- height: int = 512,
227
- width: int = 512,
228
- num_inference_steps: int = 50,
229
- guidance_scale: float = 7.5,
230
- negative_prompt: Optional[Union[str, List[str]]] = None,
231
- num_images_per_prompt: Optional[int] = 1,
232
- eta: float = 0.0,
233
- generator: Optional[torch.Generator] = None,
234
- latents: Optional[torch.FloatTensor] = None,
235
- output_type: Optional[str] = "pil",
236
- return_dict: bool = True,
237
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
238
- callback_steps: int = 1,
239
- **kwargs,
240
- ):
241
- return self.pipe4(
242
- prompt=prompt,
243
- height=height,
244
- width=width,
245
- num_inference_steps=num_inference_steps,
246
- guidance_scale=guidance_scale,
247
- negative_prompt=negative_prompt,
248
- num_images_per_prompt=num_images_per_prompt,
249
- eta=eta,
250
- generator=generator,
251
- latents=latents,
252
- output_type=output_type,
253
- return_dict=return_dict,
254
- callback=callback,
255
- callback_steps=callback_steps,
256
- **kwargs,
257
- )
258
-
259
- @torch.no_grad()
260
- def _call_(
261
- self,
262
- prompt: Union[str, List[str]],
263
- height: int = 512,
264
- width: int = 512,
265
- num_inference_steps: int = 50,
266
- guidance_scale: float = 7.5,
267
- negative_prompt: Optional[Union[str, List[str]]] = None,
268
- num_images_per_prompt: Optional[int] = 1,
269
- eta: float = 0.0,
270
- generator: Optional[torch.Generator] = None,
271
- latents: Optional[torch.FloatTensor] = None,
272
- output_type: Optional[str] = "pil",
273
- return_dict: bool = True,
274
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
275
- callback_steps: int = 1,
276
- **kwargs,
277
- ):
278
- r"""
279
- Function invoked when calling the pipeline for generation. This function will generate 4 results as part
280
- of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
281
- Args:
282
- prompt (`str` or `List[str]`):
283
- The prompt or prompts to guide the image generation.
284
- height (`int`, optional, defaults to 512):
285
- The height in pixels of the generated image.
286
- width (`int`, optional, defaults to 512):
287
- The width in pixels of the generated image.
288
- num_inference_steps (`int`, optional, defaults to 50):
289
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
290
- expense of slower inference.
291
- guidance_scale (`float`, optional, defaults to 7.5):
292
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
293
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
294
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
295
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
296
- usually at the expense of lower image quality.
297
- eta (`float`, optional, defaults to 0.0):
298
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
299
- [`schedulers.DDIMScheduler`], will be ignored for others.
300
- generator (`torch.Generator`, optional):
301
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
302
- deterministic.
303
- latents (`torch.FloatTensor`, optional):
304
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
305
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
306
- tensor will ge generated by sampling using the supplied random `generator`.
307
- output_type (`str`, optional, defaults to `"pil"`):
308
- The output format of the generate image. Choose between
309
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
310
- return_dict (`bool`, optional, defaults to `True`):
311
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
312
- plain tuple.
313
- Returns:
314
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
315
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
316
- When returning a tuple, the first element is a list with the generated images, and the second element is a
317
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
318
- (nsfw) content, according to the `safety_checker`.
319
- """
320
-
321
- device = "cuda" if torch.cuda.is_available() else "cpu"
322
- self.to(device)
323
-
324
- # Checks if the height and width are divisible by 8 or not
325
- if height % 8 != 0 or width % 8 != 0:
326
- raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
327
-
328
- # Get first result from Stable Diffusion Checkpoint v1.1
329
- res1 = self.text2img_sd1_1(
330
- prompt=prompt,
331
- height=height,
332
- width=width,
333
- num_inference_steps=num_inference_steps,
334
- guidance_scale=guidance_scale,
335
- negative_prompt=negative_prompt,
336
- num_images_per_prompt=num_images_per_prompt,
337
- eta=eta,
338
- generator=generator,
339
- latents=latents,
340
- output_type=output_type,
341
- return_dict=return_dict,
342
- callback=callback,
343
- callback_steps=callback_steps,
344
- **kwargs,
345
- )
346
-
347
- # Get first result from Stable Diffusion Checkpoint v1.2
348
- res2 = self.text2img_sd1_2(
349
- prompt=prompt,
350
- height=height,
351
- width=width,
352
- num_inference_steps=num_inference_steps,
353
- guidance_scale=guidance_scale,
354
- negative_prompt=negative_prompt,
355
- num_images_per_prompt=num_images_per_prompt,
356
- eta=eta,
357
- generator=generator,
358
- latents=latents,
359
- output_type=output_type,
360
- return_dict=return_dict,
361
- callback=callback,
362
- callback_steps=callback_steps,
363
- **kwargs,
364
- )
365
-
366
- # Get first result from Stable Diffusion Checkpoint v1.3
367
- res3 = self.text2img_sd1_3(
368
- prompt=prompt,
369
- height=height,
370
- width=width,
371
- num_inference_steps=num_inference_steps,
372
- guidance_scale=guidance_scale,
373
- negative_prompt=negative_prompt,
374
- num_images_per_prompt=num_images_per_prompt,
375
- eta=eta,
376
- generator=generator,
377
- latents=latents,
378
- output_type=output_type,
379
- return_dict=return_dict,
380
- callback=callback,
381
- callback_steps=callback_steps,
382
- **kwargs,
383
- )
384
-
385
- # Get first result from Stable Diffusion Checkpoint v1.4
386
- res4 = self.text2img_sd1_4(
387
- prompt=prompt,
388
- height=height,
389
- width=width,
390
- num_inference_steps=num_inference_steps,
391
- guidance_scale=guidance_scale,
392
- negative_prompt=negative_prompt,
393
- num_images_per_prompt=num_images_per_prompt,
394
- eta=eta,
395
- generator=generator,
396
- latents=latents,
397
- output_type=output_type,
398
- return_dict=return_dict,
399
- callback=callback,
400
- callback_steps=callback_steps,
401
- **kwargs,
402
- )
403
-
404
- # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
405
- return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_diffusion_controlnet_img2img.py DELETED
@@ -1,989 +0,0 @@
1
- # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
-
3
- import inspect
4
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
-
6
- import numpy as np
7
- import PIL.Image
8
- import torch
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
-
11
- from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
13
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
14
- from diffusers.schedulers import KarrasDiffusionSchedulers
15
- from diffusers.utils import (
16
- PIL_INTERPOLATION,
17
- is_accelerate_available,
18
- is_accelerate_version,
19
- randn_tensor,
20
- replace_example_docstring,
21
- )
22
-
23
-
24
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
-
26
- EXAMPLE_DOC_STRING = """
27
- Examples:
28
- ```py
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from PIL import Image
32
- >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
33
- >>> from diffusers.utils import load_image
34
-
35
- >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
36
-
37
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
38
-
39
- >>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
40
- "runwayml/stable-diffusion-v1-5",
41
- controlnet=controlnet,
42
- safety_checker=None,
43
- torch_dtype=torch.float16
44
- )
45
-
46
- >>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
47
- >>> pipe_controlnet.enable_xformers_memory_efficient_attention()
48
- >>> pipe_controlnet.enable_model_cpu_offload()
49
-
50
- # using image with edges for our canny controlnet
51
- >>> control_image = load_image(
52
- "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
53
-
54
-
55
- >>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
56
- image=input_image,
57
- prompt="an android robot, cyberpank, digitl art masterpiece",
58
- num_inference_steps=20).images[0]
59
-
60
- >>> result_img.show()
61
- ```
62
- """
63
-
64
-
65
- def prepare_image(image):
66
- if isinstance(image, torch.Tensor):
67
- # Batch single image
68
- if image.ndim == 3:
69
- image = image.unsqueeze(0)
70
-
71
- image = image.to(dtype=torch.float32)
72
- else:
73
- # preprocess image
74
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
75
- image = [image]
76
-
77
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
78
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
79
- image = np.concatenate(image, axis=0)
80
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
81
- image = np.concatenate([i[None, :] for i in image], axis=0)
82
-
83
- image = image.transpose(0, 3, 1, 2)
84
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
85
-
86
- return image
87
-
88
-
89
- def prepare_controlnet_conditioning_image(
90
- controlnet_conditioning_image,
91
- width,
92
- height,
93
- batch_size,
94
- num_images_per_prompt,
95
- device,
96
- dtype,
97
- do_classifier_free_guidance,
98
- ):
99
- if not isinstance(controlnet_conditioning_image, torch.Tensor):
100
- if isinstance(controlnet_conditioning_image, PIL.Image.Image):
101
- controlnet_conditioning_image = [controlnet_conditioning_image]
102
-
103
- if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
104
- controlnet_conditioning_image = [
105
- np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
106
- for i in controlnet_conditioning_image
107
- ]
108
- controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
109
- controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
110
- controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
111
- controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
112
- elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
113
- controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
114
-
115
- image_batch_size = controlnet_conditioning_image.shape[0]
116
-
117
- if image_batch_size == 1:
118
- repeat_by = batch_size
119
- else:
120
- # image batch size is the same as prompt batch size
121
- repeat_by = num_images_per_prompt
122
-
123
- controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
124
-
125
- controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
126
-
127
- if do_classifier_free_guidance:
128
- controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
129
-
130
- return controlnet_conditioning_image
131
-
132
-
133
- class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
134
- """
135
- Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
136
- """
137
-
138
- _optional_components = ["safety_checker", "feature_extractor"]
139
-
140
- def __init__(
141
- self,
142
- vae: AutoencoderKL,
143
- text_encoder: CLIPTextModel,
144
- tokenizer: CLIPTokenizer,
145
- unet: UNet2DConditionModel,
146
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
147
- scheduler: KarrasDiffusionSchedulers,
148
- safety_checker: StableDiffusionSafetyChecker,
149
- feature_extractor: CLIPImageProcessor,
150
- requires_safety_checker: bool = True,
151
- ):
152
- super().__init__()
153
-
154
- if safety_checker is None and requires_safety_checker:
155
- logger.warning(
156
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
157
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
158
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
159
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
160
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
161
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
162
- )
163
-
164
- if safety_checker is not None and feature_extractor is None:
165
- raise ValueError(
166
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
167
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
168
- )
169
-
170
- if isinstance(controlnet, (list, tuple)):
171
- controlnet = MultiControlNetModel(controlnet)
172
-
173
- self.register_modules(
174
- vae=vae,
175
- text_encoder=text_encoder,
176
- tokenizer=tokenizer,
177
- unet=unet,
178
- controlnet=controlnet,
179
- scheduler=scheduler,
180
- safety_checker=safety_checker,
181
- feature_extractor=feature_extractor,
182
- )
183
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184
- self.register_to_config(requires_safety_checker=requires_safety_checker)
185
-
186
- def enable_vae_slicing(self):
187
- r"""
188
- Enable sliced VAE decoding.
189
-
190
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
191
- steps. This is useful to save some memory and allow larger batch sizes.
192
- """
193
- self.vae.enable_slicing()
194
-
195
- def disable_vae_slicing(self):
196
- r"""
197
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
198
- computing decoding in one step.
199
- """
200
- self.vae.disable_slicing()
201
-
202
- def enable_sequential_cpu_offload(self, gpu_id=0):
203
- r"""
204
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
205
- text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
206
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
207
- Note that offloading happens on a submodule basis. Memory savings are higher than with
208
- `enable_model_cpu_offload`, but performance is lower.
209
- """
210
- if is_accelerate_available():
211
- from accelerate import cpu_offload
212
- else:
213
- raise ImportError("Please install accelerate via `pip install accelerate`")
214
-
215
- device = torch.device(f"cuda:{gpu_id}")
216
-
217
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
218
- cpu_offload(cpu_offloaded_model, device)
219
-
220
- if self.safety_checker is not None:
221
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
222
-
223
- def enable_model_cpu_offload(self, gpu_id=0):
224
- r"""
225
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
226
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
227
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
228
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
229
- """
230
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
231
- from accelerate import cpu_offload_with_hook
232
- else:
233
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
234
-
235
- device = torch.device(f"cuda:{gpu_id}")
236
-
237
- hook = None
238
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
239
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
240
-
241
- if self.safety_checker is not None:
242
- # the safety checker can offload the vae again
243
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
244
-
245
- # control net hook has be manually offloaded as it alternates with unet
246
- cpu_offload_with_hook(self.controlnet, device)
247
-
248
- # We'll offload the last model manually.
249
- self.final_offload_hook = hook
250
-
251
- @property
252
- def _execution_device(self):
253
- r"""
254
- Returns the device on which the pipeline's models will be executed. After calling
255
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
256
- hooks.
257
- """
258
- if not hasattr(self.unet, "_hf_hook"):
259
- return self.device
260
- for module in self.unet.modules():
261
- if (
262
- hasattr(module, "_hf_hook")
263
- and hasattr(module._hf_hook, "execution_device")
264
- and module._hf_hook.execution_device is not None
265
- ):
266
- return torch.device(module._hf_hook.execution_device)
267
- return self.device
268
-
269
- def _encode_prompt(
270
- self,
271
- prompt,
272
- device,
273
- num_images_per_prompt,
274
- do_classifier_free_guidance,
275
- negative_prompt=None,
276
- prompt_embeds: Optional[torch.FloatTensor] = None,
277
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
278
- ):
279
- r"""
280
- Encodes the prompt into text encoder hidden states.
281
-
282
- Args:
283
- prompt (`str` or `List[str]`, *optional*):
284
- prompt to be encoded
285
- device: (`torch.device`):
286
- torch device
287
- num_images_per_prompt (`int`):
288
- number of images that should be generated per prompt
289
- do_classifier_free_guidance (`bool`):
290
- whether to use classifier free guidance or not
291
- negative_prompt (`str` or `List[str]`, *optional*):
292
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
293
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
294
- prompt_embeds (`torch.FloatTensor`, *optional*):
295
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
296
- provided, text embeddings will be generated from `prompt` input argument.
297
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
298
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
299
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
300
- argument.
301
- """
302
- if prompt is not None and isinstance(prompt, str):
303
- batch_size = 1
304
- elif prompt is not None and isinstance(prompt, list):
305
- batch_size = len(prompt)
306
- else:
307
- batch_size = prompt_embeds.shape[0]
308
-
309
- if prompt_embeds is None:
310
- text_inputs = self.tokenizer(
311
- prompt,
312
- padding="max_length",
313
- max_length=self.tokenizer.model_max_length,
314
- truncation=True,
315
- return_tensors="pt",
316
- )
317
- text_input_ids = text_inputs.input_ids
318
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
319
-
320
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
321
- text_input_ids, untruncated_ids
322
- ):
323
- removed_text = self.tokenizer.batch_decode(
324
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
325
- )
326
- logger.warning(
327
- "The following part of your input was truncated because CLIP can only handle sequences up to"
328
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
329
- )
330
-
331
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
332
- attention_mask = text_inputs.attention_mask.to(device)
333
- else:
334
- attention_mask = None
335
-
336
- prompt_embeds = self.text_encoder(
337
- text_input_ids.to(device),
338
- attention_mask=attention_mask,
339
- )
340
- prompt_embeds = prompt_embeds[0]
341
-
342
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
343
-
344
- bs_embed, seq_len, _ = prompt_embeds.shape
345
- # duplicate text embeddings for each generation per prompt, using mps friendly method
346
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
347
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
348
-
349
- # get unconditional embeddings for classifier free guidance
350
- if do_classifier_free_guidance and negative_prompt_embeds is None:
351
- uncond_tokens: List[str]
352
- if negative_prompt is None:
353
- uncond_tokens = [""] * batch_size
354
- elif type(prompt) is not type(negative_prompt):
355
- raise TypeError(
356
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
- f" {type(prompt)}."
358
- )
359
- elif isinstance(negative_prompt, str):
360
- uncond_tokens = [negative_prompt]
361
- elif batch_size != len(negative_prompt):
362
- raise ValueError(
363
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
- " the batch size of `prompt`."
366
- )
367
- else:
368
- uncond_tokens = negative_prompt
369
-
370
- max_length = prompt_embeds.shape[1]
371
- uncond_input = self.tokenizer(
372
- uncond_tokens,
373
- padding="max_length",
374
- max_length=max_length,
375
- truncation=True,
376
- return_tensors="pt",
377
- )
378
-
379
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
380
- attention_mask = uncond_input.attention_mask.to(device)
381
- else:
382
- attention_mask = None
383
-
384
- negative_prompt_embeds = self.text_encoder(
385
- uncond_input.input_ids.to(device),
386
- attention_mask=attention_mask,
387
- )
388
- negative_prompt_embeds = negative_prompt_embeds[0]
389
-
390
- if do_classifier_free_guidance:
391
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
- seq_len = negative_prompt_embeds.shape[1]
393
-
394
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
395
-
396
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
398
-
399
- # For classifier free guidance, we need to do two forward passes.
400
- # Here we concatenate the unconditional and text embeddings into a single batch
401
- # to avoid doing two forward passes
402
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
403
-
404
- return prompt_embeds
405
-
406
- def run_safety_checker(self, image, device, dtype):
407
- if self.safety_checker is not None:
408
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
409
- image, has_nsfw_concept = self.safety_checker(
410
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
411
- )
412
- else:
413
- has_nsfw_concept = None
414
- return image, has_nsfw_concept
415
-
416
- def decode_latents(self, latents):
417
- latents = 1 / self.vae.config.scaling_factor * latents
418
- image = self.vae.decode(latents).sample
419
- image = (image / 2 + 0.5).clamp(0, 1)
420
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
421
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
422
- return image
423
-
424
- def prepare_extra_step_kwargs(self, generator, eta):
425
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
426
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
427
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
428
- # and should be between [0, 1]
429
-
430
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
431
- extra_step_kwargs = {}
432
- if accepts_eta:
433
- extra_step_kwargs["eta"] = eta
434
-
435
- # check if the scheduler accepts generator
436
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
437
- if accepts_generator:
438
- extra_step_kwargs["generator"] = generator
439
- return extra_step_kwargs
440
-
441
- def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
442
- image_is_pil = isinstance(image, PIL.Image.Image)
443
- image_is_tensor = isinstance(image, torch.Tensor)
444
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
445
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
446
-
447
- if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
448
- raise TypeError(
449
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
450
- )
451
-
452
- if image_is_pil:
453
- image_batch_size = 1
454
- elif image_is_tensor:
455
- image_batch_size = image.shape[0]
456
- elif image_is_pil_list:
457
- image_batch_size = len(image)
458
- elif image_is_tensor_list:
459
- image_batch_size = len(image)
460
- else:
461
- raise ValueError("controlnet condition image is not valid")
462
-
463
- if prompt is not None and isinstance(prompt, str):
464
- prompt_batch_size = 1
465
- elif prompt is not None and isinstance(prompt, list):
466
- prompt_batch_size = len(prompt)
467
- elif prompt_embeds is not None:
468
- prompt_batch_size = prompt_embeds.shape[0]
469
- else:
470
- raise ValueError("prompt or prompt_embeds are not valid")
471
-
472
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
473
- raise ValueError(
474
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
475
- )
476
-
477
- def check_inputs(
478
- self,
479
- prompt,
480
- image,
481
- controlnet_conditioning_image,
482
- height,
483
- width,
484
- callback_steps,
485
- negative_prompt=None,
486
- prompt_embeds=None,
487
- negative_prompt_embeds=None,
488
- strength=None,
489
- controlnet_guidance_start=None,
490
- controlnet_guidance_end=None,
491
- controlnet_conditioning_scale=None,
492
- ):
493
- if height % 8 != 0 or width % 8 != 0:
494
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
495
-
496
- if (callback_steps is None) or (
497
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
498
- ):
499
- raise ValueError(
500
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
501
- f" {type(callback_steps)}."
502
- )
503
-
504
- if prompt is not None and prompt_embeds is not None:
505
- raise ValueError(
506
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
507
- " only forward one of the two."
508
- )
509
- elif prompt is None and prompt_embeds is None:
510
- raise ValueError(
511
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
512
- )
513
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
514
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
515
-
516
- if negative_prompt is not None and negative_prompt_embeds is not None:
517
- raise ValueError(
518
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
519
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
520
- )
521
-
522
- if prompt_embeds is not None and negative_prompt_embeds is not None:
523
- if prompt_embeds.shape != negative_prompt_embeds.shape:
524
- raise ValueError(
525
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
526
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
527
- f" {negative_prompt_embeds.shape}."
528
- )
529
-
530
- # check controlnet condition image
531
-
532
- if isinstance(self.controlnet, ControlNetModel):
533
- self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
534
- elif isinstance(self.controlnet, MultiControlNetModel):
535
- if not isinstance(controlnet_conditioning_image, list):
536
- raise TypeError("For multiple controlnets: `image` must be type `list`")
537
-
538
- if len(controlnet_conditioning_image) != len(self.controlnet.nets):
539
- raise ValueError(
540
- "For multiple controlnets: `image` must have the same length as the number of controlnets."
541
- )
542
-
543
- for image_ in controlnet_conditioning_image:
544
- self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
545
- else:
546
- assert False
547
-
548
- # Check `controlnet_conditioning_scale`
549
-
550
- if isinstance(self.controlnet, ControlNetModel):
551
- if not isinstance(controlnet_conditioning_scale, float):
552
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
553
- elif isinstance(self.controlnet, MultiControlNetModel):
554
- if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
555
- self.controlnet.nets
556
- ):
557
- raise ValueError(
558
- "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
559
- " the same length as the number of controlnets"
560
- )
561
- else:
562
- assert False
563
-
564
- if isinstance(image, torch.Tensor):
565
- if image.ndim != 3 and image.ndim != 4:
566
- raise ValueError("`image` must have 3 or 4 dimensions")
567
-
568
- if image.ndim == 3:
569
- image_batch_size = 1
570
- image_channels, image_height, image_width = image.shape
571
- elif image.ndim == 4:
572
- image_batch_size, image_channels, image_height, image_width = image.shape
573
- else:
574
- assert False
575
-
576
- if image_channels != 3:
577
- raise ValueError("`image` must have 3 channels")
578
-
579
- if image.min() < -1 or image.max() > 1:
580
- raise ValueError("`image` should be in range [-1, 1]")
581
-
582
- if self.vae.config.latent_channels != self.unet.config.in_channels:
583
- raise ValueError(
584
- f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
585
- f" latent channels: {self.vae.config.latent_channels},"
586
- f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
587
- )
588
-
589
- if strength < 0 or strength > 1:
590
- raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
591
-
592
- if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
593
- raise ValueError(
594
- f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
595
- )
596
-
597
- if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
598
- raise ValueError(
599
- f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
600
- )
601
-
602
- if controlnet_guidance_start > controlnet_guidance_end:
603
- raise ValueError(
604
- "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
605
- f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
606
- )
607
-
608
- def get_timesteps(self, num_inference_steps, strength, device):
609
- # get the original timestep using init_timestep
610
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
611
-
612
- t_start = max(num_inference_steps - init_timestep, 0)
613
- timesteps = self.scheduler.timesteps[t_start:]
614
-
615
- return timesteps, num_inference_steps - t_start
616
-
617
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
618
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
619
- raise ValueError(
620
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
621
- )
622
-
623
- image = image.to(device=device, dtype=dtype)
624
-
625
- batch_size = batch_size * num_images_per_prompt
626
- if isinstance(generator, list) and len(generator) != batch_size:
627
- raise ValueError(
628
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
629
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
630
- )
631
-
632
- if isinstance(generator, list):
633
- init_latents = [
634
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
635
- ]
636
- init_latents = torch.cat(init_latents, dim=0)
637
- else:
638
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
639
-
640
- init_latents = self.vae.config.scaling_factor * init_latents
641
-
642
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
643
- raise ValueError(
644
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
645
- )
646
- else:
647
- init_latents = torch.cat([init_latents], dim=0)
648
-
649
- shape = init_latents.shape
650
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
651
-
652
- # get latents
653
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
654
- latents = init_latents
655
-
656
- return latents
657
-
658
- def _default_height_width(self, height, width, image):
659
- if isinstance(image, list):
660
- image = image[0]
661
-
662
- if height is None:
663
- if isinstance(image, PIL.Image.Image):
664
- height = image.height
665
- elif isinstance(image, torch.Tensor):
666
- height = image.shape[3]
667
-
668
- height = (height // 8) * 8 # round down to nearest multiple of 8
669
-
670
- if width is None:
671
- if isinstance(image, PIL.Image.Image):
672
- width = image.width
673
- elif isinstance(image, torch.Tensor):
674
- width = image.shape[2]
675
-
676
- width = (width // 8) * 8 # round down to nearest multiple of 8
677
-
678
- return height, width
679
-
680
- @torch.no_grad()
681
- @replace_example_docstring(EXAMPLE_DOC_STRING)
682
- def __call__(
683
- self,
684
- prompt: Union[str, List[str]] = None,
685
- image: Union[torch.Tensor, PIL.Image.Image] = None,
686
- controlnet_conditioning_image: Union[
687
- torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
688
- ] = None,
689
- strength: float = 0.8,
690
- height: Optional[int] = None,
691
- width: Optional[int] = None,
692
- num_inference_steps: int = 50,
693
- guidance_scale: float = 7.5,
694
- negative_prompt: Optional[Union[str, List[str]]] = None,
695
- num_images_per_prompt: Optional[int] = 1,
696
- eta: float = 0.0,
697
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
698
- latents: Optional[torch.FloatTensor] = None,
699
- prompt_embeds: Optional[torch.FloatTensor] = None,
700
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
701
- output_type: Optional[str] = "pil",
702
- return_dict: bool = True,
703
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
704
- callback_steps: int = 1,
705
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
706
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
707
- controlnet_guidance_start: float = 0.0,
708
- controlnet_guidance_end: float = 1.0,
709
- ):
710
- r"""
711
- Function invoked when calling the pipeline for generation.
712
-
713
- Args:
714
- prompt (`str` or `List[str]`, *optional*):
715
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
716
- instead.
717
- image (`torch.Tensor` or `PIL.Image.Image`):
718
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
719
- be masked out with `mask_image` and repainted according to `prompt`.
720
- controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
721
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
722
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
723
- also be accepted as an image. The control image is automatically resized to fit the output image.
724
- strength (`float`, *optional*):
725
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
726
- will be used as a starting point, adding more noise to it the larger the `strength`. The number of
727
- denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
728
- be maximum and the denoising process will run for the full number of iterations specified in
729
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
730
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
731
- The height in pixels of the generated image.
732
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733
- The width in pixels of the generated image.
734
- num_inference_steps (`int`, *optional*, defaults to 50):
735
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736
- expense of slower inference.
737
- guidance_scale (`float`, *optional*, defaults to 7.5):
738
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
739
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
740
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
741
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
742
- usually at the expense of lower image quality.
743
- negative_prompt (`str` or `List[str]`, *optional*):
744
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
745
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
746
- num_images_per_prompt (`int`, *optional*, defaults to 1):
747
- The number of images to generate per prompt.
748
- eta (`float`, *optional*, defaults to 0.0):
749
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
750
- [`schedulers.DDIMScheduler`], will be ignored for others.
751
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
752
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
753
- to make generation deterministic.
754
- latents (`torch.FloatTensor`, *optional*):
755
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
756
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
757
- tensor will ge generated by sampling using the supplied random `generator`.
758
- prompt_embeds (`torch.FloatTensor`, *optional*):
759
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
760
- provided, text embeddings will be generated from `prompt` input argument.
761
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
762
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
763
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
764
- argument.
765
- output_type (`str`, *optional*, defaults to `"pil"`):
766
- The output format of the generate image. Choose between
767
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
768
- return_dict (`bool`, *optional*, defaults to `True`):
769
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
770
- plain tuple.
771
- callback (`Callable`, *optional*):
772
- A function that will be called every `callback_steps` steps during inference. The function will be
773
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
774
- callback_steps (`int`, *optional*, defaults to 1):
775
- The frequency at which the `callback` function will be called. If not specified, the callback will be
776
- called at every step.
777
- cross_attention_kwargs (`dict`, *optional*):
778
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
779
- `self.processor` in
780
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
781
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
782
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
783
- to the residual in the original unet.
784
- controlnet_guidance_start ('float', *optional*, defaults to 0.0):
785
- The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
786
- controlnet_guidance_end ('float', *optional*, defaults to 1.0):
787
- The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
788
- than `controlnet_guidance_start`.
789
-
790
- Examples:
791
-
792
- Returns:
793
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
794
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
795
- When returning a tuple, the first element is a list with the generated images, and the second element is a
796
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
797
- (nsfw) content, according to the `safety_checker`.
798
- """
799
- # 0. Default height and width to unet
800
- height, width = self._default_height_width(height, width, controlnet_conditioning_image)
801
-
802
- # 1. Check inputs. Raise error if not correct
803
- self.check_inputs(
804
- prompt,
805
- image,
806
- controlnet_conditioning_image,
807
- height,
808
- width,
809
- callback_steps,
810
- negative_prompt,
811
- prompt_embeds,
812
- negative_prompt_embeds,
813
- strength,
814
- controlnet_guidance_start,
815
- controlnet_guidance_end,
816
- controlnet_conditioning_scale,
817
- )
818
-
819
- # 2. Define call parameters
820
- if prompt is not None and isinstance(prompt, str):
821
- batch_size = 1
822
- elif prompt is not None and isinstance(prompt, list):
823
- batch_size = len(prompt)
824
- else:
825
- batch_size = prompt_embeds.shape[0]
826
-
827
- device = self._execution_device
828
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
829
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
830
- # corresponds to doing no classifier free guidance.
831
- do_classifier_free_guidance = guidance_scale > 1.0
832
-
833
- if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
834
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
835
-
836
- # 3. Encode input prompt
837
- prompt_embeds = self._encode_prompt(
838
- prompt,
839
- device,
840
- num_images_per_prompt,
841
- do_classifier_free_guidance,
842
- negative_prompt,
843
- prompt_embeds=prompt_embeds,
844
- negative_prompt_embeds=negative_prompt_embeds,
845
- )
846
-
847
- # 4. Prepare image, and controlnet_conditioning_image
848
- image = prepare_image(image)
849
-
850
- # condition image(s)
851
- if isinstance(self.controlnet, ControlNetModel):
852
- controlnet_conditioning_image = prepare_controlnet_conditioning_image(
853
- controlnet_conditioning_image=controlnet_conditioning_image,
854
- width=width,
855
- height=height,
856
- batch_size=batch_size * num_images_per_prompt,
857
- num_images_per_prompt=num_images_per_prompt,
858
- device=device,
859
- dtype=self.controlnet.dtype,
860
- do_classifier_free_guidance=do_classifier_free_guidance,
861
- )
862
- elif isinstance(self.controlnet, MultiControlNetModel):
863
- controlnet_conditioning_images = []
864
-
865
- for image_ in controlnet_conditioning_image:
866
- image_ = prepare_controlnet_conditioning_image(
867
- controlnet_conditioning_image=image_,
868
- width=width,
869
- height=height,
870
- batch_size=batch_size * num_images_per_prompt,
871
- num_images_per_prompt=num_images_per_prompt,
872
- device=device,
873
- dtype=self.controlnet.dtype,
874
- do_classifier_free_guidance=do_classifier_free_guidance,
875
- )
876
-
877
- controlnet_conditioning_images.append(image_)
878
-
879
- controlnet_conditioning_image = controlnet_conditioning_images
880
- else:
881
- assert False
882
-
883
- # 5. Prepare timesteps
884
- self.scheduler.set_timesteps(num_inference_steps, device=device)
885
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
886
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
887
-
888
- # 6. Prepare latent variables
889
- latents = self.prepare_latents(
890
- image,
891
- latent_timestep,
892
- batch_size,
893
- num_images_per_prompt,
894
- prompt_embeds.dtype,
895
- device,
896
- generator,
897
- )
898
-
899
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
900
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
901
-
902
- # 8. Denoising loop
903
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
904
- with self.progress_bar(total=num_inference_steps) as progress_bar:
905
- for i, t in enumerate(timesteps):
906
- # expand the latents if we are doing classifier free guidance
907
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
908
-
909
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
910
-
911
- # compute the percentage of total steps we are at
912
- current_sampling_percent = i / len(timesteps)
913
-
914
- if (
915
- current_sampling_percent < controlnet_guidance_start
916
- or current_sampling_percent > controlnet_guidance_end
917
- ):
918
- # do not apply the controlnet
919
- down_block_res_samples = None
920
- mid_block_res_sample = None
921
- else:
922
- # apply the controlnet
923
- down_block_res_samples, mid_block_res_sample = self.controlnet(
924
- latent_model_input,
925
- t,
926
- encoder_hidden_states=prompt_embeds,
927
- controlnet_cond=controlnet_conditioning_image,
928
- conditioning_scale=controlnet_conditioning_scale,
929
- return_dict=False,
930
- )
931
-
932
- # predict the noise residual
933
- noise_pred = self.unet(
934
- latent_model_input,
935
- t,
936
- encoder_hidden_states=prompt_embeds,
937
- cross_attention_kwargs=cross_attention_kwargs,
938
- down_block_additional_residuals=down_block_res_samples,
939
- mid_block_additional_residual=mid_block_res_sample,
940
- ).sample
941
-
942
- # perform guidance
943
- if do_classifier_free_guidance:
944
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
945
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
946
-
947
- # compute the previous noisy sample x_t -> x_t-1
948
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
949
-
950
- # call the callback, if provided
951
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
952
- progress_bar.update()
953
- if callback is not None and i % callback_steps == 0:
954
- callback(i, t, latents)
955
-
956
- # If we do sequential model offloading, let's offload unet and controlnet
957
- # manually for max memory savings
958
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
959
- self.unet.to("cpu")
960
- self.controlnet.to("cpu")
961
- torch.cuda.empty_cache()
962
-
963
- if output_type == "latent":
964
- image = latents
965
- has_nsfw_concept = None
966
- elif output_type == "pil":
967
- # 8. Post-processing
968
- image = self.decode_latents(latents)
969
-
970
- # 9. Run safety checker
971
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
972
-
973
- # 10. Convert to PIL
974
- image = self.numpy_to_pil(image)
975
- else:
976
- # 8. Post-processing
977
- image = self.decode_latents(latents)
978
-
979
- # 9. Run safety checker
980
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
981
-
982
- # Offload last model to CPU
983
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
984
- self.final_offload_hook.offload()
985
-
986
- if not return_dict:
987
- return (image, has_nsfw_concept)
988
-
989
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_diffusion_controlnet_inpaint.py DELETED
@@ -1,1076 +0,0 @@
1
- # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
-
3
- import inspect
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
-
6
- import numpy as np
7
- import PIL.Image
8
- import torch
9
- import torch.nn.functional as F
10
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
-
12
- from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14
- from diffusers.schedulers import KarrasDiffusionSchedulers
15
- from diffusers.utils import (
16
- PIL_INTERPOLATION,
17
- is_accelerate_available,
18
- is_accelerate_version,
19
- randn_tensor,
20
- replace_example_docstring,
21
- )
22
-
23
-
24
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
-
26
- EXAMPLE_DOC_STRING = """
27
- Examples:
28
- ```py
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from PIL import Image
32
- >>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
33
-
34
- >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
35
- >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
36
- >>> from diffusers.utils import load_image
37
-
38
- >>> def ade_palette():
39
- return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
40
- [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
41
- [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
42
- [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
43
- [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
44
- [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
45
- [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
46
- [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
47
- [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
48
- [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
49
- [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
50
- [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51
- [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
52
- [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
53
- [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
54
- [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
55
- [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
56
- [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
57
- [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
58
- [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
59
- [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
60
- [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
61
- [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
62
- [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
63
- [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
64
- [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
65
- [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
66
- [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
67
- [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
68
- [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
69
- [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
70
- [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
71
- [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
72
- [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
73
- [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
74
- [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
75
- [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
76
- [102, 255, 0], [92, 0, 255]]
77
-
78
- >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
- >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
-
81
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
82
-
83
- >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
84
- "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
85
- )
86
-
87
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
88
- >>> pipe.enable_xformers_memory_efficient_attention()
89
- >>> pipe.enable_model_cpu_offload()
90
-
91
- >>> def image_to_seg(image):
92
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
93
- with torch.no_grad():
94
- outputs = image_segmentor(pixel_values)
95
- seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
96
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
97
- palette = np.array(ade_palette())
98
- for label, color in enumerate(palette):
99
- color_seg[seg == label, :] = color
100
- color_seg = color_seg.astype(np.uint8)
101
- seg_image = Image.fromarray(color_seg)
102
- return seg_image
103
-
104
- >>> image = load_image(
105
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
106
- )
107
-
108
- >>> mask_image = load_image(
109
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
110
- )
111
-
112
- >>> controlnet_conditioning_image = image_to_seg(image)
113
-
114
- >>> image = pipe(
115
- "Face of a yellow cat, high resolution, sitting on a park bench",
116
- image,
117
- mask_image,
118
- controlnet_conditioning_image,
119
- num_inference_steps=20,
120
- ).images[0]
121
-
122
- >>> image.save("out.png")
123
- ```
124
- """
125
-
126
-
127
- def prepare_image(image):
128
- if isinstance(image, torch.Tensor):
129
- # Batch single image
130
- if image.ndim == 3:
131
- image = image.unsqueeze(0)
132
-
133
- image = image.to(dtype=torch.float32)
134
- else:
135
- # preprocess image
136
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
137
- image = [image]
138
-
139
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
140
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
141
- image = np.concatenate(image, axis=0)
142
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
143
- image = np.concatenate([i[None, :] for i in image], axis=0)
144
-
145
- image = image.transpose(0, 3, 1, 2)
146
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
147
-
148
- return image
149
-
150
-
151
- def prepare_mask_image(mask_image):
152
- if isinstance(mask_image, torch.Tensor):
153
- if mask_image.ndim == 2:
154
- # Batch and add channel dim for single mask
155
- mask_image = mask_image.unsqueeze(0).unsqueeze(0)
156
- elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
157
- # Single mask, the 0'th dimension is considered to be
158
- # the existing batch size of 1
159
- mask_image = mask_image.unsqueeze(0)
160
- elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
161
- # Batch of mask, the 0'th dimension is considered to be
162
- # the batching dimension
163
- mask_image = mask_image.unsqueeze(1)
164
-
165
- # Binarize mask
166
- mask_image[mask_image < 0.5] = 0
167
- mask_image[mask_image >= 0.5] = 1
168
- else:
169
- # preprocess mask
170
- if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
171
- mask_image = [mask_image]
172
-
173
- if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
174
- mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
175
- mask_image = mask_image.astype(np.float32) / 255.0
176
- elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
177
- mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
178
-
179
- mask_image[mask_image < 0.5] = 0
180
- mask_image[mask_image >= 0.5] = 1
181
- mask_image = torch.from_numpy(mask_image)
182
-
183
- return mask_image
184
-
185
-
186
- def prepare_controlnet_conditioning_image(
187
- controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188
- ):
189
- if not isinstance(controlnet_conditioning_image, torch.Tensor):
190
- if isinstance(controlnet_conditioning_image, PIL.Image.Image):
191
- controlnet_conditioning_image = [controlnet_conditioning_image]
192
-
193
- if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
194
- controlnet_conditioning_image = [
195
- np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
196
- for i in controlnet_conditioning_image
197
- ]
198
- controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
199
- controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
200
- controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
201
- controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
202
- elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
203
- controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
204
-
205
- image_batch_size = controlnet_conditioning_image.shape[0]
206
-
207
- if image_batch_size == 1:
208
- repeat_by = batch_size
209
- else:
210
- # image batch size is the same as prompt batch size
211
- repeat_by = num_images_per_prompt
212
-
213
- controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
214
-
215
- controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216
-
217
- return controlnet_conditioning_image
218
-
219
-
220
- class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
221
- """
222
- Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
223
- """
224
-
225
- _optional_components = ["safety_checker", "feature_extractor"]
226
-
227
- def __init__(
228
- self,
229
- vae: AutoencoderKL,
230
- text_encoder: CLIPTextModel,
231
- tokenizer: CLIPTokenizer,
232
- unet: UNet2DConditionModel,
233
- controlnet: ControlNetModel,
234
- scheduler: KarrasDiffusionSchedulers,
235
- safety_checker: StableDiffusionSafetyChecker,
236
- feature_extractor: CLIPImageProcessor,
237
- requires_safety_checker: bool = True,
238
- ):
239
- super().__init__()
240
-
241
- if safety_checker is None and requires_safety_checker:
242
- logger.warning(
243
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
244
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
245
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
246
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
247
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
248
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
249
- )
250
-
251
- if safety_checker is not None and feature_extractor is None:
252
- raise ValueError(
253
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
254
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255
- )
256
-
257
- self.register_modules(
258
- vae=vae,
259
- text_encoder=text_encoder,
260
- tokenizer=tokenizer,
261
- unet=unet,
262
- controlnet=controlnet,
263
- scheduler=scheduler,
264
- safety_checker=safety_checker,
265
- feature_extractor=feature_extractor,
266
- )
267
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268
- self.register_to_config(requires_safety_checker=requires_safety_checker)
269
-
270
- def enable_vae_slicing(self):
271
- r"""
272
- Enable sliced VAE decoding.
273
-
274
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
275
- steps. This is useful to save some memory and allow larger batch sizes.
276
- """
277
- self.vae.enable_slicing()
278
-
279
- def disable_vae_slicing(self):
280
- r"""
281
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
282
- computing decoding in one step.
283
- """
284
- self.vae.disable_slicing()
285
-
286
- def enable_sequential_cpu_offload(self, gpu_id=0):
287
- r"""
288
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
289
- text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
290
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
291
- Note that offloading happens on a submodule basis. Memory savings are higher than with
292
- `enable_model_cpu_offload`, but performance is lower.
293
- """
294
- if is_accelerate_available():
295
- from accelerate import cpu_offload
296
- else:
297
- raise ImportError("Please install accelerate via `pip install accelerate`")
298
-
299
- device = torch.device(f"cuda:{gpu_id}")
300
-
301
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
302
- cpu_offload(cpu_offloaded_model, device)
303
-
304
- if self.safety_checker is not None:
305
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
306
-
307
- def enable_model_cpu_offload(self, gpu_id=0):
308
- r"""
309
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
310
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
311
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
312
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
313
- """
314
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
315
- from accelerate import cpu_offload_with_hook
316
- else:
317
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
318
-
319
- device = torch.device(f"cuda:{gpu_id}")
320
-
321
- hook = None
322
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
323
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
324
-
325
- if self.safety_checker is not None:
326
- # the safety checker can offload the vae again
327
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
328
-
329
- # control net hook has be manually offloaded as it alternates with unet
330
- cpu_offload_with_hook(self.controlnet, device)
331
-
332
- # We'll offload the last model manually.
333
- self.final_offload_hook = hook
334
-
335
- @property
336
- def _execution_device(self):
337
- r"""
338
- Returns the device on which the pipeline's models will be executed. After calling
339
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
340
- hooks.
341
- """
342
- if not hasattr(self.unet, "_hf_hook"):
343
- return self.device
344
- for module in self.unet.modules():
345
- if (
346
- hasattr(module, "_hf_hook")
347
- and hasattr(module._hf_hook, "execution_device")
348
- and module._hf_hook.execution_device is not None
349
- ):
350
- return torch.device(module._hf_hook.execution_device)
351
- return self.device
352
-
353
- def _encode_prompt(
354
- self,
355
- prompt,
356
- device,
357
- num_images_per_prompt,
358
- do_classifier_free_guidance,
359
- negative_prompt=None,
360
- prompt_embeds: Optional[torch.FloatTensor] = None,
361
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
362
- ):
363
- r"""
364
- Encodes the prompt into text encoder hidden states.
365
-
366
- Args:
367
- prompt (`str` or `List[str]`, *optional*):
368
- prompt to be encoded
369
- device: (`torch.device`):
370
- torch device
371
- num_images_per_prompt (`int`):
372
- number of images that should be generated per prompt
373
- do_classifier_free_guidance (`bool`):
374
- whether to use classifier free guidance or not
375
- negative_prompt (`str` or `List[str]`, *optional*):
376
- The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
377
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
378
- prompt_embeds (`torch.FloatTensor`, *optional*):
379
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
380
- provided, text embeddings will be generated from `prompt` input argument.
381
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
382
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
383
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
384
- argument.
385
- """
386
- if prompt is not None and isinstance(prompt, str):
387
- batch_size = 1
388
- elif prompt is not None and isinstance(prompt, list):
389
- batch_size = len(prompt)
390
- else:
391
- batch_size = prompt_embeds.shape[0]
392
-
393
- if prompt_embeds is None:
394
- text_inputs = self.tokenizer(
395
- prompt,
396
- padding="max_length",
397
- max_length=self.tokenizer.model_max_length,
398
- truncation=True,
399
- return_tensors="pt",
400
- )
401
- text_input_ids = text_inputs.input_ids
402
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
403
-
404
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
405
- text_input_ids, untruncated_ids
406
- ):
407
- removed_text = self.tokenizer.batch_decode(
408
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
409
- )
410
- logger.warning(
411
- "The following part of your input was truncated because CLIP can only handle sequences up to"
412
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
413
- )
414
-
415
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
416
- attention_mask = text_inputs.attention_mask.to(device)
417
- else:
418
- attention_mask = None
419
-
420
- prompt_embeds = self.text_encoder(
421
- text_input_ids.to(device),
422
- attention_mask=attention_mask,
423
- )
424
- prompt_embeds = prompt_embeds[0]
425
-
426
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
427
-
428
- bs_embed, seq_len, _ = prompt_embeds.shape
429
- # duplicate text embeddings for each generation per prompt, using mps friendly method
430
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
431
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
432
-
433
- # get unconditional embeddings for classifier free guidance
434
- if do_classifier_free_guidance and negative_prompt_embeds is None:
435
- uncond_tokens: List[str]
436
- if negative_prompt is None:
437
- uncond_tokens = [""] * batch_size
438
- elif type(prompt) is not type(negative_prompt):
439
- raise TypeError(
440
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
441
- f" {type(prompt)}."
442
- )
443
- elif isinstance(negative_prompt, str):
444
- uncond_tokens = [negative_prompt]
445
- elif batch_size != len(negative_prompt):
446
- raise ValueError(
447
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
448
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
449
- " the batch size of `prompt`."
450
- )
451
- else:
452
- uncond_tokens = negative_prompt
453
-
454
- max_length = prompt_embeds.shape[1]
455
- uncond_input = self.tokenizer(
456
- uncond_tokens,
457
- padding="max_length",
458
- max_length=max_length,
459
- truncation=True,
460
- return_tensors="pt",
461
- )
462
-
463
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
464
- attention_mask = uncond_input.attention_mask.to(device)
465
- else:
466
- attention_mask = None
467
-
468
- negative_prompt_embeds = self.text_encoder(
469
- uncond_input.input_ids.to(device),
470
- attention_mask=attention_mask,
471
- )
472
- negative_prompt_embeds = negative_prompt_embeds[0]
473
-
474
- if do_classifier_free_guidance:
475
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
476
- seq_len = negative_prompt_embeds.shape[1]
477
-
478
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
479
-
480
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
481
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
482
-
483
- # For classifier free guidance, we need to do two forward passes.
484
- # Here we concatenate the unconditional and text embeddings into a single batch
485
- # to avoid doing two forward passes
486
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
487
-
488
- return prompt_embeds
489
-
490
- def run_safety_checker(self, image, device, dtype):
491
- if self.safety_checker is not None:
492
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
493
- image, has_nsfw_concept = self.safety_checker(
494
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
495
- )
496
- else:
497
- has_nsfw_concept = None
498
- return image, has_nsfw_concept
499
-
500
- def decode_latents(self, latents):
501
- latents = 1 / self.vae.config.scaling_factor * latents
502
- image = self.vae.decode(latents).sample
503
- image = (image / 2 + 0.5).clamp(0, 1)
504
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
505
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
506
- return image
507
-
508
- def prepare_extra_step_kwargs(self, generator, eta):
509
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
510
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
511
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
512
- # and should be between [0, 1]
513
-
514
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
515
- extra_step_kwargs = {}
516
- if accepts_eta:
517
- extra_step_kwargs["eta"] = eta
518
-
519
- # check if the scheduler accepts generator
520
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
521
- if accepts_generator:
522
- extra_step_kwargs["generator"] = generator
523
- return extra_step_kwargs
524
-
525
- def check_inputs(
526
- self,
527
- prompt,
528
- image,
529
- mask_image,
530
- controlnet_conditioning_image,
531
- height,
532
- width,
533
- callback_steps,
534
- negative_prompt=None,
535
- prompt_embeds=None,
536
- negative_prompt_embeds=None,
537
- ):
538
- if height % 8 != 0 or width % 8 != 0:
539
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
540
-
541
- if (callback_steps is None) or (
542
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
543
- ):
544
- raise ValueError(
545
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
546
- f" {type(callback_steps)}."
547
- )
548
-
549
- if prompt is not None and prompt_embeds is not None:
550
- raise ValueError(
551
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
552
- " only forward one of the two."
553
- )
554
- elif prompt is None and prompt_embeds is None:
555
- raise ValueError(
556
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
557
- )
558
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
559
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
560
-
561
- if negative_prompt is not None and negative_prompt_embeds is not None:
562
- raise ValueError(
563
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
564
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
565
- )
566
-
567
- if prompt_embeds is not None and negative_prompt_embeds is not None:
568
- if prompt_embeds.shape != negative_prompt_embeds.shape:
569
- raise ValueError(
570
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
571
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
572
- f" {negative_prompt_embeds.shape}."
573
- )
574
-
575
- controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
576
- controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
577
- controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
578
- controlnet_conditioning_image[0], PIL.Image.Image
579
- )
580
- controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
581
- controlnet_conditioning_image[0], torch.Tensor
582
- )
583
-
584
- if (
585
- not controlnet_cond_image_is_pil
586
- and not controlnet_cond_image_is_tensor
587
- and not controlnet_cond_image_is_pil_list
588
- and not controlnet_cond_image_is_tensor_list
589
- ):
590
- raise TypeError(
591
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
592
- )
593
-
594
- if controlnet_cond_image_is_pil:
595
- controlnet_cond_image_batch_size = 1
596
- elif controlnet_cond_image_is_tensor:
597
- controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
598
- elif controlnet_cond_image_is_pil_list:
599
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
600
- elif controlnet_cond_image_is_tensor_list:
601
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
602
-
603
- if prompt is not None and isinstance(prompt, str):
604
- prompt_batch_size = 1
605
- elif prompt is not None and isinstance(prompt, list):
606
- prompt_batch_size = len(prompt)
607
- elif prompt_embeds is not None:
608
- prompt_batch_size = prompt_embeds.shape[0]
609
-
610
- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
611
- raise ValueError(
612
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
613
- )
614
-
615
- if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
616
- raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
617
-
618
- if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
619
- raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
620
-
621
- if isinstance(image, torch.Tensor):
622
- if image.ndim != 3 and image.ndim != 4:
623
- raise ValueError("`image` must have 3 or 4 dimensions")
624
-
625
- if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
626
- raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
627
-
628
- if image.ndim == 3:
629
- image_batch_size = 1
630
- image_channels, image_height, image_width = image.shape
631
- elif image.ndim == 4:
632
- image_batch_size, image_channels, image_height, image_width = image.shape
633
-
634
- if mask_image.ndim == 2:
635
- mask_image_batch_size = 1
636
- mask_image_channels = 1
637
- mask_image_height, mask_image_width = mask_image.shape
638
- elif mask_image.ndim == 3:
639
- mask_image_channels = 1
640
- mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
641
- elif mask_image.ndim == 4:
642
- mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
643
-
644
- if image_channels != 3:
645
- raise ValueError("`image` must have 3 channels")
646
-
647
- if mask_image_channels != 1:
648
- raise ValueError("`mask_image` must have 1 channel")
649
-
650
- if image_batch_size != mask_image_batch_size:
651
- raise ValueError("`image` and `mask_image` mush have the same batch sizes")
652
-
653
- if image_height != mask_image_height or image_width != mask_image_width:
654
- raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
655
-
656
- if image.min() < -1 or image.max() > 1:
657
- raise ValueError("`image` should be in range [-1, 1]")
658
-
659
- if mask_image.min() < 0 or mask_image.max() > 1:
660
- raise ValueError("`mask_image` should be in range [0, 1]")
661
- else:
662
- mask_image_channels = 1
663
- image_channels = 3
664
-
665
- single_image_latent_channels = self.vae.config.latent_channels
666
-
667
- total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
668
-
669
- if total_latent_channels != self.unet.config.in_channels:
670
- raise ValueError(
671
- f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
672
- f" non inpainting latent channels: {single_image_latent_channels},"
673
- f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
674
- f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
675
- )
676
-
677
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
678
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
679
- if isinstance(generator, list) and len(generator) != batch_size:
680
- raise ValueError(
681
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
682
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
683
- )
684
-
685
- if latents is None:
686
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
687
- else:
688
- latents = latents.to(device)
689
-
690
- # scale the initial noise by the standard deviation required by the scheduler
691
- latents = latents * self.scheduler.init_noise_sigma
692
-
693
- return latents
694
-
695
- def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
696
- # resize the mask to latents shape as we concatenate the mask to the latents
697
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
698
- # and half precision
699
- mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
700
- mask_image = mask_image.to(device=device, dtype=dtype)
701
-
702
- # duplicate mask for each generation per prompt, using mps friendly method
703
- if mask_image.shape[0] < batch_size:
704
- if not batch_size % mask_image.shape[0] == 0:
705
- raise ValueError(
706
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
707
- f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
708
- " of masks that you pass is divisible by the total requested batch size."
709
- )
710
- mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
711
-
712
- mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
713
-
714
- mask_image_latents = mask_image
715
-
716
- return mask_image_latents
717
-
718
- def prepare_masked_image_latents(
719
- self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
720
- ):
721
- masked_image = masked_image.to(device=device, dtype=dtype)
722
-
723
- # encode the mask image into latents space so we can concatenate it to the latents
724
- if isinstance(generator, list):
725
- masked_image_latents = [
726
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
727
- for i in range(batch_size)
728
- ]
729
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
730
- else:
731
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
732
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
733
-
734
- # duplicate masked_image_latents for each generation per prompt, using mps friendly method
735
- if masked_image_latents.shape[0] < batch_size:
736
- if not batch_size % masked_image_latents.shape[0] == 0:
737
- raise ValueError(
738
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
739
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
740
- " Make sure the number of images that you pass is divisible by the total requested batch size."
741
- )
742
- masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
743
-
744
- masked_image_latents = (
745
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
746
- )
747
-
748
- # aligning device to prevent device errors when concating it with the latent model input
749
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
750
- return masked_image_latents
751
-
752
- def _default_height_width(self, height, width, image):
753
- if isinstance(image, list):
754
- image = image[0]
755
-
756
- if height is None:
757
- if isinstance(image, PIL.Image.Image):
758
- height = image.height
759
- elif isinstance(image, torch.Tensor):
760
- height = image.shape[3]
761
-
762
- height = (height // 8) * 8 # round down to nearest multiple of 8
763
-
764
- if width is None:
765
- if isinstance(image, PIL.Image.Image):
766
- width = image.width
767
- elif isinstance(image, torch.Tensor):
768
- width = image.shape[2]
769
-
770
- width = (width // 8) * 8 # round down to nearest multiple of 8
771
-
772
- return height, width
773
-
774
- @torch.no_grad()
775
- @replace_example_docstring(EXAMPLE_DOC_STRING)
776
- def __call__(
777
- self,
778
- prompt: Union[str, List[str]] = None,
779
- image: Union[torch.Tensor, PIL.Image.Image] = None,
780
- mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
781
- controlnet_conditioning_image: Union[
782
- torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
783
- ] = None,
784
- height: Optional[int] = None,
785
- width: Optional[int] = None,
786
- num_inference_steps: int = 50,
787
- guidance_scale: float = 7.5,
788
- negative_prompt: Optional[Union[str, List[str]]] = None,
789
- num_images_per_prompt: Optional[int] = 1,
790
- eta: float = 0.0,
791
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
792
- latents: Optional[torch.FloatTensor] = None,
793
- prompt_embeds: Optional[torch.FloatTensor] = None,
794
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
795
- output_type: Optional[str] = "pil",
796
- return_dict: bool = True,
797
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
798
- callback_steps: int = 1,
799
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
800
- controlnet_conditioning_scale: float = 1.0,
801
- ):
802
- r"""
803
- Function invoked when calling the pipeline for generation.
804
-
805
- Args:
806
- prompt (`str` or `List[str]`, *optional*):
807
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
808
- instead.
809
- image (`torch.Tensor` or `PIL.Image.Image`):
810
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
811
- be masked out with `mask_image` and repainted according to `prompt`.
812
- mask_image (`torch.Tensor` or `PIL.Image.Image`):
813
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
814
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
815
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
816
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
817
- controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
818
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
819
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
820
- also be accepted as an image. The control image is automatically resized to fit the output image.
821
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
822
- The height in pixels of the generated image.
823
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
824
- The width in pixels of the generated image.
825
- num_inference_steps (`int`, *optional*, defaults to 50):
826
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
827
- expense of slower inference.
828
- guidance_scale (`float`, *optional*, defaults to 7.5):
829
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
830
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
831
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
832
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
833
- usually at the expense of lower image quality.
834
- negative_prompt (`str` or `List[str]`, *optional*):
835
- The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
836
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
837
- num_images_per_prompt (`int`, *optional*, defaults to 1):
838
- The number of images to generate per prompt.
839
- eta (`float`, *optional*, defaults to 0.0):
840
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
841
- [`schedulers.DDIMScheduler`], will be ignored for others.
842
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
843
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
844
- to make generation deterministic.
845
- latents (`torch.FloatTensor`, *optional*):
846
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
847
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
848
- tensor will ge generated by sampling using the supplied random `generator`.
849
- prompt_embeds (`torch.FloatTensor`, *optional*):
850
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
851
- provided, text embeddings will be generated from `prompt` input argument.
852
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
853
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
854
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
855
- argument.
856
- output_type (`str`, *optional*, defaults to `"pil"`):
857
- The output format of the generate image. Choose between
858
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
859
- return_dict (`bool`, *optional*, defaults to `True`):
860
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
861
- plain tuple.
862
- callback (`Callable`, *optional*):
863
- A function that will be called every `callback_steps` steps during inference. The function will be
864
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
865
- callback_steps (`int`, *optional*, defaults to 1):
866
- The frequency at which the `callback` function will be called. If not specified, the callback will be
867
- called at every step.
868
- cross_attention_kwargs (`dict`, *optional*):
869
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
870
- `self.processor` in
871
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
872
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
873
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
874
- to the residual in the original unet.
875
-
876
- Examples:
877
-
878
- Returns:
879
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
880
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
881
- When returning a tuple, the first element is a list with the generated images, and the second element is a
882
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
883
- (nsfw) content, according to the `safety_checker`.
884
- """
885
- # 0. Default height and width to unet
886
- height, width = self._default_height_width(height, width, controlnet_conditioning_image)
887
-
888
- # 1. Check inputs. Raise error if not correct
889
- self.check_inputs(
890
- prompt,
891
- image,
892
- mask_image,
893
- controlnet_conditioning_image,
894
- height,
895
- width,
896
- callback_steps,
897
- negative_prompt,
898
- prompt_embeds,
899
- negative_prompt_embeds,
900
- )
901
-
902
- # 2. Define call parameters
903
- if prompt is not None and isinstance(prompt, str):
904
- batch_size = 1
905
- elif prompt is not None and isinstance(prompt, list):
906
- batch_size = len(prompt)
907
- else:
908
- batch_size = prompt_embeds.shape[0]
909
-
910
- device = self._execution_device
911
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
912
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
913
- # corresponds to doing no classifier free guidance.
914
- do_classifier_free_guidance = guidance_scale > 1.0
915
-
916
- # 3. Encode input prompt
917
- prompt_embeds = self._encode_prompt(
918
- prompt,
919
- device,
920
- num_images_per_prompt,
921
- do_classifier_free_guidance,
922
- negative_prompt,
923
- prompt_embeds=prompt_embeds,
924
- negative_prompt_embeds=negative_prompt_embeds,
925
- )
926
-
927
- # 4. Prepare mask, image, and controlnet_conditioning_image
928
- image = prepare_image(image)
929
-
930
- mask_image = prepare_mask_image(mask_image)
931
-
932
- controlnet_conditioning_image = prepare_controlnet_conditioning_image(
933
- controlnet_conditioning_image,
934
- width,
935
- height,
936
- batch_size * num_images_per_prompt,
937
- num_images_per_prompt,
938
- device,
939
- self.controlnet.dtype,
940
- )
941
-
942
- masked_image = image * (mask_image < 0.5)
943
-
944
- # 5. Prepare timesteps
945
- self.scheduler.set_timesteps(num_inference_steps, device=device)
946
- timesteps = self.scheduler.timesteps
947
-
948
- # 6. Prepare latent variables
949
- num_channels_latents = self.vae.config.latent_channels
950
- latents = self.prepare_latents(
951
- batch_size * num_images_per_prompt,
952
- num_channels_latents,
953
- height,
954
- width,
955
- prompt_embeds.dtype,
956
- device,
957
- generator,
958
- latents,
959
- )
960
-
961
- mask_image_latents = self.prepare_mask_latents(
962
- mask_image,
963
- batch_size * num_images_per_prompt,
964
- height,
965
- width,
966
- prompt_embeds.dtype,
967
- device,
968
- do_classifier_free_guidance,
969
- )
970
-
971
- masked_image_latents = self.prepare_masked_image_latents(
972
- masked_image,
973
- batch_size * num_images_per_prompt,
974
- height,
975
- width,
976
- prompt_embeds.dtype,
977
- device,
978
- generator,
979
- do_classifier_free_guidance,
980
- )
981
-
982
- if do_classifier_free_guidance:
983
- controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
984
-
985
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
986
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
987
-
988
- # 8. Denoising loop
989
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
990
- with self.progress_bar(total=num_inference_steps) as progress_bar:
991
- for i, t in enumerate(timesteps):
992
- # expand the latents if we are doing classifier free guidance
993
- non_inpainting_latent_model_input = (
994
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
995
- )
996
-
997
- non_inpainting_latent_model_input = self.scheduler.scale_model_input(
998
- non_inpainting_latent_model_input, t
999
- )
1000
-
1001
- inpainting_latent_model_input = torch.cat(
1002
- [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1003
- )
1004
-
1005
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1006
- non_inpainting_latent_model_input,
1007
- t,
1008
- encoder_hidden_states=prompt_embeds,
1009
- controlnet_cond=controlnet_conditioning_image,
1010
- return_dict=False,
1011
- )
1012
-
1013
- down_block_res_samples = [
1014
- down_block_res_sample * controlnet_conditioning_scale
1015
- for down_block_res_sample in down_block_res_samples
1016
- ]
1017
- mid_block_res_sample *= controlnet_conditioning_scale
1018
-
1019
- # predict the noise residual
1020
- noise_pred = self.unet(
1021
- inpainting_latent_model_input,
1022
- t,
1023
- encoder_hidden_states=prompt_embeds,
1024
- cross_attention_kwargs=cross_attention_kwargs,
1025
- down_block_additional_residuals=down_block_res_samples,
1026
- mid_block_additional_residual=mid_block_res_sample,
1027
- ).sample
1028
-
1029
- # perform guidance
1030
- if do_classifier_free_guidance:
1031
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1032
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1033
-
1034
- # compute the previous noisy sample x_t -> x_t-1
1035
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1036
-
1037
- # call the callback, if provided
1038
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1039
- progress_bar.update()
1040
- if callback is not None and i % callback_steps == 0:
1041
- callback(i, t, latents)
1042
-
1043
- # If we do sequential model offloading, let's offload unet and controlnet
1044
- # manually for max memory savings
1045
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1046
- self.unet.to("cpu")
1047
- self.controlnet.to("cpu")
1048
- torch.cuda.empty_cache()
1049
-
1050
- if output_type == "latent":
1051
- image = latents
1052
- has_nsfw_concept = None
1053
- elif output_type == "pil":
1054
- # 8. Post-processing
1055
- image = self.decode_latents(latents)
1056
-
1057
- # 9. Run safety checker
1058
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1059
-
1060
- # 10. Convert to PIL
1061
- image = self.numpy_to_pil(image)
1062
- else:
1063
- # 8. Post-processing
1064
- image = self.decode_latents(latents)
1065
-
1066
- # 9. Run safety checker
1067
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1068
-
1069
- # Offload last model to CPU
1070
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1071
- self.final_offload_hook.offload()
1072
-
1073
- if not return_dict:
1074
- return (image, has_nsfw_concept)
1075
-
1076
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py DELETED
@@ -1,1119 +0,0 @@
1
- # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
-
3
- import inspect
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
-
6
- import numpy as np
7
- import PIL.Image
8
- import torch
9
- import torch.nn.functional as F
10
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
-
12
- from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14
- from diffusers.schedulers import KarrasDiffusionSchedulers
15
- from diffusers.utils import (
16
- PIL_INTERPOLATION,
17
- is_accelerate_available,
18
- is_accelerate_version,
19
- randn_tensor,
20
- replace_example_docstring,
21
- )
22
-
23
-
24
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
-
26
- EXAMPLE_DOC_STRING = """
27
- Examples:
28
- ```py
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from PIL import Image
32
- >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
33
-
34
- >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
35
- >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
36
- >>> from diffusers.utils import load_image
37
-
38
- >>> def ade_palette():
39
- return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
40
- [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
41
- [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
42
- [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
43
- [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
44
- [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
45
- [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
46
- [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
47
- [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
48
- [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
49
- [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
50
- [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51
- [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
52
- [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
53
- [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
54
- [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
55
- [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
56
- [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
57
- [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
58
- [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
59
- [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
60
- [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
61
- [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
62
- [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
63
- [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
64
- [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
65
- [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
66
- [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
67
- [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
68
- [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
69
- [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
70
- [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
71
- [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
72
- [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
73
- [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
74
- [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
75
- [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
76
- [102, 255, 0], [92, 0, 255]]
77
-
78
- >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
- >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
-
81
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
82
-
83
- >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
84
- "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
85
- )
86
-
87
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
88
- >>> pipe.enable_xformers_memory_efficient_attention()
89
- >>> pipe.enable_model_cpu_offload()
90
-
91
- >>> def image_to_seg(image):
92
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
93
- with torch.no_grad():
94
- outputs = image_segmentor(pixel_values)
95
- seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
96
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
97
- palette = np.array(ade_palette())
98
- for label, color in enumerate(palette):
99
- color_seg[seg == label, :] = color
100
- color_seg = color_seg.astype(np.uint8)
101
- seg_image = Image.fromarray(color_seg)
102
- return seg_image
103
-
104
- >>> image = load_image(
105
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
106
- )
107
-
108
- >>> mask_image = load_image(
109
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
110
- )
111
-
112
- >>> controlnet_conditioning_image = image_to_seg(image)
113
-
114
- >>> image = pipe(
115
- "Face of a yellow cat, high resolution, sitting on a park bench",
116
- image,
117
- mask_image,
118
- controlnet_conditioning_image,
119
- num_inference_steps=20,
120
- ).images[0]
121
-
122
- >>> image.save("out.png")
123
- ```
124
- """
125
-
126
-
127
- def prepare_image(image):
128
- if isinstance(image, torch.Tensor):
129
- # Batch single image
130
- if image.ndim == 3:
131
- image = image.unsqueeze(0)
132
-
133
- image = image.to(dtype=torch.float32)
134
- else:
135
- # preprocess image
136
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
137
- image = [image]
138
-
139
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
140
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
141
- image = np.concatenate(image, axis=0)
142
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
143
- image = np.concatenate([i[None, :] for i in image], axis=0)
144
-
145
- image = image.transpose(0, 3, 1, 2)
146
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
147
-
148
- return image
149
-
150
-
151
- def prepare_mask_image(mask_image):
152
- if isinstance(mask_image, torch.Tensor):
153
- if mask_image.ndim == 2:
154
- # Batch and add channel dim for single mask
155
- mask_image = mask_image.unsqueeze(0).unsqueeze(0)
156
- elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
157
- # Single mask, the 0'th dimension is considered to be
158
- # the existing batch size of 1
159
- mask_image = mask_image.unsqueeze(0)
160
- elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
161
- # Batch of mask, the 0'th dimension is considered to be
162
- # the batching dimension
163
- mask_image = mask_image.unsqueeze(1)
164
-
165
- # Binarize mask
166
- mask_image[mask_image < 0.5] = 0
167
- mask_image[mask_image >= 0.5] = 1
168
- else:
169
- # preprocess mask
170
- if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
171
- mask_image = [mask_image]
172
-
173
- if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
174
- mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
175
- mask_image = mask_image.astype(np.float32) / 255.0
176
- elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
177
- mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
178
-
179
- mask_image[mask_image < 0.5] = 0
180
- mask_image[mask_image >= 0.5] = 1
181
- mask_image = torch.from_numpy(mask_image)
182
-
183
- return mask_image
184
-
185
-
186
- def prepare_controlnet_conditioning_image(
187
- controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188
- ):
189
- if not isinstance(controlnet_conditioning_image, torch.Tensor):
190
- if isinstance(controlnet_conditioning_image, PIL.Image.Image):
191
- controlnet_conditioning_image = [controlnet_conditioning_image]
192
-
193
- if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
194
- controlnet_conditioning_image = [
195
- np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
196
- for i in controlnet_conditioning_image
197
- ]
198
- controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
199
- controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
200
- controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
201
- controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
202
- elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
203
- controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
204
-
205
- image_batch_size = controlnet_conditioning_image.shape[0]
206
-
207
- if image_batch_size == 1:
208
- repeat_by = batch_size
209
- else:
210
- # image batch size is the same as prompt batch size
211
- repeat_by = num_images_per_prompt
212
-
213
- controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
214
-
215
- controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216
-
217
- return controlnet_conditioning_image
218
-
219
-
220
- class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
221
- """
222
- Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
223
- """
224
-
225
- _optional_components = ["safety_checker", "feature_extractor"]
226
-
227
- def __init__(
228
- self,
229
- vae: AutoencoderKL,
230
- text_encoder: CLIPTextModel,
231
- tokenizer: CLIPTokenizer,
232
- unet: UNet2DConditionModel,
233
- controlnet: ControlNetModel,
234
- scheduler: KarrasDiffusionSchedulers,
235
- safety_checker: StableDiffusionSafetyChecker,
236
- feature_extractor: CLIPImageProcessor,
237
- requires_safety_checker: bool = True,
238
- ):
239
- super().__init__()
240
-
241
- if safety_checker is None and requires_safety_checker:
242
- logger.warning(
243
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
244
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
245
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
246
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
247
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
248
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
249
- )
250
-
251
- if safety_checker is not None and feature_extractor is None:
252
- raise ValueError(
253
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
254
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255
- )
256
-
257
- self.register_modules(
258
- vae=vae,
259
- text_encoder=text_encoder,
260
- tokenizer=tokenizer,
261
- unet=unet,
262
- controlnet=controlnet,
263
- scheduler=scheduler,
264
- safety_checker=safety_checker,
265
- feature_extractor=feature_extractor,
266
- )
267
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268
- self.register_to_config(requires_safety_checker=requires_safety_checker)
269
-
270
- def enable_vae_slicing(self):
271
- r"""
272
- Enable sliced VAE decoding.
273
-
274
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
275
- steps. This is useful to save some memory and allow larger batch sizes.
276
- """
277
- self.vae.enable_slicing()
278
-
279
- def disable_vae_slicing(self):
280
- r"""
281
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
282
- computing decoding in one step.
283
- """
284
- self.vae.disable_slicing()
285
-
286
- def enable_sequential_cpu_offload(self, gpu_id=0):
287
- r"""
288
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
289
- text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
290
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
291
- Note that offloading happens on a submodule basis. Memory savings are higher than with
292
- `enable_model_cpu_offload`, but performance is lower.
293
- """
294
- if is_accelerate_available():
295
- from accelerate import cpu_offload
296
- else:
297
- raise ImportError("Please install accelerate via `pip install accelerate`")
298
-
299
- device = torch.device(f"cuda:{gpu_id}")
300
-
301
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
302
- cpu_offload(cpu_offloaded_model, device)
303
-
304
- if self.safety_checker is not None:
305
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
306
-
307
- def enable_model_cpu_offload(self, gpu_id=0):
308
- r"""
309
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
310
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
311
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
312
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
313
- """
314
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
315
- from accelerate import cpu_offload_with_hook
316
- else:
317
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
318
-
319
- device = torch.device(f"cuda:{gpu_id}")
320
-
321
- hook = None
322
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
323
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
324
-
325
- if self.safety_checker is not None:
326
- # the safety checker can offload the vae again
327
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
328
-
329
- # control net hook has be manually offloaded as it alternates with unet
330
- cpu_offload_with_hook(self.controlnet, device)
331
-
332
- # We'll offload the last model manually.
333
- self.final_offload_hook = hook
334
-
335
- @property
336
- def _execution_device(self):
337
- r"""
338
- Returns the device on which the pipeline's models will be executed. After calling
339
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
340
- hooks.
341
- """
342
- if not hasattr(self.unet, "_hf_hook"):
343
- return self.device
344
- for module in self.unet.modules():
345
- if (
346
- hasattr(module, "_hf_hook")
347
- and hasattr(module._hf_hook, "execution_device")
348
- and module._hf_hook.execution_device is not None
349
- ):
350
- return torch.device(module._hf_hook.execution_device)
351
- return self.device
352
-
353
- def _encode_prompt(
354
- self,
355
- prompt,
356
- device,
357
- num_images_per_prompt,
358
- do_classifier_free_guidance,
359
- negative_prompt=None,
360
- prompt_embeds: Optional[torch.FloatTensor] = None,
361
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
362
- ):
363
- r"""
364
- Encodes the prompt into text encoder hidden states.
365
-
366
- Args:
367
- prompt (`str` or `List[str]`, *optional*):
368
- prompt to be encoded
369
- device: (`torch.device`):
370
- torch device
371
- num_images_per_prompt (`int`):
372
- number of images that should be generated per prompt
373
- do_classifier_free_guidance (`bool`):
374
- whether to use classifier free guidance or not
375
- negative_prompt (`str` or `List[str]`, *optional*):
376
- The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
377
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
378
- prompt_embeds (`torch.FloatTensor`, *optional*):
379
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
380
- provided, text embeddings will be generated from `prompt` input argument.
381
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
382
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
383
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
384
- argument.
385
- """
386
- if prompt is not None and isinstance(prompt, str):
387
- batch_size = 1
388
- elif prompt is not None and isinstance(prompt, list):
389
- batch_size = len(prompt)
390
- else:
391
- batch_size = prompt_embeds.shape[0]
392
-
393
- if prompt_embeds is None:
394
- text_inputs = self.tokenizer(
395
- prompt,
396
- padding="max_length",
397
- max_length=self.tokenizer.model_max_length,
398
- truncation=True,
399
- return_tensors="pt",
400
- )
401
- text_input_ids = text_inputs.input_ids
402
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
403
-
404
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
405
- text_input_ids, untruncated_ids
406
- ):
407
- removed_text = self.tokenizer.batch_decode(
408
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
409
- )
410
- logger.warning(
411
- "The following part of your input was truncated because CLIP can only handle sequences up to"
412
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
413
- )
414
-
415
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
416
- attention_mask = text_inputs.attention_mask.to(device)
417
- else:
418
- attention_mask = None
419
-
420
- prompt_embeds = self.text_encoder(
421
- text_input_ids.to(device),
422
- attention_mask=attention_mask,
423
- )
424
- prompt_embeds = prompt_embeds[0]
425
-
426
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
427
-
428
- bs_embed, seq_len, _ = prompt_embeds.shape
429
- # duplicate text embeddings for each generation per prompt, using mps friendly method
430
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
431
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
432
-
433
- # get unconditional embeddings for classifier free guidance
434
- if do_classifier_free_guidance and negative_prompt_embeds is None:
435
- uncond_tokens: List[str]
436
- if negative_prompt is None:
437
- uncond_tokens = [""] * batch_size
438
- elif type(prompt) is not type(negative_prompt):
439
- raise TypeError(
440
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
441
- f" {type(prompt)}."
442
- )
443
- elif isinstance(negative_prompt, str):
444
- uncond_tokens = [negative_prompt]
445
- elif batch_size != len(negative_prompt):
446
- raise ValueError(
447
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
448
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
449
- " the batch size of `prompt`."
450
- )
451
- else:
452
- uncond_tokens = negative_prompt
453
-
454
- max_length = prompt_embeds.shape[1]
455
- uncond_input = self.tokenizer(
456
- uncond_tokens,
457
- padding="max_length",
458
- max_length=max_length,
459
- truncation=True,
460
- return_tensors="pt",
461
- )
462
-
463
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
464
- attention_mask = uncond_input.attention_mask.to(device)
465
- else:
466
- attention_mask = None
467
-
468
- negative_prompt_embeds = self.text_encoder(
469
- uncond_input.input_ids.to(device),
470
- attention_mask=attention_mask,
471
- )
472
- negative_prompt_embeds = negative_prompt_embeds[0]
473
-
474
- if do_classifier_free_guidance:
475
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
476
- seq_len = negative_prompt_embeds.shape[1]
477
-
478
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
479
-
480
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
481
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
482
-
483
- # For classifier free guidance, we need to do two forward passes.
484
- # Here we concatenate the unconditional and text embeddings into a single batch
485
- # to avoid doing two forward passes
486
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
487
-
488
- return prompt_embeds
489
-
490
- def run_safety_checker(self, image, device, dtype):
491
- if self.safety_checker is not None:
492
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
493
- image, has_nsfw_concept = self.safety_checker(
494
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
495
- )
496
- else:
497
- has_nsfw_concept = None
498
- return image, has_nsfw_concept
499
-
500
- def decode_latents(self, latents):
501
- latents = 1 / self.vae.config.scaling_factor * latents
502
- image = self.vae.decode(latents).sample
503
- image = (image / 2 + 0.5).clamp(0, 1)
504
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
505
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
506
- return image
507
-
508
- def prepare_extra_step_kwargs(self, generator, eta):
509
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
510
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
511
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
512
- # and should be between [0, 1]
513
-
514
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
515
- extra_step_kwargs = {}
516
- if accepts_eta:
517
- extra_step_kwargs["eta"] = eta
518
-
519
- # check if the scheduler accepts generator
520
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
521
- if accepts_generator:
522
- extra_step_kwargs["generator"] = generator
523
- return extra_step_kwargs
524
-
525
- def check_inputs(
526
- self,
527
- prompt,
528
- image,
529
- mask_image,
530
- controlnet_conditioning_image,
531
- height,
532
- width,
533
- callback_steps,
534
- negative_prompt=None,
535
- prompt_embeds=None,
536
- negative_prompt_embeds=None,
537
- strength=None,
538
- ):
539
- if height % 8 != 0 or width % 8 != 0:
540
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
541
-
542
- if (callback_steps is None) or (
543
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
544
- ):
545
- raise ValueError(
546
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
547
- f" {type(callback_steps)}."
548
- )
549
-
550
- if prompt is not None and prompt_embeds is not None:
551
- raise ValueError(
552
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
553
- " only forward one of the two."
554
- )
555
- elif prompt is None and prompt_embeds is None:
556
- raise ValueError(
557
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
558
- )
559
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
560
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
561
-
562
- if negative_prompt is not None and negative_prompt_embeds is not None:
563
- raise ValueError(
564
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
565
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
566
- )
567
-
568
- if prompt_embeds is not None and negative_prompt_embeds is not None:
569
- if prompt_embeds.shape != negative_prompt_embeds.shape:
570
- raise ValueError(
571
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
572
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
573
- f" {negative_prompt_embeds.shape}."
574
- )
575
-
576
- controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
577
- controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
578
- controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
579
- controlnet_conditioning_image[0], PIL.Image.Image
580
- )
581
- controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
582
- controlnet_conditioning_image[0], torch.Tensor
583
- )
584
-
585
- if (
586
- not controlnet_cond_image_is_pil
587
- and not controlnet_cond_image_is_tensor
588
- and not controlnet_cond_image_is_pil_list
589
- and not controlnet_cond_image_is_tensor_list
590
- ):
591
- raise TypeError(
592
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
593
- )
594
-
595
- if controlnet_cond_image_is_pil:
596
- controlnet_cond_image_batch_size = 1
597
- elif controlnet_cond_image_is_tensor:
598
- controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
599
- elif controlnet_cond_image_is_pil_list:
600
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
601
- elif controlnet_cond_image_is_tensor_list:
602
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
603
-
604
- if prompt is not None and isinstance(prompt, str):
605
- prompt_batch_size = 1
606
- elif prompt is not None and isinstance(prompt, list):
607
- prompt_batch_size = len(prompt)
608
- elif prompt_embeds is not None:
609
- prompt_batch_size = prompt_embeds.shape[0]
610
-
611
- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
612
- raise ValueError(
613
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
614
- )
615
-
616
- if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
617
- raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
618
-
619
- if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
620
- raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
621
-
622
- if isinstance(image, torch.Tensor):
623
- if image.ndim != 3 and image.ndim != 4:
624
- raise ValueError("`image` must have 3 or 4 dimensions")
625
-
626
- if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
627
- raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
628
-
629
- if image.ndim == 3:
630
- image_batch_size = 1
631
- image_channels, image_height, image_width = image.shape
632
- elif image.ndim == 4:
633
- image_batch_size, image_channels, image_height, image_width = image.shape
634
-
635
- if mask_image.ndim == 2:
636
- mask_image_batch_size = 1
637
- mask_image_channels = 1
638
- mask_image_height, mask_image_width = mask_image.shape
639
- elif mask_image.ndim == 3:
640
- mask_image_channels = 1
641
- mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
642
- elif mask_image.ndim == 4:
643
- mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
644
-
645
- if image_channels != 3:
646
- raise ValueError("`image` must have 3 channels")
647
-
648
- if mask_image_channels != 1:
649
- raise ValueError("`mask_image` must have 1 channel")
650
-
651
- if image_batch_size != mask_image_batch_size:
652
- raise ValueError("`image` and `mask_image` mush have the same batch sizes")
653
-
654
- if image_height != mask_image_height or image_width != mask_image_width:
655
- raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
656
-
657
- if image.min() < -1 or image.max() > 1:
658
- raise ValueError("`image` should be in range [-1, 1]")
659
-
660
- if mask_image.min() < 0 or mask_image.max() > 1:
661
- raise ValueError("`mask_image` should be in range [0, 1]")
662
- else:
663
- mask_image_channels = 1
664
- image_channels = 3
665
-
666
- single_image_latent_channels = self.vae.config.latent_channels
667
-
668
- total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
669
-
670
- if total_latent_channels != self.unet.config.in_channels:
671
- raise ValueError(
672
- f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
673
- f" non inpainting latent channels: {single_image_latent_channels},"
674
- f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
675
- f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
676
- )
677
-
678
- if strength < 0 or strength > 1:
679
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
680
-
681
- def get_timesteps(self, num_inference_steps, strength, device):
682
- # get the original timestep using init_timestep
683
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
684
-
685
- t_start = max(num_inference_steps - init_timestep, 0)
686
- timesteps = self.scheduler.timesteps[t_start:]
687
-
688
- return timesteps, num_inference_steps - t_start
689
-
690
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
691
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
692
- raise ValueError(
693
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
694
- )
695
-
696
- image = image.to(device=device, dtype=dtype)
697
-
698
- batch_size = batch_size * num_images_per_prompt
699
- if isinstance(generator, list) and len(generator) != batch_size:
700
- raise ValueError(
701
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
702
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
703
- )
704
-
705
- if isinstance(generator, list):
706
- init_latents = [
707
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
708
- ]
709
- init_latents = torch.cat(init_latents, dim=0)
710
- else:
711
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
712
-
713
- init_latents = self.vae.config.scaling_factor * init_latents
714
-
715
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
716
- raise ValueError(
717
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
718
- )
719
- else:
720
- init_latents = torch.cat([init_latents], dim=0)
721
-
722
- shape = init_latents.shape
723
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
724
-
725
- # get latents
726
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
727
- latents = init_latents
728
-
729
- return latents
730
-
731
- def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
732
- # resize the mask to latents shape as we concatenate the mask to the latents
733
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
734
- # and half precision
735
- mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
736
- mask_image = mask_image.to(device=device, dtype=dtype)
737
-
738
- # duplicate mask for each generation per prompt, using mps friendly method
739
- if mask_image.shape[0] < batch_size:
740
- if not batch_size % mask_image.shape[0] == 0:
741
- raise ValueError(
742
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
743
- f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
744
- " of masks that you pass is divisible by the total requested batch size."
745
- )
746
- mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
747
-
748
- mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
749
-
750
- mask_image_latents = mask_image
751
-
752
- return mask_image_latents
753
-
754
- def prepare_masked_image_latents(
755
- self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
756
- ):
757
- masked_image = masked_image.to(device=device, dtype=dtype)
758
-
759
- # encode the mask image into latents space so we can concatenate it to the latents
760
- if isinstance(generator, list):
761
- masked_image_latents = [
762
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
763
- for i in range(batch_size)
764
- ]
765
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
766
- else:
767
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
768
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
769
-
770
- # duplicate masked_image_latents for each generation per prompt, using mps friendly method
771
- if masked_image_latents.shape[0] < batch_size:
772
- if not batch_size % masked_image_latents.shape[0] == 0:
773
- raise ValueError(
774
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
775
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
776
- " Make sure the number of images that you pass is divisible by the total requested batch size."
777
- )
778
- masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
779
-
780
- masked_image_latents = (
781
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
782
- )
783
-
784
- # aligning device to prevent device errors when concating it with the latent model input
785
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
786
- return masked_image_latents
787
-
788
- def _default_height_width(self, height, width, image):
789
- if isinstance(image, list):
790
- image = image[0]
791
-
792
- if height is None:
793
- if isinstance(image, PIL.Image.Image):
794
- height = image.height
795
- elif isinstance(image, torch.Tensor):
796
- height = image.shape[3]
797
-
798
- height = (height // 8) * 8 # round down to nearest multiple of 8
799
-
800
- if width is None:
801
- if isinstance(image, PIL.Image.Image):
802
- width = image.width
803
- elif isinstance(image, torch.Tensor):
804
- width = image.shape[2]
805
-
806
- width = (width // 8) * 8 # round down to nearest multiple of 8
807
-
808
- return height, width
809
-
810
- @torch.no_grad()
811
- @replace_example_docstring(EXAMPLE_DOC_STRING)
812
- def __call__(
813
- self,
814
- prompt: Union[str, List[str]] = None,
815
- image: Union[torch.Tensor, PIL.Image.Image] = None,
816
- mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
817
- controlnet_conditioning_image: Union[
818
- torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
819
- ] = None,
820
- strength: float = 0.8,
821
- height: Optional[int] = None,
822
- width: Optional[int] = None,
823
- num_inference_steps: int = 50,
824
- guidance_scale: float = 7.5,
825
- negative_prompt: Optional[Union[str, List[str]]] = None,
826
- num_images_per_prompt: Optional[int] = 1,
827
- eta: float = 0.0,
828
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
829
- latents: Optional[torch.FloatTensor] = None,
830
- prompt_embeds: Optional[torch.FloatTensor] = None,
831
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
832
- output_type: Optional[str] = "pil",
833
- return_dict: bool = True,
834
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
835
- callback_steps: int = 1,
836
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
837
- controlnet_conditioning_scale: float = 1.0,
838
- ):
839
- r"""
840
- Function invoked when calling the pipeline for generation.
841
-
842
- Args:
843
- prompt (`str` or `List[str]`, *optional*):
844
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
845
- instead.
846
- image (`torch.Tensor` or `PIL.Image.Image`):
847
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
848
- be masked out with `mask_image` and repainted according to `prompt`.
849
- mask_image (`torch.Tensor` or `PIL.Image.Image`):
850
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
851
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
852
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
853
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
854
- controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
855
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
856
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
857
- also be accepted as an image. The control image is automatically resized to fit the output image.
858
- strength (`float`, *optional*):
859
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
860
- will be used as a starting point, adding more noise to it the larger the `strength`. The number of
861
- denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
862
- be maximum and the denoising process will run for the full number of iterations specified in
863
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
864
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
865
- The height in pixels of the generated image.
866
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
867
- The width in pixels of the generated image.
868
- num_inference_steps (`int`, *optional*, defaults to 50):
869
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
870
- expense of slower inference.
871
- guidance_scale (`float`, *optional*, defaults to 7.5):
872
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
873
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
874
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
875
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
876
- usually at the expense of lower image quality.
877
- negative_prompt (`str` or `List[str]`, *optional*):
878
- The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
879
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
880
- num_images_per_prompt (`int`, *optional*, defaults to 1):
881
- The number of images to generate per prompt.
882
- eta (`float`, *optional*, defaults to 0.0):
883
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
884
- [`schedulers.DDIMScheduler`], will be ignored for others.
885
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
886
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
887
- to make generation deterministic.
888
- latents (`torch.FloatTensor`, *optional*):
889
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
890
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
891
- tensor will ge generated by sampling using the supplied random `generator`.
892
- prompt_embeds (`torch.FloatTensor`, *optional*):
893
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
894
- provided, text embeddings will be generated from `prompt` input argument.
895
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
896
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
897
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
898
- argument.
899
- output_type (`str`, *optional*, defaults to `"pil"`):
900
- The output format of the generate image. Choose between
901
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
902
- return_dict (`bool`, *optional*, defaults to `True`):
903
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
904
- plain tuple.
905
- callback (`Callable`, *optional*):
906
- A function that will be called every `callback_steps` steps during inference. The function will be
907
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
908
- callback_steps (`int`, *optional*, defaults to 1):
909
- The frequency at which the `callback` function will be called. If not specified, the callback will be
910
- called at every step.
911
- cross_attention_kwargs (`dict`, *optional*):
912
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
913
- `self.processor` in
914
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
915
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
916
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
917
- to the residual in the original unet.
918
-
919
- Examples:
920
-
921
- Returns:
922
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
923
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
924
- When returning a tuple, the first element is a list with the generated images, and the second element is a
925
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
926
- (nsfw) content, according to the `safety_checker`.
927
- """
928
- # 0. Default height and width to unet
929
- height, width = self._default_height_width(height, width, controlnet_conditioning_image)
930
-
931
- # 1. Check inputs. Raise error if not correct
932
- self.check_inputs(
933
- prompt,
934
- image,
935
- mask_image,
936
- controlnet_conditioning_image,
937
- height,
938
- width,
939
- callback_steps,
940
- negative_prompt,
941
- prompt_embeds,
942
- negative_prompt_embeds,
943
- strength,
944
- )
945
-
946
- # 2. Define call parameters
947
- if prompt is not None and isinstance(prompt, str):
948
- batch_size = 1
949
- elif prompt is not None and isinstance(prompt, list):
950
- batch_size = len(prompt)
951
- else:
952
- batch_size = prompt_embeds.shape[0]
953
-
954
- device = self._execution_device
955
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
956
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
957
- # corresponds to doing no classifier free guidance.
958
- do_classifier_free_guidance = guidance_scale > 1.0
959
-
960
- # 3. Encode input prompt
961
- prompt_embeds = self._encode_prompt(
962
- prompt,
963
- device,
964
- num_images_per_prompt,
965
- do_classifier_free_guidance,
966
- negative_prompt,
967
- prompt_embeds=prompt_embeds,
968
- negative_prompt_embeds=negative_prompt_embeds,
969
- )
970
-
971
- # 4. Prepare mask, image, and controlnet_conditioning_image
972
- image = prepare_image(image)
973
-
974
- mask_image = prepare_mask_image(mask_image)
975
-
976
- controlnet_conditioning_image = prepare_controlnet_conditioning_image(
977
- controlnet_conditioning_image,
978
- width,
979
- height,
980
- batch_size * num_images_per_prompt,
981
- num_images_per_prompt,
982
- device,
983
- self.controlnet.dtype,
984
- )
985
-
986
- masked_image = image * (mask_image < 0.5)
987
-
988
- # 5. Prepare timesteps
989
- self.scheduler.set_timesteps(num_inference_steps, device=device)
990
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
991
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
992
-
993
- # 6. Prepare latent variables
994
- latents = self.prepare_latents(
995
- image,
996
- latent_timestep,
997
- batch_size,
998
- num_images_per_prompt,
999
- prompt_embeds.dtype,
1000
- device,
1001
- generator,
1002
- )
1003
-
1004
- mask_image_latents = self.prepare_mask_latents(
1005
- mask_image,
1006
- batch_size * num_images_per_prompt,
1007
- height,
1008
- width,
1009
- prompt_embeds.dtype,
1010
- device,
1011
- do_classifier_free_guidance,
1012
- )
1013
-
1014
- masked_image_latents = self.prepare_masked_image_latents(
1015
- masked_image,
1016
- batch_size * num_images_per_prompt,
1017
- height,
1018
- width,
1019
- prompt_embeds.dtype,
1020
- device,
1021
- generator,
1022
- do_classifier_free_guidance,
1023
- )
1024
-
1025
- if do_classifier_free_guidance:
1026
- controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
1027
-
1028
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1029
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1030
-
1031
- # 8. Denoising loop
1032
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1033
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1034
- for i, t in enumerate(timesteps):
1035
- # expand the latents if we are doing classifier free guidance
1036
- non_inpainting_latent_model_input = (
1037
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1038
- )
1039
-
1040
- non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1041
- non_inpainting_latent_model_input, t
1042
- )
1043
-
1044
- inpainting_latent_model_input = torch.cat(
1045
- [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1046
- )
1047
-
1048
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1049
- non_inpainting_latent_model_input,
1050
- t,
1051
- encoder_hidden_states=prompt_embeds,
1052
- controlnet_cond=controlnet_conditioning_image,
1053
- return_dict=False,
1054
- )
1055
-
1056
- down_block_res_samples = [
1057
- down_block_res_sample * controlnet_conditioning_scale
1058
- for down_block_res_sample in down_block_res_samples
1059
- ]
1060
- mid_block_res_sample *= controlnet_conditioning_scale
1061
-
1062
- # predict the noise residual
1063
- noise_pred = self.unet(
1064
- inpainting_latent_model_input,
1065
- t,
1066
- encoder_hidden_states=prompt_embeds,
1067
- cross_attention_kwargs=cross_attention_kwargs,
1068
- down_block_additional_residuals=down_block_res_samples,
1069
- mid_block_additional_residual=mid_block_res_sample,
1070
- ).sample
1071
-
1072
- # perform guidance
1073
- if do_classifier_free_guidance:
1074
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1075
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1076
-
1077
- # compute the previous noisy sample x_t -> x_t-1
1078
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1079
-
1080
- # call the callback, if provided
1081
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1082
- progress_bar.update()
1083
- if callback is not None and i % callback_steps == 0:
1084
- callback(i, t, latents)
1085
-
1086
- # If we do sequential model offloading, let's offload unet and controlnet
1087
- # manually for max memory savings
1088
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1089
- self.unet.to("cpu")
1090
- self.controlnet.to("cpu")
1091
- torch.cuda.empty_cache()
1092
-
1093
- if output_type == "latent":
1094
- image = latents
1095
- has_nsfw_concept = None
1096
- elif output_type == "pil":
1097
- # 8. Post-processing
1098
- image = self.decode_latents(latents)
1099
-
1100
- # 9. Run safety checker
1101
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1102
-
1103
- # 10. Convert to PIL
1104
- image = self.numpy_to_pil(image)
1105
- else:
1106
- # 8. Post-processing
1107
- image = self.decode_latents(latents)
1108
-
1109
- # 9. Run safety checker
1110
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1111
-
1112
- # Offload last model to CPU
1113
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1114
- self.final_offload_hook.offload()
1115
-
1116
- if not return_dict:
1117
- return (image, has_nsfw_concept)
1118
-
1119
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_diffusion_mega.py DELETED
@@ -1,227 +0,0 @@
1
- from typing import Any, Callable, Dict, List, Optional, Union
2
-
3
- import PIL.Image
4
- import torch
5
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6
-
7
- from diffusers import (
8
- AutoencoderKL,
9
- DDIMScheduler,
10
- DiffusionPipeline,
11
- LMSDiscreteScheduler,
12
- PNDMScheduler,
13
- StableDiffusionImg2ImgPipeline,
14
- StableDiffusionInpaintPipelineLegacy,
15
- StableDiffusionPipeline,
16
- UNet2DConditionModel,
17
- )
18
- from diffusers.configuration_utils import FrozenDict
19
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
- from diffusers.utils import deprecate, logging
21
-
22
-
23
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
-
25
-
26
- class StableDiffusionMegaPipeline(DiffusionPipeline):
27
- r"""
28
- Pipeline for text-to-image generation using Stable Diffusion.
29
-
30
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
31
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
32
-
33
- Args:
34
- vae ([`AutoencoderKL`]):
35
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
36
- text_encoder ([`CLIPTextModel`]):
37
- Frozen text-encoder. Stable Diffusion uses the text portion of
38
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
39
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
40
- tokenizer (`CLIPTokenizer`):
41
- Tokenizer of class
42
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
43
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
44
- scheduler ([`SchedulerMixin`]):
45
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
46
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
47
- safety_checker ([`StableDiffusionMegaSafetyChecker`]):
48
- Classification module that estimates whether generated images could be considered offensive or harmful.
49
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
50
- feature_extractor ([`CLIPImageProcessor`]):
51
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
52
- """
53
- _optional_components = ["safety_checker", "feature_extractor"]
54
-
55
- def __init__(
56
- self,
57
- vae: AutoencoderKL,
58
- text_encoder: CLIPTextModel,
59
- tokenizer: CLIPTokenizer,
60
- unet: UNet2DConditionModel,
61
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
62
- safety_checker: StableDiffusionSafetyChecker,
63
- feature_extractor: CLIPImageProcessor,
64
- requires_safety_checker: bool = True,
65
- ):
66
- super().__init__()
67
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
68
- deprecation_message = (
69
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
70
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
71
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
72
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
73
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
74
- " file"
75
- )
76
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
77
- new_config = dict(scheduler.config)
78
- new_config["steps_offset"] = 1
79
- scheduler._internal_dict = FrozenDict(new_config)
80
-
81
- self.register_modules(
82
- vae=vae,
83
- text_encoder=text_encoder,
84
- tokenizer=tokenizer,
85
- unet=unet,
86
- scheduler=scheduler,
87
- safety_checker=safety_checker,
88
- feature_extractor=feature_extractor,
89
- )
90
- self.register_to_config(requires_safety_checker=requires_safety_checker)
91
-
92
- @property
93
- def components(self) -> Dict[str, Any]:
94
- return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
95
-
96
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
97
- r"""
98
- Enable sliced attention computation.
99
-
100
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
101
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
102
-
103
- Args:
104
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
105
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
106
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
107
- `attention_head_dim` must be a multiple of `slice_size`.
108
- """
109
- if slice_size == "auto":
110
- # half the attention head size is usually a good trade-off between
111
- # speed and memory
112
- slice_size = self.unet.config.attention_head_dim // 2
113
- self.unet.set_attention_slice(slice_size)
114
-
115
- def disable_attention_slicing(self):
116
- r"""
117
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
118
- back to computing attention in one step.
119
- """
120
- # set slice_size = `None` to disable `attention slicing`
121
- self.enable_attention_slicing(None)
122
-
123
- @torch.no_grad()
124
- def inpaint(
125
- self,
126
- prompt: Union[str, List[str]],
127
- image: Union[torch.FloatTensor, PIL.Image.Image],
128
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
129
- strength: float = 0.8,
130
- num_inference_steps: Optional[int] = 50,
131
- guidance_scale: Optional[float] = 7.5,
132
- negative_prompt: Optional[Union[str, List[str]]] = None,
133
- num_images_per_prompt: Optional[int] = 1,
134
- eta: Optional[float] = 0.0,
135
- generator: Optional[torch.Generator] = None,
136
- output_type: Optional[str] = "pil",
137
- return_dict: bool = True,
138
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
139
- callback_steps: int = 1,
140
- ):
141
- # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
142
- return StableDiffusionInpaintPipelineLegacy(**self.components)(
143
- prompt=prompt,
144
- image=image,
145
- mask_image=mask_image,
146
- strength=strength,
147
- num_inference_steps=num_inference_steps,
148
- guidance_scale=guidance_scale,
149
- negative_prompt=negative_prompt,
150
- num_images_per_prompt=num_images_per_prompt,
151
- eta=eta,
152
- generator=generator,
153
- output_type=output_type,
154
- return_dict=return_dict,
155
- callback=callback,
156
- )
157
-
158
- @torch.no_grad()
159
- def img2img(
160
- self,
161
- prompt: Union[str, List[str]],
162
- image: Union[torch.FloatTensor, PIL.Image.Image],
163
- strength: float = 0.8,
164
- num_inference_steps: Optional[int] = 50,
165
- guidance_scale: Optional[float] = 7.5,
166
- negative_prompt: Optional[Union[str, List[str]]] = None,
167
- num_images_per_prompt: Optional[int] = 1,
168
- eta: Optional[float] = 0.0,
169
- generator: Optional[torch.Generator] = None,
170
- output_type: Optional[str] = "pil",
171
- return_dict: bool = True,
172
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
173
- callback_steps: int = 1,
174
- **kwargs,
175
- ):
176
- # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
177
- return StableDiffusionImg2ImgPipeline(**self.components)(
178
- prompt=prompt,
179
- image=image,
180
- strength=strength,
181
- num_inference_steps=num_inference_steps,
182
- guidance_scale=guidance_scale,
183
- negative_prompt=negative_prompt,
184
- num_images_per_prompt=num_images_per_prompt,
185
- eta=eta,
186
- generator=generator,
187
- output_type=output_type,
188
- return_dict=return_dict,
189
- callback=callback,
190
- callback_steps=callback_steps,
191
- )
192
-
193
- @torch.no_grad()
194
- def text2img(
195
- self,
196
- prompt: Union[str, List[str]],
197
- height: int = 512,
198
- width: int = 512,
199
- num_inference_steps: int = 50,
200
- guidance_scale: float = 7.5,
201
- negative_prompt: Optional[Union[str, List[str]]] = None,
202
- num_images_per_prompt: Optional[int] = 1,
203
- eta: float = 0.0,
204
- generator: Optional[torch.Generator] = None,
205
- latents: Optional[torch.FloatTensor] = None,
206
- output_type: Optional[str] = "pil",
207
- return_dict: bool = True,
208
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
209
- callback_steps: int = 1,
210
- ):
211
- # For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
212
- return StableDiffusionPipeline(**self.components)(
213
- prompt=prompt,
214
- height=height,
215
- width=width,
216
- num_inference_steps=num_inference_steps,
217
- guidance_scale=guidance_scale,
218
- negative_prompt=negative_prompt,
219
- num_images_per_prompt=num_images_per_prompt,
220
- eta=eta,
221
- generator=generator,
222
- latents=latents,
223
- output_type=output_type,
224
- return_dict=return_dict,
225
- callback=callback,
226
- callback_steps=callback_steps,
227
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/stable_unclip.py DELETED
@@ -1,287 +0,0 @@
1
- import types
2
- from typing import List, Optional, Tuple, Union
3
-
4
- import torch
5
- from transformers import CLIPTextModelWithProjection, CLIPTokenizer
6
- from transformers.models.clip.modeling_clip import CLIPTextModelOutput
7
-
8
- from diffusers.models import PriorTransformer
9
- from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
10
- from diffusers.schedulers import UnCLIPScheduler
11
- from diffusers.utils import logging, randn_tensor
12
-
13
-
14
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15
-
16
-
17
- def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
18
- image = image.to(device=device)
19
- image_embeddings = image # take image as image_embeddings
20
- image_embeddings = image_embeddings.unsqueeze(1)
21
-
22
- # duplicate image embeddings for each generation per prompt, using mps friendly method
23
- bs_embed, seq_len, _ = image_embeddings.shape
24
- image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
25
- image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
26
-
27
- if do_classifier_free_guidance:
28
- uncond_embeddings = torch.zeros_like(image_embeddings)
29
-
30
- # For classifier free guidance, we need to do two forward passes.
31
- # Here we concatenate the unconditional and text embeddings into a single batch
32
- # to avoid doing two forward passes
33
- image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
34
-
35
- return image_embeddings
36
-
37
-
38
- class StableUnCLIPPipeline(DiffusionPipeline):
39
- def __init__(
40
- self,
41
- prior: PriorTransformer,
42
- tokenizer: CLIPTokenizer,
43
- text_encoder: CLIPTextModelWithProjection,
44
- prior_scheduler: UnCLIPScheduler,
45
- decoder_pipe_kwargs: Optional[dict] = None,
46
- ):
47
- super().__init__()
48
-
49
- decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs
50
-
51
- decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype
52
-
53
- self.decoder_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
54
- "lambdalabs/sd-image-variations-diffusers", **decoder_pipe_kwargs
55
- )
56
-
57
- # replace `_encode_image` method
58
- self.decoder_pipe._encode_image = types.MethodType(_encode_image, self.decoder_pipe)
59
-
60
- self.register_modules(
61
- prior=prior,
62
- tokenizer=tokenizer,
63
- text_encoder=text_encoder,
64
- prior_scheduler=prior_scheduler,
65
- )
66
-
67
- def _encode_prompt(
68
- self,
69
- prompt,
70
- device,
71
- num_images_per_prompt,
72
- do_classifier_free_guidance,
73
- text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
74
- text_attention_mask: Optional[torch.Tensor] = None,
75
- ):
76
- if text_model_output is None:
77
- batch_size = len(prompt) if isinstance(prompt, list) else 1
78
- # get prompt text embeddings
79
- text_inputs = self.tokenizer(
80
- prompt,
81
- padding="max_length",
82
- max_length=self.tokenizer.model_max_length,
83
- return_tensors="pt",
84
- )
85
- text_input_ids = text_inputs.input_ids
86
- text_mask = text_inputs.attention_mask.bool().to(device)
87
-
88
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
89
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
90
- logger.warning(
91
- "The following part of your input was truncated because CLIP can only handle sequences up to"
92
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
93
- )
94
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
95
-
96
- text_encoder_output = self.text_encoder(text_input_ids.to(device))
97
-
98
- text_embeddings = text_encoder_output.text_embeds
99
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
100
-
101
- else:
102
- batch_size = text_model_output[0].shape[0]
103
- text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
104
- text_mask = text_attention_mask
105
-
106
- text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
107
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
108
- text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
109
-
110
- if do_classifier_free_guidance:
111
- uncond_tokens = [""] * batch_size
112
-
113
- uncond_input = self.tokenizer(
114
- uncond_tokens,
115
- padding="max_length",
116
- max_length=self.tokenizer.model_max_length,
117
- truncation=True,
118
- return_tensors="pt",
119
- )
120
- uncond_text_mask = uncond_input.attention_mask.bool().to(device)
121
- uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
122
-
123
- uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
124
- uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
125
-
126
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
127
-
128
- seq_len = uncond_embeddings.shape[1]
129
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
130
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
131
-
132
- seq_len = uncond_text_encoder_hidden_states.shape[1]
133
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
134
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
135
- batch_size * num_images_per_prompt, seq_len, -1
136
- )
137
- uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
138
-
139
- # done duplicates
140
-
141
- # For classifier free guidance, we need to do two forward passes.
142
- # Here we concatenate the unconditional and text embeddings into a single batch
143
- # to avoid doing two forward passes
144
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
145
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
146
-
147
- text_mask = torch.cat([uncond_text_mask, text_mask])
148
-
149
- return text_embeddings, text_encoder_hidden_states, text_mask
150
-
151
- @property
152
- def _execution_device(self):
153
- r"""
154
- Returns the device on which the pipeline's models will be executed. After calling
155
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
156
- hooks.
157
- """
158
- if self.device != torch.device("meta") or not hasattr(self.prior, "_hf_hook"):
159
- return self.device
160
- for module in self.prior.modules():
161
- if (
162
- hasattr(module, "_hf_hook")
163
- and hasattr(module._hf_hook, "execution_device")
164
- and module._hf_hook.execution_device is not None
165
- ):
166
- return torch.device(module._hf_hook.execution_device)
167
- return self.device
168
-
169
- def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
170
- if latents is None:
171
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
172
- else:
173
- if latents.shape != shape:
174
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
175
- latents = latents.to(device)
176
-
177
- latents = latents * scheduler.init_noise_sigma
178
- return latents
179
-
180
- def to(self, torch_device: Optional[Union[str, torch.device]] = None):
181
- self.decoder_pipe.to(torch_device)
182
- super().to(torch_device)
183
-
184
- @torch.no_grad()
185
- def __call__(
186
- self,
187
- prompt: Optional[Union[str, List[str]]] = None,
188
- height: Optional[int] = None,
189
- width: Optional[int] = None,
190
- num_images_per_prompt: int = 1,
191
- prior_num_inference_steps: int = 25,
192
- generator: Optional[torch.Generator] = None,
193
- prior_latents: Optional[torch.FloatTensor] = None,
194
- text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
195
- text_attention_mask: Optional[torch.Tensor] = None,
196
- prior_guidance_scale: float = 4.0,
197
- decoder_guidance_scale: float = 8.0,
198
- decoder_num_inference_steps: int = 50,
199
- decoder_num_images_per_prompt: Optional[int] = 1,
200
- decoder_eta: float = 0.0,
201
- output_type: Optional[str] = "pil",
202
- return_dict: bool = True,
203
- ):
204
- if prompt is not None:
205
- if isinstance(prompt, str):
206
- batch_size = 1
207
- elif isinstance(prompt, list):
208
- batch_size = len(prompt)
209
- else:
210
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
211
- else:
212
- batch_size = text_model_output[0].shape[0]
213
-
214
- device = self._execution_device
215
-
216
- batch_size = batch_size * num_images_per_prompt
217
-
218
- do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
219
-
220
- text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
221
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
222
- )
223
-
224
- # prior
225
-
226
- self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
227
- prior_timesteps_tensor = self.prior_scheduler.timesteps
228
-
229
- embedding_dim = self.prior.config.embedding_dim
230
-
231
- prior_latents = self.prepare_latents(
232
- (batch_size, embedding_dim),
233
- text_embeddings.dtype,
234
- device,
235
- generator,
236
- prior_latents,
237
- self.prior_scheduler,
238
- )
239
-
240
- for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
241
- # expand the latents if we are doing classifier free guidance
242
- latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
243
-
244
- predicted_image_embedding = self.prior(
245
- latent_model_input,
246
- timestep=t,
247
- proj_embedding=text_embeddings,
248
- encoder_hidden_states=text_encoder_hidden_states,
249
- attention_mask=text_mask,
250
- ).predicted_image_embedding
251
-
252
- if do_classifier_free_guidance:
253
- predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
254
- predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
255
- predicted_image_embedding_text - predicted_image_embedding_uncond
256
- )
257
-
258
- if i + 1 == prior_timesteps_tensor.shape[0]:
259
- prev_timestep = None
260
- else:
261
- prev_timestep = prior_timesteps_tensor[i + 1]
262
-
263
- prior_latents = self.prior_scheduler.step(
264
- predicted_image_embedding,
265
- timestep=t,
266
- sample=prior_latents,
267
- generator=generator,
268
- prev_timestep=prev_timestep,
269
- ).prev_sample
270
-
271
- prior_latents = self.prior.post_process_latents(prior_latents)
272
-
273
- image_embeddings = prior_latents
274
-
275
- output = self.decoder_pipe(
276
- image=image_embeddings,
277
- height=height,
278
- width=width,
279
- num_inference_steps=decoder_num_inference_steps,
280
- guidance_scale=decoder_guidance_scale,
281
- generator=generator,
282
- output_type=output_type,
283
- return_dict=return_dict,
284
- num_images_per_prompt=decoder_num_images_per_prompt,
285
- eta=decoder_eta,
286
- )
287
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/text_inpainting.py DELETED
@@ -1,302 +0,0 @@
1
- from typing import Callable, List, Optional, Union
2
-
3
- import PIL
4
- import torch
5
- from transformers import (
6
- CLIPImageProcessor,
7
- CLIPSegForImageSegmentation,
8
- CLIPSegProcessor,
9
- CLIPTextModel,
10
- CLIPTokenizer,
11
- )
12
-
13
- from diffusers import DiffusionPipeline
14
- from diffusers.configuration_utils import FrozenDict
15
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
- from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
17
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
18
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
19
- from diffusers.utils import deprecate, is_accelerate_available, logging
20
-
21
-
22
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
-
24
-
25
- class TextInpainting(DiffusionPipeline):
26
- r"""
27
- Pipeline for text based inpainting using Stable Diffusion.
28
- Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask
29
-
30
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
31
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
32
-
33
- Args:
34
- segmentation_model ([`CLIPSegForImageSegmentation`]):
35
- CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.
36
- segmentation_processor ([`CLIPSegProcessor`]):
37
- CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the
38
- [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.
39
- vae ([`AutoencoderKL`]):
40
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
41
- text_encoder ([`CLIPTextModel`]):
42
- Frozen text-encoder. Stable Diffusion uses the text portion of
43
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
44
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
45
- tokenizer (`CLIPTokenizer`):
46
- Tokenizer of class
47
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
48
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
49
- scheduler ([`SchedulerMixin`]):
50
- A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
51
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
52
- safety_checker ([`StableDiffusionSafetyChecker`]):
53
- Classification module that estimates whether generated images could be considered offensive or harmful.
54
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
55
- feature_extractor ([`CLIPImageProcessor`]):
56
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
57
- """
58
-
59
- def __init__(
60
- self,
61
- segmentation_model: CLIPSegForImageSegmentation,
62
- segmentation_processor: CLIPSegProcessor,
63
- vae: AutoencoderKL,
64
- text_encoder: CLIPTextModel,
65
- tokenizer: CLIPTokenizer,
66
- unet: UNet2DConditionModel,
67
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
68
- safety_checker: StableDiffusionSafetyChecker,
69
- feature_extractor: CLIPImageProcessor,
70
- ):
71
- super().__init__()
72
-
73
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
74
- deprecation_message = (
75
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
76
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
77
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
78
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
79
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
80
- " file"
81
- )
82
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
83
- new_config = dict(scheduler.config)
84
- new_config["steps_offset"] = 1
85
- scheduler._internal_dict = FrozenDict(new_config)
86
-
87
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
88
- deprecation_message = (
89
- f"The configuration file of this scheduler: {scheduler} has not set the configuration"
90
- " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
91
- " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
92
- " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
93
- " Hub, it would be very nice if you could open a Pull request for the"
94
- " `scheduler/scheduler_config.json` file"
95
- )
96
- deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
97
- new_config = dict(scheduler.config)
98
- new_config["skip_prk_steps"] = True
99
- scheduler._internal_dict = FrozenDict(new_config)
100
-
101
- if safety_checker is None:
102
- logger.warning(
103
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
104
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
105
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
106
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
107
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
108
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
109
- )
110
-
111
- self.register_modules(
112
- segmentation_model=segmentation_model,
113
- segmentation_processor=segmentation_processor,
114
- vae=vae,
115
- text_encoder=text_encoder,
116
- tokenizer=tokenizer,
117
- unet=unet,
118
- scheduler=scheduler,
119
- safety_checker=safety_checker,
120
- feature_extractor=feature_extractor,
121
- )
122
-
123
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
124
- r"""
125
- Enable sliced attention computation.
126
-
127
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
128
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
129
-
130
- Args:
131
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
132
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
133
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
134
- `attention_head_dim` must be a multiple of `slice_size`.
135
- """
136
- if slice_size == "auto":
137
- # half the attention head size is usually a good trade-off between
138
- # speed and memory
139
- slice_size = self.unet.config.attention_head_dim // 2
140
- self.unet.set_attention_slice(slice_size)
141
-
142
- def disable_attention_slicing(self):
143
- r"""
144
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
145
- back to computing attention in one step.
146
- """
147
- # set slice_size = `None` to disable `attention slicing`
148
- self.enable_attention_slicing(None)
149
-
150
- def enable_sequential_cpu_offload(self):
151
- r"""
152
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
153
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
154
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
155
- """
156
- if is_accelerate_available():
157
- from accelerate import cpu_offload
158
- else:
159
- raise ImportError("Please install accelerate via `pip install accelerate`")
160
-
161
- device = torch.device("cuda")
162
-
163
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
164
- if cpu_offloaded_model is not None:
165
- cpu_offload(cpu_offloaded_model, device)
166
-
167
- @property
168
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
169
- def _execution_device(self):
170
- r"""
171
- Returns the device on which the pipeline's models will be executed. After calling
172
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
173
- hooks.
174
- """
175
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
176
- return self.device
177
- for module in self.unet.modules():
178
- if (
179
- hasattr(module, "_hf_hook")
180
- and hasattr(module._hf_hook, "execution_device")
181
- and module._hf_hook.execution_device is not None
182
- ):
183
- return torch.device(module._hf_hook.execution_device)
184
- return self.device
185
-
186
- @torch.no_grad()
187
- def __call__(
188
- self,
189
- prompt: Union[str, List[str]],
190
- image: Union[torch.FloatTensor, PIL.Image.Image],
191
- text: str,
192
- height: int = 512,
193
- width: int = 512,
194
- num_inference_steps: int = 50,
195
- guidance_scale: float = 7.5,
196
- negative_prompt: Optional[Union[str, List[str]]] = None,
197
- num_images_per_prompt: Optional[int] = 1,
198
- eta: float = 0.0,
199
- generator: Optional[torch.Generator] = None,
200
- latents: Optional[torch.FloatTensor] = None,
201
- output_type: Optional[str] = "pil",
202
- return_dict: bool = True,
203
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
204
- callback_steps: int = 1,
205
- **kwargs,
206
- ):
207
- r"""
208
- Function invoked when calling the pipeline for generation.
209
-
210
- Args:
211
- prompt (`str` or `List[str]`):
212
- The prompt or prompts to guide the image generation.
213
- image (`PIL.Image.Image`):
214
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
215
- be masked out with `mask_image` and repainted according to `prompt`.
216
- text (`str``):
217
- The text to use to generate the mask.
218
- height (`int`, *optional*, defaults to 512):
219
- The height in pixels of the generated image.
220
- width (`int`, *optional*, defaults to 512):
221
- The width in pixels of the generated image.
222
- num_inference_steps (`int`, *optional*, defaults to 50):
223
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
224
- expense of slower inference.
225
- guidance_scale (`float`, *optional*, defaults to 7.5):
226
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
227
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
228
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
229
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
230
- usually at the expense of lower image quality.
231
- negative_prompt (`str` or `List[str]`, *optional*):
232
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
233
- if `guidance_scale` is less than `1`).
234
- num_images_per_prompt (`int`, *optional*, defaults to 1):
235
- The number of images to generate per prompt.
236
- eta (`float`, *optional*, defaults to 0.0):
237
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
238
- [`schedulers.DDIMScheduler`], will be ignored for others.
239
- generator (`torch.Generator`, *optional*):
240
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
241
- deterministic.
242
- latents (`torch.FloatTensor`, *optional*):
243
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
244
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
245
- tensor will ge generated by sampling using the supplied random `generator`.
246
- output_type (`str`, *optional*, defaults to `"pil"`):
247
- The output format of the generate image. Choose between
248
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
249
- return_dict (`bool`, *optional*, defaults to `True`):
250
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
251
- plain tuple.
252
- callback (`Callable`, *optional*):
253
- A function that will be called every `callback_steps` steps during inference. The function will be
254
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
255
- callback_steps (`int`, *optional*, defaults to 1):
256
- The frequency at which the `callback` function will be called. If not specified, the callback will be
257
- called at every step.
258
-
259
- Returns:
260
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
261
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
262
- When returning a tuple, the first element is a list with the generated images, and the second element is a
263
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
264
- (nsfw) content, according to the `safety_checker`.
265
- """
266
-
267
- # We use the input text to generate the mask
268
- inputs = self.segmentation_processor(
269
- text=[text], images=[image], padding="max_length", return_tensors="pt"
270
- ).to(self.device)
271
- outputs = self.segmentation_model(**inputs)
272
- mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
273
- mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)
274
-
275
- # Run inpainting pipeline with the generated mask
276
- inpainting_pipeline = StableDiffusionInpaintPipeline(
277
- vae=self.vae,
278
- text_encoder=self.text_encoder,
279
- tokenizer=self.tokenizer,
280
- unet=self.unet,
281
- scheduler=self.scheduler,
282
- safety_checker=self.safety_checker,
283
- feature_extractor=self.feature_extractor,
284
- )
285
- return inpainting_pipeline(
286
- prompt=prompt,
287
- image=image,
288
- mask_image=mask_pil,
289
- height=height,
290
- width=width,
291
- num_inference_steps=num_inference_steps,
292
- guidance_scale=guidance_scale,
293
- negative_prompt=negative_prompt,
294
- num_images_per_prompt=num_images_per_prompt,
295
- eta=eta,
296
- generator=generator,
297
- latents=latents,
298
- output_type=output_type,
299
- return_dict=return_dict,
300
- callback=callback,
301
- callback_steps=callback_steps,
302
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/tiled_upscaling.py DELETED
@@ -1,298 +0,0 @@
1
- # Copyright 2023 Peter Willemsen <[email protected]>. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- from typing import Callable, List, Optional, Union
17
-
18
- import numpy as np
19
- import PIL
20
- import torch
21
- from PIL import Image
22
- from transformers import CLIPTextModel, CLIPTokenizer
23
-
24
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
26
- from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
27
-
28
-
29
- def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
30
- size_x = size[0] - overlap_pixels * 2
31
- size_y = size[1] - overlap_pixels * 2
32
- for letter in ["l", "r"]:
33
- if letter in remove_borders:
34
- size_x += overlap_pixels
35
- for letter in ["t", "b"]:
36
- if letter in remove_borders:
37
- size_y += overlap_pixels
38
- mask = np.ones((size_y, size_x), dtype=np.uint8) * 255
39
- mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)
40
-
41
- if "l" in remove_borders:
42
- mask = mask[:, overlap_pixels : mask.shape[1]]
43
- if "r" in remove_borders:
44
- mask = mask[:, 0 : mask.shape[1] - overlap_pixels]
45
- if "t" in remove_borders:
46
- mask = mask[overlap_pixels : mask.shape[0], :]
47
- if "b" in remove_borders:
48
- mask = mask[0 : mask.shape[0] - overlap_pixels, :]
49
- return mask
50
-
51
-
52
- def clamp(n, smallest, largest):
53
- return max(smallest, min(n, largest))
54
-
55
-
56
- def clamp_rect(rect: [int], min: [int], max: [int]):
57
- return (
58
- clamp(rect[0], min[0], max[0]),
59
- clamp(rect[1], min[1], max[1]),
60
- clamp(rect[2], min[0], max[0]),
61
- clamp(rect[3], min[1], max[1]),
62
- )
63
-
64
-
65
- def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):
66
- rect = list(rect)
67
- rect[0] -= overlap
68
- rect[1] -= overlap
69
- rect[2] += overlap
70
- rect[3] += overlap
71
- rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])
72
- return rect
73
-
74
-
75
- def squeeze_tile(tile, original_image, original_slice, slice_x):
76
- result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))
77
- result.paste(
78
- original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(
79
- (slice_x, 0, slice_x + original_slice, tile.size[1])
80
- ),
81
- (0, 0),
82
- )
83
- result.paste(tile, (original_slice, 0))
84
- return result
85
-
86
-
87
- def unsqueeze_tile(tile, original_image_slice):
88
- crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])
89
- tile = tile.crop(crop_rect)
90
- return tile
91
-
92
-
93
- def next_divisible(n, d):
94
- divisor = n % d
95
- return n - divisor
96
-
97
-
98
- class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
99
- r"""
100
- Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute
101
- to create gigantic images.
102
-
103
- This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the
104
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
105
-
106
- Args:
107
- vae ([`AutoencoderKL`]):
108
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
109
- text_encoder ([`CLIPTextModel`]):
110
- Frozen text-encoder. Stable Diffusion uses the text portion of
111
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
112
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
113
- tokenizer (`CLIPTokenizer`):
114
- Tokenizer of class
115
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
116
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
117
- low_res_scheduler ([`SchedulerMixin`]):
118
- A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
119
- [`DDPMScheduler`].
120
- scheduler ([`SchedulerMixin`]):
121
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
122
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
123
- """
124
-
125
- def __init__(
126
- self,
127
- vae: AutoencoderKL,
128
- text_encoder: CLIPTextModel,
129
- tokenizer: CLIPTokenizer,
130
- unet: UNet2DConditionModel,
131
- low_res_scheduler: DDPMScheduler,
132
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
133
- max_noise_level: int = 350,
134
- ):
135
- super().__init__(
136
- vae=vae,
137
- text_encoder=text_encoder,
138
- tokenizer=tokenizer,
139
- unet=unet,
140
- low_res_scheduler=low_res_scheduler,
141
- scheduler=scheduler,
142
- max_noise_level=max_noise_level,
143
- )
144
-
145
- def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):
146
- torch.manual_seed(0)
147
- crop_rect = (
148
- min(image.size[0] - (tile_size + original_image_slice), x * tile_size),
149
- min(image.size[1] - (tile_size + original_image_slice), y * tile_size),
150
- min(image.size[0], (x + 1) * tile_size),
151
- min(image.size[1], (y + 1) * tile_size),
152
- )
153
- crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)
154
- tile = image.crop(crop_rect_with_overlap)
155
- translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]
156
- translated_slice_x = translated_slice_x - (original_image_slice / 2)
157
- translated_slice_x = max(0, translated_slice_x)
158
- to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)
159
- orig_input_size = to_input.size
160
- to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)
161
- upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]
162
- upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)
163
- upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)
164
- upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)
165
- remove_borders = []
166
- if x == 0:
167
- remove_borders.append("l")
168
- elif crop_rect[2] == image.size[0]:
169
- remove_borders.append("r")
170
- if y == 0:
171
- remove_borders.append("t")
172
- elif crop_rect[3] == image.size[1]:
173
- remove_borders.append("b")
174
- transparency_mask = Image.fromarray(
175
- make_transparency_mask(
176
- (upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders
177
- ),
178
- mode="L",
179
- )
180
- final_image.paste(
181
- upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask
182
- )
183
-
184
- @torch.no_grad()
185
- def __call__(
186
- self,
187
- prompt: Union[str, List[str]],
188
- image: Union[PIL.Image.Image, List[PIL.Image.Image]],
189
- num_inference_steps: int = 75,
190
- guidance_scale: float = 9.0,
191
- noise_level: int = 50,
192
- negative_prompt: Optional[Union[str, List[str]]] = None,
193
- num_images_per_prompt: Optional[int] = 1,
194
- eta: float = 0.0,
195
- generator: Optional[torch.Generator] = None,
196
- latents: Optional[torch.FloatTensor] = None,
197
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
198
- callback_steps: int = 1,
199
- tile_size: int = 128,
200
- tile_border: int = 32,
201
- original_image_slice: int = 32,
202
- ):
203
- r"""
204
- Function invoked when calling the pipeline for generation.
205
-
206
- Args:
207
- prompt (`str` or `List[str]`):
208
- The prompt or prompts to guide the image generation.
209
- image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
210
- `Image`, or tensor representing an image batch which will be upscaled. *
211
- num_inference_steps (`int`, *optional*, defaults to 50):
212
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
213
- expense of slower inference.
214
- guidance_scale (`float`, *optional*, defaults to 7.5):
215
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
216
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
217
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
218
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
219
- usually at the expense of lower image quality.
220
- negative_prompt (`str` or `List[str]`, *optional*):
221
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
222
- if `guidance_scale` is less than `1`).
223
- num_images_per_prompt (`int`, *optional*, defaults to 1):
224
- The number of images to generate per prompt.
225
- eta (`float`, *optional*, defaults to 0.0):
226
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
227
- [`schedulers.DDIMScheduler`], will be ignored for others.
228
- generator (`torch.Generator`, *optional*):
229
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
230
- deterministic.
231
- latents (`torch.FloatTensor`, *optional*):
232
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
233
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
234
- tensor will ge generated by sampling using the supplied random `generator`.
235
- tile_size (`int`, *optional*):
236
- The size of the tiles. Too big can result in an OOM-error.
237
- tile_border (`int`, *optional*):
238
- The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).
239
- original_image_slice (`int`, *optional*):
240
- The amount of pixels of the original image to calculate with the current tile (bigger means more depth
241
- is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).
242
- callback (`Callable`, *optional*):
243
- A function that take a callback function with a single argument, a dict,
244
- that contains the (partially) processed image under "image",
245
- as well as the progress (0 to 1, where 1 is completed) under "progress".
246
-
247
- Returns: A PIL.Image that is 4 times larger than the original input image.
248
-
249
- """
250
-
251
- final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))
252
- tcx = math.ceil(image.size[0] / tile_size)
253
- tcy = math.ceil(image.size[1] / tile_size)
254
- total_tile_count = tcx * tcy
255
- current_count = 0
256
- for y in range(tcy):
257
- for x in range(tcx):
258
- self._process_tile(
259
- original_image_slice,
260
- x,
261
- y,
262
- tile_size,
263
- tile_border,
264
- image,
265
- final_image,
266
- prompt=prompt,
267
- num_inference_steps=num_inference_steps,
268
- guidance_scale=guidance_scale,
269
- noise_level=noise_level,
270
- negative_prompt=negative_prompt,
271
- num_images_per_prompt=num_images_per_prompt,
272
- eta=eta,
273
- generator=generator,
274
- latents=latents,
275
- )
276
- current_count += 1
277
- if callback is not None:
278
- callback({"progress": current_count / total_tile_count, "image": final_image})
279
- return final_image
280
-
281
-
282
- def main():
283
- # Run a demo
284
- model_id = "stabilityai/stable-diffusion-x4-upscaler"
285
- pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
286
- pipe = pipe.to("cuda")
287
- image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
288
-
289
- def callback(obj):
290
- print(f"progress: {obj['progress']:.4f}")
291
- obj["image"].save("diffusers_library_progress.jpg")
292
-
293
- final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)
294
- final_image.save("diffusers_library.jpg")
295
-
296
-
297
- if __name__ == "__main__":
298
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/unclip_image_interpolation.py DELETED
@@ -1,493 +0,0 @@
1
- import inspect
2
- from typing import List, Optional, Union
3
-
4
- import PIL
5
- import torch
6
- from torch.nn import functional as F
7
- from transformers import (
8
- CLIPImageProcessor,
9
- CLIPTextModelWithProjection,
10
- CLIPTokenizer,
11
- CLIPVisionModelWithProjection,
12
- )
13
-
14
- from diffusers import (
15
- DiffusionPipeline,
16
- ImagePipelineOutput,
17
- UnCLIPScheduler,
18
- UNet2DConditionModel,
19
- UNet2DModel,
20
- )
21
- from diffusers.pipelines.unclip import UnCLIPTextProjModel
22
- from diffusers.utils import is_accelerate_available, logging, randn_tensor
23
-
24
-
25
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
-
27
-
28
- def slerp(val, low, high):
29
- """
30
- Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
31
- """
32
- low_norm = low / torch.norm(low)
33
- high_norm = high / torch.norm(high)
34
- omega = torch.acos((low_norm * high_norm))
35
- so = torch.sin(omega)
36
- res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
37
- return res
38
-
39
-
40
- class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
41
- """
42
- Pipeline to generate variations from an input image using unCLIP
43
-
44
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
45
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
46
-
47
- Args:
48
- text_encoder ([`CLIPTextModelWithProjection`]):
49
- Frozen text-encoder.
50
- tokenizer (`CLIPTokenizer`):
51
- Tokenizer of class
52
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
53
- feature_extractor ([`CLIPImageProcessor`]):
54
- Model that extracts features from generated images to be used as inputs for the `image_encoder`.
55
- image_encoder ([`CLIPVisionModelWithProjection`]):
56
- Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
57
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
58
- specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
59
- text_proj ([`UnCLIPTextProjModel`]):
60
- Utility class to prepare and combine the embeddings before they are passed to the decoder.
61
- decoder ([`UNet2DConditionModel`]):
62
- The decoder to invert the image embedding into an image.
63
- super_res_first ([`UNet2DModel`]):
64
- Super resolution unet. Used in all but the last step of the super resolution diffusion process.
65
- super_res_last ([`UNet2DModel`]):
66
- Super resolution unet. Used in the last step of the super resolution diffusion process.
67
- decoder_scheduler ([`UnCLIPScheduler`]):
68
- Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
69
- super_res_scheduler ([`UnCLIPScheduler`]):
70
- Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
71
-
72
- """
73
-
74
- decoder: UNet2DConditionModel
75
- text_proj: UnCLIPTextProjModel
76
- text_encoder: CLIPTextModelWithProjection
77
- tokenizer: CLIPTokenizer
78
- feature_extractor: CLIPImageProcessor
79
- image_encoder: CLIPVisionModelWithProjection
80
- super_res_first: UNet2DModel
81
- super_res_last: UNet2DModel
82
-
83
- decoder_scheduler: UnCLIPScheduler
84
- super_res_scheduler: UnCLIPScheduler
85
-
86
- # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
87
- def __init__(
88
- self,
89
- decoder: UNet2DConditionModel,
90
- text_encoder: CLIPTextModelWithProjection,
91
- tokenizer: CLIPTokenizer,
92
- text_proj: UnCLIPTextProjModel,
93
- feature_extractor: CLIPImageProcessor,
94
- image_encoder: CLIPVisionModelWithProjection,
95
- super_res_first: UNet2DModel,
96
- super_res_last: UNet2DModel,
97
- decoder_scheduler: UnCLIPScheduler,
98
- super_res_scheduler: UnCLIPScheduler,
99
- ):
100
- super().__init__()
101
-
102
- self.register_modules(
103
- decoder=decoder,
104
- text_encoder=text_encoder,
105
- tokenizer=tokenizer,
106
- text_proj=text_proj,
107
- feature_extractor=feature_extractor,
108
- image_encoder=image_encoder,
109
- super_res_first=super_res_first,
110
- super_res_last=super_res_last,
111
- decoder_scheduler=decoder_scheduler,
112
- super_res_scheduler=super_res_scheduler,
113
- )
114
-
115
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
116
- def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
117
- if latents is None:
118
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
119
- else:
120
- if latents.shape != shape:
121
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
122
- latents = latents.to(device)
123
-
124
- latents = latents * scheduler.init_noise_sigma
125
- return latents
126
-
127
- # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
128
- def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
129
- batch_size = len(prompt) if isinstance(prompt, list) else 1
130
-
131
- # get prompt text embeddings
132
- text_inputs = self.tokenizer(
133
- prompt,
134
- padding="max_length",
135
- max_length=self.tokenizer.model_max_length,
136
- return_tensors="pt",
137
- )
138
- text_input_ids = text_inputs.input_ids
139
- text_mask = text_inputs.attention_mask.bool().to(device)
140
- text_encoder_output = self.text_encoder(text_input_ids.to(device))
141
-
142
- prompt_embeds = text_encoder_output.text_embeds
143
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
144
-
145
- prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
146
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
147
- text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
148
-
149
- if do_classifier_free_guidance:
150
- uncond_tokens = [""] * batch_size
151
-
152
- max_length = text_input_ids.shape[-1]
153
- uncond_input = self.tokenizer(
154
- uncond_tokens,
155
- padding="max_length",
156
- max_length=max_length,
157
- truncation=True,
158
- return_tensors="pt",
159
- )
160
- uncond_text_mask = uncond_input.attention_mask.bool().to(device)
161
- negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
162
-
163
- negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
164
- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
165
-
166
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
167
-
168
- seq_len = negative_prompt_embeds.shape[1]
169
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
170
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
171
-
172
- seq_len = uncond_text_encoder_hidden_states.shape[1]
173
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
174
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
175
- batch_size * num_images_per_prompt, seq_len, -1
176
- )
177
- uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
178
-
179
- # done duplicates
180
-
181
- # For classifier free guidance, we need to do two forward passes.
182
- # Here we concatenate the unconditional and text embeddings into a single batch
183
- # to avoid doing two forward passes
184
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
185
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
186
-
187
- text_mask = torch.cat([uncond_text_mask, text_mask])
188
-
189
- return prompt_embeds, text_encoder_hidden_states, text_mask
190
-
191
- # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
192
- def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
193
- dtype = next(self.image_encoder.parameters()).dtype
194
-
195
- if image_embeddings is None:
196
- if not isinstance(image, torch.Tensor):
197
- image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
198
-
199
- image = image.to(device=device, dtype=dtype)
200
- image_embeddings = self.image_encoder(image).image_embeds
201
-
202
- image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
203
-
204
- return image_embeddings
205
-
206
- # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload
207
- def enable_sequential_cpu_offload(self, gpu_id=0):
208
- r"""
209
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
210
- models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
211
- when their specific submodule has its `forward` method called.
212
- """
213
- if is_accelerate_available():
214
- from accelerate import cpu_offload
215
- else:
216
- raise ImportError("Please install accelerate via `pip install accelerate`")
217
-
218
- device = torch.device(f"cuda:{gpu_id}")
219
-
220
- models = [
221
- self.decoder,
222
- self.text_proj,
223
- self.text_encoder,
224
- self.super_res_first,
225
- self.super_res_last,
226
- ]
227
- for cpu_offloaded_model in models:
228
- if cpu_offloaded_model is not None:
229
- cpu_offload(cpu_offloaded_model, device)
230
-
231
- @property
232
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
233
- def _execution_device(self):
234
- r"""
235
- Returns the device on which the pipeline's models will be executed. After calling
236
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
237
- hooks.
238
- """
239
- if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
240
- return self.device
241
- for module in self.decoder.modules():
242
- if (
243
- hasattr(module, "_hf_hook")
244
- and hasattr(module._hf_hook, "execution_device")
245
- and module._hf_hook.execution_device is not None
246
- ):
247
- return torch.device(module._hf_hook.execution_device)
248
- return self.device
249
-
250
- @torch.no_grad()
251
- def __call__(
252
- self,
253
- image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
254
- steps: int = 5,
255
- decoder_num_inference_steps: int = 25,
256
- super_res_num_inference_steps: int = 7,
257
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
258
- image_embeddings: Optional[torch.Tensor] = None,
259
- decoder_latents: Optional[torch.FloatTensor] = None,
260
- super_res_latents: Optional[torch.FloatTensor] = None,
261
- decoder_guidance_scale: float = 8.0,
262
- output_type: Optional[str] = "pil",
263
- return_dict: bool = True,
264
- ):
265
- """
266
- Function invoked when calling the pipeline for generation.
267
-
268
- Args:
269
- image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
270
- The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
271
- configuration of
272
- [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
273
- `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
274
- steps (`int`, *optional*, defaults to 5):
275
- The number of interpolation images to generate.
276
- decoder_num_inference_steps (`int`, *optional*, defaults to 25):
277
- The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
278
- image at the expense of slower inference.
279
- super_res_num_inference_steps (`int`, *optional*, defaults to 7):
280
- The number of denoising steps for super resolution. More denoising steps usually lead to a higher
281
- quality image at the expense of slower inference.
282
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
283
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
284
- to make generation deterministic.
285
- image_embeddings (`torch.Tensor`, *optional*):
286
- Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
287
- can be passed for tasks like image interpolations. `image` can the be left to `None`.
288
- decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
289
- Pre-generated noisy latents to be used as inputs for the decoder.
290
- super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
291
- Pre-generated noisy latents to be used as inputs for the decoder.
292
- decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
293
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
294
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
295
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
296
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
297
- usually at the expense of lower image quality.
298
- output_type (`str`, *optional*, defaults to `"pil"`):
299
- The output format of the generated image. Choose between
300
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
301
- return_dict (`bool`, *optional*, defaults to `True`):
302
- Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
303
- """
304
-
305
- batch_size = steps
306
-
307
- device = self._execution_device
308
-
309
- if isinstance(image, List):
310
- if len(image) != 2:
311
- raise AssertionError(
312
- f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
313
- )
314
- elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
315
- raise AssertionError(
316
- f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
317
- )
318
- elif isinstance(image, torch.FloatTensor):
319
- if image.shape[0] != 2:
320
- raise AssertionError(
321
- f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
322
- )
323
- elif isinstance(image_embeddings, torch.Tensor):
324
- if image_embeddings.shape[0] != 2:
325
- raise AssertionError(
326
- f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
327
- )
328
- else:
329
- raise AssertionError(
330
- f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
331
- )
332
-
333
- original_image_embeddings = self._encode_image(
334
- image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
335
- )
336
-
337
- image_embeddings = []
338
-
339
- for interp_step in torch.linspace(0, 1, steps):
340
- temp_image_embeddings = slerp(
341
- interp_step, original_image_embeddings[0], original_image_embeddings[1]
342
- ).unsqueeze(0)
343
- image_embeddings.append(temp_image_embeddings)
344
-
345
- image_embeddings = torch.cat(image_embeddings).to(device)
346
-
347
- do_classifier_free_guidance = decoder_guidance_scale > 1.0
348
-
349
- prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
350
- prompt=["" for i in range(steps)],
351
- device=device,
352
- num_images_per_prompt=1,
353
- do_classifier_free_guidance=do_classifier_free_guidance,
354
- )
355
-
356
- text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
357
- image_embeddings=image_embeddings,
358
- prompt_embeds=prompt_embeds,
359
- text_encoder_hidden_states=text_encoder_hidden_states,
360
- do_classifier_free_guidance=do_classifier_free_guidance,
361
- )
362
-
363
- if device.type == "mps":
364
- # HACK: MPS: There is a panic when padding bool tensors,
365
- # so cast to int tensor for the pad and back to bool afterwards
366
- text_mask = text_mask.type(torch.int)
367
- decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
368
- decoder_text_mask = decoder_text_mask.type(torch.bool)
369
- else:
370
- decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
371
-
372
- self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
373
- decoder_timesteps_tensor = self.decoder_scheduler.timesteps
374
-
375
- num_channels_latents = self.decoder.in_channels
376
- height = self.decoder.sample_size
377
- width = self.decoder.sample_size
378
-
379
- decoder_latents = self.prepare_latents(
380
- (batch_size, num_channels_latents, height, width),
381
- text_encoder_hidden_states.dtype,
382
- device,
383
- generator,
384
- decoder_latents,
385
- self.decoder_scheduler,
386
- )
387
-
388
- for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
389
- # expand the latents if we are doing classifier free guidance
390
- latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
391
-
392
- noise_pred = self.decoder(
393
- sample=latent_model_input,
394
- timestep=t,
395
- encoder_hidden_states=text_encoder_hidden_states,
396
- class_labels=additive_clip_time_embeddings,
397
- attention_mask=decoder_text_mask,
398
- ).sample
399
-
400
- if do_classifier_free_guidance:
401
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
402
- noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
403
- noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
404
- noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
405
- noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
406
-
407
- if i + 1 == decoder_timesteps_tensor.shape[0]:
408
- prev_timestep = None
409
- else:
410
- prev_timestep = decoder_timesteps_tensor[i + 1]
411
-
412
- # compute the previous noisy sample x_t -> x_t-1
413
- decoder_latents = self.decoder_scheduler.step(
414
- noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
415
- ).prev_sample
416
-
417
- decoder_latents = decoder_latents.clamp(-1, 1)
418
-
419
- image_small = decoder_latents
420
-
421
- # done decoder
422
-
423
- # super res
424
-
425
- self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
426
- super_res_timesteps_tensor = self.super_res_scheduler.timesteps
427
-
428
- channels = self.super_res_first.in_channels // 2
429
- height = self.super_res_first.sample_size
430
- width = self.super_res_first.sample_size
431
-
432
- super_res_latents = self.prepare_latents(
433
- (batch_size, channels, height, width),
434
- image_small.dtype,
435
- device,
436
- generator,
437
- super_res_latents,
438
- self.super_res_scheduler,
439
- )
440
-
441
- if device.type == "mps":
442
- # MPS does not support many interpolations
443
- image_upscaled = F.interpolate(image_small, size=[height, width])
444
- else:
445
- interpolate_antialias = {}
446
- if "antialias" in inspect.signature(F.interpolate).parameters:
447
- interpolate_antialias["antialias"] = True
448
-
449
- image_upscaled = F.interpolate(
450
- image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
451
- )
452
-
453
- for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
454
- # no classifier free guidance
455
-
456
- if i == super_res_timesteps_tensor.shape[0] - 1:
457
- unet = self.super_res_last
458
- else:
459
- unet = self.super_res_first
460
-
461
- latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
462
-
463
- noise_pred = unet(
464
- sample=latent_model_input,
465
- timestep=t,
466
- ).sample
467
-
468
- if i + 1 == super_res_timesteps_tensor.shape[0]:
469
- prev_timestep = None
470
- else:
471
- prev_timestep = super_res_timesteps_tensor[i + 1]
472
-
473
- # compute the previous noisy sample x_t -> x_t-1
474
- super_res_latents = self.super_res_scheduler.step(
475
- noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
476
- ).prev_sample
477
-
478
- image = super_res_latents
479
- # done super res
480
-
481
- # post processing
482
-
483
- image = image * 0.5 + 0.5
484
- image = image.clamp(0, 1)
485
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
486
-
487
- if output_type == "pil":
488
- image = self.numpy_to_pil(image)
489
-
490
- if not return_dict:
491
- return (image,)
492
-
493
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/unclip_text_interpolation.py DELETED
@@ -1,573 +0,0 @@
1
- import inspect
2
- from typing import List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch.nn import functional as F
6
- from transformers import CLIPTextModelWithProjection, CLIPTokenizer
7
- from transformers.models.clip.modeling_clip import CLIPTextModelOutput
8
-
9
- from diffusers import (
10
- DiffusionPipeline,
11
- ImagePipelineOutput,
12
- PriorTransformer,
13
- UnCLIPScheduler,
14
- UNet2DConditionModel,
15
- UNet2DModel,
16
- )
17
- from diffusers.pipelines.unclip import UnCLIPTextProjModel
18
- from diffusers.utils import is_accelerate_available, logging, randn_tensor
19
-
20
-
21
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
-
23
-
24
- def slerp(val, low, high):
25
- """
26
- Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
27
- """
28
- low_norm = low / torch.norm(low)
29
- high_norm = high / torch.norm(high)
30
- omega = torch.acos((low_norm * high_norm))
31
- so = torch.sin(omega)
32
- res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
33
- return res
34
-
35
-
36
- class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
37
-
38
- """
39
- Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.
40
-
41
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
42
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
43
-
44
- Args:
45
- text_encoder ([`CLIPTextModelWithProjection`]):
46
- Frozen text-encoder.
47
- tokenizer (`CLIPTokenizer`):
48
- Tokenizer of class
49
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
50
- prior ([`PriorTransformer`]):
51
- The canonincal unCLIP prior to approximate the image embedding from the text embedding.
52
- text_proj ([`UnCLIPTextProjModel`]):
53
- Utility class to prepare and combine the embeddings before they are passed to the decoder.
54
- decoder ([`UNet2DConditionModel`]):
55
- The decoder to invert the image embedding into an image.
56
- super_res_first ([`UNet2DModel`]):
57
- Super resolution unet. Used in all but the last step of the super resolution diffusion process.
58
- super_res_last ([`UNet2DModel`]):
59
- Super resolution unet. Used in the last step of the super resolution diffusion process.
60
- prior_scheduler ([`UnCLIPScheduler`]):
61
- Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
62
- decoder_scheduler ([`UnCLIPScheduler`]):
63
- Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
64
- super_res_scheduler ([`UnCLIPScheduler`]):
65
- Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
66
-
67
- """
68
-
69
- prior: PriorTransformer
70
- decoder: UNet2DConditionModel
71
- text_proj: UnCLIPTextProjModel
72
- text_encoder: CLIPTextModelWithProjection
73
- tokenizer: CLIPTokenizer
74
- super_res_first: UNet2DModel
75
- super_res_last: UNet2DModel
76
-
77
- prior_scheduler: UnCLIPScheduler
78
- decoder_scheduler: UnCLIPScheduler
79
- super_res_scheduler: UnCLIPScheduler
80
-
81
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__
82
- def __init__(
83
- self,
84
- prior: PriorTransformer,
85
- decoder: UNet2DConditionModel,
86
- text_encoder: CLIPTextModelWithProjection,
87
- tokenizer: CLIPTokenizer,
88
- text_proj: UnCLIPTextProjModel,
89
- super_res_first: UNet2DModel,
90
- super_res_last: UNet2DModel,
91
- prior_scheduler: UnCLIPScheduler,
92
- decoder_scheduler: UnCLIPScheduler,
93
- super_res_scheduler: UnCLIPScheduler,
94
- ):
95
- super().__init__()
96
-
97
- self.register_modules(
98
- prior=prior,
99
- decoder=decoder,
100
- text_encoder=text_encoder,
101
- tokenizer=tokenizer,
102
- text_proj=text_proj,
103
- super_res_first=super_res_first,
104
- super_res_last=super_res_last,
105
- prior_scheduler=prior_scheduler,
106
- decoder_scheduler=decoder_scheduler,
107
- super_res_scheduler=super_res_scheduler,
108
- )
109
-
110
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
111
- def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
112
- if latents is None:
113
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
114
- else:
115
- if latents.shape != shape:
116
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
117
- latents = latents.to(device)
118
-
119
- latents = latents * scheduler.init_noise_sigma
120
- return latents
121
-
122
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
123
- def _encode_prompt(
124
- self,
125
- prompt,
126
- device,
127
- num_images_per_prompt,
128
- do_classifier_free_guidance,
129
- text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
130
- text_attention_mask: Optional[torch.Tensor] = None,
131
- ):
132
- if text_model_output is None:
133
- batch_size = len(prompt) if isinstance(prompt, list) else 1
134
- # get prompt text embeddings
135
- text_inputs = self.tokenizer(
136
- prompt,
137
- padding="max_length",
138
- max_length=self.tokenizer.model_max_length,
139
- truncation=True,
140
- return_tensors="pt",
141
- )
142
- text_input_ids = text_inputs.input_ids
143
- text_mask = text_inputs.attention_mask.bool().to(device)
144
-
145
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
146
-
147
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
148
- text_input_ids, untruncated_ids
149
- ):
150
- removed_text = self.tokenizer.batch_decode(
151
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
152
- )
153
- logger.warning(
154
- "The following part of your input was truncated because CLIP can only handle sequences up to"
155
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
156
- )
157
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
158
-
159
- text_encoder_output = self.text_encoder(text_input_ids.to(device))
160
-
161
- prompt_embeds = text_encoder_output.text_embeds
162
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
163
-
164
- else:
165
- batch_size = text_model_output[0].shape[0]
166
- prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
167
- text_mask = text_attention_mask
168
-
169
- prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
170
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
171
- text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
172
-
173
- if do_classifier_free_guidance:
174
- uncond_tokens = [""] * batch_size
175
-
176
- uncond_input = self.tokenizer(
177
- uncond_tokens,
178
- padding="max_length",
179
- max_length=self.tokenizer.model_max_length,
180
- truncation=True,
181
- return_tensors="pt",
182
- )
183
- uncond_text_mask = uncond_input.attention_mask.bool().to(device)
184
- negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
185
-
186
- negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
187
- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
188
-
189
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
190
-
191
- seq_len = negative_prompt_embeds.shape[1]
192
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
193
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
194
-
195
- seq_len = uncond_text_encoder_hidden_states.shape[1]
196
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
197
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
198
- batch_size * num_images_per_prompt, seq_len, -1
199
- )
200
- uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
201
-
202
- # done duplicates
203
-
204
- # For classifier free guidance, we need to do two forward passes.
205
- # Here we concatenate the unconditional and text embeddings into a single batch
206
- # to avoid doing two forward passes
207
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
208
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
209
-
210
- text_mask = torch.cat([uncond_text_mask, text_mask])
211
-
212
- return prompt_embeds, text_encoder_hidden_states, text_mask
213
-
214
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload
215
- def enable_sequential_cpu_offload(self, gpu_id=0):
216
- r"""
217
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
218
- models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
219
- when their specific submodule has its `forward` method called.
220
- """
221
- if is_accelerate_available():
222
- from accelerate import cpu_offload
223
- else:
224
- raise ImportError("Please install accelerate via `pip install accelerate`")
225
-
226
- device = torch.device(f"cuda:{gpu_id}")
227
-
228
- # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
229
- models = [
230
- self.decoder,
231
- self.text_proj,
232
- self.text_encoder,
233
- self.super_res_first,
234
- self.super_res_last,
235
- ]
236
- for cpu_offloaded_model in models:
237
- if cpu_offloaded_model is not None:
238
- cpu_offload(cpu_offloaded_model, device)
239
-
240
- @property
241
- # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
242
- def _execution_device(self):
243
- r"""
244
- Returns the device on which the pipeline's models will be executed. After calling
245
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
246
- hooks.
247
- """
248
- if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
249
- return self.device
250
- for module in self.decoder.modules():
251
- if (
252
- hasattr(module, "_hf_hook")
253
- and hasattr(module._hf_hook, "execution_device")
254
- and module._hf_hook.execution_device is not None
255
- ):
256
- return torch.device(module._hf_hook.execution_device)
257
- return self.device
258
-
259
- @torch.no_grad()
260
- def __call__(
261
- self,
262
- start_prompt: str,
263
- end_prompt: str,
264
- steps: int = 5,
265
- prior_num_inference_steps: int = 25,
266
- decoder_num_inference_steps: int = 25,
267
- super_res_num_inference_steps: int = 7,
268
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
269
- prior_guidance_scale: float = 4.0,
270
- decoder_guidance_scale: float = 8.0,
271
- enable_sequential_cpu_offload=True,
272
- gpu_id=0,
273
- output_type: Optional[str] = "pil",
274
- return_dict: bool = True,
275
- ):
276
- """
277
- Function invoked when calling the pipeline for generation.
278
-
279
- Args:
280
- start_prompt (`str`):
281
- The prompt to start the image generation interpolation from.
282
- end_prompt (`str`):
283
- The prompt to end the image generation interpolation at.
284
- steps (`int`, *optional*, defaults to 5):
285
- The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns
286
- the same number of images as this value.
287
- prior_num_inference_steps (`int`, *optional*, defaults to 25):
288
- The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
289
- image at the expense of slower inference.
290
- decoder_num_inference_steps (`int`, *optional*, defaults to 25):
291
- The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
292
- image at the expense of slower inference.
293
- super_res_num_inference_steps (`int`, *optional*, defaults to 7):
294
- The number of denoising steps for super resolution. More denoising steps usually lead to a higher
295
- quality image at the expense of slower inference.
296
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
297
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
298
- to make generation deterministic.
299
- prior_guidance_scale (`float`, *optional*, defaults to 4.0):
300
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
301
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
302
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
303
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
304
- usually at the expense of lower image quality.
305
- decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
306
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
307
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
308
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
309
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
310
- usually at the expense of lower image quality.
311
- output_type (`str`, *optional*, defaults to `"pil"`):
312
- The output format of the generated image. Choose between
313
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
314
- enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):
315
- If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
316
- models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
317
- when their specific submodule has its `forward` method called.
318
- gpu_id (`int`, *optional*, defaults to `0`):
319
- The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.
320
- return_dict (`bool`, *optional*, defaults to `True`):
321
- Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
322
- """
323
-
324
- if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):
325
- raise ValueError(
326
- f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"
327
- f" {type(end_prompt)} instead"
328
- )
329
-
330
- if enable_sequential_cpu_offload:
331
- self.enable_sequential_cpu_offload(gpu_id=gpu_id)
332
-
333
- device = self._execution_device
334
-
335
- # Turn the prompts into embeddings.
336
- inputs = self.tokenizer(
337
- [start_prompt, end_prompt],
338
- padding="max_length",
339
- truncation=True,
340
- max_length=self.tokenizer.model_max_length,
341
- return_tensors="pt",
342
- )
343
- inputs.to(device)
344
- text_model_output = self.text_encoder(**inputs)
345
-
346
- text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])
347
- text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)
348
-
349
- # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline
350
- batch_text_embeds = []
351
- batch_last_hidden_state = []
352
-
353
- for interp_val in torch.linspace(0, 1, steps):
354
- text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])
355
- last_hidden_state = slerp(
356
- interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]
357
- )
358
- batch_text_embeds.append(text_embeds.unsqueeze(0))
359
- batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))
360
-
361
- batch_text_embeds = torch.cat(batch_text_embeds)
362
- batch_last_hidden_state = torch.cat(batch_last_hidden_state)
363
-
364
- text_model_output = CLIPTextModelOutput(
365
- text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state
366
- )
367
-
368
- batch_size = text_model_output[0].shape[0]
369
-
370
- do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
371
-
372
- prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
373
- prompt=None,
374
- device=device,
375
- num_images_per_prompt=1,
376
- do_classifier_free_guidance=do_classifier_free_guidance,
377
- text_model_output=text_model_output,
378
- text_attention_mask=text_attention_mask,
379
- )
380
-
381
- # prior
382
-
383
- self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
384
- prior_timesteps_tensor = self.prior_scheduler.timesteps
385
-
386
- embedding_dim = self.prior.config.embedding_dim
387
-
388
- prior_latents = self.prepare_latents(
389
- (batch_size, embedding_dim),
390
- prompt_embeds.dtype,
391
- device,
392
- generator,
393
- None,
394
- self.prior_scheduler,
395
- )
396
-
397
- for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
398
- # expand the latents if we are doing classifier free guidance
399
- latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
400
-
401
- predicted_image_embedding = self.prior(
402
- latent_model_input,
403
- timestep=t,
404
- proj_embedding=prompt_embeds,
405
- encoder_hidden_states=text_encoder_hidden_states,
406
- attention_mask=text_mask,
407
- ).predicted_image_embedding
408
-
409
- if do_classifier_free_guidance:
410
- predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
411
- predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
412
- predicted_image_embedding_text - predicted_image_embedding_uncond
413
- )
414
-
415
- if i + 1 == prior_timesteps_tensor.shape[0]:
416
- prev_timestep = None
417
- else:
418
- prev_timestep = prior_timesteps_tensor[i + 1]
419
-
420
- prior_latents = self.prior_scheduler.step(
421
- predicted_image_embedding,
422
- timestep=t,
423
- sample=prior_latents,
424
- generator=generator,
425
- prev_timestep=prev_timestep,
426
- ).prev_sample
427
-
428
- prior_latents = self.prior.post_process_latents(prior_latents)
429
-
430
- image_embeddings = prior_latents
431
-
432
- # done prior
433
-
434
- # decoder
435
-
436
- text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
437
- image_embeddings=image_embeddings,
438
- prompt_embeds=prompt_embeds,
439
- text_encoder_hidden_states=text_encoder_hidden_states,
440
- do_classifier_free_guidance=do_classifier_free_guidance,
441
- )
442
-
443
- if device.type == "mps":
444
- # HACK: MPS: There is a panic when padding bool tensors,
445
- # so cast to int tensor for the pad and back to bool afterwards
446
- text_mask = text_mask.type(torch.int)
447
- decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
448
- decoder_text_mask = decoder_text_mask.type(torch.bool)
449
- else:
450
- decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
451
-
452
- self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
453
- decoder_timesteps_tensor = self.decoder_scheduler.timesteps
454
-
455
- num_channels_latents = self.decoder.in_channels
456
- height = self.decoder.sample_size
457
- width = self.decoder.sample_size
458
-
459
- decoder_latents = self.prepare_latents(
460
- (batch_size, num_channels_latents, height, width),
461
- text_encoder_hidden_states.dtype,
462
- device,
463
- generator,
464
- None,
465
- self.decoder_scheduler,
466
- )
467
-
468
- for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
469
- # expand the latents if we are doing classifier free guidance
470
- latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
471
-
472
- noise_pred = self.decoder(
473
- sample=latent_model_input,
474
- timestep=t,
475
- encoder_hidden_states=text_encoder_hidden_states,
476
- class_labels=additive_clip_time_embeddings,
477
- attention_mask=decoder_text_mask,
478
- ).sample
479
-
480
- if do_classifier_free_guidance:
481
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
482
- noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
483
- noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
484
- noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
485
- noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
486
-
487
- if i + 1 == decoder_timesteps_tensor.shape[0]:
488
- prev_timestep = None
489
- else:
490
- prev_timestep = decoder_timesteps_tensor[i + 1]
491
-
492
- # compute the previous noisy sample x_t -> x_t-1
493
- decoder_latents = self.decoder_scheduler.step(
494
- noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
495
- ).prev_sample
496
-
497
- decoder_latents = decoder_latents.clamp(-1, 1)
498
-
499
- image_small = decoder_latents
500
-
501
- # done decoder
502
-
503
- # super res
504
-
505
- self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
506
- super_res_timesteps_tensor = self.super_res_scheduler.timesteps
507
-
508
- channels = self.super_res_first.in_channels // 2
509
- height = self.super_res_first.sample_size
510
- width = self.super_res_first.sample_size
511
-
512
- super_res_latents = self.prepare_latents(
513
- (batch_size, channels, height, width),
514
- image_small.dtype,
515
- device,
516
- generator,
517
- None,
518
- self.super_res_scheduler,
519
- )
520
-
521
- if device.type == "mps":
522
- # MPS does not support many interpolations
523
- image_upscaled = F.interpolate(image_small, size=[height, width])
524
- else:
525
- interpolate_antialias = {}
526
- if "antialias" in inspect.signature(F.interpolate).parameters:
527
- interpolate_antialias["antialias"] = True
528
-
529
- image_upscaled = F.interpolate(
530
- image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
531
- )
532
-
533
- for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
534
- # no classifier free guidance
535
-
536
- if i == super_res_timesteps_tensor.shape[0] - 1:
537
- unet = self.super_res_last
538
- else:
539
- unet = self.super_res_first
540
-
541
- latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
542
-
543
- noise_pred = unet(
544
- sample=latent_model_input,
545
- timestep=t,
546
- ).sample
547
-
548
- if i + 1 == super_res_timesteps_tensor.shape[0]:
549
- prev_timestep = None
550
- else:
551
- prev_timestep = super_res_timesteps_tensor[i + 1]
552
-
553
- # compute the previous noisy sample x_t -> x_t-1
554
- super_res_latents = self.super_res_scheduler.step(
555
- noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
556
- ).prev_sample
557
-
558
- image = super_res_latents
559
- # done super res
560
-
561
- # post processing
562
-
563
- image = image * 0.5 + 0.5
564
- image = image.clamp(0, 1)
565
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
566
-
567
- if output_type == "pil":
568
- image = self.numpy_to_pil(image)
569
-
570
- if not return_dict:
571
- return (image,)
572
-
573
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/community/wildcard_stable_diffusion.py DELETED
@@ -1,418 +0,0 @@
1
- import inspect
2
- import os
3
- import random
4
- import re
5
- from dataclasses import dataclass
6
- from typing import Callable, Dict, List, Optional, Union
7
-
8
- import torch
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
-
11
- from diffusers import DiffusionPipeline
12
- from diffusers.configuration_utils import FrozenDict
13
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
15
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
16
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
17
- from diffusers.utils import deprecate, logging
18
-
19
-
20
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
-
22
- global_re_wildcard = re.compile(r"__([^_]*)__")
23
-
24
-
25
- def get_filename(path: str):
26
- # this doesn't work on Windows
27
- return os.path.basename(path).split(".txt")[0]
28
-
29
-
30
- def read_wildcard_values(path: str):
31
- with open(path, encoding="utf8") as f:
32
- return f.read().splitlines()
33
-
34
-
35
- def grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []):
36
- for wildcard_file in wildcard_files:
37
- filename = get_filename(wildcard_file)
38
- read_values = read_wildcard_values(wildcard_file)
39
- if filename not in wildcard_option_dict:
40
- wildcard_option_dict[filename] = []
41
- wildcard_option_dict[filename].extend(read_values)
42
- return wildcard_option_dict
43
-
44
-
45
- def replace_prompt_with_wildcards(
46
- prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []
47
- ):
48
- new_prompt = prompt
49
-
50
- # get wildcard options
51
- wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files)
52
-
53
- for m in global_re_wildcard.finditer(new_prompt):
54
- wildcard_value = m.group()
55
- replace_value = random.choice(wildcard_option_dict[wildcard_value.strip("__")])
56
- new_prompt = new_prompt.replace(wildcard_value, replace_value, 1)
57
-
58
- return new_prompt
59
-
60
-
61
- @dataclass
62
- class WildcardStableDiffusionOutput(StableDiffusionPipelineOutput):
63
- prompts: List[str]
64
-
65
-
66
- class WildcardStableDiffusionPipeline(DiffusionPipeline):
67
- r"""
68
- Example Usage:
69
- pipe = WildcardStableDiffusionPipeline.from_pretrained(
70
- "CompVis/stable-diffusion-v1-4",
71
-
72
- torch_dtype=torch.float16,
73
- )
74
- prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
75
- out = pipe(
76
- prompt,
77
- wildcard_option_dict={
78
- "clothing":["hat", "shirt", "scarf", "beret"]
79
- },
80
- wildcard_files=["object.txt", "animal.txt"],
81
- num_prompt_samples=1
82
- )
83
-
84
-
85
- Pipeline for text-to-image generation with wild cards using Stable Diffusion.
86
-
87
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
88
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
89
-
90
- Args:
91
- vae ([`AutoencoderKL`]):
92
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
93
- text_encoder ([`CLIPTextModel`]):
94
- Frozen text-encoder. Stable Diffusion uses the text portion of
95
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
96
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
97
- tokenizer (`CLIPTokenizer`):
98
- Tokenizer of class
99
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
100
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
101
- scheduler ([`SchedulerMixin`]):
102
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
103
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
104
- safety_checker ([`StableDiffusionSafetyChecker`]):
105
- Classification module that estimates whether generated images could be considered offensive or harmful.
106
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
107
- feature_extractor ([`CLIPImageProcessor`]):
108
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
109
- """
110
-
111
- def __init__(
112
- self,
113
- vae: AutoencoderKL,
114
- text_encoder: CLIPTextModel,
115
- tokenizer: CLIPTokenizer,
116
- unet: UNet2DConditionModel,
117
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
118
- safety_checker: StableDiffusionSafetyChecker,
119
- feature_extractor: CLIPImageProcessor,
120
- ):
121
- super().__init__()
122
-
123
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
124
- deprecation_message = (
125
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
126
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
127
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
128
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
129
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
130
- " file"
131
- )
132
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
133
- new_config = dict(scheduler.config)
134
- new_config["steps_offset"] = 1
135
- scheduler._internal_dict = FrozenDict(new_config)
136
-
137
- if safety_checker is None:
138
- logger.warning(
139
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
140
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
141
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
142
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
143
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
144
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
145
- )
146
-
147
- self.register_modules(
148
- vae=vae,
149
- text_encoder=text_encoder,
150
- tokenizer=tokenizer,
151
- unet=unet,
152
- scheduler=scheduler,
153
- safety_checker=safety_checker,
154
- feature_extractor=feature_extractor,
155
- )
156
-
157
- @torch.no_grad()
158
- def __call__(
159
- self,
160
- prompt: Union[str, List[str]],
161
- height: int = 512,
162
- width: int = 512,
163
- num_inference_steps: int = 50,
164
- guidance_scale: float = 7.5,
165
- negative_prompt: Optional[Union[str, List[str]]] = None,
166
- num_images_per_prompt: Optional[int] = 1,
167
- eta: float = 0.0,
168
- generator: Optional[torch.Generator] = None,
169
- latents: Optional[torch.FloatTensor] = None,
170
- output_type: Optional[str] = "pil",
171
- return_dict: bool = True,
172
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
173
- callback_steps: int = 1,
174
- wildcard_option_dict: Dict[str, List[str]] = {},
175
- wildcard_files: List[str] = [],
176
- num_prompt_samples: Optional[int] = 1,
177
- **kwargs,
178
- ):
179
- r"""
180
- Function invoked when calling the pipeline for generation.
181
-
182
- Args:
183
- prompt (`str` or `List[str]`):
184
- The prompt or prompts to guide the image generation.
185
- height (`int`, *optional*, defaults to 512):
186
- The height in pixels of the generated image.
187
- width (`int`, *optional*, defaults to 512):
188
- The width in pixels of the generated image.
189
- num_inference_steps (`int`, *optional*, defaults to 50):
190
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191
- expense of slower inference.
192
- guidance_scale (`float`, *optional*, defaults to 7.5):
193
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
194
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
195
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
196
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
197
- usually at the expense of lower image quality.
198
- negative_prompt (`str` or `List[str]`, *optional*):
199
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
200
- if `guidance_scale` is less than `1`).
201
- num_images_per_prompt (`int`, *optional*, defaults to 1):
202
- The number of images to generate per prompt.
203
- eta (`float`, *optional*, defaults to 0.0):
204
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
205
- [`schedulers.DDIMScheduler`], will be ignored for others.
206
- generator (`torch.Generator`, *optional*):
207
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
208
- deterministic.
209
- latents (`torch.FloatTensor`, *optional*):
210
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
211
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
212
- tensor will ge generated by sampling using the supplied random `generator`.
213
- output_type (`str`, *optional*, defaults to `"pil"`):
214
- The output format of the generate image. Choose between
215
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
216
- return_dict (`bool`, *optional*, defaults to `True`):
217
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
218
- plain tuple.
219
- callback (`Callable`, *optional*):
220
- A function that will be called every `callback_steps` steps during inference. The function will be
221
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
222
- callback_steps (`int`, *optional*, defaults to 1):
223
- The frequency at which the `callback` function will be called. If not specified, the callback will be
224
- called at every step.
225
- wildcard_option_dict (Dict[str, List[str]]):
226
- dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, "A __animal__ sitting on a chair". A wildcard_option_dict can provide possible values for "animal" like this: {"animal":["dog", "cat", "fox"]}
227
- wildcard_files: (List[str])
228
- List of filenames of txt files for wildcard replacements. For example if a prompt, "A __animal__ sitting on a chair". A file can be provided ["animal.txt"]
229
- num_prompt_samples: int
230
- Number of times to sample wildcards for each prompt provided
231
-
232
- Returns:
233
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
234
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
235
- When returning a tuple, the first element is a list with the generated images, and the second element is a
236
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
237
- (nsfw) content, according to the `safety_checker`.
238
- """
239
-
240
- if isinstance(prompt, str):
241
- prompt = [
242
- replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files)
243
- for i in range(num_prompt_samples)
244
- ]
245
- batch_size = len(prompt)
246
- elif isinstance(prompt, list):
247
- prompt_list = []
248
- for p in prompt:
249
- for i in range(num_prompt_samples):
250
- prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files))
251
- prompt = prompt_list
252
- batch_size = len(prompt)
253
- else:
254
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
255
-
256
- if height % 8 != 0 or width % 8 != 0:
257
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
258
-
259
- if (callback_steps is None) or (
260
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
261
- ):
262
- raise ValueError(
263
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
264
- f" {type(callback_steps)}."
265
- )
266
-
267
- # get prompt text embeddings
268
- text_inputs = self.tokenizer(
269
- prompt,
270
- padding="max_length",
271
- max_length=self.tokenizer.model_max_length,
272
- return_tensors="pt",
273
- )
274
- text_input_ids = text_inputs.input_ids
275
-
276
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
277
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
278
- logger.warning(
279
- "The following part of your input was truncated because CLIP can only handle sequences up to"
280
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
281
- )
282
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
283
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
284
-
285
- # duplicate text embeddings for each generation per prompt, using mps friendly method
286
- bs_embed, seq_len, _ = text_embeddings.shape
287
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
288
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
289
-
290
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
291
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
292
- # corresponds to doing no classifier free guidance.
293
- do_classifier_free_guidance = guidance_scale > 1.0
294
- # get unconditional embeddings for classifier free guidance
295
- if do_classifier_free_guidance:
296
- uncond_tokens: List[str]
297
- if negative_prompt is None:
298
- uncond_tokens = [""] * batch_size
299
- elif type(prompt) is not type(negative_prompt):
300
- raise TypeError(
301
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
302
- f" {type(prompt)}."
303
- )
304
- elif isinstance(negative_prompt, str):
305
- uncond_tokens = [negative_prompt]
306
- elif batch_size != len(negative_prompt):
307
- raise ValueError(
308
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
309
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
310
- " the batch size of `prompt`."
311
- )
312
- else:
313
- uncond_tokens = negative_prompt
314
-
315
- max_length = text_input_ids.shape[-1]
316
- uncond_input = self.tokenizer(
317
- uncond_tokens,
318
- padding="max_length",
319
- max_length=max_length,
320
- truncation=True,
321
- return_tensors="pt",
322
- )
323
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
324
-
325
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
326
- seq_len = uncond_embeddings.shape[1]
327
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
328
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
329
-
330
- # For classifier free guidance, we need to do two forward passes.
331
- # Here we concatenate the unconditional and text embeddings into a single batch
332
- # to avoid doing two forward passes
333
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
334
-
335
- # get the initial random noise unless the user supplied it
336
-
337
- # Unlike in other pipelines, latents need to be generated in the target device
338
- # for 1-to-1 results reproducibility with the CompVis implementation.
339
- # However this currently doesn't work in `mps`.
340
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
341
- latents_dtype = text_embeddings.dtype
342
- if latents is None:
343
- if self.device.type == "mps":
344
- # randn does not exist on mps
345
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
346
- self.device
347
- )
348
- else:
349
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
350
- else:
351
- if latents.shape != latents_shape:
352
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
353
- latents = latents.to(self.device)
354
-
355
- # set timesteps
356
- self.scheduler.set_timesteps(num_inference_steps)
357
-
358
- # Some schedulers like PNDM have timesteps as arrays
359
- # It's more optimized to move all timesteps to correct device beforehand
360
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
361
-
362
- # scale the initial noise by the standard deviation required by the scheduler
363
- latents = latents * self.scheduler.init_noise_sigma
364
-
365
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
366
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
367
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
368
- # and should be between [0, 1]
369
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
370
- extra_step_kwargs = {}
371
- if accepts_eta:
372
- extra_step_kwargs["eta"] = eta
373
-
374
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
375
- # expand the latents if we are doing classifier free guidance
376
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
377
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
378
-
379
- # predict the noise residual
380
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
381
-
382
- # perform guidance
383
- if do_classifier_free_guidance:
384
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
385
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
386
-
387
- # compute the previous noisy sample x_t -> x_t-1
388
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
389
-
390
- # call the callback, if provided
391
- if callback is not None and i % callback_steps == 0:
392
- callback(i, t, latents)
393
-
394
- latents = 1 / 0.18215 * latents
395
- image = self.vae.decode(latents).sample
396
-
397
- image = (image / 2 + 0.5).clamp(0, 1)
398
-
399
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
400
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
401
-
402
- if self.safety_checker is not None:
403
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
404
- self.device
405
- )
406
- image, has_nsfw_concept = self.safety_checker(
407
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
408
- )
409
- else:
410
- has_nsfw_concept = None
411
-
412
- if output_type == "pil":
413
- image = self.numpy_to_pil(image)
414
-
415
- if not return_dict:
416
- return (image, has_nsfw_concept)
417
-
418
- return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/conftest.py DELETED
@@ -1,45 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # tests directory-specific settings - this file is run automatically
16
- # by pytest before any tests are run
17
-
18
- import sys
19
- import warnings
20
- from os.path import abspath, dirname, join
21
-
22
-
23
- # allow having multiple repository checkouts and not needing to remember to rerun
24
- # 'pip install -e .[dev]' when switching between checkouts and running tests.
25
- git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
26
- sys.path.insert(1, git_repo_path)
27
-
28
-
29
- # silence FutureWarning warnings in tests since often we can't act on them until
30
- # they become normal warnings - i.e. the tests still need to test the current functionality
31
- warnings.simplefilter(action="ignore", category=FutureWarning)
32
-
33
-
34
- def pytest_addoption(parser):
35
- from diffusers.utils.testing_utils import pytest_addoption_shared
36
-
37
- pytest_addoption_shared(parser)
38
-
39
-
40
- def pytest_terminal_summary(terminalreporter):
41
- from diffusers.utils.testing_utils import pytest_terminal_summary_main
42
-
43
- make_reports = terminalreporter.config.getoption("--make-reports")
44
- if make_reports:
45
- pytest_terminal_summary_main(terminalreporter, id=make_reports)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/controlnet/README.md DELETED
@@ -1,392 +0,0 @@
1
- # ControlNet training example
2
-
3
- [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
4
-
5
- This example is based on the [training example in the original ControlNet repository](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md). It trains a ControlNet to fill circles using a [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k).
6
-
7
- ## Installing the dependencies
8
-
9
- Before running the scripts, make sure to install the library's training dependencies:
10
-
11
- **Important**
12
-
13
- To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14
- ```bash
15
- git clone https://github.com/huggingface/diffusers
16
- cd diffusers
17
- pip install -e .
18
- ```
19
-
20
- Then cd in the example folder and run
21
- ```bash
22
- pip install -r requirements.txt
23
- ```
24
-
25
- And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
26
-
27
- ```bash
28
- accelerate config
29
- ```
30
-
31
- Or for a default accelerate configuration without answering questions about your environment
32
-
33
- ```bash
34
- accelerate config default
35
- ```
36
-
37
- Or if your environment doesn't support an interactive shell e.g. a notebook
38
-
39
- ```python
40
- from accelerate.utils import write_basic_config
41
- write_basic_config()
42
- ```
43
-
44
- ## Circle filling dataset
45
-
46
- The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
47
-
48
- Our training examples use [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the original set of ControlNet models were trained from it. However, ControlNet can be trained to augment any Stable Diffusion compatible model (such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
49
-
50
- ## Training
51
-
52
- Our training examples use two test conditioning images. They can be downloaded by running
53
-
54
- ```sh
55
- wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
56
-
57
- wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
58
- ```
59
-
60
-
61
- ```bash
62
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
63
- export OUTPUT_DIR="path to save model"
64
-
65
- accelerate launch train_controlnet.py \
66
- --pretrained_model_name_or_path=$MODEL_DIR \
67
- --output_dir=$OUTPUT_DIR \
68
- --dataset_name=fusing/fill50k \
69
- --resolution=512 \
70
- --learning_rate=1e-5 \
71
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
72
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
73
- --train_batch_size=4
74
- ```
75
-
76
- This default configuration requires ~38GB VRAM.
77
-
78
- By default, the training script logs outputs to tensorboard. Pass `--report_to wandb` to use weights and
79
- biases.
80
-
81
- Gradient accumulation with a smaller batch size can be used to reduce training requirements to ~20 GB VRAM.
82
-
83
- ```bash
84
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
85
- export OUTPUT_DIR="path to save model"
86
-
87
- accelerate launch train_controlnet.py \
88
- --pretrained_model_name_or_path=$MODEL_DIR \
89
- --output_dir=$OUTPUT_DIR \
90
- --dataset_name=fusing/fill50k \
91
- --resolution=512 \
92
- --learning_rate=1e-5 \
93
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
94
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
95
- --train_batch_size=1 \
96
- --gradient_accumulation_steps=4
97
- ```
98
-
99
- ## Example results
100
-
101
- #### After 300 steps with batch size 8
102
-
103
- | | |
104
- |-------------------|:-------------------------:|
105
- | | red circle with blue background |
106
- ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |
107
- | | cyan circle with brown floral background |
108
- ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |
109
-
110
-
111
- #### After 6000 steps with batch size 8:
112
-
113
- | | |
114
- |-------------------|:-------------------------:|
115
- | | red circle with blue background |
116
- ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |
117
- | | cyan circle with brown floral background |
118
- ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |
119
-
120
- ## Training on a 16 GB GPU
121
-
122
- Optimizations:
123
- - Gradient checkpointing
124
- - bitsandbyte's 8-bit optimizer
125
-
126
- [bitandbytes install instructions](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
127
-
128
- ```bash
129
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
130
- export OUTPUT_DIR="path to save model"
131
-
132
- accelerate launch train_controlnet.py \
133
- --pretrained_model_name_or_path=$MODEL_DIR \
134
- --output_dir=$OUTPUT_DIR \
135
- --dataset_name=fusing/fill50k \
136
- --resolution=512 \
137
- --learning_rate=1e-5 \
138
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
139
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
140
- --train_batch_size=1 \
141
- --gradient_accumulation_steps=4 \
142
- --gradient_checkpointing \
143
- --use_8bit_adam
144
- ```
145
-
146
- ## Training on a 12 GB GPU
147
-
148
- Optimizations:
149
- - Gradient checkpointing
150
- - bitsandbyte's 8-bit optimizer
151
- - xformers
152
- - set grads to none
153
-
154
- ```bash
155
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
156
- export OUTPUT_DIR="path to save model"
157
-
158
- accelerate launch train_controlnet.py \
159
- --pretrained_model_name_or_path=$MODEL_DIR \
160
- --output_dir=$OUTPUT_DIR \
161
- --dataset_name=fusing/fill50k \
162
- --resolution=512 \
163
- --learning_rate=1e-5 \
164
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
165
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
166
- --train_batch_size=1 \
167
- --gradient_accumulation_steps=4 \
168
- --gradient_checkpointing \
169
- --use_8bit_adam \
170
- --enable_xformers_memory_efficient_attention \
171
- --set_grads_to_none
172
- ```
173
-
174
- When using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`.
175
-
176
- ## Training on an 8 GB GPU
177
-
178
- We have not exhaustively tested DeepSpeed support for ControlNet. While the configuration does
179
- save memory, we have not confirmed the configuration to train successfully. You will very likely
180
- have to make changes to the config to have a successful training run.
181
-
182
- Optimizations:
183
- - Gradient checkpointing
184
- - xformers
185
- - set grads to none
186
- - DeepSpeed stage 2 with parameter and optimizer offloading
187
- - fp16 mixed precision
188
-
189
- [DeepSpeed](https://www.deepspeed.ai/) can offload tensors from VRAM to either
190
- CPU or NVME. This requires significantly more RAM (about 25 GB).
191
-
192
- Use `accelerate config` to enable DeepSpeed stage 2.
193
-
194
- The relevant parts of the resulting accelerate config file are
195
-
196
- ```yaml
197
- compute_environment: LOCAL_MACHINE
198
- deepspeed_config:
199
- gradient_accumulation_steps: 4
200
- offload_optimizer_device: cpu
201
- offload_param_device: cpu
202
- zero3_init_flag: false
203
- zero_stage: 2
204
- distributed_type: DEEPSPEED
205
- ```
206
-
207
- See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
208
-
209
- Changing the default Adam optimizer to DeepSpeed's Adam
210
- `deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
211
- it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
212
- does not seem to be compatible with DeepSpeed at the moment.
213
-
214
- ```bash
215
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
216
- export OUTPUT_DIR="path to save model"
217
-
218
- accelerate launch train_controlnet.py \
219
- --pretrained_model_name_or_path=$MODEL_DIR \
220
- --output_dir=$OUTPUT_DIR \
221
- --dataset_name=fusing/fill50k \
222
- --resolution=512 \
223
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
224
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
225
- --train_batch_size=1 \
226
- --gradient_accumulation_steps=4 \
227
- --gradient_checkpointing \
228
- --enable_xformers_memory_efficient_attention \
229
- --set_grads_to_none \
230
- --mixed_precision fp16
231
- ```
232
-
233
- ## Performing inference with the trained ControlNet
234
-
235
- The trained model can be run the same as the original ControlNet pipeline with the newly trained ControlNet.
236
- Set `base_model_path` and `controlnet_path` to the values `--pretrained_model_name_or_path` and
237
- `--output_dir` were respectively set to in the training script.
238
-
239
- ```py
240
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
241
- from diffusers.utils import load_image
242
- import torch
243
-
244
- base_model_path = "path to model"
245
- controlnet_path = "path to controlnet"
246
-
247
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
248
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
249
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16
250
- )
251
-
252
- # speed up diffusion process with faster scheduler and memory optimization
253
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
254
- # remove following line if xformers is not installed
255
- pipe.enable_xformers_memory_efficient_attention()
256
-
257
- pipe.enable_model_cpu_offload()
258
-
259
- control_image = load_image("./conditioning_image_1.png")
260
- prompt = "pale golden rod circle with old lace background"
261
-
262
- # generate image
263
- generator = torch.manual_seed(0)
264
- image = pipe(
265
- prompt, num_inference_steps=20, generator=generator, image=control_image
266
- ).images[0]
267
-
268
- image.save("./output.png")
269
- ```
270
-
271
- ## Training with Flax/JAX
272
-
273
- For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
274
-
275
- ### Running on Google Cloud TPU
276
-
277
- See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).
278
-
279
- First create a single TPUv4-8 VM and connect to it:
280
-
281
- ```
282
- ZONE=us-central2-b
283
- TPU_TYPE=v4-8
284
- VM_NAME=hg_flax
285
-
286
- gcloud alpha compute tpus tpu-vm create $VM_NAME \
287
- --zone $ZONE \
288
- --accelerator-type $TPU_TYPE \
289
- --version tpu-vm-v4-base
290
-
291
- gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
292
- ```
293
-
294
- When connected install JAX `0.4.5`:
295
-
296
- ```
297
- pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
298
- ```
299
-
300
- To verify that JAX was correctly installed, you can run the following command:
301
-
302
- ```
303
- import jax
304
- jax.device_count()
305
- ```
306
-
307
- This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.
308
-
309
- Then install Diffusers and the library's training dependencies:
310
-
311
- ```bash
312
- git clone https://github.com/huggingface/diffusers
313
- cd diffusers
314
- pip install .
315
- ```
316
-
317
- Then cd in the example folder and run
318
-
319
- ```bash
320
- pip install -U -r requirements_flax.txt
321
- ```
322
-
323
- Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
324
-
325
- ```
326
- wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
327
- wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
328
- ```
329
-
330
- We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
331
-
332
- ```
333
- huggingface-cli login
334
- ```
335
-
336
- Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
337
-
338
- ```bash
339
- export MODEL_DIR="runwayml/stable-diffusion-v1-5"
340
- export OUTPUT_DIR="control_out"
341
- export HUB_MODEL_ID="fill-circle-controlnet"
342
- ```
343
-
344
- And finally start the training
345
-
346
- ```bash
347
- python3 train_controlnet_flax.py \
348
- --pretrained_model_name_or_path=$MODEL_DIR \
349
- --output_dir=$OUTPUT_DIR \
350
- --dataset_name=fusing/fill50k \
351
- --resolution=512 \
352
- --learning_rate=1e-5 \
353
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
354
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
355
- --validation_steps=1000 \
356
- --train_batch_size=2 \
357
- --revision="non-ema" \
358
- --from_pt \
359
- --report_to="wandb" \
360
- --max_train_steps=10000 \
361
- --push_to_hub \
362
- --hub_model_id=$HUB_MODEL_ID
363
- ```
364
-
365
- Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
366
-
367
- Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
368
-
369
- ```bash
370
- python3 train_controlnet_flax.py \
371
- --pretrained_model_name_or_path=$MODEL_DIR \
372
- --output_dir=$OUTPUT_DIR \
373
- --dataset_name=multimodalart/facesyntheticsspigacaptioned \
374
- --streaming \
375
- --conditioning_image_column=spiga_seg \
376
- --image_column=image \
377
- --caption_column=image_caption \
378
- --resolution=512 \
379
- --max_train_samples 50 \
380
- --max_train_steps 5 \
381
- --learning_rate=1e-5 \
382
- --validation_steps=2 \
383
- --train_batch_size=1 \
384
- --revision="flax" \
385
- --report_to="wandb"
386
- ```
387
-
388
- Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
389
-
390
- * [Webdataset](https://webdataset.github.io/webdataset/)
391
- * [TorchData](https://github.com/pytorch/data)
392
- * [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/controlnet/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- accelerate
2
- torchvision
3
- transformers>=4.25.1
4
- ftfy
5
- tensorboard
6
- datasets
 
 
 
 
 
 
 
diffusers/examples/controlnet/requirements_flax.txt DELETED
@@ -1,9 +0,0 @@
1
- transformers>=4.25.1
2
- datasets
3
- flax
4
- optax
5
- torch
6
- torchvision
7
- ftfy
8
- tensorboard
9
- Jinja2
 
 
 
 
 
 
 
 
 
 
diffusers/examples/controlnet/train_controlnet.py DELETED
@@ -1,1046 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import logging
18
- import math
19
- import os
20
- import random
21
- from pathlib import Path
22
-
23
- import accelerate
24
- import numpy as np
25
- import torch
26
- import torch.nn.functional as F
27
- import torch.utils.checkpoint
28
- import transformers
29
- from accelerate import Accelerator
30
- from accelerate.logging import get_logger
31
- from accelerate.utils import ProjectConfiguration, set_seed
32
- from datasets import load_dataset
33
- from huggingface_hub import create_repo, upload_folder
34
- from packaging import version
35
- from PIL import Image
36
- from torchvision import transforms
37
- from tqdm.auto import tqdm
38
- from transformers import AutoTokenizer, PretrainedConfig
39
-
40
- import diffusers
41
- from diffusers import (
42
- AutoencoderKL,
43
- ControlNetModel,
44
- DDPMScheduler,
45
- StableDiffusionControlNetPipeline,
46
- UNet2DConditionModel,
47
- UniPCMultistepScheduler,
48
- )
49
- from diffusers.optimization import get_scheduler
50
- from diffusers.utils import check_min_version, is_wandb_available
51
- from diffusers.utils.import_utils import is_xformers_available
52
-
53
-
54
- if is_wandb_available():
55
- import wandb
56
-
57
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
- check_min_version("0.15.0.dev0")
59
-
60
- logger = get_logger(__name__)
61
-
62
-
63
- def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
64
- logger.info("Running validation... ")
65
-
66
- controlnet = accelerator.unwrap_model(controlnet)
67
-
68
- pipeline = StableDiffusionControlNetPipeline.from_pretrained(
69
- args.pretrained_model_name_or_path,
70
- vae=vae,
71
- text_encoder=text_encoder,
72
- tokenizer=tokenizer,
73
- unet=unet,
74
- controlnet=controlnet,
75
- safety_checker=None,
76
- revision=args.revision,
77
- torch_dtype=weight_dtype,
78
- )
79
- pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
80
- pipeline = pipeline.to(accelerator.device)
81
- pipeline.set_progress_bar_config(disable=True)
82
-
83
- if args.enable_xformers_memory_efficient_attention:
84
- pipeline.enable_xformers_memory_efficient_attention()
85
-
86
- if args.seed is None:
87
- generator = None
88
- else:
89
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
90
-
91
- if len(args.validation_image) == len(args.validation_prompt):
92
- validation_images = args.validation_image
93
- validation_prompts = args.validation_prompt
94
- elif len(args.validation_image) == 1:
95
- validation_images = args.validation_image * len(args.validation_prompt)
96
- validation_prompts = args.validation_prompt
97
- elif len(args.validation_prompt) == 1:
98
- validation_images = args.validation_image
99
- validation_prompts = args.validation_prompt * len(args.validation_image)
100
- else:
101
- raise ValueError(
102
- "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103
- )
104
-
105
- image_logs = []
106
-
107
- for validation_prompt, validation_image in zip(validation_prompts, validation_images):
108
- validation_image = Image.open(validation_image).convert("RGB")
109
-
110
- images = []
111
-
112
- for _ in range(args.num_validation_images):
113
- with torch.autocast("cuda"):
114
- image = pipeline(
115
- validation_prompt, validation_image, num_inference_steps=20, generator=generator
116
- ).images[0]
117
-
118
- images.append(image)
119
-
120
- image_logs.append(
121
- {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
122
- )
123
-
124
- for tracker in accelerator.trackers:
125
- if tracker.name == "tensorboard":
126
- for log in image_logs:
127
- images = log["images"]
128
- validation_prompt = log["validation_prompt"]
129
- validation_image = log["validation_image"]
130
-
131
- formatted_images = []
132
-
133
- formatted_images.append(np.asarray(validation_image))
134
-
135
- for image in images:
136
- formatted_images.append(np.asarray(image))
137
-
138
- formatted_images = np.stack(formatted_images)
139
-
140
- tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
141
- elif tracker.name == "wandb":
142
- formatted_images = []
143
-
144
- for log in image_logs:
145
- images = log["images"]
146
- validation_prompt = log["validation_prompt"]
147
- validation_image = log["validation_image"]
148
-
149
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
150
-
151
- for image in images:
152
- image = wandb.Image(image, caption=validation_prompt)
153
- formatted_images.append(image)
154
-
155
- tracker.log({"validation": formatted_images})
156
- else:
157
- logger.warn(f"image logging not implemented for {tracker.name}")
158
-
159
-
160
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
161
- text_encoder_config = PretrainedConfig.from_pretrained(
162
- pretrained_model_name_or_path,
163
- subfolder="text_encoder",
164
- revision=revision,
165
- )
166
- model_class = text_encoder_config.architectures[0]
167
-
168
- if model_class == "CLIPTextModel":
169
- from transformers import CLIPTextModel
170
-
171
- return CLIPTextModel
172
- elif model_class == "RobertaSeriesModelWithTransformation":
173
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
174
-
175
- return RobertaSeriesModelWithTransformation
176
- else:
177
- raise ValueError(f"{model_class} is not supported.")
178
-
179
-
180
- def parse_args(input_args=None):
181
- parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
182
- parser.add_argument(
183
- "--pretrained_model_name_or_path",
184
- type=str,
185
- default=None,
186
- required=True,
187
- help="Path to pretrained model or model identifier from huggingface.co/models.",
188
- )
189
- parser.add_argument(
190
- "--controlnet_model_name_or_path",
191
- type=str,
192
- default=None,
193
- help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
194
- " If not specified controlnet weights are initialized from unet.",
195
- )
196
- parser.add_argument(
197
- "--revision",
198
- type=str,
199
- default=None,
200
- required=False,
201
- help=(
202
- "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
203
- " float32 precision."
204
- ),
205
- )
206
- parser.add_argument(
207
- "--tokenizer_name",
208
- type=str,
209
- default=None,
210
- help="Pretrained tokenizer name or path if not the same as model_name",
211
- )
212
- parser.add_argument(
213
- "--output_dir",
214
- type=str,
215
- default="controlnet-model",
216
- help="The output directory where the model predictions and checkpoints will be written.",
217
- )
218
- parser.add_argument(
219
- "--cache_dir",
220
- type=str,
221
- default=None,
222
- help="The directory where the downloaded models and datasets will be stored.",
223
- )
224
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
225
- parser.add_argument(
226
- "--resolution",
227
- type=int,
228
- default=512,
229
- help=(
230
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
231
- " resolution"
232
- ),
233
- )
234
- parser.add_argument(
235
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
236
- )
237
- parser.add_argument("--num_train_epochs", type=int, default=1)
238
- parser.add_argument(
239
- "--max_train_steps",
240
- type=int,
241
- default=None,
242
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
243
- )
244
- parser.add_argument(
245
- "--checkpointing_steps",
246
- type=int,
247
- default=500,
248
- help=(
249
- "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
250
- "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
251
- "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
252
- "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
253
- "instructions."
254
- ),
255
- )
256
- parser.add_argument(
257
- "--checkpoints_total_limit",
258
- type=int,
259
- default=None,
260
- help=(
261
- "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
262
- " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
263
- " for more details"
264
- ),
265
- )
266
- parser.add_argument(
267
- "--resume_from_checkpoint",
268
- type=str,
269
- default=None,
270
- help=(
271
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
272
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
273
- ),
274
- )
275
- parser.add_argument(
276
- "--gradient_accumulation_steps",
277
- type=int,
278
- default=1,
279
- help="Number of updates steps to accumulate before performing a backward/update pass.",
280
- )
281
- parser.add_argument(
282
- "--gradient_checkpointing",
283
- action="store_true",
284
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
285
- )
286
- parser.add_argument(
287
- "--learning_rate",
288
- type=float,
289
- default=5e-6,
290
- help="Initial learning rate (after the potential warmup period) to use.",
291
- )
292
- parser.add_argument(
293
- "--scale_lr",
294
- action="store_true",
295
- default=False,
296
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
297
- )
298
- parser.add_argument(
299
- "--lr_scheduler",
300
- type=str,
301
- default="constant",
302
- help=(
303
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
304
- ' "constant", "constant_with_warmup"]'
305
- ),
306
- )
307
- parser.add_argument(
308
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
309
- )
310
- parser.add_argument(
311
- "--lr_num_cycles",
312
- type=int,
313
- default=1,
314
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
315
- )
316
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
317
- parser.add_argument(
318
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
319
- )
320
- parser.add_argument(
321
- "--dataloader_num_workers",
322
- type=int,
323
- default=0,
324
- help=(
325
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
326
- ),
327
- )
328
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
329
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
330
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
331
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
332
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
333
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
334
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
335
- parser.add_argument(
336
- "--hub_model_id",
337
- type=str,
338
- default=None,
339
- help="The name of the repository to keep in sync with the local `output_dir`.",
340
- )
341
- parser.add_argument(
342
- "--logging_dir",
343
- type=str,
344
- default="logs",
345
- help=(
346
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
347
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
348
- ),
349
- )
350
- parser.add_argument(
351
- "--allow_tf32",
352
- action="store_true",
353
- help=(
354
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
355
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
356
- ),
357
- )
358
- parser.add_argument(
359
- "--report_to",
360
- type=str,
361
- default="tensorboard",
362
- help=(
363
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
364
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
365
- ),
366
- )
367
- parser.add_argument(
368
- "--mixed_precision",
369
- type=str,
370
- default=None,
371
- choices=["no", "fp16", "bf16"],
372
- help=(
373
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
374
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
375
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
376
- ),
377
- )
378
- parser.add_argument(
379
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
380
- )
381
- parser.add_argument(
382
- "--set_grads_to_none",
383
- action="store_true",
384
- help=(
385
- "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
386
- " behaviors, so disable this argument if it causes any problems. More info:"
387
- " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
388
- ),
389
- )
390
- parser.add_argument(
391
- "--dataset_name",
392
- type=str,
393
- default=None,
394
- help=(
395
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
396
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
397
- " or to a folder containing files that 🤗 Datasets can understand."
398
- ),
399
- )
400
- parser.add_argument(
401
- "--dataset_config_name",
402
- type=str,
403
- default=None,
404
- help="The config of the Dataset, leave as None if there's only one config.",
405
- )
406
- parser.add_argument(
407
- "--train_data_dir",
408
- type=str,
409
- default=None,
410
- help=(
411
- "A folder containing the training data. Folder contents must follow the structure described in"
412
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
413
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
414
- ),
415
- )
416
- parser.add_argument(
417
- "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
418
- )
419
- parser.add_argument(
420
- "--conditioning_image_column",
421
- type=str,
422
- default="conditioning_image",
423
- help="The column of the dataset containing the controlnet conditioning image.",
424
- )
425
- parser.add_argument(
426
- "--caption_column",
427
- type=str,
428
- default="text",
429
- help="The column of the dataset containing a caption or a list of captions.",
430
- )
431
- parser.add_argument(
432
- "--max_train_samples",
433
- type=int,
434
- default=None,
435
- help=(
436
- "For debugging purposes or quicker training, truncate the number of training examples to this "
437
- "value if set."
438
- ),
439
- )
440
- parser.add_argument(
441
- "--proportion_empty_prompts",
442
- type=float,
443
- default=0,
444
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
445
- )
446
- parser.add_argument(
447
- "--validation_prompt",
448
- type=str,
449
- default=None,
450
- nargs="+",
451
- help=(
452
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
453
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
454
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
455
- ),
456
- )
457
- parser.add_argument(
458
- "--validation_image",
459
- type=str,
460
- default=None,
461
- nargs="+",
462
- help=(
463
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
464
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
465
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
466
- " `--validation_image` that will be used with all `--validation_prompt`s."
467
- ),
468
- )
469
- parser.add_argument(
470
- "--num_validation_images",
471
- type=int,
472
- default=4,
473
- help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
474
- )
475
- parser.add_argument(
476
- "--validation_steps",
477
- type=int,
478
- default=100,
479
- help=(
480
- "Run validation every X steps. Validation consists of running the prompt"
481
- " `args.validation_prompt` multiple times: `args.num_validation_images`"
482
- " and logging the images."
483
- ),
484
- )
485
- parser.add_argument(
486
- "--tracker_project_name",
487
- type=str,
488
- default="train_controlnet",
489
- required=True,
490
- help=(
491
- "The `project_name` argument passed to Accelerator.init_trackers for"
492
- " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
493
- ),
494
- )
495
-
496
- if input_args is not None:
497
- args = parser.parse_args(input_args)
498
- else:
499
- args = parser.parse_args()
500
-
501
- if args.dataset_name is None and args.train_data_dir is None:
502
- raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
503
-
504
- if args.dataset_name is not None and args.train_data_dir is not None:
505
- raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
506
-
507
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
508
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
509
-
510
- if args.validation_prompt is not None and args.validation_image is None:
511
- raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
512
-
513
- if args.validation_prompt is None and args.validation_image is not None:
514
- raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
515
-
516
- if (
517
- args.validation_image is not None
518
- and args.validation_prompt is not None
519
- and len(args.validation_image) != 1
520
- and len(args.validation_prompt) != 1
521
- and len(args.validation_image) != len(args.validation_prompt)
522
- ):
523
- raise ValueError(
524
- "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
525
- " or the same number of `--validation_prompt`s and `--validation_image`s"
526
- )
527
-
528
- return args
529
-
530
-
531
- def make_train_dataset(args, tokenizer, accelerator):
532
- # Get the datasets: you can either provide your own training and evaluation files (see below)
533
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
534
-
535
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
536
- # download the dataset.
537
- if args.dataset_name is not None:
538
- # Downloading and loading a dataset from the hub.
539
- dataset = load_dataset(
540
- args.dataset_name,
541
- args.dataset_config_name,
542
- cache_dir=args.cache_dir,
543
- )
544
- else:
545
- if args.train_data_dir is not None:
546
- dataset = load_dataset(
547
- args.train_data_dir,
548
- cache_dir=args.cache_dir,
549
- )
550
- # See more about loading custom images at
551
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
552
-
553
- # Preprocessing the datasets.
554
- # We need to tokenize inputs and targets.
555
- column_names = dataset["train"].column_names
556
-
557
- # 6. Get the column names for input/target.
558
- if args.image_column is None:
559
- image_column = column_names[0]
560
- logger.info(f"image column defaulting to {image_column}")
561
- else:
562
- image_column = args.image_column
563
- if image_column not in column_names:
564
- raise ValueError(
565
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
566
- )
567
-
568
- if args.caption_column is None:
569
- caption_column = column_names[1]
570
- logger.info(f"caption column defaulting to {caption_column}")
571
- else:
572
- caption_column = args.caption_column
573
- if caption_column not in column_names:
574
- raise ValueError(
575
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
576
- )
577
-
578
- if args.conditioning_image_column is None:
579
- conditioning_image_column = column_names[2]
580
- logger.info(f"conditioning image column defaulting to {caption_column}")
581
- else:
582
- conditioning_image_column = args.conditioning_image_column
583
- if conditioning_image_column not in column_names:
584
- raise ValueError(
585
- f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
586
- )
587
-
588
- def tokenize_captions(examples, is_train=True):
589
- captions = []
590
- for caption in examples[caption_column]:
591
- if random.random() < args.proportion_empty_prompts:
592
- captions.append("")
593
- elif isinstance(caption, str):
594
- captions.append(caption)
595
- elif isinstance(caption, (list, np.ndarray)):
596
- # take a random caption if there are multiple
597
- captions.append(random.choice(caption) if is_train else caption[0])
598
- else:
599
- raise ValueError(
600
- f"Caption column `{caption_column}` should contain either strings or lists of strings."
601
- )
602
- inputs = tokenizer(
603
- captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
604
- )
605
- return inputs.input_ids
606
-
607
- image_transforms = transforms.Compose(
608
- [
609
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
610
- transforms.ToTensor(),
611
- transforms.Normalize([0.5], [0.5]),
612
- ]
613
- )
614
-
615
- conditioning_image_transforms = transforms.Compose(
616
- [
617
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
618
- transforms.ToTensor(),
619
- ]
620
- )
621
-
622
- def preprocess_train(examples):
623
- images = [image.convert("RGB") for image in examples[image_column]]
624
- images = [image_transforms(image) for image in images]
625
-
626
- conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
627
- conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
628
-
629
- examples["pixel_values"] = images
630
- examples["conditioning_pixel_values"] = conditioning_images
631
- examples["input_ids"] = tokenize_captions(examples)
632
-
633
- return examples
634
-
635
- with accelerator.main_process_first():
636
- if args.max_train_samples is not None:
637
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
638
- # Set the training transforms
639
- train_dataset = dataset["train"].with_transform(preprocess_train)
640
-
641
- return train_dataset
642
-
643
-
644
- def collate_fn(examples):
645
- pixel_values = torch.stack([example["pixel_values"] for example in examples])
646
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
647
-
648
- conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
649
- conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
650
-
651
- input_ids = torch.stack([example["input_ids"] for example in examples])
652
-
653
- return {
654
- "pixel_values": pixel_values,
655
- "conditioning_pixel_values": conditioning_pixel_values,
656
- "input_ids": input_ids,
657
- }
658
-
659
-
660
- def main(args):
661
- logging_dir = Path(args.output_dir, args.logging_dir)
662
-
663
- accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
664
-
665
- accelerator = Accelerator(
666
- gradient_accumulation_steps=args.gradient_accumulation_steps,
667
- mixed_precision=args.mixed_precision,
668
- log_with=args.report_to,
669
- logging_dir=logging_dir,
670
- project_config=accelerator_project_config,
671
- )
672
-
673
- # Make one log on every process with the configuration for debugging.
674
- logging.basicConfig(
675
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
676
- datefmt="%m/%d/%Y %H:%M:%S",
677
- level=logging.INFO,
678
- )
679
- logger.info(accelerator.state, main_process_only=False)
680
- if accelerator.is_local_main_process:
681
- transformers.utils.logging.set_verbosity_warning()
682
- diffusers.utils.logging.set_verbosity_info()
683
- else:
684
- transformers.utils.logging.set_verbosity_error()
685
- diffusers.utils.logging.set_verbosity_error()
686
-
687
- # If passed along, set the training seed now.
688
- if args.seed is not None:
689
- set_seed(args.seed)
690
-
691
- # Handle the repository creation
692
- if accelerator.is_main_process:
693
- if args.output_dir is not None:
694
- os.makedirs(args.output_dir, exist_ok=True)
695
-
696
- if args.push_to_hub:
697
- repo_id = create_repo(
698
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
699
- ).repo_id
700
-
701
- # Load the tokenizer
702
- if args.tokenizer_name:
703
- tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
704
- elif args.pretrained_model_name_or_path:
705
- tokenizer = AutoTokenizer.from_pretrained(
706
- args.pretrained_model_name_or_path,
707
- subfolder="tokenizer",
708
- revision=args.revision,
709
- use_fast=False,
710
- )
711
-
712
- # import correct text encoder class
713
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
714
-
715
- # Load scheduler and models
716
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
717
- text_encoder = text_encoder_cls.from_pretrained(
718
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
719
- )
720
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
721
- unet = UNet2DConditionModel.from_pretrained(
722
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
723
- )
724
-
725
- if args.controlnet_model_name_or_path:
726
- logger.info("Loading existing controlnet weights")
727
- controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
728
- else:
729
- logger.info("Initializing controlnet weights from unet")
730
- controlnet = ControlNetModel.from_unet(unet)
731
-
732
- # `accelerate` 0.16.0 will have better support for customized saving
733
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
734
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
735
- def save_model_hook(models, weights, output_dir):
736
- i = len(weights) - 1
737
-
738
- while len(weights) > 0:
739
- weights.pop()
740
- model = models[i]
741
-
742
- sub_dir = "controlnet"
743
- model.save_pretrained(os.path.join(output_dir, sub_dir))
744
-
745
- i -= 1
746
-
747
- def load_model_hook(models, input_dir):
748
- while len(models) > 0:
749
- # pop models so that they are not loaded again
750
- model = models.pop()
751
-
752
- # load diffusers style into model
753
- load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
754
- model.register_to_config(**load_model.config)
755
-
756
- model.load_state_dict(load_model.state_dict())
757
- del load_model
758
-
759
- accelerator.register_save_state_pre_hook(save_model_hook)
760
- accelerator.register_load_state_pre_hook(load_model_hook)
761
-
762
- vae.requires_grad_(False)
763
- unet.requires_grad_(False)
764
- text_encoder.requires_grad_(False)
765
- controlnet.train()
766
-
767
- if args.enable_xformers_memory_efficient_attention:
768
- if is_xformers_available():
769
- import xformers
770
-
771
- xformers_version = version.parse(xformers.__version__)
772
- if xformers_version == version.parse("0.0.16"):
773
- logger.warn(
774
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
775
- )
776
- unet.enable_xformers_memory_efficient_attention()
777
- controlnet.enable_xformers_memory_efficient_attention()
778
- else:
779
- raise ValueError("xformers is not available. Make sure it is installed correctly")
780
-
781
- if args.gradient_checkpointing:
782
- controlnet.enable_gradient_checkpointing()
783
-
784
- # Check that all trainable models are in full precision
785
- low_precision_error_string = (
786
- " Please make sure to always have all model weights in full float32 precision when starting training - even if"
787
- " doing mixed precision training, copy of the weights should still be float32."
788
- )
789
-
790
- if accelerator.unwrap_model(controlnet).dtype != torch.float32:
791
- raise ValueError(
792
- f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
793
- )
794
-
795
- # Enable TF32 for faster training on Ampere GPUs,
796
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
797
- if args.allow_tf32:
798
- torch.backends.cuda.matmul.allow_tf32 = True
799
-
800
- if args.scale_lr:
801
- args.learning_rate = (
802
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
803
- )
804
-
805
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
806
- if args.use_8bit_adam:
807
- try:
808
- import bitsandbytes as bnb
809
- except ImportError:
810
- raise ImportError(
811
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
812
- )
813
-
814
- optimizer_class = bnb.optim.AdamW8bit
815
- else:
816
- optimizer_class = torch.optim.AdamW
817
-
818
- # Optimizer creation
819
- params_to_optimize = controlnet.parameters()
820
- optimizer = optimizer_class(
821
- params_to_optimize,
822
- lr=args.learning_rate,
823
- betas=(args.adam_beta1, args.adam_beta2),
824
- weight_decay=args.adam_weight_decay,
825
- eps=args.adam_epsilon,
826
- )
827
-
828
- train_dataset = make_train_dataset(args, tokenizer, accelerator)
829
-
830
- train_dataloader = torch.utils.data.DataLoader(
831
- train_dataset,
832
- shuffle=True,
833
- collate_fn=collate_fn,
834
- batch_size=args.train_batch_size,
835
- num_workers=args.dataloader_num_workers,
836
- )
837
-
838
- # Scheduler and math around the number of training steps.
839
- overrode_max_train_steps = False
840
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
841
- if args.max_train_steps is None:
842
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
843
- overrode_max_train_steps = True
844
-
845
- lr_scheduler = get_scheduler(
846
- args.lr_scheduler,
847
- optimizer=optimizer,
848
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
849
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
850
- num_cycles=args.lr_num_cycles,
851
- power=args.lr_power,
852
- )
853
-
854
- # Prepare everything with our `accelerator`.
855
- controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
856
- controlnet, optimizer, train_dataloader, lr_scheduler
857
- )
858
-
859
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
860
- # as these models are only used for inference, keeping weights in full precision is not required.
861
- weight_dtype = torch.float32
862
- if accelerator.mixed_precision == "fp16":
863
- weight_dtype = torch.float16
864
- elif accelerator.mixed_precision == "bf16":
865
- weight_dtype = torch.bfloat16
866
-
867
- # Move vae, unet and text_encoder to device and cast to weight_dtype
868
- vae.to(accelerator.device, dtype=weight_dtype)
869
- unet.to(accelerator.device, dtype=weight_dtype)
870
- text_encoder.to(accelerator.device, dtype=weight_dtype)
871
-
872
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
873
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
874
- if overrode_max_train_steps:
875
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
876
- # Afterwards we recalculate our number of training epochs
877
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
878
-
879
- # We need to initialize the trackers we use, and also store our configuration.
880
- # The trackers initializes automatically on the main process.
881
- if accelerator.is_main_process:
882
- tracker_config = dict(vars(args))
883
-
884
- # tensorboard cannot handle list types for config
885
- tracker_config.pop("validation_prompt")
886
- tracker_config.pop("validation_image")
887
-
888
- accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
889
-
890
- # Train!
891
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
892
-
893
- logger.info("***** Running training *****")
894
- logger.info(f" Num examples = {len(train_dataset)}")
895
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
896
- logger.info(f" Num Epochs = {args.num_train_epochs}")
897
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
898
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
899
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
900
- logger.info(f" Total optimization steps = {args.max_train_steps}")
901
- global_step = 0
902
- first_epoch = 0
903
-
904
- # Potentially load in the weights and states from a previous save
905
- if args.resume_from_checkpoint:
906
- if args.resume_from_checkpoint != "latest":
907
- path = os.path.basename(args.resume_from_checkpoint)
908
- else:
909
- # Get the most recent checkpoint
910
- dirs = os.listdir(args.output_dir)
911
- dirs = [d for d in dirs if d.startswith("checkpoint")]
912
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
913
- path = dirs[-1] if len(dirs) > 0 else None
914
-
915
- if path is None:
916
- accelerator.print(
917
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
918
- )
919
- args.resume_from_checkpoint = None
920
- initial_global_step = 0
921
- else:
922
- accelerator.print(f"Resuming from checkpoint {path}")
923
- accelerator.load_state(os.path.join(args.output_dir, path))
924
- global_step = int(path.split("-")[1])
925
-
926
- initial_global_step = global_step * args.gradient_accumulation_steps
927
- first_epoch = global_step // num_update_steps_per_epoch
928
- else:
929
- initial_global_step = 0
930
-
931
- progress_bar = tqdm(
932
- range(0, args.max_train_steps),
933
- initial=initial_global_step,
934
- desc="Steps",
935
- # Only show the progress bar once on each machine.
936
- disable=not accelerator.is_local_main_process,
937
- )
938
-
939
- for epoch in range(first_epoch, args.num_train_epochs):
940
- for step, batch in enumerate(train_dataloader):
941
- with accelerator.accumulate(controlnet):
942
- # Convert images to latent space
943
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
944
- latents = latents * vae.config.scaling_factor
945
-
946
- # Sample noise that we'll add to the latents
947
- noise = torch.randn_like(latents)
948
- bsz = latents.shape[0]
949
- # Sample a random timestep for each image
950
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
951
- timesteps = timesteps.long()
952
-
953
- # Add noise to the latents according to the noise magnitude at each timestep
954
- # (this is the forward diffusion process)
955
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
956
-
957
- # Get the text embedding for conditioning
958
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
959
-
960
- controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
961
-
962
- down_block_res_samples, mid_block_res_sample = controlnet(
963
- noisy_latents,
964
- timesteps,
965
- encoder_hidden_states=encoder_hidden_states,
966
- controlnet_cond=controlnet_image,
967
- return_dict=False,
968
- )
969
-
970
- # Predict the noise residual
971
- model_pred = unet(
972
- noisy_latents,
973
- timesteps,
974
- encoder_hidden_states=encoder_hidden_states,
975
- down_block_additional_residuals=down_block_res_samples,
976
- mid_block_additional_residual=mid_block_res_sample,
977
- ).sample
978
-
979
- # Get the target for loss depending on the prediction type
980
- if noise_scheduler.config.prediction_type == "epsilon":
981
- target = noise
982
- elif noise_scheduler.config.prediction_type == "v_prediction":
983
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
984
- else:
985
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
986
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
987
-
988
- accelerator.backward(loss)
989
- if accelerator.sync_gradients:
990
- params_to_clip = controlnet.parameters()
991
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
992
- optimizer.step()
993
- lr_scheduler.step()
994
- optimizer.zero_grad(set_to_none=args.set_grads_to_none)
995
-
996
- # Checks if the accelerator has performed an optimization step behind the scenes
997
- if accelerator.sync_gradients:
998
- progress_bar.update(1)
999
- global_step += 1
1000
-
1001
- if accelerator.is_main_process:
1002
- if global_step % args.checkpointing_steps == 0:
1003
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1004
- accelerator.save_state(save_path)
1005
- logger.info(f"Saved state to {save_path}")
1006
-
1007
- if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1008
- log_validation(
1009
- vae,
1010
- text_encoder,
1011
- tokenizer,
1012
- unet,
1013
- controlnet,
1014
- args,
1015
- accelerator,
1016
- weight_dtype,
1017
- global_step,
1018
- )
1019
-
1020
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1021
- progress_bar.set_postfix(**logs)
1022
- accelerator.log(logs, step=global_step)
1023
-
1024
- if global_step >= args.max_train_steps:
1025
- break
1026
-
1027
- # Create the pipeline using using the trained modules and save it.
1028
- accelerator.wait_for_everyone()
1029
- if accelerator.is_main_process:
1030
- controlnet = accelerator.unwrap_model(controlnet)
1031
- controlnet.save_pretrained(args.output_dir)
1032
-
1033
- if args.push_to_hub:
1034
- upload_folder(
1035
- repo_id=repo_id,
1036
- folder_path=args.output_dir,
1037
- commit_message="End of training",
1038
- ignore_patterns=["step_*", "epoch_*"],
1039
- )
1040
-
1041
- accelerator.end_training()
1042
-
1043
-
1044
- if __name__ == "__main__":
1045
- args = parse_args()
1046
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/controlnet/train_controlnet_flax.py DELETED
@@ -1,1015 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import logging
18
- import math
19
- import os
20
- import random
21
- from pathlib import Path
22
-
23
- import jax
24
- import jax.numpy as jnp
25
- import numpy as np
26
- import optax
27
- import torch
28
- import torch.utils.checkpoint
29
- import transformers
30
- from datasets import load_dataset
31
- from flax import jax_utils
32
- from flax.core.frozen_dict import unfreeze
33
- from flax.training import train_state
34
- from flax.training.common_utils import shard
35
- from huggingface_hub import create_repo, upload_folder
36
- from PIL import Image
37
- from torch.utils.data import IterableDataset
38
- from torchvision import transforms
39
- from tqdm.auto import tqdm
40
- from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
41
-
42
- from diffusers import (
43
- FlaxAutoencoderKL,
44
- FlaxControlNetModel,
45
- FlaxDDPMScheduler,
46
- FlaxStableDiffusionControlNetPipeline,
47
- FlaxUNet2DConditionModel,
48
- )
49
- from diffusers.utils import check_min_version, is_wandb_available
50
-
51
-
52
- if is_wandb_available():
53
- import wandb
54
-
55
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56
- check_min_version("0.15.0.dev0")
57
-
58
- logger = logging.getLogger(__name__)
59
-
60
-
61
- def image_grid(imgs, rows, cols):
62
- assert len(imgs) == rows * cols
63
-
64
- w, h = imgs[0].size
65
- grid = Image.new("RGB", size=(cols * w, rows * h))
66
- grid_w, grid_h = grid.size
67
-
68
- for i, img in enumerate(imgs):
69
- grid.paste(img, box=(i % cols * w, i // cols * h))
70
- return grid
71
-
72
-
73
- def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
74
- logger.info("Running validation... ")
75
-
76
- pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
77
- args.pretrained_model_name_or_path,
78
- tokenizer=tokenizer,
79
- controlnet=controlnet,
80
- safety_checker=None,
81
- dtype=weight_dtype,
82
- revision=args.revision,
83
- from_pt=args.from_pt,
84
- )
85
- params = jax_utils.replicate(params)
86
- params["controlnet"] = controlnet_params
87
-
88
- num_samples = jax.device_count()
89
- prng_seed = jax.random.split(rng, jax.device_count())
90
-
91
- if len(args.validation_image) == len(args.validation_prompt):
92
- validation_images = args.validation_image
93
- validation_prompts = args.validation_prompt
94
- elif len(args.validation_image) == 1:
95
- validation_images = args.validation_image * len(args.validation_prompt)
96
- validation_prompts = args.validation_prompt
97
- elif len(args.validation_prompt) == 1:
98
- validation_images = args.validation_image
99
- validation_prompts = args.validation_prompt * len(args.validation_image)
100
- else:
101
- raise ValueError(
102
- "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103
- )
104
-
105
- image_logs = []
106
-
107
- for validation_prompt, validation_image in zip(validation_prompts, validation_images):
108
- prompts = num_samples * [validation_prompt]
109
- prompt_ids = pipeline.prepare_text_inputs(prompts)
110
- prompt_ids = shard(prompt_ids)
111
-
112
- validation_image = Image.open(validation_image).convert("RGB")
113
- processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
114
- processed_image = shard(processed_image)
115
- images = pipeline(
116
- prompt_ids=prompt_ids,
117
- image=processed_image,
118
- params=params,
119
- prng_seed=prng_seed,
120
- num_inference_steps=50,
121
- jit=True,
122
- ).images
123
-
124
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
125
- images = pipeline.numpy_to_pil(images)
126
-
127
- image_logs.append(
128
- {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
129
- )
130
-
131
- if args.report_to == "wandb":
132
- formatted_images = []
133
- for log in image_logs:
134
- images = log["images"]
135
- validation_prompt = log["validation_prompt"]
136
- validation_image = log["validation_image"]
137
-
138
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
139
- for image in images:
140
- image = wandb.Image(image, caption=validation_prompt)
141
- formatted_images.append(image)
142
-
143
- wandb.log({"validation": formatted_images})
144
- else:
145
- logger.warn(f"image logging not implemented for {args.report_to}")
146
-
147
- return image_logs
148
-
149
-
150
- def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
151
- img_str = ""
152
- for i, log in enumerate(image_logs):
153
- images = log["images"]
154
- validation_prompt = log["validation_prompt"]
155
- validation_image = log["validation_image"]
156
- validation_image.save(os.path.join(repo_folder, "image_control.png"))
157
- img_str += f"prompt: {validation_prompt}\n"
158
- images = [validation_image] + images
159
- image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
160
- img_str += f"![images_{i})](./images_{i}.png)\n"
161
-
162
- yaml = f"""
163
- ---
164
- license: creativeml-openrail-m
165
- base_model: {base_model}
166
- tags:
167
- - stable-diffusion
168
- - stable-diffusion-diffusers
169
- - text-to-image
170
- - diffusers
171
- - controlnet
172
- inference: true
173
- ---
174
- """
175
- model_card = f"""
176
- # controlnet- {repo_id}
177
-
178
- These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
179
- {img_str}
180
- """
181
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
182
- f.write(yaml + model_card)
183
-
184
-
185
- def parse_args():
186
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
187
- parser.add_argument(
188
- "--pretrained_model_name_or_path",
189
- type=str,
190
- required=True,
191
- help="Path to pretrained model or model identifier from huggingface.co/models.",
192
- )
193
- parser.add_argument(
194
- "--controlnet_model_name_or_path",
195
- type=str,
196
- default=None,
197
- help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
198
- " If not specified controlnet weights are initialized from unet.",
199
- )
200
- parser.add_argument(
201
- "--revision",
202
- type=str,
203
- default=None,
204
- help="Revision of pretrained model identifier from huggingface.co/models.",
205
- )
206
- parser.add_argument(
207
- "--from_pt",
208
- action="store_true",
209
- help="Load the pretrained model from a PyTorch checkpoint.",
210
- )
211
- parser.add_argument(
212
- "--tokenizer_name",
213
- type=str,
214
- default=None,
215
- help="Pretrained tokenizer name or path if not the same as model_name",
216
- )
217
- parser.add_argument(
218
- "--output_dir",
219
- type=str,
220
- default="controlnet-model",
221
- help="The output directory where the model predictions and checkpoints will be written.",
222
- )
223
- parser.add_argument(
224
- "--cache_dir",
225
- type=str,
226
- default=None,
227
- help="The directory where the downloaded models and datasets will be stored.",
228
- )
229
- parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
230
- parser.add_argument(
231
- "--resolution",
232
- type=int,
233
- default=512,
234
- help=(
235
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
236
- " resolution"
237
- ),
238
- )
239
- parser.add_argument(
240
- "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
241
- )
242
- parser.add_argument("--num_train_epochs", type=int, default=100)
243
- parser.add_argument(
244
- "--max_train_steps",
245
- type=int,
246
- default=None,
247
- help="Total number of training steps to perform.",
248
- )
249
- parser.add_argument(
250
- "--learning_rate",
251
- type=float,
252
- default=1e-4,
253
- help="Initial learning rate (after the potential warmup period) to use.",
254
- )
255
- parser.add_argument(
256
- "--scale_lr",
257
- action="store_true",
258
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
259
- )
260
- parser.add_argument(
261
- "--lr_scheduler",
262
- type=str,
263
- default="constant",
264
- help=(
265
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
266
- ' "constant", "constant_with_warmup"]'
267
- ),
268
- )
269
- parser.add_argument(
270
- "--dataloader_num_workers",
271
- type=int,
272
- default=0,
273
- help=(
274
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
275
- ),
276
- )
277
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
278
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
279
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
280
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
281
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
282
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
283
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
284
- parser.add_argument(
285
- "--hub_model_id",
286
- type=str,
287
- default=None,
288
- help="The name of the repository to keep in sync with the local `output_dir`.",
289
- )
290
- parser.add_argument(
291
- "--logging_dir",
292
- type=str,
293
- default="logs",
294
- help=(
295
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
296
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
297
- ),
298
- )
299
- parser.add_argument(
300
- "--logging_steps",
301
- type=int,
302
- default=100,
303
- help=("log training metric every X steps to `--report_t`"),
304
- )
305
- parser.add_argument(
306
- "--report_to",
307
- type=str,
308
- default="tensorboard",
309
- help=(
310
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
311
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
312
- ),
313
- )
314
- parser.add_argument(
315
- "--mixed_precision",
316
- type=str,
317
- default="no",
318
- choices=["no", "fp16", "bf16"],
319
- help=(
320
- "Whether to use mixed precision. Choose"
321
- "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
322
- "and an Nvidia Ampere GPU."
323
- ),
324
- )
325
- parser.add_argument(
326
- "--dataset_name",
327
- type=str,
328
- default=None,
329
- help=(
330
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
331
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
332
- " or to a folder containing files that 🤗 Datasets can understand."
333
- ),
334
- )
335
- parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
336
- parser.add_argument(
337
- "--dataset_config_name",
338
- type=str,
339
- default=None,
340
- help="The config of the Dataset, leave as None if there's only one config.",
341
- )
342
- parser.add_argument(
343
- "--train_data_dir",
344
- type=str,
345
- default=None,
346
- help=(
347
- "A folder containing the training data. Folder contents must follow the structure described in"
348
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
349
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
350
- ),
351
- )
352
- parser.add_argument(
353
- "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
354
- )
355
- parser.add_argument(
356
- "--conditioning_image_column",
357
- type=str,
358
- default="conditioning_image",
359
- help="The column of the dataset containing the controlnet conditioning image.",
360
- )
361
- parser.add_argument(
362
- "--caption_column",
363
- type=str,
364
- default="text",
365
- help="The column of the dataset containing a caption or a list of captions.",
366
- )
367
- parser.add_argument(
368
- "--max_train_samples",
369
- type=int,
370
- default=None,
371
- help=(
372
- "For debugging purposes or quicker training, truncate the number of training examples to this "
373
- "value if set. Needed if `streaming` is set to True."
374
- ),
375
- )
376
- parser.add_argument(
377
- "--proportion_empty_prompts",
378
- type=float,
379
- default=0,
380
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
381
- )
382
- parser.add_argument(
383
- "--validation_prompt",
384
- type=str,
385
- default=None,
386
- nargs="+",
387
- help=(
388
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
389
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
390
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
391
- ),
392
- )
393
- parser.add_argument(
394
- "--validation_image",
395
- type=str,
396
- default=None,
397
- nargs="+",
398
- help=(
399
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
400
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
401
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
402
- " `--validation_image` that will be used with all `--validation_prompt`s."
403
- ),
404
- )
405
- parser.add_argument(
406
- "--validation_steps",
407
- type=int,
408
- default=100,
409
- help=(
410
- "Run validation every X steps. Validation consists of running the prompt"
411
- " `args.validation_prompt` and logging the images."
412
- ),
413
- )
414
- parser.add_argument(
415
- "--tracker_project_name",
416
- type=str,
417
- default="train_controlnet_flax",
418
- help=("The `project` argument passed to wandb"),
419
- )
420
- parser.add_argument(
421
- "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over"
422
- )
423
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
424
-
425
- args = parser.parse_args()
426
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
427
- if env_local_rank != -1 and env_local_rank != args.local_rank:
428
- args.local_rank = env_local_rank
429
-
430
- # Sanity checks
431
- if args.dataset_name is None and args.train_data_dir is None:
432
- raise ValueError("Need either a dataset name or a training folder.")
433
- if args.dataset_name is not None and args.train_data_dir is not None:
434
- raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
435
-
436
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
437
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
438
-
439
- if args.validation_prompt is not None and args.validation_image is None:
440
- raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
441
-
442
- if args.validation_prompt is None and args.validation_image is not None:
443
- raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
444
-
445
- if (
446
- args.validation_image is not None
447
- and args.validation_prompt is not None
448
- and len(args.validation_image) != 1
449
- and len(args.validation_prompt) != 1
450
- and len(args.validation_image) != len(args.validation_prompt)
451
- ):
452
- raise ValueError(
453
- "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
454
- " or the same number of `--validation_prompt`s and `--validation_image`s"
455
- )
456
-
457
- # This idea comes from
458
- # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
459
- if args.streaming and args.max_train_samples is None:
460
- raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
461
-
462
- return args
463
-
464
-
465
- def make_train_dataset(args, tokenizer, batch_size=None):
466
- # Get the datasets: you can either provide your own training and evaluation files (see below)
467
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
468
-
469
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
470
- # download the dataset.
471
- if args.dataset_name is not None:
472
- # Downloading and loading a dataset from the hub.
473
- dataset = load_dataset(
474
- args.dataset_name,
475
- args.dataset_config_name,
476
- cache_dir=args.cache_dir,
477
- streaming=args.streaming,
478
- )
479
- else:
480
- if args.train_data_dir is not None:
481
- dataset = load_dataset(
482
- args.train_data_dir,
483
- cache_dir=args.cache_dir,
484
- )
485
- # See more about loading custom images at
486
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
487
-
488
- # Preprocessing the datasets.
489
- # We need to tokenize inputs and targets.
490
- if isinstance(dataset["train"], IterableDataset):
491
- column_names = next(iter(dataset["train"])).keys()
492
- else:
493
- column_names = dataset["train"].column_names
494
-
495
- # 6. Get the column names for input/target.
496
- if args.image_column is None:
497
- image_column = column_names[0]
498
- logger.info(f"image column defaulting to {image_column}")
499
- else:
500
- image_column = args.image_column
501
- if image_column not in column_names:
502
- raise ValueError(
503
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
504
- )
505
-
506
- if args.caption_column is None:
507
- caption_column = column_names[1]
508
- logger.info(f"caption column defaulting to {caption_column}")
509
- else:
510
- caption_column = args.caption_column
511
- if caption_column not in column_names:
512
- raise ValueError(
513
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
514
- )
515
-
516
- if args.conditioning_image_column is None:
517
- conditioning_image_column = column_names[2]
518
- logger.info(f"conditioning image column defaulting to {caption_column}")
519
- else:
520
- conditioning_image_column = args.conditioning_image_column
521
- if conditioning_image_column not in column_names:
522
- raise ValueError(
523
- f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
524
- )
525
-
526
- def tokenize_captions(examples, is_train=True):
527
- captions = []
528
- for caption in examples[caption_column]:
529
- if random.random() < args.proportion_empty_prompts:
530
- captions.append("")
531
- elif isinstance(caption, str):
532
- captions.append(caption)
533
- elif isinstance(caption, (list, np.ndarray)):
534
- # take a random caption if there are multiple
535
- captions.append(random.choice(caption) if is_train else caption[0])
536
- else:
537
- raise ValueError(
538
- f"Caption column `{caption_column}` should contain either strings or lists of strings."
539
- )
540
- inputs = tokenizer(
541
- captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
542
- )
543
- return inputs.input_ids
544
-
545
- image_transforms = transforms.Compose(
546
- [
547
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
548
- transforms.ToTensor(),
549
- transforms.Normalize([0.5], [0.5]),
550
- ]
551
- )
552
-
553
- conditioning_image_transforms = transforms.Compose(
554
- [
555
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
556
- transforms.ToTensor(),
557
- ]
558
- )
559
-
560
- def preprocess_train(examples):
561
- images = [image.convert("RGB") for image in examples[image_column]]
562
- images = [image_transforms(image) for image in images]
563
-
564
- conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
565
- conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
566
-
567
- examples["pixel_values"] = images
568
- examples["conditioning_pixel_values"] = conditioning_images
569
- examples["input_ids"] = tokenize_captions(examples)
570
-
571
- return examples
572
-
573
- if jax.process_index() == 0:
574
- if args.max_train_samples is not None:
575
- if args.streaming:
576
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
577
- else:
578
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
579
- # Set the training transforms
580
- if args.streaming:
581
- train_dataset = dataset["train"].map(
582
- preprocess_train,
583
- batched=True,
584
- batch_size=batch_size,
585
- remove_columns=list(dataset["train"].features.keys()),
586
- )
587
- else:
588
- train_dataset = dataset["train"].with_transform(preprocess_train)
589
-
590
- return train_dataset
591
-
592
-
593
- def collate_fn(examples):
594
- pixel_values = torch.stack([example["pixel_values"] for example in examples])
595
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
596
-
597
- conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
598
- conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
599
-
600
- input_ids = torch.stack([example["input_ids"] for example in examples])
601
-
602
- batch = {
603
- "pixel_values": pixel_values,
604
- "conditioning_pixel_values": conditioning_pixel_values,
605
- "input_ids": input_ids,
606
- }
607
- batch = {k: v.numpy() for k, v in batch.items()}
608
- return batch
609
-
610
-
611
- def get_params_to_save(params):
612
- return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
613
-
614
-
615
- def main():
616
- args = parse_args()
617
-
618
- logging.basicConfig(
619
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
620
- datefmt="%m/%d/%Y %H:%M:%S",
621
- level=logging.INFO,
622
- )
623
- # Setup logging, we only want one process per machine to log things on the screen.
624
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
625
- if jax.process_index() == 0:
626
- transformers.utils.logging.set_verbosity_info()
627
- else:
628
- transformers.utils.logging.set_verbosity_error()
629
-
630
- # wandb init
631
- if jax.process_index() == 0 and args.report_to == "wandb":
632
- wandb.init(
633
- project=args.tracker_project_name,
634
- job_type="train",
635
- config=args,
636
- )
637
-
638
- if args.seed is not None:
639
- set_seed(args.seed)
640
-
641
- rng = jax.random.PRNGKey(0)
642
-
643
- # Handle the repository creation
644
- if jax.process_index() == 0:
645
- if args.output_dir is not None:
646
- os.makedirs(args.output_dir, exist_ok=True)
647
-
648
- if args.push_to_hub:
649
- repo_id = create_repo(
650
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
651
- ).repo_id
652
-
653
- # Load the tokenizer and add the placeholder token as a additional special token
654
- if args.tokenizer_name:
655
- tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
656
- elif args.pretrained_model_name_or_path:
657
- tokenizer = CLIPTokenizer.from_pretrained(
658
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
659
- )
660
- else:
661
- raise NotImplementedError("No tokenizer specified!")
662
-
663
- # Get the datasets: you can either provide your own training and evaluation files (see below)
664
- total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
665
- train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
666
-
667
- train_dataloader = torch.utils.data.DataLoader(
668
- train_dataset,
669
- shuffle=not args.streaming,
670
- collate_fn=collate_fn,
671
- batch_size=total_train_batch_size,
672
- num_workers=args.dataloader_num_workers,
673
- drop_last=True,
674
- )
675
-
676
- weight_dtype = jnp.float32
677
- if args.mixed_precision == "fp16":
678
- weight_dtype = jnp.float16
679
- elif args.mixed_precision == "bf16":
680
- weight_dtype = jnp.bfloat16
681
-
682
- # Load models and create wrapper for stable diffusion
683
- text_encoder = FlaxCLIPTextModel.from_pretrained(
684
- args.pretrained_model_name_or_path,
685
- subfolder="text_encoder",
686
- dtype=weight_dtype,
687
- revision=args.revision,
688
- from_pt=args.from_pt,
689
- )
690
- vae, vae_params = FlaxAutoencoderKL.from_pretrained(
691
- args.pretrained_model_name_or_path,
692
- revision=args.revision,
693
- subfolder="vae",
694
- dtype=weight_dtype,
695
- from_pt=args.from_pt,
696
- )
697
- unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
698
- args.pretrained_model_name_or_path,
699
- subfolder="unet",
700
- dtype=weight_dtype,
701
- revision=args.revision,
702
- from_pt=args.from_pt,
703
- )
704
-
705
- if args.controlnet_model_name_or_path:
706
- logger.info("Loading existing controlnet weights")
707
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
708
- args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32
709
- )
710
- else:
711
- logger.info("Initializing controlnet weights from unet")
712
- rng, rng_params = jax.random.split(rng)
713
-
714
- controlnet = FlaxControlNetModel(
715
- in_channels=unet.config.in_channels,
716
- down_block_types=unet.config.down_block_types,
717
- only_cross_attention=unet.config.only_cross_attention,
718
- block_out_channels=unet.config.block_out_channels,
719
- layers_per_block=unet.config.layers_per_block,
720
- attention_head_dim=unet.config.attention_head_dim,
721
- cross_attention_dim=unet.config.cross_attention_dim,
722
- use_linear_projection=unet.config.use_linear_projection,
723
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
724
- freq_shift=unet.config.freq_shift,
725
- )
726
- controlnet_params = controlnet.init_weights(rng=rng_params)
727
- controlnet_params = unfreeze(controlnet_params)
728
- for key in [
729
- "conv_in",
730
- "time_embedding",
731
- "down_blocks_0",
732
- "down_blocks_1",
733
- "down_blocks_2",
734
- "down_blocks_3",
735
- "mid_block",
736
- ]:
737
- controlnet_params[key] = unet_params[key]
738
-
739
- # Optimization
740
- if args.scale_lr:
741
- args.learning_rate = args.learning_rate * total_train_batch_size
742
-
743
- constant_scheduler = optax.constant_schedule(args.learning_rate)
744
-
745
- adamw = optax.adamw(
746
- learning_rate=constant_scheduler,
747
- b1=args.adam_beta1,
748
- b2=args.adam_beta2,
749
- eps=args.adam_epsilon,
750
- weight_decay=args.adam_weight_decay,
751
- )
752
-
753
- optimizer = optax.chain(
754
- optax.clip_by_global_norm(args.max_grad_norm),
755
- adamw,
756
- )
757
-
758
- state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)
759
-
760
- noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(
761
- args.pretrained_model_name_or_path, subfolder="scheduler"
762
- )
763
-
764
- # Initialize our training
765
- validation_rng, train_rngs = jax.random.split(rng)
766
- train_rngs = jax.random.split(train_rngs, jax.local_device_count())
767
-
768
- def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
769
- # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
770
- if args.gradient_accumulation_steps > 1:
771
- grad_steps = args.gradient_accumulation_steps
772
- batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)
773
-
774
- def compute_loss(params, minibatch, sample_rng):
775
- # Convert images to latent space
776
- vae_outputs = vae.apply(
777
- {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode
778
- )
779
- latents = vae_outputs.latent_dist.sample(sample_rng)
780
- # (NHWC) -> (NCHW)
781
- latents = jnp.transpose(latents, (0, 3, 1, 2))
782
- latents = latents * vae.config.scaling_factor
783
-
784
- # Sample noise that we'll add to the latents
785
- noise_rng, timestep_rng = jax.random.split(sample_rng)
786
- noise = jax.random.normal(noise_rng, latents.shape)
787
- # Sample a random timestep for each image
788
- bsz = latents.shape[0]
789
- timesteps = jax.random.randint(
790
- timestep_rng,
791
- (bsz,),
792
- 0,
793
- noise_scheduler.config.num_train_timesteps,
794
- )
795
-
796
- # Add noise to the latents according to the noise magnitude at each timestep
797
- # (this is the forward diffusion process)
798
- noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
799
-
800
- # Get the text embedding for conditioning
801
- encoder_hidden_states = text_encoder(
802
- minibatch["input_ids"],
803
- params=text_encoder_params,
804
- train=False,
805
- )[0]
806
-
807
- controlnet_cond = minibatch["conditioning_pixel_values"]
808
-
809
- # Predict the noise residual and compute loss
810
- down_block_res_samples, mid_block_res_sample = controlnet.apply(
811
- {"params": params},
812
- noisy_latents,
813
- timesteps,
814
- encoder_hidden_states,
815
- controlnet_cond,
816
- train=True,
817
- return_dict=False,
818
- )
819
-
820
- model_pred = unet.apply(
821
- {"params": unet_params},
822
- noisy_latents,
823
- timesteps,
824
- encoder_hidden_states,
825
- down_block_additional_residuals=down_block_res_samples,
826
- mid_block_additional_residual=mid_block_res_sample,
827
- ).sample
828
-
829
- # Get the target for loss depending on the prediction type
830
- if noise_scheduler.config.prediction_type == "epsilon":
831
- target = noise
832
- elif noise_scheduler.config.prediction_type == "v_prediction":
833
- target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
834
- else:
835
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
836
-
837
- loss = (target - model_pred) ** 2
838
- loss = loss.mean()
839
-
840
- return loss
841
-
842
- grad_fn = jax.value_and_grad(compute_loss)
843
-
844
- # get a minibatch (one gradient accumulation slice)
845
- def get_minibatch(batch, grad_idx):
846
- return jax.tree_util.tree_map(
847
- lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
848
- batch,
849
- )
850
-
851
- def loss_and_grad(grad_idx, train_rng):
852
- # create minibatch for the grad step
853
- minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
854
- sample_rng, train_rng = jax.random.split(train_rng, 2)
855
- loss, grad = grad_fn(state.params, minibatch, sample_rng)
856
- return loss, grad, train_rng
857
-
858
- if args.gradient_accumulation_steps == 1:
859
- loss, grad, new_train_rng = loss_and_grad(None, train_rng)
860
- else:
861
- init_loss_grad_rng = (
862
- 0.0, # initial value for cumul_loss
863
- jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
864
- train_rng, # initial value for train_rng
865
- )
866
-
867
- def cumul_grad_step(grad_idx, loss_grad_rng):
868
- cumul_loss, cumul_grad, train_rng = loss_grad_rng
869
- loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
870
- cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
871
- return cumul_loss, cumul_grad, new_train_rng
872
-
873
- loss, grad, new_train_rng = jax.lax.fori_loop(
874
- 0,
875
- args.gradient_accumulation_steps,
876
- cumul_grad_step,
877
- init_loss_grad_rng,
878
- )
879
- loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))
880
-
881
- grad = jax.lax.pmean(grad, "batch")
882
-
883
- new_state = state.apply_gradients(grads=grad)
884
-
885
- metrics = {"loss": loss}
886
- metrics = jax.lax.pmean(metrics, axis_name="batch")
887
-
888
- return new_state, metrics, new_train_rng
889
-
890
- # Create parallel version of the train step
891
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
892
-
893
- # Replicate the train state on each device
894
- state = jax_utils.replicate(state)
895
- unet_params = jax_utils.replicate(unet_params)
896
- text_encoder_params = jax_utils.replicate(text_encoder.params)
897
- vae_params = jax_utils.replicate(vae_params)
898
-
899
- # Train!
900
- if args.streaming:
901
- dataset_length = args.max_train_samples
902
- else:
903
- dataset_length = len(train_dataloader)
904
- num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
905
-
906
- # Scheduler and math around the number of training steps.
907
- if args.max_train_steps is None:
908
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
909
-
910
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
911
-
912
- logger.info("***** Running training *****")
913
- logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
914
- logger.info(f" Num Epochs = {args.num_train_epochs}")
915
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
916
- logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
917
- logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
918
-
919
- if jax.process_index() == 0:
920
- wandb.define_metric("*", step_metric="train/step")
921
- wandb.config.update(
922
- {
923
- "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
924
- "total_train_batch_size": total_train_batch_size,
925
- "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
926
- "num_devices": jax.device_count(),
927
- }
928
- )
929
-
930
- global_step = 0
931
- epochs = tqdm(
932
- range(args.num_train_epochs),
933
- desc="Epoch ... ",
934
- position=0,
935
- disable=jax.process_index() > 0,
936
- )
937
- for epoch in epochs:
938
- # ======================== Training ================================
939
-
940
- train_metrics = []
941
-
942
- steps_per_epoch = (
943
- args.max_train_samples // total_train_batch_size
944
- if args.streaming
945
- else len(train_dataset) // total_train_batch_size
946
- )
947
- train_step_progress_bar = tqdm(
948
- total=steps_per_epoch,
949
- desc="Training...",
950
- position=1,
951
- leave=False,
952
- disable=jax.process_index() > 0,
953
- )
954
- # train
955
- for batch in train_dataloader:
956
- batch = shard(batch)
957
- state, train_metric, train_rngs = p_train_step(
958
- state, unet_params, text_encoder_params, vae_params, batch, train_rngs
959
- )
960
- train_metrics.append(train_metric)
961
-
962
- train_step_progress_bar.update(1)
963
-
964
- global_step += 1
965
- if global_step >= args.max_train_steps:
966
- break
967
-
968
- if (
969
- args.validation_prompt is not None
970
- and global_step % args.validation_steps == 0
971
- and jax.process_index() == 0
972
- ):
973
- _ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
974
-
975
- if global_step % args.logging_steps == 0 and jax.process_index() == 0:
976
- if args.report_to == "wandb":
977
- wandb.log(
978
- {
979
- "train/step": global_step,
980
- "train/epoch": epoch,
981
- "train/loss": jax_utils.unreplicate(train_metric)["loss"],
982
- }
983
- )
984
-
985
- train_metric = jax_utils.unreplicate(train_metric)
986
- train_step_progress_bar.close()
987
- epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
988
-
989
- # Create the pipeline using using the trained modules and save it.
990
- if jax.process_index() == 0:
991
- if args.validation_prompt is not None:
992
- image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
993
-
994
- controlnet.save_pretrained(
995
- args.output_dir,
996
- params=get_params_to_save(state.params),
997
- )
998
-
999
- if args.push_to_hub:
1000
- save_model_card(
1001
- repo_id,
1002
- image_logs=image_logs,
1003
- base_model=args.pretrained_model_name_or_path,
1004
- repo_folder=args.output_dir,
1005
- )
1006
- upload_folder(
1007
- repo_id=repo_id,
1008
- folder_path=args.output_dir,
1009
- commit_message="End of training",
1010
- ignore_patterns=["step_*", "epoch_*"],
1011
- )
1012
-
1013
-
1014
- if __name__ == "__main__":
1015
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/dreambooth/README.md DELETED
@@ -1,464 +0,0 @@
1
- # DreamBooth training example
2
-
3
- [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
4
- The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5
-
6
-
7
- ## Running locally with PyTorch
8
-
9
- ### Installing the dependencies
10
-
11
- Before running the scripts, make sure to install the library's training dependencies:
12
-
13
- **Important**
14
-
15
- To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
16
- ```bash
17
- git clone https://github.com/huggingface/diffusers
18
- cd diffusers
19
- pip install -e .
20
- ```
21
-
22
- Then cd in the example folder and run
23
- ```bash
24
- pip install -r requirements.txt
25
- ```
26
-
27
- And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
28
-
29
- ```bash
30
- accelerate config
31
- ```
32
-
33
- Or for a default accelerate configuration without answering questions about your environment
34
-
35
- ```bash
36
- accelerate config default
37
- ```
38
-
39
- Or if your environment doesn't support an interactive shell e.g. a notebook
40
-
41
- ```python
42
- from accelerate.utils import write_basic_config
43
- write_basic_config()
44
- ```
45
-
46
- ### Dog toy example
47
-
48
- Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
49
-
50
- And launch the training using
51
-
52
- **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
53
-
54
- ```bash
55
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
56
- export INSTANCE_DIR="path-to-instance-images"
57
- export OUTPUT_DIR="path-to-save-model"
58
-
59
- accelerate launch train_dreambooth.py \
60
- --pretrained_model_name_or_path=$MODEL_NAME \
61
- --instance_data_dir=$INSTANCE_DIR \
62
- --output_dir=$OUTPUT_DIR \
63
- --instance_prompt="a photo of sks dog" \
64
- --resolution=512 \
65
- --train_batch_size=1 \
66
- --gradient_accumulation_steps=1 \
67
- --learning_rate=5e-6 \
68
- --lr_scheduler="constant" \
69
- --lr_warmup_steps=0 \
70
- --max_train_steps=400
71
- ```
72
-
73
- ### Training with prior-preservation loss
74
-
75
- Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
76
- According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
77
-
78
- ```bash
79
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
80
- export INSTANCE_DIR="path-to-instance-images"
81
- export CLASS_DIR="path-to-class-images"
82
- export OUTPUT_DIR="path-to-save-model"
83
-
84
- accelerate launch train_dreambooth.py \
85
- --pretrained_model_name_or_path=$MODEL_NAME \
86
- --instance_data_dir=$INSTANCE_DIR \
87
- --class_data_dir=$CLASS_DIR \
88
- --output_dir=$OUTPUT_DIR \
89
- --with_prior_preservation --prior_loss_weight=1.0 \
90
- --instance_prompt="a photo of sks dog" \
91
- --class_prompt="a photo of dog" \
92
- --resolution=512 \
93
- --train_batch_size=1 \
94
- --gradient_accumulation_steps=1 \
95
- --learning_rate=5e-6 \
96
- --lr_scheduler="constant" \
97
- --lr_warmup_steps=0 \
98
- --num_class_images=200 \
99
- --max_train_steps=800
100
- ```
101
-
102
-
103
- ### Training on a 16GB GPU:
104
-
105
- With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
106
-
107
- To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
108
-
109
- ```bash
110
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
111
- export INSTANCE_DIR="path-to-instance-images"
112
- export CLASS_DIR="path-to-class-images"
113
- export OUTPUT_DIR="path-to-save-model"
114
-
115
- accelerate launch train_dreambooth.py \
116
- --pretrained_model_name_or_path=$MODEL_NAME \
117
- --instance_data_dir=$INSTANCE_DIR \
118
- --class_data_dir=$CLASS_DIR \
119
- --output_dir=$OUTPUT_DIR \
120
- --with_prior_preservation --prior_loss_weight=1.0 \
121
- --instance_prompt="a photo of sks dog" \
122
- --class_prompt="a photo of dog" \
123
- --resolution=512 \
124
- --train_batch_size=1 \
125
- --gradient_accumulation_steps=2 --gradient_checkpointing \
126
- --use_8bit_adam \
127
- --learning_rate=5e-6 \
128
- --lr_scheduler="constant" \
129
- --lr_warmup_steps=0 \
130
- --num_class_images=200 \
131
- --max_train_steps=800
132
- ```
133
-
134
-
135
- ### Training on a 12GB GPU:
136
-
137
- It is possible to run dreambooth on a 12GB GPU by using the following optimizations:
138
- - [gradient checkpointing and the 8-bit optimizer](#training-on-a-16gb-gpu)
139
- - [xformers](#training-with-xformers)
140
- - [setting grads to none](#set-grads-to-none)
141
-
142
- ```bash
143
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
144
- export INSTANCE_DIR="path-to-instance-images"
145
- export CLASS_DIR="path-to-class-images"
146
- export OUTPUT_DIR="path-to-save-model"
147
-
148
- accelerate launch train_dreambooth.py \
149
- --pretrained_model_name_or_path=$MODEL_NAME \
150
- --instance_data_dir=$INSTANCE_DIR \
151
- --class_data_dir=$CLASS_DIR \
152
- --output_dir=$OUTPUT_DIR \
153
- --with_prior_preservation --prior_loss_weight=1.0 \
154
- --instance_prompt="a photo of sks dog" \
155
- --class_prompt="a photo of dog" \
156
- --resolution=512 \
157
- --train_batch_size=1 \
158
- --gradient_accumulation_steps=1 --gradient_checkpointing \
159
- --use_8bit_adam \
160
- --enable_xformers_memory_efficient_attention \
161
- --set_grads_to_none \
162
- --learning_rate=2e-6 \
163
- --lr_scheduler="constant" \
164
- --lr_warmup_steps=0 \
165
- --num_class_images=200 \
166
- --max_train_steps=800
167
- ```
168
-
169
-
170
- ### Training on a 8 GB GPU:
171
-
172
- By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
173
- tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
174
-
175
- DeepSpeed needs to be enabled with `accelerate config`. During configuration
176
- answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
177
- mixed precision and offloading both parameters and optimizer state to cpu it's
178
- possible to train on under 8 GB VRAM with a drawback of requiring significantly
179
- more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
180
-
181
- Changing the default Adam optimizer to DeepSpeed's special version of Adam
182
- `deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
183
- it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
184
- does not seem to be compatible with DeepSpeed at the moment.
185
-
186
- ```bash
187
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
188
- export INSTANCE_DIR="path-to-instance-images"
189
- export CLASS_DIR="path-to-class-images"
190
- export OUTPUT_DIR="path-to-save-model"
191
-
192
- accelerate launch --mixed_precision="fp16" train_dreambooth.py \
193
- --pretrained_model_name_or_path=$MODEL_NAME \
194
- --instance_data_dir=$INSTANCE_DIR \
195
- --class_data_dir=$CLASS_DIR \
196
- --output_dir=$OUTPUT_DIR \
197
- --with_prior_preservation --prior_loss_weight=1.0 \
198
- --instance_prompt="a photo of sks dog" \
199
- --class_prompt="a photo of dog" \
200
- --resolution=512 \
201
- --train_batch_size=1 \
202
- --sample_batch_size=1 \
203
- --gradient_accumulation_steps=1 --gradient_checkpointing \
204
- --learning_rate=5e-6 \
205
- --lr_scheduler="constant" \
206
- --lr_warmup_steps=0 \
207
- --num_class_images=200 \
208
- --max_train_steps=800
209
- ```
210
-
211
- ### Fine-tune text encoder with the UNet.
212
-
213
- The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
214
- Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
215
-
216
- ___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
217
-
218
- ```bash
219
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
220
- export INSTANCE_DIR="path-to-instance-images"
221
- export CLASS_DIR="path-to-class-images"
222
- export OUTPUT_DIR="path-to-save-model"
223
-
224
- accelerate launch train_dreambooth.py \
225
- --pretrained_model_name_or_path=$MODEL_NAME \
226
- --train_text_encoder \
227
- --instance_data_dir=$INSTANCE_DIR \
228
- --class_data_dir=$CLASS_DIR \
229
- --output_dir=$OUTPUT_DIR \
230
- --with_prior_preservation --prior_loss_weight=1.0 \
231
- --instance_prompt="a photo of sks dog" \
232
- --class_prompt="a photo of dog" \
233
- --resolution=512 \
234
- --train_batch_size=1 \
235
- --use_8bit_adam \
236
- --gradient_checkpointing \
237
- --learning_rate=2e-6 \
238
- --lr_scheduler="constant" \
239
- --lr_warmup_steps=0 \
240
- --num_class_images=200 \
241
- --max_train_steps=800
242
- ```
243
-
244
- ### Using DreamBooth for pipelines other than Stable Diffusion
245
-
246
- The [AltDiffusion pipeline](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion) also supports dreambooth fine-tuning. The process is the same as above, all you need to do is replace the `MODEL_NAME` like this:
247
-
248
- ```
249
- export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
250
- or
251
- export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
252
- ```
253
-
254
- ### Inference
255
-
256
- Once you have trained a model using the above command, you can run inference simply using the `StableDiffusionPipeline`. Make sure to include the `identifier` (e.g. sks in above example) in your prompt.
257
-
258
- ```python
259
- from diffusers import StableDiffusionPipeline
260
- import torch
261
-
262
- model_id = "path-to-your-trained-model"
263
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
264
-
265
- prompt = "A photo of sks dog in a bucket"
266
- image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
267
-
268
- image.save("dog-bucket.png")
269
- ```
270
-
271
- ### Inference from a training checkpoint
272
-
273
- You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
274
-
275
- ## Training with Low-Rank Adaptation of Large Language Models (LoRA)
276
-
277
- Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
278
-
279
- In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
280
- - Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
281
- - Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
282
- - LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
283
-
284
- [cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
285
- the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
286
-
287
- ### Training
288
-
289
- Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
290
-
291
- First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
292
- Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
293
-
294
- Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
295
-
296
- **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
297
-
298
- **___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
299
-
300
-
301
- ```bash
302
- export MODEL_NAME="runwayml/stable-diffusion-v1-5"
303
- export INSTANCE_DIR="path-to-instance-images"
304
- export OUTPUT_DIR="path-to-save-model"
305
- ```
306
-
307
- For this example we want to directly store the trained LoRA embeddings on the Hub, so
308
- we need to be logged in and add the `--push_to_hub` flag.
309
-
310
- ```bash
311
- huggingface-cli login
312
- ```
313
-
314
- Now we can start training!
315
-
316
- ```bash
317
- accelerate launch train_dreambooth_lora.py \
318
- --pretrained_model_name_or_path=$MODEL_NAME \
319
- --instance_data_dir=$INSTANCE_DIR \
320
- --output_dir=$OUTPUT_DIR \
321
- --instance_prompt="a photo of sks dog" \
322
- --resolution=512 \
323
- --train_batch_size=1 \
324
- --gradient_accumulation_steps=1 \
325
- --checkpointing_steps=100 \
326
- --learning_rate=1e-4 \
327
- --report_to="wandb" \
328
- --lr_scheduler="constant" \
329
- --lr_warmup_steps=0 \
330
- --max_train_steps=500 \
331
- --validation_prompt="A photo of sks dog in a bucket" \
332
- --validation_epochs=50 \
333
- --seed="0" \
334
- --push_to_hub
335
- ```
336
-
337
- **___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
338
- use *1e-4* instead of the usual *2e-6*.___**
339
-
340
- The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
341
-
342
- The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
343
- You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
344
-
345
- ### Inference
346
-
347
- After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
348
- load the original pipeline:
349
-
350
- ```python
351
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
352
- import torch
353
-
354
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
355
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
356
- pipe.to("cuda")
357
- ```
358
-
359
- Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
360
-
361
- ```python
362
- pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
363
- ```
364
-
365
- Finally, we can run the model in inference.
366
-
367
- ```python
368
- image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
369
- ```
370
-
371
- ## Training with Flax/JAX
372
-
373
- For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
374
-
375
- ____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
376
-
377
-
378
- Before running the scripts, make sure to install the library's training dependencies:
379
-
380
- ```bash
381
- pip install -U -r requirements_flax.txt
382
- ```
383
-
384
-
385
- ### Training without prior preservation loss
386
-
387
- ```bash
388
- export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
389
- export INSTANCE_DIR="path-to-instance-images"
390
- export OUTPUT_DIR="path-to-save-model"
391
-
392
- python train_dreambooth_flax.py \
393
- --pretrained_model_name_or_path=$MODEL_NAME \
394
- --instance_data_dir=$INSTANCE_DIR \
395
- --output_dir=$OUTPUT_DIR \
396
- --instance_prompt="a photo of sks dog" \
397
- --resolution=512 \
398
- --train_batch_size=1 \
399
- --learning_rate=5e-6 \
400
- --max_train_steps=400
401
- ```
402
-
403
-
404
- ### Training with prior preservation loss
405
-
406
- ```bash
407
- export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
408
- export INSTANCE_DIR="path-to-instance-images"
409
- export CLASS_DIR="path-to-class-images"
410
- export OUTPUT_DIR="path-to-save-model"
411
-
412
- python train_dreambooth_flax.py \
413
- --pretrained_model_name_or_path=$MODEL_NAME \
414
- --instance_data_dir=$INSTANCE_DIR \
415
- --class_data_dir=$CLASS_DIR \
416
- --output_dir=$OUTPUT_DIR \
417
- --with_prior_preservation --prior_loss_weight=1.0 \
418
- --instance_prompt="a photo of sks dog" \
419
- --class_prompt="a photo of dog" \
420
- --resolution=512 \
421
- --train_batch_size=1 \
422
- --learning_rate=5e-6 \
423
- --num_class_images=200 \
424
- --max_train_steps=800
425
- ```
426
-
427
-
428
- ### Fine-tune text encoder with the UNet.
429
-
430
- ```bash
431
- export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
432
- export INSTANCE_DIR="path-to-instance-images"
433
- export CLASS_DIR="path-to-class-images"
434
- export OUTPUT_DIR="path-to-save-model"
435
-
436
- python train_dreambooth_flax.py \
437
- --pretrained_model_name_or_path=$MODEL_NAME \
438
- --train_text_encoder \
439
- --instance_data_dir=$INSTANCE_DIR \
440
- --class_data_dir=$CLASS_DIR \
441
- --output_dir=$OUTPUT_DIR \
442
- --with_prior_preservation --prior_loss_weight=1.0 \
443
- --instance_prompt="a photo of sks dog" \
444
- --class_prompt="a photo of dog" \
445
- --resolution=512 \
446
- --train_batch_size=1 \
447
- --learning_rate=2e-6 \
448
- --num_class_images=200 \
449
- --max_train_steps=800
450
- ```
451
-
452
- ### Training with xformers:
453
- You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
454
-
455
- You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
456
-
457
- ### Set grads to none
458
-
459
- To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
460
-
461
- More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
462
-
463
- ### Experimental results
464
- You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/dreambooth/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- accelerate
2
- torchvision
3
- transformers>=4.25.1
4
- ftfy
5
- tensorboard
6
- Jinja2
 
 
 
 
 
 
 
diffusers/examples/dreambooth/requirements_flax.txt DELETED
@@ -1,8 +0,0 @@
1
- transformers>=4.25.1
2
- flax
3
- optax
4
- torch
5
- torchvision
6
- ftfy
7
- tensorboard
8
- Jinja2
 
 
 
 
 
 
 
 
 
diffusers/examples/dreambooth/train_dreambooth.py DELETED
@@ -1,1039 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import hashlib
18
- import itertools
19
- import logging
20
- import math
21
- import os
22
- import warnings
23
- from pathlib import Path
24
-
25
- import accelerate
26
- import numpy as np
27
- import torch
28
- import torch.nn.functional as F
29
- import torch.utils.checkpoint
30
- import transformers
31
- from accelerate import Accelerator
32
- from accelerate.logging import get_logger
33
- from accelerate.utils import ProjectConfiguration, set_seed
34
- from huggingface_hub import create_repo, upload_folder
35
- from packaging import version
36
- from PIL import Image
37
- from torch.utils.data import Dataset
38
- from torchvision import transforms
39
- from tqdm.auto import tqdm
40
- from transformers import AutoTokenizer, PretrainedConfig
41
-
42
- import diffusers
43
- from diffusers import (
44
- AutoencoderKL,
45
- DDPMScheduler,
46
- DiffusionPipeline,
47
- DPMSolverMultistepScheduler,
48
- UNet2DConditionModel,
49
- )
50
- from diffusers.optimization import get_scheduler
51
- from diffusers.utils import check_min_version, is_wandb_available
52
- from diffusers.utils.import_utils import is_xformers_available
53
-
54
-
55
- if is_wandb_available():
56
- import wandb
57
-
58
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
- check_min_version("0.15.0.dev0")
60
-
61
- logger = get_logger(__name__)
62
-
63
-
64
- def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
65
- logger.info(
66
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
67
- f" {args.validation_prompt}."
68
- )
69
- # create pipeline (note: unet and vae are loaded again in float32)
70
- pipeline = DiffusionPipeline.from_pretrained(
71
- args.pretrained_model_name_or_path,
72
- text_encoder=accelerator.unwrap_model(text_encoder),
73
- tokenizer=tokenizer,
74
- unet=accelerator.unwrap_model(unet),
75
- vae=vae,
76
- revision=args.revision,
77
- torch_dtype=weight_dtype,
78
- )
79
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
80
- pipeline = pipeline.to(accelerator.device)
81
- pipeline.set_progress_bar_config(disable=True)
82
-
83
- # run inference
84
- generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
85
- images = []
86
- for _ in range(args.num_validation_images):
87
- with torch.autocast("cuda"):
88
- image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
89
- images.append(image)
90
-
91
- for tracker in accelerator.trackers:
92
- if tracker.name == "tensorboard":
93
- np_images = np.stack([np.asarray(img) for img in images])
94
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
95
- if tracker.name == "wandb":
96
- tracker.log(
97
- {
98
- "validation": [
99
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
100
- ]
101
- }
102
- )
103
-
104
- del pipeline
105
- torch.cuda.empty_cache()
106
-
107
-
108
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
109
- text_encoder_config = PretrainedConfig.from_pretrained(
110
- pretrained_model_name_or_path,
111
- subfolder="text_encoder",
112
- revision=revision,
113
- )
114
- model_class = text_encoder_config.architectures[0]
115
-
116
- if model_class == "CLIPTextModel":
117
- from transformers import CLIPTextModel
118
-
119
- return CLIPTextModel
120
- elif model_class == "RobertaSeriesModelWithTransformation":
121
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
122
-
123
- return RobertaSeriesModelWithTransformation
124
- else:
125
- raise ValueError(f"{model_class} is not supported.")
126
-
127
-
128
- def parse_args(input_args=None):
129
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
130
- parser.add_argument(
131
- "--pretrained_model_name_or_path",
132
- type=str,
133
- default=None,
134
- required=True,
135
- help="Path to pretrained model or model identifier from huggingface.co/models.",
136
- )
137
- parser.add_argument(
138
- "--revision",
139
- type=str,
140
- default=None,
141
- required=False,
142
- help=(
143
- "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
144
- " float32 precision."
145
- ),
146
- )
147
- parser.add_argument(
148
- "--tokenizer_name",
149
- type=str,
150
- default=None,
151
- help="Pretrained tokenizer name or path if not the same as model_name",
152
- )
153
- parser.add_argument(
154
- "--instance_data_dir",
155
- type=str,
156
- default=None,
157
- required=True,
158
- help="A folder containing the training data of instance images.",
159
- )
160
- parser.add_argument(
161
- "--class_data_dir",
162
- type=str,
163
- default=None,
164
- required=False,
165
- help="A folder containing the training data of class images.",
166
- )
167
- parser.add_argument(
168
- "--instance_prompt",
169
- type=str,
170
- default=None,
171
- required=True,
172
- help="The prompt with identifier specifying the instance",
173
- )
174
- parser.add_argument(
175
- "--class_prompt",
176
- type=str,
177
- default=None,
178
- help="The prompt to specify images in the same class as provided instance images.",
179
- )
180
- parser.add_argument(
181
- "--with_prior_preservation",
182
- default=False,
183
- action="store_true",
184
- help="Flag to add prior preservation loss.",
185
- )
186
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
187
- parser.add_argument(
188
- "--num_class_images",
189
- type=int,
190
- default=100,
191
- help=(
192
- "Minimal class images for prior preservation loss. If there are not enough images already present in"
193
- " class_data_dir, additional images will be sampled with class_prompt."
194
- ),
195
- )
196
- parser.add_argument(
197
- "--output_dir",
198
- type=str,
199
- default="text-inversion-model",
200
- help="The output directory where the model predictions and checkpoints will be written.",
201
- )
202
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
203
- parser.add_argument(
204
- "--resolution",
205
- type=int,
206
- default=512,
207
- help=(
208
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
209
- " resolution"
210
- ),
211
- )
212
- parser.add_argument(
213
- "--center_crop",
214
- default=False,
215
- action="store_true",
216
- help=(
217
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
218
- " cropped. The images will be resized to the resolution first before cropping."
219
- ),
220
- )
221
- parser.add_argument(
222
- "--train_text_encoder",
223
- action="store_true",
224
- help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
225
- )
226
- parser.add_argument(
227
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
228
- )
229
- parser.add_argument(
230
- "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
231
- )
232
- parser.add_argument("--num_train_epochs", type=int, default=1)
233
- parser.add_argument(
234
- "--max_train_steps",
235
- type=int,
236
- default=None,
237
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
238
- )
239
- parser.add_argument(
240
- "--checkpointing_steps",
241
- type=int,
242
- default=500,
243
- help=(
244
- "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
245
- "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
246
- "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
247
- "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
248
- "instructions."
249
- ),
250
- )
251
- parser.add_argument(
252
- "--checkpoints_total_limit",
253
- type=int,
254
- default=None,
255
- help=(
256
- "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
257
- " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
258
- " for more details"
259
- ),
260
- )
261
- parser.add_argument(
262
- "--resume_from_checkpoint",
263
- type=str,
264
- default=None,
265
- help=(
266
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
267
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
268
- ),
269
- )
270
- parser.add_argument(
271
- "--gradient_accumulation_steps",
272
- type=int,
273
- default=1,
274
- help="Number of updates steps to accumulate before performing a backward/update pass.",
275
- )
276
- parser.add_argument(
277
- "--gradient_checkpointing",
278
- action="store_true",
279
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
280
- )
281
- parser.add_argument(
282
- "--learning_rate",
283
- type=float,
284
- default=5e-6,
285
- help="Initial learning rate (after the potential warmup period) to use.",
286
- )
287
- parser.add_argument(
288
- "--scale_lr",
289
- action="store_true",
290
- default=False,
291
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
292
- )
293
- parser.add_argument(
294
- "--lr_scheduler",
295
- type=str,
296
- default="constant",
297
- help=(
298
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
299
- ' "constant", "constant_with_warmup"]'
300
- ),
301
- )
302
- parser.add_argument(
303
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
304
- )
305
- parser.add_argument(
306
- "--lr_num_cycles",
307
- type=int,
308
- default=1,
309
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
310
- )
311
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
312
- parser.add_argument(
313
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
314
- )
315
- parser.add_argument(
316
- "--dataloader_num_workers",
317
- type=int,
318
- default=0,
319
- help=(
320
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
321
- ),
322
- )
323
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
324
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
325
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
326
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
327
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
328
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
329
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
330
- parser.add_argument(
331
- "--hub_model_id",
332
- type=str,
333
- default=None,
334
- help="The name of the repository to keep in sync with the local `output_dir`.",
335
- )
336
- parser.add_argument(
337
- "--logging_dir",
338
- type=str,
339
- default="logs",
340
- help=(
341
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
342
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
343
- ),
344
- )
345
- parser.add_argument(
346
- "--allow_tf32",
347
- action="store_true",
348
- help=(
349
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
350
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
351
- ),
352
- )
353
- parser.add_argument(
354
- "--report_to",
355
- type=str,
356
- default="tensorboard",
357
- help=(
358
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
359
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
360
- ),
361
- )
362
- parser.add_argument(
363
- "--validation_prompt",
364
- type=str,
365
- default=None,
366
- help="A prompt that is used during validation to verify that the model is learning.",
367
- )
368
- parser.add_argument(
369
- "--num_validation_images",
370
- type=int,
371
- default=4,
372
- help="Number of images that should be generated during validation with `validation_prompt`.",
373
- )
374
- parser.add_argument(
375
- "--validation_steps",
376
- type=int,
377
- default=100,
378
- help=(
379
- "Run validation every X steps. Validation consists of running the prompt"
380
- " `args.validation_prompt` multiple times: `args.num_validation_images`"
381
- " and logging the images."
382
- ),
383
- )
384
- parser.add_argument(
385
- "--mixed_precision",
386
- type=str,
387
- default=None,
388
- choices=["no", "fp16", "bf16"],
389
- help=(
390
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
391
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
392
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
393
- ),
394
- )
395
- parser.add_argument(
396
- "--prior_generation_precision",
397
- type=str,
398
- default=None,
399
- choices=["no", "fp32", "fp16", "bf16"],
400
- help=(
401
- "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
402
- " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
403
- ),
404
- )
405
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
406
- parser.add_argument(
407
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
408
- )
409
- parser.add_argument(
410
- "--set_grads_to_none",
411
- action="store_true",
412
- help=(
413
- "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
414
- " behaviors, so disable this argument if it causes any problems. More info:"
415
- " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
416
- ),
417
- )
418
-
419
- parser.add_argument(
420
- "--offset_noise",
421
- action="store_true",
422
- default=False,
423
- help=(
424
- "Fine-tuning against a modified noise"
425
- " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
426
- ),
427
- )
428
-
429
- if input_args is not None:
430
- args = parser.parse_args(input_args)
431
- else:
432
- args = parser.parse_args()
433
-
434
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
435
- if env_local_rank != -1 and env_local_rank != args.local_rank:
436
- args.local_rank = env_local_rank
437
-
438
- if args.with_prior_preservation:
439
- if args.class_data_dir is None:
440
- raise ValueError("You must specify a data directory for class images.")
441
- if args.class_prompt is None:
442
- raise ValueError("You must specify prompt for class images.")
443
- else:
444
- # logger is not available yet
445
- if args.class_data_dir is not None:
446
- warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
447
- if args.class_prompt is not None:
448
- warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
449
-
450
- return args
451
-
452
-
453
- class DreamBoothDataset(Dataset):
454
- """
455
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
456
- It pre-processes the images and the tokenizes prompts.
457
- """
458
-
459
- def __init__(
460
- self,
461
- instance_data_root,
462
- instance_prompt,
463
- tokenizer,
464
- class_data_root=None,
465
- class_prompt=None,
466
- class_num=None,
467
- size=512,
468
- center_crop=False,
469
- ):
470
- self.size = size
471
- self.center_crop = center_crop
472
- self.tokenizer = tokenizer
473
-
474
- self.instance_data_root = Path(instance_data_root)
475
- if not self.instance_data_root.exists():
476
- raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
477
-
478
- self.instance_images_path = list(Path(instance_data_root).iterdir())
479
- self.num_instance_images = len(self.instance_images_path)
480
- self.instance_prompt = instance_prompt
481
- self._length = self.num_instance_images
482
-
483
- if class_data_root is not None:
484
- self.class_data_root = Path(class_data_root)
485
- self.class_data_root.mkdir(parents=True, exist_ok=True)
486
- self.class_images_path = list(self.class_data_root.iterdir())
487
- if class_num is not None:
488
- self.num_class_images = min(len(self.class_images_path), class_num)
489
- else:
490
- self.num_class_images = len(self.class_images_path)
491
- self._length = max(self.num_class_images, self.num_instance_images)
492
- self.class_prompt = class_prompt
493
- else:
494
- self.class_data_root = None
495
-
496
- self.image_transforms = transforms.Compose(
497
- [
498
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
499
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
500
- transforms.ToTensor(),
501
- transforms.Normalize([0.5], [0.5]),
502
- ]
503
- )
504
-
505
- def __len__(self):
506
- return self._length
507
-
508
- def __getitem__(self, index):
509
- example = {}
510
- instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
511
- if not instance_image.mode == "RGB":
512
- instance_image = instance_image.convert("RGB")
513
- example["instance_images"] = self.image_transforms(instance_image)
514
- example["instance_prompt_ids"] = self.tokenizer(
515
- self.instance_prompt,
516
- truncation=True,
517
- padding="max_length",
518
- max_length=self.tokenizer.model_max_length,
519
- return_tensors="pt",
520
- ).input_ids
521
-
522
- if self.class_data_root:
523
- class_image = Image.open(self.class_images_path[index % self.num_class_images])
524
- if not class_image.mode == "RGB":
525
- class_image = class_image.convert("RGB")
526
- example["class_images"] = self.image_transforms(class_image)
527
- example["class_prompt_ids"] = self.tokenizer(
528
- self.class_prompt,
529
- truncation=True,
530
- padding="max_length",
531
- max_length=self.tokenizer.model_max_length,
532
- return_tensors="pt",
533
- ).input_ids
534
-
535
- return example
536
-
537
-
538
- def collate_fn(examples, with_prior_preservation=False):
539
- input_ids = [example["instance_prompt_ids"] for example in examples]
540
- pixel_values = [example["instance_images"] for example in examples]
541
-
542
- # Concat class and instance examples for prior preservation.
543
- # We do this to avoid doing two forward passes.
544
- if with_prior_preservation:
545
- input_ids += [example["class_prompt_ids"] for example in examples]
546
- pixel_values += [example["class_images"] for example in examples]
547
-
548
- pixel_values = torch.stack(pixel_values)
549
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
550
-
551
- input_ids = torch.cat(input_ids, dim=0)
552
-
553
- batch = {
554
- "input_ids": input_ids,
555
- "pixel_values": pixel_values,
556
- }
557
- return batch
558
-
559
-
560
- class PromptDataset(Dataset):
561
- "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
562
-
563
- def __init__(self, prompt, num_samples):
564
- self.prompt = prompt
565
- self.num_samples = num_samples
566
-
567
- def __len__(self):
568
- return self.num_samples
569
-
570
- def __getitem__(self, index):
571
- example = {}
572
- example["prompt"] = self.prompt
573
- example["index"] = index
574
- return example
575
-
576
-
577
- def main(args):
578
- logging_dir = Path(args.output_dir, args.logging_dir)
579
-
580
- accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
581
-
582
- accelerator = Accelerator(
583
- gradient_accumulation_steps=args.gradient_accumulation_steps,
584
- mixed_precision=args.mixed_precision,
585
- log_with=args.report_to,
586
- logging_dir=logging_dir,
587
- project_config=accelerator_project_config,
588
- )
589
-
590
- if args.report_to == "wandb":
591
- if not is_wandb_available():
592
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
593
-
594
- # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
595
- # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
596
- # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
597
- if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
598
- raise ValueError(
599
- "Gradient accumulation is not supported when training the text encoder in distributed training. "
600
- "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
601
- )
602
-
603
- # Make one log on every process with the configuration for debugging.
604
- logging.basicConfig(
605
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
606
- datefmt="%m/%d/%Y %H:%M:%S",
607
- level=logging.INFO,
608
- )
609
- logger.info(accelerator.state, main_process_only=False)
610
- if accelerator.is_local_main_process:
611
- transformers.utils.logging.set_verbosity_warning()
612
- diffusers.utils.logging.set_verbosity_info()
613
- else:
614
- transformers.utils.logging.set_verbosity_error()
615
- diffusers.utils.logging.set_verbosity_error()
616
-
617
- # If passed along, set the training seed now.
618
- if args.seed is not None:
619
- set_seed(args.seed)
620
-
621
- # Generate class images if prior preservation is enabled.
622
- if args.with_prior_preservation:
623
- class_images_dir = Path(args.class_data_dir)
624
- if not class_images_dir.exists():
625
- class_images_dir.mkdir(parents=True)
626
- cur_class_images = len(list(class_images_dir.iterdir()))
627
-
628
- if cur_class_images < args.num_class_images:
629
- torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
630
- if args.prior_generation_precision == "fp32":
631
- torch_dtype = torch.float32
632
- elif args.prior_generation_precision == "fp16":
633
- torch_dtype = torch.float16
634
- elif args.prior_generation_precision == "bf16":
635
- torch_dtype = torch.bfloat16
636
- pipeline = DiffusionPipeline.from_pretrained(
637
- args.pretrained_model_name_or_path,
638
- torch_dtype=torch_dtype,
639
- safety_checker=None,
640
- revision=args.revision,
641
- )
642
- pipeline.set_progress_bar_config(disable=True)
643
-
644
- num_new_images = args.num_class_images - cur_class_images
645
- logger.info(f"Number of class images to sample: {num_new_images}.")
646
-
647
- sample_dataset = PromptDataset(args.class_prompt, num_new_images)
648
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
649
-
650
- sample_dataloader = accelerator.prepare(sample_dataloader)
651
- pipeline.to(accelerator.device)
652
-
653
- for example in tqdm(
654
- sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
655
- ):
656
- images = pipeline(example["prompt"]).images
657
-
658
- for i, image in enumerate(images):
659
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
660
- image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
661
- image.save(image_filename)
662
-
663
- del pipeline
664
- if torch.cuda.is_available():
665
- torch.cuda.empty_cache()
666
-
667
- # Handle the repository creation
668
- if accelerator.is_main_process:
669
- if args.output_dir is not None:
670
- os.makedirs(args.output_dir, exist_ok=True)
671
-
672
- if args.push_to_hub:
673
- repo_id = create_repo(
674
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
675
- ).repo_id
676
-
677
- # Load the tokenizer
678
- if args.tokenizer_name:
679
- tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
680
- elif args.pretrained_model_name_or_path:
681
- tokenizer = AutoTokenizer.from_pretrained(
682
- args.pretrained_model_name_or_path,
683
- subfolder="tokenizer",
684
- revision=args.revision,
685
- use_fast=False,
686
- )
687
-
688
- # import correct text encoder class
689
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
690
-
691
- # Load scheduler and models
692
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
693
- text_encoder = text_encoder_cls.from_pretrained(
694
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
695
- )
696
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
697
- unet = UNet2DConditionModel.from_pretrained(
698
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
699
- )
700
-
701
- # `accelerate` 0.16.0 will have better support for customized saving
702
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
703
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
704
- def save_model_hook(models, weights, output_dir):
705
- for model in models:
706
- sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
707
- model.save_pretrained(os.path.join(output_dir, sub_dir))
708
-
709
- # make sure to pop weight so that corresponding model is not saved again
710
- weights.pop()
711
-
712
- def load_model_hook(models, input_dir):
713
- while len(models) > 0:
714
- # pop models so that they are not loaded again
715
- model = models.pop()
716
-
717
- if type(model) == type(text_encoder):
718
- # load transformers style into model
719
- load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
720
- model.config = load_model.config
721
- else:
722
- # load diffusers style into model
723
- load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
724
- model.register_to_config(**load_model.config)
725
-
726
- model.load_state_dict(load_model.state_dict())
727
- del load_model
728
-
729
- accelerator.register_save_state_pre_hook(save_model_hook)
730
- accelerator.register_load_state_pre_hook(load_model_hook)
731
-
732
- vae.requires_grad_(False)
733
- if not args.train_text_encoder:
734
- text_encoder.requires_grad_(False)
735
-
736
- if args.enable_xformers_memory_efficient_attention:
737
- if is_xformers_available():
738
- import xformers
739
-
740
- xformers_version = version.parse(xformers.__version__)
741
- if xformers_version == version.parse("0.0.16"):
742
- logger.warn(
743
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
744
- )
745
- unet.enable_xformers_memory_efficient_attention()
746
- else:
747
- raise ValueError("xformers is not available. Make sure it is installed correctly")
748
-
749
- if args.gradient_checkpointing:
750
- unet.enable_gradient_checkpointing()
751
- if args.train_text_encoder:
752
- text_encoder.gradient_checkpointing_enable()
753
-
754
- # Check that all trainable models are in full precision
755
- low_precision_error_string = (
756
- "Please make sure to always have all model weights in full float32 precision when starting training - even if"
757
- " doing mixed precision training. copy of the weights should still be float32."
758
- )
759
-
760
- if accelerator.unwrap_model(unet).dtype != torch.float32:
761
- raise ValueError(
762
- f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
763
- )
764
-
765
- if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
766
- raise ValueError(
767
- f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
768
- f" {low_precision_error_string}"
769
- )
770
-
771
- # Enable TF32 for faster training on Ampere GPUs,
772
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
773
- if args.allow_tf32:
774
- torch.backends.cuda.matmul.allow_tf32 = True
775
-
776
- if args.scale_lr:
777
- args.learning_rate = (
778
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
779
- )
780
-
781
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
782
- if args.use_8bit_adam:
783
- try:
784
- import bitsandbytes as bnb
785
- except ImportError:
786
- raise ImportError(
787
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
788
- )
789
-
790
- optimizer_class = bnb.optim.AdamW8bit
791
- else:
792
- optimizer_class = torch.optim.AdamW
793
-
794
- # Optimizer creation
795
- params_to_optimize = (
796
- itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
797
- )
798
- optimizer = optimizer_class(
799
- params_to_optimize,
800
- lr=args.learning_rate,
801
- betas=(args.adam_beta1, args.adam_beta2),
802
- weight_decay=args.adam_weight_decay,
803
- eps=args.adam_epsilon,
804
- )
805
-
806
- # Dataset and DataLoaders creation:
807
- train_dataset = DreamBoothDataset(
808
- instance_data_root=args.instance_data_dir,
809
- instance_prompt=args.instance_prompt,
810
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
811
- class_prompt=args.class_prompt,
812
- class_num=args.num_class_images,
813
- tokenizer=tokenizer,
814
- size=args.resolution,
815
- center_crop=args.center_crop,
816
- )
817
-
818
- train_dataloader = torch.utils.data.DataLoader(
819
- train_dataset,
820
- batch_size=args.train_batch_size,
821
- shuffle=True,
822
- collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
823
- num_workers=args.dataloader_num_workers,
824
- )
825
-
826
- # Scheduler and math around the number of training steps.
827
- overrode_max_train_steps = False
828
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
829
- if args.max_train_steps is None:
830
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
831
- overrode_max_train_steps = True
832
-
833
- lr_scheduler = get_scheduler(
834
- args.lr_scheduler,
835
- optimizer=optimizer,
836
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
837
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
838
- num_cycles=args.lr_num_cycles,
839
- power=args.lr_power,
840
- )
841
-
842
- # Prepare everything with our `accelerator`.
843
- if args.train_text_encoder:
844
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
845
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler
846
- )
847
- else:
848
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
849
- unet, optimizer, train_dataloader, lr_scheduler
850
- )
851
-
852
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
853
- # as these models are only used for inference, keeping weights in full precision is not required.
854
- weight_dtype = torch.float32
855
- if accelerator.mixed_precision == "fp16":
856
- weight_dtype = torch.float16
857
- elif accelerator.mixed_precision == "bf16":
858
- weight_dtype = torch.bfloat16
859
-
860
- # Move vae and text_encoder to device and cast to weight_dtype
861
- vae.to(accelerator.device, dtype=weight_dtype)
862
- if not args.train_text_encoder:
863
- text_encoder.to(accelerator.device, dtype=weight_dtype)
864
-
865
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
866
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
867
- if overrode_max_train_steps:
868
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
869
- # Afterwards we recalculate our number of training epochs
870
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
871
-
872
- # We need to initialize the trackers we use, and also store our configuration.
873
- # The trackers initializes automatically on the main process.
874
- if accelerator.is_main_process:
875
- accelerator.init_trackers("dreambooth", config=vars(args))
876
-
877
- # Train!
878
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
879
-
880
- logger.info("***** Running training *****")
881
- logger.info(f" Num examples = {len(train_dataset)}")
882
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
883
- logger.info(f" Num Epochs = {args.num_train_epochs}")
884
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
885
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
886
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
887
- logger.info(f" Total optimization steps = {args.max_train_steps}")
888
- global_step = 0
889
- first_epoch = 0
890
-
891
- # Potentially load in the weights and states from a previous save
892
- if args.resume_from_checkpoint:
893
- if args.resume_from_checkpoint != "latest":
894
- path = os.path.basename(args.resume_from_checkpoint)
895
- else:
896
- # Get the mos recent checkpoint
897
- dirs = os.listdir(args.output_dir)
898
- dirs = [d for d in dirs if d.startswith("checkpoint")]
899
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
900
- path = dirs[-1] if len(dirs) > 0 else None
901
-
902
- if path is None:
903
- accelerator.print(
904
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
905
- )
906
- args.resume_from_checkpoint = None
907
- else:
908
- accelerator.print(f"Resuming from checkpoint {path}")
909
- accelerator.load_state(os.path.join(args.output_dir, path))
910
- global_step = int(path.split("-")[1])
911
-
912
- resume_global_step = global_step * args.gradient_accumulation_steps
913
- first_epoch = global_step // num_update_steps_per_epoch
914
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
915
-
916
- # Only show the progress bar once on each machine.
917
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
918
- progress_bar.set_description("Steps")
919
-
920
- for epoch in range(first_epoch, args.num_train_epochs):
921
- unet.train()
922
- if args.train_text_encoder:
923
- text_encoder.train()
924
- for step, batch in enumerate(train_dataloader):
925
- # Skip steps until we reach the resumed step
926
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
927
- if step % args.gradient_accumulation_steps == 0:
928
- progress_bar.update(1)
929
- continue
930
-
931
- with accelerator.accumulate(unet):
932
- # Convert images to latent space
933
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
934
- latents = latents * vae.config.scaling_factor
935
-
936
- # Sample noise that we'll add to the latents
937
- if args.offset_noise:
938
- noise = torch.randn_like(latents) + 0.1 * torch.randn(
939
- latents.shape[0], latents.shape[1], 1, 1, device=latents.device
940
- )
941
- else:
942
- noise = torch.randn_like(latents)
943
- bsz = latents.shape[0]
944
- # Sample a random timestep for each image
945
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
946
- timesteps = timesteps.long()
947
-
948
- # Add noise to the latents according to the noise magnitude at each timestep
949
- # (this is the forward diffusion process)
950
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
951
-
952
- # Get the text embedding for conditioning
953
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
954
-
955
- # Predict the noise residual
956
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
957
-
958
- # Get the target for loss depending on the prediction type
959
- if noise_scheduler.config.prediction_type == "epsilon":
960
- target = noise
961
- elif noise_scheduler.config.prediction_type == "v_prediction":
962
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
963
- else:
964
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
965
-
966
- if args.with_prior_preservation:
967
- # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
968
- model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
969
- target, target_prior = torch.chunk(target, 2, dim=0)
970
-
971
- # Compute instance loss
972
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
973
-
974
- # Compute prior loss
975
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
976
-
977
- # Add the prior loss to the instance loss.
978
- loss = loss + args.prior_loss_weight * prior_loss
979
- else:
980
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
981
-
982
- accelerator.backward(loss)
983
- if accelerator.sync_gradients:
984
- params_to_clip = (
985
- itertools.chain(unet.parameters(), text_encoder.parameters())
986
- if args.train_text_encoder
987
- else unet.parameters()
988
- )
989
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
990
- optimizer.step()
991
- lr_scheduler.step()
992
- optimizer.zero_grad(set_to_none=args.set_grads_to_none)
993
-
994
- # Checks if the accelerator has performed an optimization step behind the scenes
995
- if accelerator.sync_gradients:
996
- progress_bar.update(1)
997
- global_step += 1
998
-
999
- if accelerator.is_main_process:
1000
- if global_step % args.checkpointing_steps == 0:
1001
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1002
- accelerator.save_state(save_path)
1003
- logger.info(f"Saved state to {save_path}")
1004
-
1005
- if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1006
- log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
1007
-
1008
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1009
- progress_bar.set_postfix(**logs)
1010
- accelerator.log(logs, step=global_step)
1011
-
1012
- if global_step >= args.max_train_steps:
1013
- break
1014
-
1015
- # Create the pipeline using using the trained modules and save it.
1016
- accelerator.wait_for_everyone()
1017
- if accelerator.is_main_process:
1018
- pipeline = DiffusionPipeline.from_pretrained(
1019
- args.pretrained_model_name_or_path,
1020
- unet=accelerator.unwrap_model(unet),
1021
- text_encoder=accelerator.unwrap_model(text_encoder),
1022
- revision=args.revision,
1023
- )
1024
- pipeline.save_pretrained(args.output_dir)
1025
-
1026
- if args.push_to_hub:
1027
- upload_folder(
1028
- repo_id=repo_id,
1029
- folder_path=args.output_dir,
1030
- commit_message="End of training",
1031
- ignore_patterns=["step_*", "epoch_*"],
1032
- )
1033
-
1034
- accelerator.end_training()
1035
-
1036
-
1037
- if __name__ == "__main__":
1038
- args = parse_args()
1039
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/dreambooth/train_dreambooth_flax.py DELETED
@@ -1,709 +0,0 @@
1
- import argparse
2
- import hashlib
3
- import logging
4
- import math
5
- import os
6
- from pathlib import Path
7
- from typing import Optional
8
-
9
- import jax
10
- import jax.numpy as jnp
11
- import numpy as np
12
- import optax
13
- import torch
14
- import torch.utils.checkpoint
15
- import transformers
16
- from flax import jax_utils
17
- from flax.training import train_state
18
- from flax.training.common_utils import shard
19
- from huggingface_hub import HfFolder, Repository, create_repo, whoami
20
- from jax.experimental.compilation_cache import compilation_cache as cc
21
- from PIL import Image
22
- from torch.utils.data import Dataset
23
- from torchvision import transforms
24
- from tqdm.auto import tqdm
25
- from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
26
-
27
- from diffusers import (
28
- FlaxAutoencoderKL,
29
- FlaxDDPMScheduler,
30
- FlaxPNDMScheduler,
31
- FlaxStableDiffusionPipeline,
32
- FlaxUNet2DConditionModel,
33
- )
34
- from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
35
- from diffusers.utils import check_min_version
36
-
37
-
38
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
- check_min_version("0.15.0.dev0")
40
-
41
- # Cache compiled models across invocations of this script.
42
- cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
43
-
44
- logger = logging.getLogger(__name__)
45
-
46
-
47
- def parse_args():
48
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
49
- parser.add_argument(
50
- "--pretrained_model_name_or_path",
51
- type=str,
52
- default=None,
53
- required=True,
54
- help="Path to pretrained model or model identifier from huggingface.co/models.",
55
- )
56
- parser.add_argument(
57
- "--pretrained_vae_name_or_path",
58
- type=str,
59
- default=None,
60
- help="Path to pretrained vae or vae identifier from huggingface.co/models.",
61
- )
62
- parser.add_argument(
63
- "--revision",
64
- type=str,
65
- default=None,
66
- required=False,
67
- help="Revision of pretrained model identifier from huggingface.co/models.",
68
- )
69
- parser.add_argument(
70
- "--tokenizer_name",
71
- type=str,
72
- default=None,
73
- help="Pretrained tokenizer name or path if not the same as model_name",
74
- )
75
- parser.add_argument(
76
- "--instance_data_dir",
77
- type=str,
78
- default=None,
79
- required=True,
80
- help="A folder containing the training data of instance images.",
81
- )
82
- parser.add_argument(
83
- "--class_data_dir",
84
- type=str,
85
- default=None,
86
- required=False,
87
- help="A folder containing the training data of class images.",
88
- )
89
- parser.add_argument(
90
- "--instance_prompt",
91
- type=str,
92
- default=None,
93
- help="The prompt with identifier specifying the instance",
94
- )
95
- parser.add_argument(
96
- "--class_prompt",
97
- type=str,
98
- default=None,
99
- help="The prompt to specify images in the same class as provided instance images.",
100
- )
101
- parser.add_argument(
102
- "--with_prior_preservation",
103
- default=False,
104
- action="store_true",
105
- help="Flag to add prior preservation loss.",
106
- )
107
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
108
- parser.add_argument(
109
- "--num_class_images",
110
- type=int,
111
- default=100,
112
- help=(
113
- "Minimal class images for prior preservation loss. If there are not enough images already present in"
114
- " class_data_dir, additional images will be sampled with class_prompt."
115
- ),
116
- )
117
- parser.add_argument(
118
- "--output_dir",
119
- type=str,
120
- default="text-inversion-model",
121
- help="The output directory where the model predictions and checkpoints will be written.",
122
- )
123
- parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
124
- parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
125
- parser.add_argument(
126
- "--resolution",
127
- type=int,
128
- default=512,
129
- help=(
130
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
131
- " resolution"
132
- ),
133
- )
134
- parser.add_argument(
135
- "--center_crop",
136
- default=False,
137
- action="store_true",
138
- help=(
139
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
140
- " cropped. The images will be resized to the resolution first before cropping."
141
- ),
142
- )
143
- parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
144
- parser.add_argument(
145
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
146
- )
147
- parser.add_argument(
148
- "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
149
- )
150
- parser.add_argument("--num_train_epochs", type=int, default=1)
151
- parser.add_argument(
152
- "--max_train_steps",
153
- type=int,
154
- default=None,
155
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
156
- )
157
- parser.add_argument(
158
- "--learning_rate",
159
- type=float,
160
- default=5e-6,
161
- help="Initial learning rate (after the potential warmup period) to use.",
162
- )
163
- parser.add_argument(
164
- "--scale_lr",
165
- action="store_true",
166
- default=False,
167
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
168
- )
169
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
170
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
171
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
172
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
173
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
174
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
175
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
176
- parser.add_argument(
177
- "--hub_model_id",
178
- type=str,
179
- default=None,
180
- help="The name of the repository to keep in sync with the local `output_dir`.",
181
- )
182
- parser.add_argument(
183
- "--logging_dir",
184
- type=str,
185
- default="logs",
186
- help=(
187
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
188
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
189
- ),
190
- )
191
- parser.add_argument(
192
- "--mixed_precision",
193
- type=str,
194
- default="no",
195
- choices=["no", "fp16", "bf16"],
196
- help=(
197
- "Whether to use mixed precision. Choose"
198
- "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
199
- "and an Nvidia Ampere GPU."
200
- ),
201
- )
202
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
203
-
204
- args = parser.parse_args()
205
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
206
- if env_local_rank != -1 and env_local_rank != args.local_rank:
207
- args.local_rank = env_local_rank
208
-
209
- if args.instance_data_dir is None:
210
- raise ValueError("You must specify a train data directory.")
211
-
212
- if args.with_prior_preservation:
213
- if args.class_data_dir is None:
214
- raise ValueError("You must specify a data directory for class images.")
215
- if args.class_prompt is None:
216
- raise ValueError("You must specify prompt for class images.")
217
-
218
- return args
219
-
220
-
221
- class DreamBoothDataset(Dataset):
222
- """
223
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
224
- It pre-processes the images and the tokenizes prompts.
225
- """
226
-
227
- def __init__(
228
- self,
229
- instance_data_root,
230
- instance_prompt,
231
- tokenizer,
232
- class_data_root=None,
233
- class_prompt=None,
234
- class_num=None,
235
- size=512,
236
- center_crop=False,
237
- ):
238
- self.size = size
239
- self.center_crop = center_crop
240
- self.tokenizer = tokenizer
241
-
242
- self.instance_data_root = Path(instance_data_root)
243
- if not self.instance_data_root.exists():
244
- raise ValueError("Instance images root doesn't exists.")
245
-
246
- self.instance_images_path = list(Path(instance_data_root).iterdir())
247
- self.num_instance_images = len(self.instance_images_path)
248
- self.instance_prompt = instance_prompt
249
- self._length = self.num_instance_images
250
-
251
- if class_data_root is not None:
252
- self.class_data_root = Path(class_data_root)
253
- self.class_data_root.mkdir(parents=True, exist_ok=True)
254
- self.class_images_path = list(self.class_data_root.iterdir())
255
- if class_num is not None:
256
- self.num_class_images = min(len(self.class_images_path), class_num)
257
- else:
258
- self.num_class_images = len(self.class_images_path)
259
- self._length = max(self.num_class_images, self.num_instance_images)
260
- self.class_prompt = class_prompt
261
- else:
262
- self.class_data_root = None
263
-
264
- self.image_transforms = transforms.Compose(
265
- [
266
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
267
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
268
- transforms.ToTensor(),
269
- transforms.Normalize([0.5], [0.5]),
270
- ]
271
- )
272
-
273
- def __len__(self):
274
- return self._length
275
-
276
- def __getitem__(self, index):
277
- example = {}
278
- instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
279
- if not instance_image.mode == "RGB":
280
- instance_image = instance_image.convert("RGB")
281
- example["instance_images"] = self.image_transforms(instance_image)
282
- example["instance_prompt_ids"] = self.tokenizer(
283
- self.instance_prompt,
284
- padding="do_not_pad",
285
- truncation=True,
286
- max_length=self.tokenizer.model_max_length,
287
- ).input_ids
288
-
289
- if self.class_data_root:
290
- class_image = Image.open(self.class_images_path[index % self.num_class_images])
291
- if not class_image.mode == "RGB":
292
- class_image = class_image.convert("RGB")
293
- example["class_images"] = self.image_transforms(class_image)
294
- example["class_prompt_ids"] = self.tokenizer(
295
- self.class_prompt,
296
- padding="do_not_pad",
297
- truncation=True,
298
- max_length=self.tokenizer.model_max_length,
299
- ).input_ids
300
-
301
- return example
302
-
303
-
304
- class PromptDataset(Dataset):
305
- "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
306
-
307
- def __init__(self, prompt, num_samples):
308
- self.prompt = prompt
309
- self.num_samples = num_samples
310
-
311
- def __len__(self):
312
- return self.num_samples
313
-
314
- def __getitem__(self, index):
315
- example = {}
316
- example["prompt"] = self.prompt
317
- example["index"] = index
318
- return example
319
-
320
-
321
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
322
- if token is None:
323
- token = HfFolder.get_token()
324
- if organization is None:
325
- username = whoami(token)["name"]
326
- return f"{username}/{model_id}"
327
- else:
328
- return f"{organization}/{model_id}"
329
-
330
-
331
- def get_params_to_save(params):
332
- return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
333
-
334
-
335
- def main():
336
- args = parse_args()
337
-
338
- logging.basicConfig(
339
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
340
- datefmt="%m/%d/%Y %H:%M:%S",
341
- level=logging.INFO,
342
- )
343
- # Setup logging, we only want one process per machine to log things on the screen.
344
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
345
- if jax.process_index() == 0:
346
- transformers.utils.logging.set_verbosity_info()
347
- else:
348
- transformers.utils.logging.set_verbosity_error()
349
-
350
- if args.seed is not None:
351
- set_seed(args.seed)
352
-
353
- rng = jax.random.PRNGKey(args.seed)
354
-
355
- if args.with_prior_preservation:
356
- class_images_dir = Path(args.class_data_dir)
357
- if not class_images_dir.exists():
358
- class_images_dir.mkdir(parents=True)
359
- cur_class_images = len(list(class_images_dir.iterdir()))
360
-
361
- if cur_class_images < args.num_class_images:
362
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
363
- args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
364
- )
365
- pipeline.set_progress_bar_config(disable=True)
366
-
367
- num_new_images = args.num_class_images - cur_class_images
368
- logger.info(f"Number of class images to sample: {num_new_images}.")
369
-
370
- sample_dataset = PromptDataset(args.class_prompt, num_new_images)
371
- total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
372
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
373
-
374
- for example in tqdm(
375
- sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
376
- ):
377
- prompt_ids = pipeline.prepare_inputs(example["prompt"])
378
- prompt_ids = shard(prompt_ids)
379
- p_params = jax_utils.replicate(params)
380
- rng = jax.random.split(rng)[0]
381
- sample_rng = jax.random.split(rng, jax.device_count())
382
- images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
383
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
384
- images = pipeline.numpy_to_pil(np.array(images))
385
-
386
- for i, image in enumerate(images):
387
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
388
- image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
389
- image.save(image_filename)
390
-
391
- del pipeline
392
-
393
- # Handle the repository creation
394
- if jax.process_index() == 0:
395
- if args.push_to_hub:
396
- if args.hub_model_id is None:
397
- repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
398
- else:
399
- repo_name = args.hub_model_id
400
- create_repo(repo_name, exist_ok=True, token=args.hub_token)
401
- repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
402
-
403
- with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
404
- if "step_*" not in gitignore:
405
- gitignore.write("step_*\n")
406
- if "epoch_*" not in gitignore:
407
- gitignore.write("epoch_*\n")
408
- elif args.output_dir is not None:
409
- os.makedirs(args.output_dir, exist_ok=True)
410
-
411
- # Load the tokenizer and add the placeholder token as a additional special token
412
- if args.tokenizer_name:
413
- tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
414
- elif args.pretrained_model_name_or_path:
415
- tokenizer = CLIPTokenizer.from_pretrained(
416
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
417
- )
418
- else:
419
- raise NotImplementedError("No tokenizer specified!")
420
-
421
- train_dataset = DreamBoothDataset(
422
- instance_data_root=args.instance_data_dir,
423
- instance_prompt=args.instance_prompt,
424
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
425
- class_prompt=args.class_prompt,
426
- class_num=args.num_class_images,
427
- tokenizer=tokenizer,
428
- size=args.resolution,
429
- center_crop=args.center_crop,
430
- )
431
-
432
- def collate_fn(examples):
433
- input_ids = [example["instance_prompt_ids"] for example in examples]
434
- pixel_values = [example["instance_images"] for example in examples]
435
-
436
- # Concat class and instance examples for prior preservation.
437
- # We do this to avoid doing two forward passes.
438
- if args.with_prior_preservation:
439
- input_ids += [example["class_prompt_ids"] for example in examples]
440
- pixel_values += [example["class_images"] for example in examples]
441
-
442
- pixel_values = torch.stack(pixel_values)
443
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
444
-
445
- input_ids = tokenizer.pad(
446
- {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
447
- ).input_ids
448
-
449
- batch = {
450
- "input_ids": input_ids,
451
- "pixel_values": pixel_values,
452
- }
453
- batch = {k: v.numpy() for k, v in batch.items()}
454
- return batch
455
-
456
- total_train_batch_size = args.train_batch_size * jax.local_device_count()
457
- if len(train_dataset) < total_train_batch_size:
458
- raise ValueError(
459
- f"Training batch size is {total_train_batch_size}, but your dataset only contains"
460
- f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
461
- f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
462
- )
463
-
464
- train_dataloader = torch.utils.data.DataLoader(
465
- train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
466
- )
467
-
468
- weight_dtype = jnp.float32
469
- if args.mixed_precision == "fp16":
470
- weight_dtype = jnp.float16
471
- elif args.mixed_precision == "bf16":
472
- weight_dtype = jnp.bfloat16
473
-
474
- if args.pretrained_vae_name_or_path:
475
- # TODO(patil-suraj): Upload flax weights for the VAE
476
- vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
477
- else:
478
- vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
479
-
480
- # Load models and create wrapper for stable diffusion
481
- text_encoder = FlaxCLIPTextModel.from_pretrained(
482
- args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
483
- )
484
- vae, vae_params = FlaxAutoencoderKL.from_pretrained(
485
- vae_arg,
486
- dtype=weight_dtype,
487
- **vae_kwargs,
488
- )
489
- unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
490
- args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
491
- )
492
-
493
- # Optimization
494
- if args.scale_lr:
495
- args.learning_rate = args.learning_rate * total_train_batch_size
496
-
497
- constant_scheduler = optax.constant_schedule(args.learning_rate)
498
-
499
- adamw = optax.adamw(
500
- learning_rate=constant_scheduler,
501
- b1=args.adam_beta1,
502
- b2=args.adam_beta2,
503
- eps=args.adam_epsilon,
504
- weight_decay=args.adam_weight_decay,
505
- )
506
-
507
- optimizer = optax.chain(
508
- optax.clip_by_global_norm(args.max_grad_norm),
509
- adamw,
510
- )
511
-
512
- unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
513
- text_encoder_state = train_state.TrainState.create(
514
- apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
515
- )
516
-
517
- noise_scheduler = FlaxDDPMScheduler(
518
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
519
- )
520
- noise_scheduler_state = noise_scheduler.create_state()
521
-
522
- # Initialize our training
523
- train_rngs = jax.random.split(rng, jax.local_device_count())
524
-
525
- def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
526
- dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
527
-
528
- if args.train_text_encoder:
529
- params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
530
- else:
531
- params = {"unet": unet_state.params}
532
-
533
- def compute_loss(params):
534
- # Convert images to latent space
535
- vae_outputs = vae.apply(
536
- {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
537
- )
538
- latents = vae_outputs.latent_dist.sample(sample_rng)
539
- # (NHWC) -> (NCHW)
540
- latents = jnp.transpose(latents, (0, 3, 1, 2))
541
- latents = latents * vae.config.scaling_factor
542
-
543
- # Sample noise that we'll add to the latents
544
- noise_rng, timestep_rng = jax.random.split(sample_rng)
545
- noise = jax.random.normal(noise_rng, latents.shape)
546
- # Sample a random timestep for each image
547
- bsz = latents.shape[0]
548
- timesteps = jax.random.randint(
549
- timestep_rng,
550
- (bsz,),
551
- 0,
552
- noise_scheduler.config.num_train_timesteps,
553
- )
554
-
555
- # Add noise to the latents according to the noise magnitude at each timestep
556
- # (this is the forward diffusion process)
557
- noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
558
-
559
- # Get the text embedding for conditioning
560
- if args.train_text_encoder:
561
- encoder_hidden_states = text_encoder_state.apply_fn(
562
- batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
563
- )[0]
564
- else:
565
- encoder_hidden_states = text_encoder(
566
- batch["input_ids"], params=text_encoder_state.params, train=False
567
- )[0]
568
-
569
- # Predict the noise residual
570
- model_pred = unet.apply(
571
- {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
572
- ).sample
573
-
574
- # Get the target for loss depending on the prediction type
575
- if noise_scheduler.config.prediction_type == "epsilon":
576
- target = noise
577
- elif noise_scheduler.config.prediction_type == "v_prediction":
578
- target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
579
- else:
580
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
581
-
582
- if args.with_prior_preservation:
583
- # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
584
- model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
585
- target, target_prior = jnp.split(target, 2, axis=0)
586
-
587
- # Compute instance loss
588
- loss = (target - model_pred) ** 2
589
- loss = loss.mean()
590
-
591
- # Compute prior loss
592
- prior_loss = (target_prior - model_pred_prior) ** 2
593
- prior_loss = prior_loss.mean()
594
-
595
- # Add the prior loss to the instance loss.
596
- loss = loss + args.prior_loss_weight * prior_loss
597
- else:
598
- loss = (target - model_pred) ** 2
599
- loss = loss.mean()
600
-
601
- return loss
602
-
603
- grad_fn = jax.value_and_grad(compute_loss)
604
- loss, grad = grad_fn(params)
605
- grad = jax.lax.pmean(grad, "batch")
606
-
607
- new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
608
- if args.train_text_encoder:
609
- new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
610
- else:
611
- new_text_encoder_state = text_encoder_state
612
-
613
- metrics = {"loss": loss}
614
- metrics = jax.lax.pmean(metrics, axis_name="batch")
615
-
616
- return new_unet_state, new_text_encoder_state, metrics, new_train_rng
617
-
618
- # Create parallel version of the train step
619
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
620
-
621
- # Replicate the train state on each device
622
- unet_state = jax_utils.replicate(unet_state)
623
- text_encoder_state = jax_utils.replicate(text_encoder_state)
624
- vae_params = jax_utils.replicate(vae_params)
625
-
626
- # Train!
627
- num_update_steps_per_epoch = math.ceil(len(train_dataloader))
628
-
629
- # Scheduler and math around the number of training steps.
630
- if args.max_train_steps is None:
631
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
632
-
633
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
634
-
635
- logger.info("***** Running training *****")
636
- logger.info(f" Num examples = {len(train_dataset)}")
637
- logger.info(f" Num Epochs = {args.num_train_epochs}")
638
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
639
- logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
640
- logger.info(f" Total optimization steps = {args.max_train_steps}")
641
-
642
- def checkpoint(step=None):
643
- # Create the pipeline using the trained modules and save it.
644
- scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
645
- safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
646
- "CompVis/stable-diffusion-safety-checker", from_pt=True
647
- )
648
- pipeline = FlaxStableDiffusionPipeline(
649
- text_encoder=text_encoder,
650
- vae=vae,
651
- unet=unet,
652
- tokenizer=tokenizer,
653
- scheduler=scheduler,
654
- safety_checker=safety_checker,
655
- feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
656
- )
657
-
658
- outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
659
- pipeline.save_pretrained(
660
- outdir,
661
- params={
662
- "text_encoder": get_params_to_save(text_encoder_state.params),
663
- "vae": get_params_to_save(vae_params),
664
- "unet": get_params_to_save(unet_state.params),
665
- "safety_checker": safety_checker.params,
666
- },
667
- )
668
-
669
- if args.push_to_hub:
670
- message = f"checkpoint-{step}" if step is not None else "End of training"
671
- repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
672
-
673
- global_step = 0
674
-
675
- epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
676
- for epoch in epochs:
677
- # ======================== Training ================================
678
-
679
- train_metrics = []
680
-
681
- steps_per_epoch = len(train_dataset) // total_train_batch_size
682
- train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
683
- # train
684
- for batch in train_dataloader:
685
- batch = shard(batch)
686
- unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
687
- unet_state, text_encoder_state, vae_params, batch, train_rngs
688
- )
689
- train_metrics.append(train_metric)
690
-
691
- train_step_progress_bar.update(jax.local_device_count())
692
-
693
- global_step += 1
694
- if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
695
- checkpoint(global_step)
696
- if global_step >= args.max_train_steps:
697
- break
698
-
699
- train_metric = jax_utils.unreplicate(train_metric)
700
-
701
- train_step_progress_bar.close()
702
- epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
703
-
704
- if jax.process_index() == 0:
705
- checkpoint()
706
-
707
-
708
- if __name__ == "__main__":
709
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/dreambooth/train_dreambooth_lora.py DELETED
@@ -1,1028 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import hashlib
18
- import logging
19
- import math
20
- import os
21
- import warnings
22
- from pathlib import Path
23
-
24
- import numpy as np
25
- import torch
26
- import torch.nn.functional as F
27
- import torch.utils.checkpoint
28
- import transformers
29
- from accelerate import Accelerator
30
- from accelerate.logging import get_logger
31
- from accelerate.utils import ProjectConfiguration, set_seed
32
- from huggingface_hub import create_repo, upload_folder
33
- from packaging import version
34
- from PIL import Image
35
- from torch.utils.data import Dataset
36
- from torchvision import transforms
37
- from tqdm.auto import tqdm
38
- from transformers import AutoTokenizer, PretrainedConfig
39
-
40
- import diffusers
41
- from diffusers import (
42
- AutoencoderKL,
43
- DDPMScheduler,
44
- DiffusionPipeline,
45
- DPMSolverMultistepScheduler,
46
- UNet2DConditionModel,
47
- )
48
- from diffusers.loaders import AttnProcsLayers
49
- from diffusers.models.attention_processor import LoRAAttnProcessor
50
- from diffusers.optimization import get_scheduler
51
- from diffusers.utils import check_min_version, is_wandb_available
52
- from diffusers.utils.import_utils import is_xformers_available
53
-
54
-
55
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56
- check_min_version("0.15.0.dev0")
57
-
58
- logger = get_logger(__name__)
59
-
60
-
61
- def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
62
- img_str = ""
63
- for i, image in enumerate(images):
64
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
65
- img_str += f"![img_{i}](./image_{i}.png)\n"
66
-
67
- yaml = f"""
68
- ---
69
- license: creativeml-openrail-m
70
- base_model: {base_model}
71
- instance_prompt: {prompt}
72
- tags:
73
- - stable-diffusion
74
- - stable-diffusion-diffusers
75
- - text-to-image
76
- - diffusers
77
- - lora
78
- inference: true
79
- ---
80
- """
81
- model_card = f"""
82
- # LoRA DreamBooth - {repo_id}
83
-
84
- These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
85
- {img_str}
86
- """
87
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
88
- f.write(yaml + model_card)
89
-
90
-
91
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
92
- text_encoder_config = PretrainedConfig.from_pretrained(
93
- pretrained_model_name_or_path,
94
- subfolder="text_encoder",
95
- revision=revision,
96
- )
97
- model_class = text_encoder_config.architectures[0]
98
-
99
- if model_class == "CLIPTextModel":
100
- from transformers import CLIPTextModel
101
-
102
- return CLIPTextModel
103
- elif model_class == "RobertaSeriesModelWithTransformation":
104
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
105
-
106
- return RobertaSeriesModelWithTransformation
107
- else:
108
- raise ValueError(f"{model_class} is not supported.")
109
-
110
-
111
- def parse_args(input_args=None):
112
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
113
- parser.add_argument(
114
- "--pretrained_model_name_or_path",
115
- type=str,
116
- default=None,
117
- required=True,
118
- help="Path to pretrained model or model identifier from huggingface.co/models.",
119
- )
120
- parser.add_argument(
121
- "--revision",
122
- type=str,
123
- default=None,
124
- required=False,
125
- help="Revision of pretrained model identifier from huggingface.co/models.",
126
- )
127
- parser.add_argument(
128
- "--tokenizer_name",
129
- type=str,
130
- default=None,
131
- help="Pretrained tokenizer name or path if not the same as model_name",
132
- )
133
- parser.add_argument(
134
- "--instance_data_dir",
135
- type=str,
136
- default=None,
137
- required=True,
138
- help="A folder containing the training data of instance images.",
139
- )
140
- parser.add_argument(
141
- "--class_data_dir",
142
- type=str,
143
- default=None,
144
- required=False,
145
- help="A folder containing the training data of class images.",
146
- )
147
- parser.add_argument(
148
- "--instance_prompt",
149
- type=str,
150
- default=None,
151
- required=True,
152
- help="The prompt with identifier specifying the instance",
153
- )
154
- parser.add_argument(
155
- "--class_prompt",
156
- type=str,
157
- default=None,
158
- help="The prompt to specify images in the same class as provided instance images.",
159
- )
160
- parser.add_argument(
161
- "--validation_prompt",
162
- type=str,
163
- default=None,
164
- help="A prompt that is used during validation to verify that the model is learning.",
165
- )
166
- parser.add_argument(
167
- "--num_validation_images",
168
- type=int,
169
- default=4,
170
- help="Number of images that should be generated during validation with `validation_prompt`.",
171
- )
172
- parser.add_argument(
173
- "--validation_epochs",
174
- type=int,
175
- default=50,
176
- help=(
177
- "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
178
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
179
- ),
180
- )
181
- parser.add_argument(
182
- "--with_prior_preservation",
183
- default=False,
184
- action="store_true",
185
- help="Flag to add prior preservation loss.",
186
- )
187
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
188
- parser.add_argument(
189
- "--num_class_images",
190
- type=int,
191
- default=100,
192
- help=(
193
- "Minimal class images for prior preservation loss. If there are not enough images already present in"
194
- " class_data_dir, additional images will be sampled with class_prompt."
195
- ),
196
- )
197
- parser.add_argument(
198
- "--output_dir",
199
- type=str,
200
- default="lora-dreambooth-model",
201
- help="The output directory where the model predictions and checkpoints will be written.",
202
- )
203
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
204
- parser.add_argument(
205
- "--resolution",
206
- type=int,
207
- default=512,
208
- help=(
209
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
210
- " resolution"
211
- ),
212
- )
213
- parser.add_argument(
214
- "--center_crop",
215
- default=False,
216
- action="store_true",
217
- help=(
218
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
219
- " cropped. The images will be resized to the resolution first before cropping."
220
- ),
221
- )
222
- parser.add_argument(
223
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
224
- )
225
- parser.add_argument(
226
- "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
227
- )
228
- parser.add_argument("--num_train_epochs", type=int, default=1)
229
- parser.add_argument(
230
- "--max_train_steps",
231
- type=int,
232
- default=None,
233
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
234
- )
235
- parser.add_argument(
236
- "--checkpointing_steps",
237
- type=int,
238
- default=500,
239
- help=(
240
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
241
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
242
- " training using `--resume_from_checkpoint`."
243
- ),
244
- )
245
- parser.add_argument(
246
- "--checkpoints_total_limit",
247
- type=int,
248
- default=None,
249
- help=(
250
- "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
251
- " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
252
- " for more docs"
253
- ),
254
- )
255
- parser.add_argument(
256
- "--resume_from_checkpoint",
257
- type=str,
258
- default=None,
259
- help=(
260
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
261
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
262
- ),
263
- )
264
- parser.add_argument(
265
- "--gradient_accumulation_steps",
266
- type=int,
267
- default=1,
268
- help="Number of updates steps to accumulate before performing a backward/update pass.",
269
- )
270
- parser.add_argument(
271
- "--gradient_checkpointing",
272
- action="store_true",
273
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
274
- )
275
- parser.add_argument(
276
- "--learning_rate",
277
- type=float,
278
- default=5e-4,
279
- help="Initial learning rate (after the potential warmup period) to use.",
280
- )
281
- parser.add_argument(
282
- "--scale_lr",
283
- action="store_true",
284
- default=False,
285
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
286
- )
287
- parser.add_argument(
288
- "--lr_scheduler",
289
- type=str,
290
- default="constant",
291
- help=(
292
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
293
- ' "constant", "constant_with_warmup"]'
294
- ),
295
- )
296
- parser.add_argument(
297
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
298
- )
299
- parser.add_argument(
300
- "--lr_num_cycles",
301
- type=int,
302
- default=1,
303
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
304
- )
305
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
306
- parser.add_argument(
307
- "--dataloader_num_workers",
308
- type=int,
309
- default=0,
310
- help=(
311
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
312
- ),
313
- )
314
- parser.add_argument(
315
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
316
- )
317
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
318
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
319
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
320
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
321
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
322
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
323
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
324
- parser.add_argument(
325
- "--hub_model_id",
326
- type=str,
327
- default=None,
328
- help="The name of the repository to keep in sync with the local `output_dir`.",
329
- )
330
- parser.add_argument(
331
- "--logging_dir",
332
- type=str,
333
- default="logs",
334
- help=(
335
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
336
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
337
- ),
338
- )
339
- parser.add_argument(
340
- "--allow_tf32",
341
- action="store_true",
342
- help=(
343
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
344
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
345
- ),
346
- )
347
- parser.add_argument(
348
- "--report_to",
349
- type=str,
350
- default="tensorboard",
351
- help=(
352
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
353
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
354
- ),
355
- )
356
- parser.add_argument(
357
- "--mixed_precision",
358
- type=str,
359
- default=None,
360
- choices=["no", "fp16", "bf16"],
361
- help=(
362
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
363
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
364
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
365
- ),
366
- )
367
- parser.add_argument(
368
- "--prior_generation_precision",
369
- type=str,
370
- default=None,
371
- choices=["no", "fp32", "fp16", "bf16"],
372
- help=(
373
- "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
374
- " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
375
- ),
376
- )
377
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
378
- parser.add_argument(
379
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
380
- )
381
-
382
- if input_args is not None:
383
- args = parser.parse_args(input_args)
384
- else:
385
- args = parser.parse_args()
386
-
387
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
388
- if env_local_rank != -1 and env_local_rank != args.local_rank:
389
- args.local_rank = env_local_rank
390
-
391
- if args.with_prior_preservation:
392
- if args.class_data_dir is None:
393
- raise ValueError("You must specify a data directory for class images.")
394
- if args.class_prompt is None:
395
- raise ValueError("You must specify prompt for class images.")
396
- else:
397
- # logger is not available yet
398
- if args.class_data_dir is not None:
399
- warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
400
- if args.class_prompt is not None:
401
- warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
402
-
403
- return args
404
-
405
-
406
- class DreamBoothDataset(Dataset):
407
- """
408
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
409
- It pre-processes the images and the tokenizes prompts.
410
- """
411
-
412
- def __init__(
413
- self,
414
- instance_data_root,
415
- instance_prompt,
416
- tokenizer,
417
- class_data_root=None,
418
- class_prompt=None,
419
- class_num=None,
420
- size=512,
421
- center_crop=False,
422
- ):
423
- self.size = size
424
- self.center_crop = center_crop
425
- self.tokenizer = tokenizer
426
-
427
- self.instance_data_root = Path(instance_data_root)
428
- if not self.instance_data_root.exists():
429
- raise ValueError("Instance images root doesn't exists.")
430
-
431
- self.instance_images_path = list(Path(instance_data_root).iterdir())
432
- self.num_instance_images = len(self.instance_images_path)
433
- self.instance_prompt = instance_prompt
434
- self._length = self.num_instance_images
435
-
436
- if class_data_root is not None:
437
- self.class_data_root = Path(class_data_root)
438
- self.class_data_root.mkdir(parents=True, exist_ok=True)
439
- self.class_images_path = list(self.class_data_root.iterdir())
440
- if class_num is not None:
441
- self.num_class_images = min(len(self.class_images_path), class_num)
442
- else:
443
- self.num_class_images = len(self.class_images_path)
444
- self._length = max(self.num_class_images, self.num_instance_images)
445
- self.class_prompt = class_prompt
446
- else:
447
- self.class_data_root = None
448
-
449
- self.image_transforms = transforms.Compose(
450
- [
451
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
452
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
453
- transforms.ToTensor(),
454
- transforms.Normalize([0.5], [0.5]),
455
- ]
456
- )
457
-
458
- def __len__(self):
459
- return self._length
460
-
461
- def __getitem__(self, index):
462
- example = {}
463
- instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
464
- if not instance_image.mode == "RGB":
465
- instance_image = instance_image.convert("RGB")
466
- example["instance_images"] = self.image_transforms(instance_image)
467
- example["instance_prompt_ids"] = self.tokenizer(
468
- self.instance_prompt,
469
- truncation=True,
470
- padding="max_length",
471
- max_length=self.tokenizer.model_max_length,
472
- return_tensors="pt",
473
- ).input_ids
474
-
475
- if self.class_data_root:
476
- class_image = Image.open(self.class_images_path[index % self.num_class_images])
477
- if not class_image.mode == "RGB":
478
- class_image = class_image.convert("RGB")
479
- example["class_images"] = self.image_transforms(class_image)
480
- example["class_prompt_ids"] = self.tokenizer(
481
- self.class_prompt,
482
- truncation=True,
483
- padding="max_length",
484
- max_length=self.tokenizer.model_max_length,
485
- return_tensors="pt",
486
- ).input_ids
487
-
488
- return example
489
-
490
-
491
- def collate_fn(examples, with_prior_preservation=False):
492
- input_ids = [example["instance_prompt_ids"] for example in examples]
493
- pixel_values = [example["instance_images"] for example in examples]
494
-
495
- # Concat class and instance examples for prior preservation.
496
- # We do this to avoid doing two forward passes.
497
- if with_prior_preservation:
498
- input_ids += [example["class_prompt_ids"] for example in examples]
499
- pixel_values += [example["class_images"] for example in examples]
500
-
501
- pixel_values = torch.stack(pixel_values)
502
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
503
-
504
- input_ids = torch.cat(input_ids, dim=0)
505
-
506
- batch = {
507
- "input_ids": input_ids,
508
- "pixel_values": pixel_values,
509
- }
510
- return batch
511
-
512
-
513
- class PromptDataset(Dataset):
514
- "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
515
-
516
- def __init__(self, prompt, num_samples):
517
- self.prompt = prompt
518
- self.num_samples = num_samples
519
-
520
- def __len__(self):
521
- return self.num_samples
522
-
523
- def __getitem__(self, index):
524
- example = {}
525
- example["prompt"] = self.prompt
526
- example["index"] = index
527
- return example
528
-
529
-
530
- def main(args):
531
- logging_dir = Path(args.output_dir, args.logging_dir)
532
-
533
- accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
534
-
535
- accelerator = Accelerator(
536
- gradient_accumulation_steps=args.gradient_accumulation_steps,
537
- mixed_precision=args.mixed_precision,
538
- log_with=args.report_to,
539
- logging_dir=logging_dir,
540
- project_config=accelerator_project_config,
541
- )
542
-
543
- if args.report_to == "wandb":
544
- if not is_wandb_available():
545
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
546
- import wandb
547
-
548
- # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
549
- # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
550
- # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
551
- # Make one log on every process with the configuration for debugging.
552
- logging.basicConfig(
553
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
554
- datefmt="%m/%d/%Y %H:%M:%S",
555
- level=logging.INFO,
556
- )
557
- logger.info(accelerator.state, main_process_only=False)
558
- if accelerator.is_local_main_process:
559
- transformers.utils.logging.set_verbosity_warning()
560
- diffusers.utils.logging.set_verbosity_info()
561
- else:
562
- transformers.utils.logging.set_verbosity_error()
563
- diffusers.utils.logging.set_verbosity_error()
564
-
565
- # If passed along, set the training seed now.
566
- if args.seed is not None:
567
- set_seed(args.seed)
568
-
569
- # Generate class images if prior preservation is enabled.
570
- if args.with_prior_preservation:
571
- class_images_dir = Path(args.class_data_dir)
572
- if not class_images_dir.exists():
573
- class_images_dir.mkdir(parents=True)
574
- cur_class_images = len(list(class_images_dir.iterdir()))
575
-
576
- if cur_class_images < args.num_class_images:
577
- torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
578
- if args.prior_generation_precision == "fp32":
579
- torch_dtype = torch.float32
580
- elif args.prior_generation_precision == "fp16":
581
- torch_dtype = torch.float16
582
- elif args.prior_generation_precision == "bf16":
583
- torch_dtype = torch.bfloat16
584
- pipeline = DiffusionPipeline.from_pretrained(
585
- args.pretrained_model_name_or_path,
586
- torch_dtype=torch_dtype,
587
- safety_checker=None,
588
- revision=args.revision,
589
- )
590
- pipeline.set_progress_bar_config(disable=True)
591
-
592
- num_new_images = args.num_class_images - cur_class_images
593
- logger.info(f"Number of class images to sample: {num_new_images}.")
594
-
595
- sample_dataset = PromptDataset(args.class_prompt, num_new_images)
596
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
597
-
598
- sample_dataloader = accelerator.prepare(sample_dataloader)
599
- pipeline.to(accelerator.device)
600
-
601
- for example in tqdm(
602
- sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
603
- ):
604
- images = pipeline(example["prompt"]).images
605
-
606
- for i, image in enumerate(images):
607
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
608
- image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
609
- image.save(image_filename)
610
-
611
- del pipeline
612
- if torch.cuda.is_available():
613
- torch.cuda.empty_cache()
614
-
615
- # Handle the repository creation
616
- if accelerator.is_main_process:
617
- if args.output_dir is not None:
618
- os.makedirs(args.output_dir, exist_ok=True)
619
-
620
- if args.push_to_hub:
621
- repo_id = create_repo(
622
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
623
- ).repo_id
624
-
625
- # Load the tokenizer
626
- if args.tokenizer_name:
627
- tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
628
- elif args.pretrained_model_name_or_path:
629
- tokenizer = AutoTokenizer.from_pretrained(
630
- args.pretrained_model_name_or_path,
631
- subfolder="tokenizer",
632
- revision=args.revision,
633
- use_fast=False,
634
- )
635
-
636
- # import correct text encoder class
637
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
638
-
639
- # Load scheduler and models
640
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
641
- text_encoder = text_encoder_cls.from_pretrained(
642
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
643
- )
644
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
645
- unet = UNet2DConditionModel.from_pretrained(
646
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
647
- )
648
-
649
- # We only train the additional adapter LoRA layers
650
- vae.requires_grad_(False)
651
- text_encoder.requires_grad_(False)
652
- unet.requires_grad_(False)
653
-
654
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
655
- # as these models are only used for inference, keeping weights in full precision is not required.
656
- weight_dtype = torch.float32
657
- if accelerator.mixed_precision == "fp16":
658
- weight_dtype = torch.float16
659
- elif accelerator.mixed_precision == "bf16":
660
- weight_dtype = torch.bfloat16
661
-
662
- # Move unet, vae and text_encoder to device and cast to weight_dtype
663
- unet.to(accelerator.device, dtype=weight_dtype)
664
- vae.to(accelerator.device, dtype=weight_dtype)
665
- text_encoder.to(accelerator.device, dtype=weight_dtype)
666
-
667
- if args.enable_xformers_memory_efficient_attention:
668
- if is_xformers_available():
669
- import xformers
670
-
671
- xformers_version = version.parse(xformers.__version__)
672
- if xformers_version == version.parse("0.0.16"):
673
- logger.warn(
674
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
675
- )
676
- unet.enable_xformers_memory_efficient_attention()
677
- else:
678
- raise ValueError("xformers is not available. Make sure it is installed correctly")
679
-
680
- # now we will add new LoRA weights to the attention layers
681
- # It's important to realize here how many attention weights will be added and of which sizes
682
- # The sizes of the attention layers consist only of two different variables:
683
- # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
684
- # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
685
-
686
- # Let's first see how many attention processors we will have to set.
687
- # For Stable Diffusion, it should be equal to:
688
- # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
689
- # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
690
- # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
691
- # => 32 layers
692
-
693
- # Set correct lora layers
694
- lora_attn_procs = {}
695
- for name in unet.attn_processors.keys():
696
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
697
- if name.startswith("mid_block"):
698
- hidden_size = unet.config.block_out_channels[-1]
699
- elif name.startswith("up_blocks"):
700
- block_id = int(name[len("up_blocks.")])
701
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
702
- elif name.startswith("down_blocks"):
703
- block_id = int(name[len("down_blocks.")])
704
- hidden_size = unet.config.block_out_channels[block_id]
705
-
706
- lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
707
-
708
- unet.set_attn_processor(lora_attn_procs)
709
- lora_layers = AttnProcsLayers(unet.attn_processors)
710
-
711
- accelerator.register_for_checkpointing(lora_layers)
712
-
713
- if args.scale_lr:
714
- args.learning_rate = (
715
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
716
- )
717
-
718
- # Enable TF32 for faster training on Ampere GPUs,
719
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
720
- if args.allow_tf32:
721
- torch.backends.cuda.matmul.allow_tf32 = True
722
-
723
- if args.scale_lr:
724
- args.learning_rate = (
725
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
726
- )
727
-
728
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
729
- if args.use_8bit_adam:
730
- try:
731
- import bitsandbytes as bnb
732
- except ImportError:
733
- raise ImportError(
734
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
735
- )
736
-
737
- optimizer_class = bnb.optim.AdamW8bit
738
- else:
739
- optimizer_class = torch.optim.AdamW
740
-
741
- # Optimizer creation
742
- optimizer = optimizer_class(
743
- lora_layers.parameters(),
744
- lr=args.learning_rate,
745
- betas=(args.adam_beta1, args.adam_beta2),
746
- weight_decay=args.adam_weight_decay,
747
- eps=args.adam_epsilon,
748
- )
749
-
750
- # Dataset and DataLoaders creation:
751
- train_dataset = DreamBoothDataset(
752
- instance_data_root=args.instance_data_dir,
753
- instance_prompt=args.instance_prompt,
754
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
755
- class_prompt=args.class_prompt,
756
- class_num=args.num_class_images,
757
- tokenizer=tokenizer,
758
- size=args.resolution,
759
- center_crop=args.center_crop,
760
- )
761
-
762
- train_dataloader = torch.utils.data.DataLoader(
763
- train_dataset,
764
- batch_size=args.train_batch_size,
765
- shuffle=True,
766
- collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
767
- num_workers=args.dataloader_num_workers,
768
- )
769
-
770
- # Scheduler and math around the number of training steps.
771
- overrode_max_train_steps = False
772
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
773
- if args.max_train_steps is None:
774
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
775
- overrode_max_train_steps = True
776
-
777
- lr_scheduler = get_scheduler(
778
- args.lr_scheduler,
779
- optimizer=optimizer,
780
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
781
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
782
- num_cycles=args.lr_num_cycles,
783
- power=args.lr_power,
784
- )
785
-
786
- # Prepare everything with our `accelerator`.
787
- lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
788
- lora_layers, optimizer, train_dataloader, lr_scheduler
789
- )
790
-
791
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
792
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
793
- if overrode_max_train_steps:
794
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
795
- # Afterwards we recalculate our number of training epochs
796
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
797
-
798
- # We need to initialize the trackers we use, and also store our configuration.
799
- # The trackers initializes automatically on the main process.
800
- if accelerator.is_main_process:
801
- accelerator.init_trackers("dreambooth-lora", config=vars(args))
802
-
803
- # Train!
804
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
805
-
806
- logger.info("***** Running training *****")
807
- logger.info(f" Num examples = {len(train_dataset)}")
808
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
809
- logger.info(f" Num Epochs = {args.num_train_epochs}")
810
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
811
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
812
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
813
- logger.info(f" Total optimization steps = {args.max_train_steps}")
814
- global_step = 0
815
- first_epoch = 0
816
-
817
- # Potentially load in the weights and states from a previous save
818
- if args.resume_from_checkpoint:
819
- if args.resume_from_checkpoint != "latest":
820
- path = os.path.basename(args.resume_from_checkpoint)
821
- else:
822
- # Get the mos recent checkpoint
823
- dirs = os.listdir(args.output_dir)
824
- dirs = [d for d in dirs if d.startswith("checkpoint")]
825
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
826
- path = dirs[-1] if len(dirs) > 0 else None
827
-
828
- if path is None:
829
- accelerator.print(
830
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
831
- )
832
- args.resume_from_checkpoint = None
833
- else:
834
- accelerator.print(f"Resuming from checkpoint {path}")
835
- accelerator.load_state(os.path.join(args.output_dir, path))
836
- global_step = int(path.split("-")[1])
837
-
838
- resume_global_step = global_step * args.gradient_accumulation_steps
839
- first_epoch = global_step // num_update_steps_per_epoch
840
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
841
-
842
- # Only show the progress bar once on each machine.
843
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
844
- progress_bar.set_description("Steps")
845
-
846
- for epoch in range(first_epoch, args.num_train_epochs):
847
- unet.train()
848
- for step, batch in enumerate(train_dataloader):
849
- # Skip steps until we reach the resumed step
850
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
851
- if step % args.gradient_accumulation_steps == 0:
852
- progress_bar.update(1)
853
- continue
854
-
855
- with accelerator.accumulate(unet):
856
- # Convert images to latent space
857
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
858
- latents = latents * vae.config.scaling_factor
859
-
860
- # Sample noise that we'll add to the latents
861
- noise = torch.randn_like(latents)
862
- bsz = latents.shape[0]
863
- # Sample a random timestep for each image
864
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
865
- timesteps = timesteps.long()
866
-
867
- # Add noise to the latents according to the noise magnitude at each timestep
868
- # (this is the forward diffusion process)
869
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
870
-
871
- # Get the text embedding for conditioning
872
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
873
-
874
- # Predict the noise residual
875
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
876
-
877
- # Get the target for loss depending on the prediction type
878
- if noise_scheduler.config.prediction_type == "epsilon":
879
- target = noise
880
- elif noise_scheduler.config.prediction_type == "v_prediction":
881
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
882
- else:
883
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
884
-
885
- if args.with_prior_preservation:
886
- # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
887
- model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
888
- target, target_prior = torch.chunk(target, 2, dim=0)
889
-
890
- # Compute instance loss
891
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
892
-
893
- # Compute prior loss
894
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
895
-
896
- # Add the prior loss to the instance loss.
897
- loss = loss + args.prior_loss_weight * prior_loss
898
- else:
899
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
900
-
901
- accelerator.backward(loss)
902
- if accelerator.sync_gradients:
903
- params_to_clip = lora_layers.parameters()
904
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
905
- optimizer.step()
906
- lr_scheduler.step()
907
- optimizer.zero_grad()
908
-
909
- # Checks if the accelerator has performed an optimization step behind the scenes
910
- if accelerator.sync_gradients:
911
- progress_bar.update(1)
912
- global_step += 1
913
-
914
- if global_step % args.checkpointing_steps == 0:
915
- if accelerator.is_main_process:
916
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
917
- accelerator.save_state(save_path)
918
- logger.info(f"Saved state to {save_path}")
919
-
920
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
921
- progress_bar.set_postfix(**logs)
922
- accelerator.log(logs, step=global_step)
923
-
924
- if global_step >= args.max_train_steps:
925
- break
926
-
927
- if accelerator.is_main_process:
928
- if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
929
- logger.info(
930
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
931
- f" {args.validation_prompt}."
932
- )
933
- # create pipeline
934
- pipeline = DiffusionPipeline.from_pretrained(
935
- args.pretrained_model_name_or_path,
936
- unet=accelerator.unwrap_model(unet),
937
- text_encoder=accelerator.unwrap_model(text_encoder),
938
- revision=args.revision,
939
- torch_dtype=weight_dtype,
940
- )
941
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
942
- pipeline = pipeline.to(accelerator.device)
943
- pipeline.set_progress_bar_config(disable=True)
944
-
945
- # run inference
946
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
947
- images = [
948
- pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
949
- for _ in range(args.num_validation_images)
950
- ]
951
-
952
- for tracker in accelerator.trackers:
953
- if tracker.name == "tensorboard":
954
- np_images = np.stack([np.asarray(img) for img in images])
955
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
956
- if tracker.name == "wandb":
957
- tracker.log(
958
- {
959
- "validation": [
960
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
961
- for i, image in enumerate(images)
962
- ]
963
- }
964
- )
965
-
966
- del pipeline
967
- torch.cuda.empty_cache()
968
-
969
- # Save the lora layers
970
- accelerator.wait_for_everyone()
971
- if accelerator.is_main_process:
972
- unet = unet.to(torch.float32)
973
- unet.save_attn_procs(args.output_dir)
974
-
975
- # Final inference
976
- # Load previous pipeline
977
- pipeline = DiffusionPipeline.from_pretrained(
978
- args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
979
- )
980
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
981
- pipeline = pipeline.to(accelerator.device)
982
-
983
- # load attention processors
984
- pipeline.unet.load_attn_procs(args.output_dir)
985
-
986
- # run inference
987
- if args.validation_prompt and args.num_validation_images > 0:
988
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
989
- images = [
990
- pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
991
- for _ in range(args.num_validation_images)
992
- ]
993
-
994
- for tracker in accelerator.trackers:
995
- if tracker.name == "tensorboard":
996
- np_images = np.stack([np.asarray(img) for img in images])
997
- tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
998
- if tracker.name == "wandb":
999
- tracker.log(
1000
- {
1001
- "test": [
1002
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1003
- for i, image in enumerate(images)
1004
- ]
1005
- }
1006
- )
1007
-
1008
- if args.push_to_hub:
1009
- save_model_card(
1010
- repo_id,
1011
- images=images,
1012
- base_model=args.pretrained_model_name_or_path,
1013
- prompt=args.instance_prompt,
1014
- repo_folder=args.output_dir,
1015
- )
1016
- upload_folder(
1017
- repo_id=repo_id,
1018
- folder_path=args.output_dir,
1019
- commit_message="End of training",
1020
- ignore_patterns=["step_*", "epoch_*"],
1021
- )
1022
-
1023
- accelerator.end_training()
1024
-
1025
-
1026
- if __name__ == "__main__":
1027
- args = parse_args()
1028
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/inference/README.md DELETED
@@ -1,8 +0,0 @@
1
- # Inference Examples
2
-
3
- **The inference examples folder is deprecated and will be removed in a future version**.
4
- **Officially supported inference examples can be found in the [Pipelines folder](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines)**.
5
-
6
- - For `Image-to-Image text-guided generation with Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
7
- - For `In-painting using Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
8
- - For `Tweak prompts reusing seeds and latents`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
 
 
 
 
 
 
 
 
 
diffusers/examples/inference/image_to_image.py DELETED
@@ -1,9 +0,0 @@
1
- import warnings
2
-
3
- from diffusers import StableDiffusionImg2ImgPipeline # noqa F401
4
-
5
-
6
- warnings.warn(
7
- "The `image_to_image.py` script is outdated. Please use directly `from diffusers import"
8
- " StableDiffusionImg2ImgPipeline` instead."
9
- )
 
 
 
 
 
 
 
 
 
 
diffusers/examples/inference/inpainting.py DELETED
@@ -1,9 +0,0 @@
1
- import warnings
2
-
3
- from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline # noqa F401
4
-
5
-
6
- warnings.warn(
7
- "The `inpainting.py` script is outdated. Please use directly `from diffusers import"
8
- " StableDiffusionInpaintPipeline` instead."
9
- )
 
 
 
 
 
 
 
 
 
 
diffusers/examples/instruct_pix2pix/README.md DELETED
@@ -1,166 +0,0 @@
1
- # InstructPix2Pix training example
2
-
3
- [InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:
4
-
5
- <p align="center">
6
- <img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png" alt="instructpix2pix-inputs" width=600/>
7
- </p>
8
-
9
- The output is an "edited" image that reflects the edit instruction applied on the input image:
10
-
11
- <p align="center">
12
- <img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/output-gs%407-igs%401-steps%4050.png" alt="instructpix2pix-output" width=600/>
13
- </p>
14
-
15
- The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.
16
-
17
- ***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix
18
- training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***
19
-
20
- ## Running locally with PyTorch
21
-
22
- ### Installing the dependencies
23
-
24
- Before running the scripts, make sure to install the library's training dependencies:
25
-
26
- **Important**
27
-
28
- To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
29
- ```bash
30
- git clone https://github.com/huggingface/diffusers
31
- cd diffusers
32
- pip install -e .
33
- ```
34
-
35
- Then cd in the example folder and run
36
- ```bash
37
- pip install -r requirements.txt
38
- ```
39
-
40
- And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
41
-
42
- ```bash
43
- accelerate config
44
- ```
45
-
46
- Or for a default accelerate configuration without answering questions about your environment
47
-
48
- ```bash
49
- accelerate config default
50
- ```
51
-
52
- Or if your environment doesn't support an interactive shell e.g. a notebook
53
-
54
- ```python
55
- from accelerate.utils import write_basic_config
56
- write_basic_config()
57
- ```
58
-
59
- ### Toy example
60
-
61
- As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
62
- is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
63
-
64
- Configure environment variables such as the dataset identifier and the Stable Diffusion
65
- checkpoint:
66
-
67
- ```bash
68
- export MODEL_NAME="runwayml/stable-diffusion-v1-5"
69
- export DATASET_ID="fusing/instructpix2pix-1000-samples"
70
- ```
71
-
72
- Now, we can launch training:
73
-
74
- ```bash
75
- accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
76
- --pretrained_model_name_or_path=$MODEL_NAME \
77
- --dataset_name=$DATASET_ID \
78
- --enable_xformers_memory_efficient_attention \
79
- --resolution=256 --random_flip \
80
- --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
81
- --max_train_steps=15000 \
82
- --checkpointing_steps=5000 --checkpoints_total_limit=1 \
83
- --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
84
- --conditioning_dropout_prob=0.05 \
85
- --mixed_precision=fp16 \
86
- --seed=42
87
- ```
88
-
89
- Additionally, we support performing validation inference to monitor training progress
90
- with Weights and Biases. You can enable this feature with `report_to="wandb"`:
91
-
92
- ```bash
93
- accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
94
- --pretrained_model_name_or_path=$MODEL_NAME \
95
- --dataset_name=$DATASET_ID \
96
- --enable_xformers_memory_efficient_attention \
97
- --resolution=256 --random_flip \
98
- --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
99
- --max_train_steps=15000 \
100
- --checkpointing_steps=5000 --checkpoints_total_limit=1 \
101
- --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
102
- --conditioning_dropout_prob=0.05 \
103
- --mixed_precision=fp16 \
104
- --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
105
- --validation_prompt="make the mountains snowy" \
106
- --seed=42 \
107
- --report_to=wandb
108
- ```
109
-
110
- We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
111
-
112
- [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
113
-
114
- ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
115
-
116
- ## Inference
117
-
118
- Once training is complete, we can perform inference:
119
-
120
- ```python
121
- import PIL
122
- import requests
123
- import torch
124
- from diffusers import StableDiffusionInstructPix2PixPipeline
125
-
126
- model_id = "your_model_id" # <- replace this
127
- pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
128
- generator = torch.Generator("cuda").manual_seed(0)
129
-
130
- url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
131
-
132
-
133
- def download_image(url):
134
- image = PIL.Image.open(requests.get(url, stream=True).raw)
135
- image = PIL.ImageOps.exif_transpose(image)
136
- image = image.convert("RGB")
137
- return image
138
-
139
- image = download_image(url)
140
- prompt = "wipe out the lake"
141
- num_inference_steps = 20
142
- image_guidance_scale = 1.5
143
- guidance_scale = 10
144
-
145
- edited_image = pipe(prompt,
146
- image=image,
147
- num_inference_steps=num_inference_steps,
148
- image_guidance_scale=image_guidance_scale,
149
- guidance_scale=guidance_scale,
150
- generator=generator,
151
- ).images[0]
152
- edited_image.save("edited_image.png")
153
- ```
154
-
155
- An example model repo obtained using this training script can be found
156
- here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).
157
-
158
- We encourage you to play with the following three parameters to control
159
- speed and quality during performance:
160
-
161
- * `num_inference_steps`
162
- * `image_guidance_scale`
163
- * `guidance_scale`
164
-
165
- Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
166
- on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/instruct_pix2pix/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- accelerate
2
- torchvision
3
- transformers>=4.25.1
4
- datasets
5
- ftfy
6
- tensorboard
 
 
 
 
 
 
 
diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py DELETED
@@ -1,988 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- """Script to fine-tune Stable Diffusion for InstructPix2Pix."""
18
-
19
- import argparse
20
- import logging
21
- import math
22
- import os
23
- from pathlib import Path
24
-
25
- import accelerate
26
- import datasets
27
- import numpy as np
28
- import PIL
29
- import requests
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- import torch.utils.checkpoint
34
- import transformers
35
- from accelerate import Accelerator
36
- from accelerate.logging import get_logger
37
- from accelerate.utils import ProjectConfiguration, set_seed
38
- from datasets import load_dataset
39
- from huggingface_hub import create_repo, upload_folder
40
- from packaging import version
41
- from torchvision import transforms
42
- from tqdm.auto import tqdm
43
- from transformers import CLIPTextModel, CLIPTokenizer
44
-
45
- import diffusers
46
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
47
- from diffusers.optimization import get_scheduler
48
- from diffusers.training_utils import EMAModel
49
- from diffusers.utils import check_min_version, deprecate, is_wandb_available
50
- from diffusers.utils.import_utils import is_xformers_available
51
-
52
-
53
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
54
- check_min_version("0.15.0.dev0")
55
-
56
- logger = get_logger(__name__, log_level="INFO")
57
-
58
- DATASET_NAME_MAPPING = {
59
- "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
60
- }
61
- WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
62
-
63
-
64
- def parse_args():
65
- parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
66
- parser.add_argument(
67
- "--pretrained_model_name_or_path",
68
- type=str,
69
- default=None,
70
- required=True,
71
- help="Path to pretrained model or model identifier from huggingface.co/models.",
72
- )
73
- parser.add_argument(
74
- "--revision",
75
- type=str,
76
- default=None,
77
- required=False,
78
- help="Revision of pretrained model identifier from huggingface.co/models.",
79
- )
80
- parser.add_argument(
81
- "--dataset_name",
82
- type=str,
83
- default=None,
84
- help=(
85
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
86
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
87
- " or to a folder containing files that 🤗 Datasets can understand."
88
- ),
89
- )
90
- parser.add_argument(
91
- "--dataset_config_name",
92
- type=str,
93
- default=None,
94
- help="The config of the Dataset, leave as None if there's only one config.",
95
- )
96
- parser.add_argument(
97
- "--train_data_dir",
98
- type=str,
99
- default=None,
100
- help=(
101
- "A folder containing the training data. Folder contents must follow the structure described in"
102
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
103
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
104
- ),
105
- )
106
- parser.add_argument(
107
- "--original_image_column",
108
- type=str,
109
- default="input_image",
110
- help="The column of the dataset containing the original image on which edits where made.",
111
- )
112
- parser.add_argument(
113
- "--edited_image_column",
114
- type=str,
115
- default="edited_image",
116
- help="The column of the dataset containing the edited image.",
117
- )
118
- parser.add_argument(
119
- "--edit_prompt_column",
120
- type=str,
121
- default="edit_prompt",
122
- help="The column of the dataset containing the edit instruction.",
123
- )
124
- parser.add_argument(
125
- "--val_image_url",
126
- type=str,
127
- default=None,
128
- help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
129
- )
130
- parser.add_argument(
131
- "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
132
- )
133
- parser.add_argument(
134
- "--num_validation_images",
135
- type=int,
136
- default=4,
137
- help="Number of images that should be generated during validation with `validation_prompt`.",
138
- )
139
- parser.add_argument(
140
- "--validation_epochs",
141
- type=int,
142
- default=1,
143
- help=(
144
- "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
145
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
146
- ),
147
- )
148
- parser.add_argument(
149
- "--max_train_samples",
150
- type=int,
151
- default=None,
152
- help=(
153
- "For debugging purposes or quicker training, truncate the number of training examples to this "
154
- "value if set."
155
- ),
156
- )
157
- parser.add_argument(
158
- "--output_dir",
159
- type=str,
160
- default="instruct-pix2pix-model",
161
- help="The output directory where the model predictions and checkpoints will be written.",
162
- )
163
- parser.add_argument(
164
- "--cache_dir",
165
- type=str,
166
- default=None,
167
- help="The directory where the downloaded models and datasets will be stored.",
168
- )
169
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
170
- parser.add_argument(
171
- "--resolution",
172
- type=int,
173
- default=256,
174
- help=(
175
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
176
- " resolution"
177
- ),
178
- )
179
- parser.add_argument(
180
- "--center_crop",
181
- default=False,
182
- action="store_true",
183
- help=(
184
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
185
- " cropped. The images will be resized to the resolution first before cropping."
186
- ),
187
- )
188
- parser.add_argument(
189
- "--random_flip",
190
- action="store_true",
191
- help="whether to randomly flip images horizontally",
192
- )
193
- parser.add_argument(
194
- "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
195
- )
196
- parser.add_argument("--num_train_epochs", type=int, default=100)
197
- parser.add_argument(
198
- "--max_train_steps",
199
- type=int,
200
- default=None,
201
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
202
- )
203
- parser.add_argument(
204
- "--gradient_accumulation_steps",
205
- type=int,
206
- default=1,
207
- help="Number of updates steps to accumulate before performing a backward/update pass.",
208
- )
209
- parser.add_argument(
210
- "--gradient_checkpointing",
211
- action="store_true",
212
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
213
- )
214
- parser.add_argument(
215
- "--learning_rate",
216
- type=float,
217
- default=1e-4,
218
- help="Initial learning rate (after the potential warmup period) to use.",
219
- )
220
- parser.add_argument(
221
- "--scale_lr",
222
- action="store_true",
223
- default=False,
224
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
225
- )
226
- parser.add_argument(
227
- "--lr_scheduler",
228
- type=str,
229
- default="constant",
230
- help=(
231
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
232
- ' "constant", "constant_with_warmup"]'
233
- ),
234
- )
235
- parser.add_argument(
236
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
237
- )
238
- parser.add_argument(
239
- "--conditioning_dropout_prob",
240
- type=float,
241
- default=None,
242
- help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
243
- )
244
- parser.add_argument(
245
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
246
- )
247
- parser.add_argument(
248
- "--allow_tf32",
249
- action="store_true",
250
- help=(
251
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
252
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
253
- ),
254
- )
255
- parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
256
- parser.add_argument(
257
- "--non_ema_revision",
258
- type=str,
259
- default=None,
260
- required=False,
261
- help=(
262
- "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
263
- " remote repository specified with --pretrained_model_name_or_path."
264
- ),
265
- )
266
- parser.add_argument(
267
- "--dataloader_num_workers",
268
- type=int,
269
- default=0,
270
- help=(
271
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
272
- ),
273
- )
274
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
275
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
276
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
277
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
278
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
279
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
280
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
281
- parser.add_argument(
282
- "--hub_model_id",
283
- type=str,
284
- default=None,
285
- help="The name of the repository to keep in sync with the local `output_dir`.",
286
- )
287
- parser.add_argument(
288
- "--logging_dir",
289
- type=str,
290
- default="logs",
291
- help=(
292
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
293
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
294
- ),
295
- )
296
- parser.add_argument(
297
- "--mixed_precision",
298
- type=str,
299
- default=None,
300
- choices=["no", "fp16", "bf16"],
301
- help=(
302
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
303
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
304
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
305
- ),
306
- )
307
- parser.add_argument(
308
- "--report_to",
309
- type=str,
310
- default="tensorboard",
311
- help=(
312
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
313
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
314
- ),
315
- )
316
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
317
- parser.add_argument(
318
- "--checkpointing_steps",
319
- type=int,
320
- default=500,
321
- help=(
322
- "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
323
- " training using `--resume_from_checkpoint`."
324
- ),
325
- )
326
- parser.add_argument(
327
- "--checkpoints_total_limit",
328
- type=int,
329
- default=None,
330
- help=(
331
- "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
332
- " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
333
- " for more docs"
334
- ),
335
- )
336
- parser.add_argument(
337
- "--resume_from_checkpoint",
338
- type=str,
339
- default=None,
340
- help=(
341
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
342
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
343
- ),
344
- )
345
- parser.add_argument(
346
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
347
- )
348
-
349
- args = parser.parse_args()
350
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
351
- if env_local_rank != -1 and env_local_rank != args.local_rank:
352
- args.local_rank = env_local_rank
353
-
354
- # Sanity checks
355
- if args.dataset_name is None and args.train_data_dir is None:
356
- raise ValueError("Need either a dataset name or a training folder.")
357
-
358
- # default to using the same revision for the non-ema model if not specified
359
- if args.non_ema_revision is None:
360
- args.non_ema_revision = args.revision
361
-
362
- return args
363
-
364
-
365
- def convert_to_np(image, resolution):
366
- image = image.convert("RGB").resize((resolution, resolution))
367
- return np.array(image).transpose(2, 0, 1)
368
-
369
-
370
- def download_image(url):
371
- image = PIL.Image.open(requests.get(url, stream=True).raw)
372
- image = PIL.ImageOps.exif_transpose(image)
373
- image = image.convert("RGB")
374
- return image
375
-
376
-
377
- def main():
378
- args = parse_args()
379
-
380
- if args.non_ema_revision is not None:
381
- deprecate(
382
- "non_ema_revision!=None",
383
- "0.15.0",
384
- message=(
385
- "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
386
- " use `--variant=non_ema` instead."
387
- ),
388
- )
389
- logging_dir = os.path.join(args.output_dir, args.logging_dir)
390
- accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
391
- accelerator = Accelerator(
392
- gradient_accumulation_steps=args.gradient_accumulation_steps,
393
- mixed_precision=args.mixed_precision,
394
- log_with=args.report_to,
395
- logging_dir=logging_dir,
396
- project_config=accelerator_project_config,
397
- )
398
-
399
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
400
-
401
- if args.report_to == "wandb":
402
- if not is_wandb_available():
403
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
404
- import wandb
405
-
406
- # Make one log on every process with the configuration for debugging.
407
- logging.basicConfig(
408
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
409
- datefmt="%m/%d/%Y %H:%M:%S",
410
- level=logging.INFO,
411
- )
412
- logger.info(accelerator.state, main_process_only=False)
413
- if accelerator.is_local_main_process:
414
- datasets.utils.logging.set_verbosity_warning()
415
- transformers.utils.logging.set_verbosity_warning()
416
- diffusers.utils.logging.set_verbosity_info()
417
- else:
418
- datasets.utils.logging.set_verbosity_error()
419
- transformers.utils.logging.set_verbosity_error()
420
- diffusers.utils.logging.set_verbosity_error()
421
-
422
- # If passed along, set the training seed now.
423
- if args.seed is not None:
424
- set_seed(args.seed)
425
-
426
- # Handle the repository creation
427
- if accelerator.is_main_process:
428
- if args.output_dir is not None:
429
- os.makedirs(args.output_dir, exist_ok=True)
430
-
431
- if args.push_to_hub:
432
- repo_id = create_repo(
433
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
434
- ).repo_id
435
-
436
- # Load scheduler, tokenizer and models.
437
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
438
- tokenizer = CLIPTokenizer.from_pretrained(
439
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
440
- )
441
- text_encoder = CLIPTextModel.from_pretrained(
442
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
443
- )
444
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
445
- unet = UNet2DConditionModel.from_pretrained(
446
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
447
- )
448
-
449
- # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
450
- # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
451
- # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
452
- # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
453
- # initialized to zero.
454
- if accelerator.is_main_process:
455
- logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
456
- in_channels = 8
457
- out_channels = unet.conv_in.out_channels
458
- unet.register_to_config(in_channels=in_channels)
459
-
460
- with torch.no_grad():
461
- new_conv_in = nn.Conv2d(
462
- in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
463
- )
464
- new_conv_in.weight.zero_()
465
- new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
466
- unet.conv_in = new_conv_in
467
-
468
- # Freeze vae and text_encoder
469
- vae.requires_grad_(False)
470
- text_encoder.requires_grad_(False)
471
-
472
- # Create EMA for the unet.
473
- if args.use_ema:
474
- ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
475
-
476
- if args.enable_xformers_memory_efficient_attention:
477
- if is_xformers_available():
478
- import xformers
479
-
480
- xformers_version = version.parse(xformers.__version__)
481
- if xformers_version == version.parse("0.0.16"):
482
- logger.warn(
483
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
484
- )
485
- unet.enable_xformers_memory_efficient_attention()
486
- else:
487
- raise ValueError("xformers is not available. Make sure it is installed correctly")
488
-
489
- # `accelerate` 0.16.0 will have better support for customized saving
490
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
491
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
492
- def save_model_hook(models, weights, output_dir):
493
- if args.use_ema:
494
- ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
495
-
496
- for i, model in enumerate(models):
497
- model.save_pretrained(os.path.join(output_dir, "unet"))
498
-
499
- # make sure to pop weight so that corresponding model is not saved again
500
- weights.pop()
501
-
502
- def load_model_hook(models, input_dir):
503
- if args.use_ema:
504
- load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
505
- ema_unet.load_state_dict(load_model.state_dict())
506
- ema_unet.to(accelerator.device)
507
- del load_model
508
-
509
- for i in range(len(models)):
510
- # pop models so that they are not loaded again
511
- model = models.pop()
512
-
513
- # load diffusers style into model
514
- load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
515
- model.register_to_config(**load_model.config)
516
-
517
- model.load_state_dict(load_model.state_dict())
518
- del load_model
519
-
520
- accelerator.register_save_state_pre_hook(save_model_hook)
521
- accelerator.register_load_state_pre_hook(load_model_hook)
522
-
523
- if args.gradient_checkpointing:
524
- unet.enable_gradient_checkpointing()
525
-
526
- # Enable TF32 for faster training on Ampere GPUs,
527
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
528
- if args.allow_tf32:
529
- torch.backends.cuda.matmul.allow_tf32 = True
530
-
531
- if args.scale_lr:
532
- args.learning_rate = (
533
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
534
- )
535
-
536
- # Initialize the optimizer
537
- if args.use_8bit_adam:
538
- try:
539
- import bitsandbytes as bnb
540
- except ImportError:
541
- raise ImportError(
542
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
543
- )
544
-
545
- optimizer_cls = bnb.optim.AdamW8bit
546
- else:
547
- optimizer_cls = torch.optim.AdamW
548
-
549
- optimizer = optimizer_cls(
550
- unet.parameters(),
551
- lr=args.learning_rate,
552
- betas=(args.adam_beta1, args.adam_beta2),
553
- weight_decay=args.adam_weight_decay,
554
- eps=args.adam_epsilon,
555
- )
556
-
557
- # Get the datasets: you can either provide your own training and evaluation files (see below)
558
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
559
-
560
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
561
- # download the dataset.
562
- if args.dataset_name is not None:
563
- # Downloading and loading a dataset from the hub.
564
- dataset = load_dataset(
565
- args.dataset_name,
566
- args.dataset_config_name,
567
- cache_dir=args.cache_dir,
568
- )
569
- else:
570
- data_files = {}
571
- if args.train_data_dir is not None:
572
- data_files["train"] = os.path.join(args.train_data_dir, "**")
573
- dataset = load_dataset(
574
- "imagefolder",
575
- data_files=data_files,
576
- cache_dir=args.cache_dir,
577
- )
578
- # See more about loading custom images at
579
- # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
580
-
581
- # Preprocessing the datasets.
582
- # We need to tokenize inputs and targets.
583
- column_names = dataset["train"].column_names
584
-
585
- # 6. Get the column names for input/target.
586
- dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
587
- if args.original_image_column is None:
588
- original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
589
- else:
590
- original_image_column = args.original_image_column
591
- if original_image_column not in column_names:
592
- raise ValueError(
593
- f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
594
- )
595
- if args.edit_prompt_column is None:
596
- edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
597
- else:
598
- edit_prompt_column = args.edit_prompt_column
599
- if edit_prompt_column not in column_names:
600
- raise ValueError(
601
- f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
602
- )
603
- if args.edited_image_column is None:
604
- edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
605
- else:
606
- edited_image_column = args.edited_image_column
607
- if edited_image_column not in column_names:
608
- raise ValueError(
609
- f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
610
- )
611
-
612
- # Preprocessing the datasets.
613
- # We need to tokenize input captions and transform the images.
614
- def tokenize_captions(captions):
615
- inputs = tokenizer(
616
- captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
617
- )
618
- return inputs.input_ids
619
-
620
- # Preprocessing the datasets.
621
- train_transforms = transforms.Compose(
622
- [
623
- transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
624
- transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
625
- ]
626
- )
627
-
628
- def preprocess_images(examples):
629
- original_images = np.concatenate(
630
- [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
631
- )
632
- edited_images = np.concatenate(
633
- [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
634
- )
635
- # We need to ensure that the original and the edited images undergo the same
636
- # augmentation transforms.
637
- images = np.concatenate([original_images, edited_images])
638
- images = torch.tensor(images)
639
- images = 2 * (images / 255) - 1
640
- return train_transforms(images)
641
-
642
- def preprocess_train(examples):
643
- # Preprocess images.
644
- preprocessed_images = preprocess_images(examples)
645
- # Since the original and edited images were concatenated before
646
- # applying the transformations, we need to separate them and reshape
647
- # them accordingly.
648
- original_images, edited_images = preprocessed_images.chunk(2)
649
- original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
650
- edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
651
-
652
- # Collate the preprocessed images into the `examples`.
653
- examples["original_pixel_values"] = original_images
654
- examples["edited_pixel_values"] = edited_images
655
-
656
- # Preprocess the captions.
657
- captions = list(examples[edit_prompt_column])
658
- examples["input_ids"] = tokenize_captions(captions)
659
- return examples
660
-
661
- with accelerator.main_process_first():
662
- if args.max_train_samples is not None:
663
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
664
- # Set the training transforms
665
- train_dataset = dataset["train"].with_transform(preprocess_train)
666
-
667
- def collate_fn(examples):
668
- original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
669
- original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
670
- edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
671
- edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
672
- input_ids = torch.stack([example["input_ids"] for example in examples])
673
- return {
674
- "original_pixel_values": original_pixel_values,
675
- "edited_pixel_values": edited_pixel_values,
676
- "input_ids": input_ids,
677
- }
678
-
679
- # DataLoaders creation:
680
- train_dataloader = torch.utils.data.DataLoader(
681
- train_dataset,
682
- shuffle=True,
683
- collate_fn=collate_fn,
684
- batch_size=args.train_batch_size,
685
- num_workers=args.dataloader_num_workers,
686
- )
687
-
688
- # Scheduler and math around the number of training steps.
689
- overrode_max_train_steps = False
690
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
691
- if args.max_train_steps is None:
692
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
693
- overrode_max_train_steps = True
694
-
695
- lr_scheduler = get_scheduler(
696
- args.lr_scheduler,
697
- optimizer=optimizer,
698
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
699
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
700
- )
701
-
702
- # Prepare everything with our `accelerator`.
703
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
704
- unet, optimizer, train_dataloader, lr_scheduler
705
- )
706
-
707
- if args.use_ema:
708
- ema_unet.to(accelerator.device)
709
-
710
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
711
- # as these models are only used for inference, keeping weights in full precision is not required.
712
- weight_dtype = torch.float32
713
- if accelerator.mixed_precision == "fp16":
714
- weight_dtype = torch.float16
715
- elif accelerator.mixed_precision == "bf16":
716
- weight_dtype = torch.bfloat16
717
-
718
- # Move text_encode and vae to gpu and cast to weight_dtype
719
- text_encoder.to(accelerator.device, dtype=weight_dtype)
720
- vae.to(accelerator.device, dtype=weight_dtype)
721
-
722
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
723
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
724
- if overrode_max_train_steps:
725
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
726
- # Afterwards we recalculate our number of training epochs
727
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
728
-
729
- # We need to initialize the trackers we use, and also store our configuration.
730
- # The trackers initializes automatically on the main process.
731
- if accelerator.is_main_process:
732
- accelerator.init_trackers("instruct-pix2pix", config=vars(args))
733
-
734
- # Train!
735
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
736
-
737
- logger.info("***** Running training *****")
738
- logger.info(f" Num examples = {len(train_dataset)}")
739
- logger.info(f" Num Epochs = {args.num_train_epochs}")
740
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
741
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
742
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
743
- logger.info(f" Total optimization steps = {args.max_train_steps}")
744
- global_step = 0
745
- first_epoch = 0
746
-
747
- # Potentially load in the weights and states from a previous save
748
- if args.resume_from_checkpoint:
749
- if args.resume_from_checkpoint != "latest":
750
- path = os.path.basename(args.resume_from_checkpoint)
751
- else:
752
- # Get the most recent checkpoint
753
- dirs = os.listdir(args.output_dir)
754
- dirs = [d for d in dirs if d.startswith("checkpoint")]
755
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
756
- path = dirs[-1] if len(dirs) > 0 else None
757
-
758
- if path is None:
759
- accelerator.print(
760
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
761
- )
762
- args.resume_from_checkpoint = None
763
- else:
764
- accelerator.print(f"Resuming from checkpoint {path}")
765
- accelerator.load_state(os.path.join(args.output_dir, path))
766
- global_step = int(path.split("-")[1])
767
-
768
- resume_global_step = global_step * args.gradient_accumulation_steps
769
- first_epoch = global_step // num_update_steps_per_epoch
770
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
771
-
772
- # Only show the progress bar once on each machine.
773
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
774
- progress_bar.set_description("Steps")
775
-
776
- for epoch in range(first_epoch, args.num_train_epochs):
777
- unet.train()
778
- train_loss = 0.0
779
- for step, batch in enumerate(train_dataloader):
780
- # Skip steps until we reach the resumed step
781
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
782
- if step % args.gradient_accumulation_steps == 0:
783
- progress_bar.update(1)
784
- continue
785
-
786
- with accelerator.accumulate(unet):
787
- # We want to learn the denoising process w.r.t the edited images which
788
- # are conditioned on the original image (which was edited) and the edit instruction.
789
- # So, first, convert images to latent space.
790
- latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
791
- latents = latents * vae.config.scaling_factor
792
-
793
- # Sample noise that we'll add to the latents
794
- noise = torch.randn_like(latents)
795
- bsz = latents.shape[0]
796
- # Sample a random timestep for each image
797
- timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
798
- timesteps = timesteps.long()
799
-
800
- # Add noise to the latents according to the noise magnitude at each timestep
801
- # (this is the forward diffusion process)
802
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
803
-
804
- # Get the text embedding for conditioning.
805
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
806
-
807
- # Get the additional image embedding for conditioning.
808
- # Instead of getting a diagonal Gaussian here, we simply take the mode.
809
- original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
810
-
811
- # Conditioning dropout to support classifier-free guidance during inference. For more details
812
- # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
813
- if args.conditioning_dropout_prob is not None:
814
- random_p = torch.rand(bsz, device=latents.device, generator=generator)
815
- # Sample masks for the edit prompts.
816
- prompt_mask = random_p < 2 * args.conditioning_dropout_prob
817
- prompt_mask = prompt_mask.reshape(bsz, 1, 1)
818
- # Final text conditioning.
819
- null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
820
- encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
821
-
822
- # Sample masks for the original images.
823
- image_mask_dtype = original_image_embeds.dtype
824
- image_mask = 1 - (
825
- (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
826
- * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
827
- )
828
- image_mask = image_mask.reshape(bsz, 1, 1, 1)
829
- # Final image conditioning.
830
- original_image_embeds = image_mask * original_image_embeds
831
-
832
- # Concatenate the `original_image_embeds` with the `noisy_latents`.
833
- concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
834
-
835
- # Get the target for loss depending on the prediction type
836
- if noise_scheduler.config.prediction_type == "epsilon":
837
- target = noise
838
- elif noise_scheduler.config.prediction_type == "v_prediction":
839
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
840
- else:
841
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
842
-
843
- # Predict the noise residual and compute loss
844
- model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
845
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
846
-
847
- # Gather the losses across all processes for logging (if we use distributed training).
848
- avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
849
- train_loss += avg_loss.item() / args.gradient_accumulation_steps
850
-
851
- # Backpropagate
852
- accelerator.backward(loss)
853
- if accelerator.sync_gradients:
854
- accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
855
- optimizer.step()
856
- lr_scheduler.step()
857
- optimizer.zero_grad()
858
-
859
- # Checks if the accelerator has performed an optimization step behind the scenes
860
- if accelerator.sync_gradients:
861
- if args.use_ema:
862
- ema_unet.step(unet.parameters())
863
- progress_bar.update(1)
864
- global_step += 1
865
- accelerator.log({"train_loss": train_loss}, step=global_step)
866
- train_loss = 0.0
867
-
868
- if global_step % args.checkpointing_steps == 0:
869
- if accelerator.is_main_process:
870
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
871
- accelerator.save_state(save_path)
872
- logger.info(f"Saved state to {save_path}")
873
-
874
- logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
875
- progress_bar.set_postfix(**logs)
876
-
877
- if global_step >= args.max_train_steps:
878
- break
879
-
880
- if accelerator.is_main_process:
881
- if (
882
- (args.val_image_url is not None)
883
- and (args.validation_prompt is not None)
884
- and (epoch % args.validation_epochs == 0)
885
- ):
886
- logger.info(
887
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
888
- f" {args.validation_prompt}."
889
- )
890
- # create pipeline
891
- if args.use_ema:
892
- # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
893
- ema_unet.store(unet.parameters())
894
- ema_unet.copy_to(unet.parameters())
895
- pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
896
- args.pretrained_model_name_or_path,
897
- unet=unet,
898
- revision=args.revision,
899
- torch_dtype=weight_dtype,
900
- )
901
- pipeline = pipeline.to(accelerator.device)
902
- pipeline.set_progress_bar_config(disable=True)
903
-
904
- # run inference
905
- original_image = download_image(args.val_image_url)
906
- edited_images = []
907
- with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):
908
- for _ in range(args.num_validation_images):
909
- edited_images.append(
910
- pipeline(
911
- args.validation_prompt,
912
- image=original_image,
913
- num_inference_steps=20,
914
- image_guidance_scale=1.5,
915
- guidance_scale=7,
916
- generator=generator,
917
- ).images[0]
918
- )
919
-
920
- for tracker in accelerator.trackers:
921
- if tracker.name == "wandb":
922
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
923
- for edited_image in edited_images:
924
- wandb_table.add_data(
925
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
926
- )
927
- tracker.log({"validation": wandb_table})
928
- if args.use_ema:
929
- # Switch back to the original UNet parameters.
930
- ema_unet.restore(unet.parameters())
931
-
932
- del pipeline
933
- torch.cuda.empty_cache()
934
-
935
- # Create the pipeline using the trained modules and save it.
936
- accelerator.wait_for_everyone()
937
- if accelerator.is_main_process:
938
- unet = accelerator.unwrap_model(unet)
939
- if args.use_ema:
940
- ema_unet.copy_to(unet.parameters())
941
-
942
- pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
943
- args.pretrained_model_name_or_path,
944
- text_encoder=accelerator.unwrap_model(text_encoder),
945
- vae=accelerator.unwrap_model(vae),
946
- unet=unet,
947
- revision=args.revision,
948
- )
949
- pipeline.save_pretrained(args.output_dir)
950
-
951
- if args.push_to_hub:
952
- upload_folder(
953
- repo_id=repo_id,
954
- folder_path=args.output_dir,
955
- commit_message="End of training",
956
- ignore_patterns=["step_*", "epoch_*"],
957
- )
958
-
959
- if args.validation_prompt is not None:
960
- edited_images = []
961
- pipeline = pipeline.to(accelerator.device)
962
- with torch.autocast(str(accelerator.device)):
963
- for _ in range(args.num_validation_images):
964
- edited_images.append(
965
- pipeline(
966
- args.validation_prompt,
967
- image=original_image,
968
- num_inference_steps=20,
969
- image_guidance_scale=1.5,
970
- guidance_scale=7,
971
- generator=generator,
972
- ).images[0]
973
- )
974
-
975
- for tracker in accelerator.trackers:
976
- if tracker.name == "wandb":
977
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
978
- for edited_image in edited_images:
979
- wandb_table.add_data(
980
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
981
- )
982
- tracker.log({"test": wandb_table})
983
-
984
- accelerator.end_training()
985
-
986
-
987
- if __name__ == "__main__":
988
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/research_projects/README.md DELETED
@@ -1,14 +0,0 @@
1
- # Research projects
2
-
3
- This folder contains various research projects using 🧨 Diffusers.
4
- They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
5
- Updating them to the most recent version of the library will require some work.
6
-
7
- To use any of them, just run the command
8
-
9
- ```
10
- pip install -r requirements.txt
11
- ```
12
- inside the folder of your choice.
13
-
14
- If you need help with any of those, please open an issue where you directly ping the author(s), as indicated at the top of the README of each folder.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/examples/research_projects/colossalai/README.md DELETED
@@ -1,111 +0,0 @@
1
- # [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git)
2
-
3
- [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
4
- The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5
-
6
- By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.
7
-
8
- ## Installing the dependencies
9
-
10
- Before running the scripts, make sure to install the library's training dependencies:
11
-
12
- ```bash
13
- pip install -r requirements.txt
14
- ```
15
-
16
- ## Install [ColossalAI](https://github.com/hpcaitech/ColossalAI.git)
17
-
18
- **From PyPI**
19
- ```bash
20
- pip install colossalai
21
- ```
22
-
23
- **From source**
24
-
25
- ```bash
26
- git clone https://github.com/hpcaitech/ColossalAI.git
27
- cd ColossalAI
28
-
29
- # install colossalai
30
- pip install .
31
- ```
32
-
33
- ## Dataset for Teyvat BLIP captions
34
- Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
35
-
36
- BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
37
-
38
- For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
39
-
40
- The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
41
-
42
- ## Training
43
-
44
- The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
45
-
46
- **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
47
-
48
- ```bash
49
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
50
- export INSTANCE_DIR="path-to-instance-images"
51
- export OUTPUT_DIR="path-to-save-model"
52
-
53
- torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
54
- --pretrained_model_name_or_path=$MODEL_NAME \
55
- --instance_data_dir=$INSTANCE_DIR \
56
- --output_dir=$OUTPUT_DIR \
57
- --instance_prompt="a photo of sks dog" \
58
- --resolution=512 \
59
- --train_batch_size=1 \
60
- --learning_rate=5e-6 \
61
- --lr_scheduler="constant" \
62
- --lr_warmup_steps=0 \
63
- --max_train_steps=400 \
64
- --placement="cuda"
65
- ```
66
-
67
-
68
- ### Training with prior-preservation loss
69
-
70
- Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
71
- According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
72
-
73
- ```bash
74
- export MODEL_NAME="CompVis/stable-diffusion-v1-4"
75
- export INSTANCE_DIR="path-to-instance-images"
76
- export CLASS_DIR="path-to-class-images"
77
- export OUTPUT_DIR="path-to-save-model"
78
-
79
- torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
80
- --pretrained_model_name_or_path=$MODEL_NAME \
81
- --instance_data_dir=$INSTANCE_DIR \
82
- --class_data_dir=$CLASS_DIR \
83
- --output_dir=$OUTPUT_DIR \
84
- --with_prior_preservation --prior_loss_weight=1.0 \
85
- --instance_prompt="a photo of sks dog" \
86
- --class_prompt="a photo of dog" \
87
- --resolution=512 \
88
- --train_batch_size=1 \
89
- --learning_rate=5e-6 \
90
- --lr_scheduler="constant" \
91
- --lr_warmup_steps=0 \
92
- --max_train_steps=800 \
93
- --placement="cuda"
94
- ```
95
-
96
- ## Inference
97
-
98
- Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
99
-
100
- ```python
101
- from diffusers import StableDiffusionPipeline
102
- import torch
103
-
104
- model_id = "path-to-save-model"
105
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
106
-
107
- prompt = "A photo of sks dog in a bucket"
108
- image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
109
-
110
- image.save("dog-bucket.png")
111
- ```