OriLib commited on
Commit
d94978e
·
verified ·
1 Parent(s): 9b8f56c

Delete replace_bg/model/replace_bg_model_pipeline_controlnet_sd_xl.py

Browse files
replace_bg/model/replace_bg_model_pipeline_controlnet_sd_xl.py DELETED
@@ -1,1601 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- import torch.nn.functional as F
23
- from transformers import (
24
- CLIPImageProcessor,
25
- CLIPTextModel,
26
- CLIPTextModelWithProjection,
27
- CLIPTokenizer,
28
- CLIPVisionModelWithProjection,
29
- )
30
-
31
- from diffusers.utils.import_utils import is_invisible_watermark_available
32
-
33
- from .image_processor import PipelineImageInput, VaeImageProcessor
34
- from diffusers.loaders import (
35
- FromSingleFileMixin,
36
- IPAdapterMixin,
37
- StableDiffusionXLLoraLoaderMixin,
38
- TextualInversionLoaderMixin,
39
- )
40
- from .controlnet import ControlNetModel
41
- # from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
42
- from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
43
- from diffusers.models.attention_processor import (
44
- AttnProcessor2_0,
45
- LoRAAttnProcessor2_0,
46
- LoRAXFormersAttnProcessor,
47
- XFormersAttnProcessor,
48
- )
49
- from diffusers.models.lora import adjust_lora_scale_text_encoder
50
- from diffusers.schedulers import KarrasDiffusionSchedulers
51
- from diffusers.utils import (
52
- USE_PEFT_BACKEND,
53
- deprecate,
54
- logging,
55
- replace_example_docstring,
56
- scale_lora_layers,
57
- unscale_lora_layers,
58
- )
59
- from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
60
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
61
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
62
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
63
-
64
- if is_invisible_watermark_available():
65
- from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
66
-
67
- from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
68
-
69
-
70
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
-
72
-
73
- EXAMPLE_DOC_STRING = """
74
- Examples:
75
- ```py
76
- >>> # !pip install opencv-python transformers accelerate
77
- >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
78
- >>> from diffusers.utils import load_image
79
- >>> import numpy as np
80
- >>> import torch
81
-
82
- >>> import cv2
83
- >>> from PIL import Image
84
-
85
- >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
86
- >>> negative_prompt = "low quality, bad quality, sketches"
87
-
88
- >>> # download an image
89
- >>> image = load_image(
90
- ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
91
- ... )
92
-
93
- >>> # initialize the models and pipeline
94
- >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
95
- >>> controlnet = ControlNetModel.from_pretrained(
96
- ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
97
- ... )
98
- >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
99
- >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
100
- ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
101
- ... )
102
- >>> pipe.enable_model_cpu_offload()
103
-
104
- >>> # get canny image
105
- >>> image = np.array(image)
106
- >>> image = cv2.Canny(image, 100, 200)
107
- >>> image = image[:, :, None]
108
- >>> image = np.concatenate([image, image, image], axis=2)
109
- >>> canny_image = Image.fromarray(image)
110
-
111
- >>> # generate image
112
- >>> image = pipe(
113
- ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
114
- ... ).images[0]
115
- ```
116
- """
117
-
118
-
119
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
120
- def retrieve_timesteps(
121
- scheduler,
122
- num_inference_steps: Optional[int] = None,
123
- device: Optional[Union[str, torch.device]] = None,
124
- timesteps: Optional[List[int]] = None,
125
- sigmas: Optional[List[float]] = None,
126
- **kwargs,
127
- ):
128
- r"""
129
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
130
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
131
-
132
- Args:
133
- scheduler (`SchedulerMixin`):
134
- The scheduler to get timesteps from.
135
- num_inference_steps (`int`):
136
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
137
- must be `None`.
138
- device (`str` or `torch.device`, *optional*):
139
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
140
- timesteps (`List[int]`, *optional*):
141
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
142
- `num_inference_steps` and `sigmas` must be `None`.
143
- sigmas (`List[float]`, *optional*):
144
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
145
- `num_inference_steps` and `timesteps` must be `None`.
146
-
147
- Returns:
148
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
149
- second element is the number of inference steps.
150
- """
151
- if timesteps is not None and sigmas is not None:
152
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
153
- if timesteps is not None:
154
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
155
- if not accepts_timesteps:
156
- raise ValueError(
157
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
158
- f" timestep schedules. Please check whether you are using the correct scheduler."
159
- )
160
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
161
- timesteps = scheduler.timesteps
162
- num_inference_steps = len(timesteps)
163
- elif sigmas is not None:
164
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
165
- if not accept_sigmas:
166
- raise ValueError(
167
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
168
- f" sigmas schedules. Please check whether you are using the correct scheduler."
169
- )
170
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
171
- timesteps = scheduler.timesteps
172
- num_inference_steps = len(timesteps)
173
- else:
174
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
175
- timesteps = scheduler.timesteps
176
- return timesteps, num_inference_steps
177
-
178
-
179
- class StableDiffusionXLControlNetPipeline(
180
- DiffusionPipeline,
181
- StableDiffusionMixin,
182
- TextualInversionLoaderMixin,
183
- StableDiffusionXLLoraLoaderMixin,
184
- IPAdapterMixin,
185
- FromSingleFileMixin,
186
- ):
187
- r"""
188
- Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
189
-
190
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
191
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
192
-
193
- The pipeline also inherits the following loading methods:
194
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
195
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
196
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
197
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
198
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
199
-
200
- Args:
201
- vae ([`AutoencoderKL`]):
202
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
203
- text_encoder ([`~transformers.CLIPTextModel`]):
204
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
205
- text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
206
- Second frozen text-encoder
207
- ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
208
- tokenizer ([`~transformers.CLIPTokenizer`]):
209
- A `CLIPTokenizer` to tokenize text.
210
- tokenizer_2 ([`~transformers.CLIPTokenizer`]):
211
- A `CLIPTokenizer` to tokenize text.
212
- unet ([`UNet2DConditionModel`]):
213
- A `UNet2DConditionModel` to denoise the encoded image latents.
214
- controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
215
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
216
- ControlNets as a list, the outputs from each ControlNet are added together to create one combined
217
- additional conditioning.
218
- scheduler ([`SchedulerMixin`]):
219
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
220
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
221
- force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
222
- Whether the negative prompt embeddings should always be set to 0. Also see the config of
223
- `stabilityai/stable-diffusion-xl-base-1-0`.
224
- add_watermarker (`bool`, *optional*):
225
- Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
226
- watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
227
- watermarker is used.
228
- """
229
-
230
- # leave controlnet out on purpose because it iterates with unet
231
- model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
232
- _optional_components = [
233
- "tokenizer",
234
- "tokenizer_2",
235
- "text_encoder",
236
- "text_encoder_2",
237
- "feature_extractor",
238
- "image_encoder",
239
- ]
240
- _callback_tensor_inputs = [
241
- "latents",
242
- "prompt_embeds",
243
- "negative_prompt_embeds",
244
- "add_text_embeds",
245
- "add_time_ids",
246
- "negative_pooled_prompt_embeds",
247
- "negative_add_time_ids",
248
- "image",
249
- ]
250
-
251
- def __init__(
252
- self,
253
- vae: AutoencoderKL,
254
- text_encoder: CLIPTextModel,
255
- text_encoder_2: CLIPTextModelWithProjection,
256
- tokenizer: CLIPTokenizer,
257
- tokenizer_2: CLIPTokenizer,
258
- unet: UNet2DConditionModel,
259
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
260
- scheduler: KarrasDiffusionSchedulers,
261
- force_zeros_for_empty_prompt: bool = True,
262
- add_watermarker: Optional[bool] = None,
263
- feature_extractor: CLIPImageProcessor = None,
264
- image_encoder: CLIPVisionModelWithProjection = None,
265
- ):
266
- super().__init__()
267
-
268
- if isinstance(controlnet, (list, tuple)):
269
- controlnet = MultiControlNetModel(controlnet)
270
-
271
- self.register_modules(
272
- vae=vae,
273
- text_encoder=text_encoder,
274
- text_encoder_2=text_encoder_2,
275
- tokenizer=tokenizer,
276
- tokenizer_2=tokenizer_2,
277
- unet=unet,
278
- controlnet=controlnet,
279
- scheduler=scheduler,
280
- feature_extractor=feature_extractor,
281
- image_encoder=image_encoder,
282
- )
283
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
284
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
285
- self.control_image_processor = VaeImageProcessor(
286
- vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
287
- )
288
- add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
289
-
290
- if add_watermarker:
291
- self.watermark = StableDiffusionXLWatermarker()
292
- else:
293
- self.watermark = None
294
-
295
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
296
-
297
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
298
- def encode_prompt(
299
- self,
300
- prompt: str,
301
- prompt_2: Optional[str] = None,
302
- device: Optional[torch.device] = None,
303
- num_images_per_prompt: int = 1,
304
- do_classifier_free_guidance: bool = True,
305
- negative_prompt: Optional[str] = None,
306
- negative_prompt_2: Optional[str] = None,
307
- prompt_embeds: Optional[torch.Tensor] = None,
308
- negative_prompt_embeds: Optional[torch.Tensor] = None,
309
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
310
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
311
- lora_scale: Optional[float] = None,
312
- clip_skip: Optional[int] = None,
313
- ):
314
- r"""
315
- Encodes the prompt into text encoder hidden states.
316
-
317
- Args:
318
- prompt (`str` or `List[str]`, *optional*):
319
- prompt to be encoded
320
- prompt_2 (`str` or `List[str]`, *optional*):
321
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
322
- used in both text-encoders
323
- device: (`torch.device`):
324
- torch device
325
- num_images_per_prompt (`int`):
326
- number of images that should be generated per prompt
327
- do_classifier_free_guidance (`bool`):
328
- whether to use classifier free guidance or not
329
- negative_prompt (`str` or `List[str]`, *optional*):
330
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
331
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
332
- less than `1`).
333
- negative_prompt_2 (`str` or `List[str]`, *optional*):
334
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
335
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
336
- prompt_embeds (`torch.Tensor`, *optional*):
337
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
- provided, text embeddings will be generated from `prompt` input argument.
339
- negative_prompt_embeds (`torch.Tensor`, *optional*):
340
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
341
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
342
- argument.
343
- pooled_prompt_embeds (`torch.Tensor`, *optional*):
344
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
345
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
346
- negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
347
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
348
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
349
- input argument.
350
- lora_scale (`float`, *optional*):
351
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
352
- clip_skip (`int`, *optional*):
353
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
354
- the output of the pre-final layer will be used for computing the prompt embeddings.
355
- """
356
- device = device or self._execution_device
357
-
358
- # set lora scale so that monkey patched LoRA
359
- # function of text encoder can correctly access it
360
- if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
361
- self._lora_scale = lora_scale
362
-
363
- # dynamically adjust the LoRA scale
364
- if self.text_encoder is not None:
365
- if not USE_PEFT_BACKEND:
366
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
367
- else:
368
- scale_lora_layers(self.text_encoder, lora_scale)
369
-
370
- if self.text_encoder_2 is not None:
371
- if not USE_PEFT_BACKEND:
372
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
373
- else:
374
- scale_lora_layers(self.text_encoder_2, lora_scale)
375
-
376
- prompt = [prompt] if isinstance(prompt, str) else prompt
377
-
378
- if prompt is not None:
379
- batch_size = len(prompt)
380
- else:
381
- batch_size = prompt_embeds.shape[0]
382
-
383
- # Define tokenizers and text encoders
384
- tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
385
- text_encoders = (
386
- [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
387
- )
388
-
389
- if prompt_embeds is None:
390
- prompt_2 = prompt_2 or prompt
391
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
392
-
393
- # textual inversion: process multi-vector tokens if necessary
394
- prompt_embeds_list = []
395
- prompts = [prompt, prompt_2]
396
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
397
- if isinstance(self, TextualInversionLoaderMixin):
398
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
399
-
400
- text_inputs = tokenizer(
401
- prompt,
402
- padding="max_length",
403
- max_length=tokenizer.model_max_length,
404
- truncation=True,
405
- return_tensors="pt",
406
- )
407
-
408
- text_input_ids = text_inputs.input_ids
409
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
410
-
411
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
412
- text_input_ids, untruncated_ids
413
- ):
414
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
415
- logger.warning(
416
- "The following part of your input was truncated because CLIP can only handle sequences up to"
417
- f" {tokenizer.model_max_length} tokens: {removed_text}"
418
- )
419
-
420
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
421
-
422
- # We are only ALWAYS interested in the pooled output of the final text encoder
423
- pooled_prompt_embeds = prompt_embeds[0]
424
- if clip_skip is None:
425
- prompt_embeds = prompt_embeds.hidden_states[-2]
426
- else:
427
- # "2" because SDXL always indexes from the penultimate layer.
428
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
429
-
430
- prompt_embeds_list.append(prompt_embeds)
431
-
432
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
433
-
434
- # get unconditional embeddings for classifier free guidance
435
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
436
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
437
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
438
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
439
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
440
- negative_prompt = negative_prompt or ""
441
- negative_prompt_2 = negative_prompt_2 or negative_prompt
442
-
443
- # normalize str to list
444
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
445
- negative_prompt_2 = (
446
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
447
- )
448
-
449
- uncond_tokens: List[str]
450
- if prompt is not None and type(prompt) is not type(negative_prompt):
451
- raise TypeError(
452
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
453
- f" {type(prompt)}."
454
- )
455
- elif batch_size != len(negative_prompt):
456
- raise ValueError(
457
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
458
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
459
- " the batch size of `prompt`."
460
- )
461
- else:
462
- uncond_tokens = [negative_prompt, negative_prompt_2]
463
-
464
- negative_prompt_embeds_list = []
465
- for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
466
- if isinstance(self, TextualInversionLoaderMixin):
467
- negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
468
-
469
- max_length = prompt_embeds.shape[1]
470
- uncond_input = tokenizer(
471
- negative_prompt,
472
- padding="max_length",
473
- max_length=max_length,
474
- truncation=True,
475
- return_tensors="pt",
476
- )
477
-
478
- negative_prompt_embeds = text_encoder(
479
- uncond_input.input_ids.to(device),
480
- output_hidden_states=True,
481
- )
482
- # We are only ALWAYS interested in the pooled output of the final text encoder
483
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
484
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
485
-
486
- negative_prompt_embeds_list.append(negative_prompt_embeds)
487
-
488
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
489
-
490
- if self.text_encoder_2 is not None:
491
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
492
- else:
493
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
494
-
495
- bs_embed, seq_len, _ = prompt_embeds.shape
496
- # duplicate text embeddings for each generation per prompt, using mps friendly method
497
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
498
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
499
-
500
- if do_classifier_free_guidance:
501
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
502
- seq_len = negative_prompt_embeds.shape[1]
503
-
504
- if self.text_encoder_2 is not None:
505
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
506
- else:
507
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
508
-
509
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
510
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
511
-
512
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
513
- bs_embed * num_images_per_prompt, -1
514
- )
515
- if do_classifier_free_guidance:
516
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
517
- bs_embed * num_images_per_prompt, -1
518
- )
519
-
520
- if self.text_encoder is not None:
521
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
522
- # Retrieve the original scale by scaling back the LoRA layers
523
- unscale_lora_layers(self.text_encoder, lora_scale)
524
-
525
- if self.text_encoder_2 is not None:
526
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
527
- # Retrieve the original scale by scaling back the LoRA layers
528
- unscale_lora_layers(self.text_encoder_2, lora_scale)
529
-
530
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
531
-
532
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
533
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
534
- dtype = next(self.image_encoder.parameters()).dtype
535
-
536
- if not isinstance(image, torch.Tensor):
537
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
538
-
539
- image = image.to(device=device, dtype=dtype)
540
- if output_hidden_states:
541
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
542
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
543
- uncond_image_enc_hidden_states = self.image_encoder(
544
- torch.zeros_like(image), output_hidden_states=True
545
- ).hidden_states[-2]
546
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
547
- num_images_per_prompt, dim=0
548
- )
549
- return image_enc_hidden_states, uncond_image_enc_hidden_states
550
- else:
551
- image_embeds = self.image_encoder(image).image_embeds
552
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
553
- uncond_image_embeds = torch.zeros_like(image_embeds)
554
-
555
- return image_embeds, uncond_image_embeds
556
-
557
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
558
- def prepare_ip_adapter_image_embeds(
559
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
560
- ):
561
- image_embeds = []
562
- if do_classifier_free_guidance:
563
- negative_image_embeds = []
564
- if ip_adapter_image_embeds is None:
565
- if not isinstance(ip_adapter_image, list):
566
- ip_adapter_image = [ip_adapter_image]
567
-
568
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
569
- raise ValueError(
570
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
571
- )
572
-
573
- for single_ip_adapter_image, image_proj_layer in zip(
574
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
575
- ):
576
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
577
- single_image_embeds, single_negative_image_embeds = self.encode_image(
578
- single_ip_adapter_image, device, 1, output_hidden_state
579
- )
580
-
581
- image_embeds.append(single_image_embeds[None, :])
582
- if do_classifier_free_guidance:
583
- negative_image_embeds.append(single_negative_image_embeds[None, :])
584
- else:
585
- for single_image_embeds in ip_adapter_image_embeds:
586
- if do_classifier_free_guidance:
587
- single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
588
- negative_image_embeds.append(single_negative_image_embeds)
589
- image_embeds.append(single_image_embeds)
590
-
591
- ip_adapter_image_embeds = []
592
- for i, single_image_embeds in enumerate(image_embeds):
593
- single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
594
- if do_classifier_free_guidance:
595
- single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
596
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
597
-
598
- single_image_embeds = single_image_embeds.to(device=device)
599
- ip_adapter_image_embeds.append(single_image_embeds)
600
-
601
- return ip_adapter_image_embeds
602
-
603
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
604
- def prepare_extra_step_kwargs(self, generator, eta):
605
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
606
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
607
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
608
- # and should be between [0, 1]
609
-
610
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
611
- extra_step_kwargs = {}
612
- if accepts_eta:
613
- extra_step_kwargs["eta"] = eta
614
-
615
- # check if the scheduler accepts generator
616
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
617
- if accepts_generator:
618
- extra_step_kwargs["generator"] = generator
619
- return extra_step_kwargs
620
-
621
- def check_inputs(
622
- self,
623
- prompt,
624
- prompt_2,
625
- image,
626
- callback_steps,
627
- negative_prompt=None,
628
- negative_prompt_2=None,
629
- prompt_embeds=None,
630
- negative_prompt_embeds=None,
631
- pooled_prompt_embeds=None,
632
- ip_adapter_image=None,
633
- ip_adapter_image_embeds=None,
634
- negative_pooled_prompt_embeds=None,
635
- controlnet_conditioning_scale=1.0,
636
- control_guidance_start=0.0,
637
- control_guidance_end=1.0,
638
- callback_on_step_end_tensor_inputs=None,
639
- ):
640
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
641
- raise ValueError(
642
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
643
- f" {type(callback_steps)}."
644
- )
645
-
646
- if callback_on_step_end_tensor_inputs is not None and not all(
647
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
648
- ):
649
- raise ValueError(
650
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
651
- )
652
-
653
- if prompt is not None and prompt_embeds is not None:
654
- raise ValueError(
655
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
656
- " only forward one of the two."
657
- )
658
- elif prompt_2 is not None and prompt_embeds is not None:
659
- raise ValueError(
660
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
661
- " only forward one of the two."
662
- )
663
- elif prompt is None and prompt_embeds is None:
664
- raise ValueError(
665
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
666
- )
667
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
668
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
669
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
670
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
671
-
672
- if negative_prompt is not None and negative_prompt_embeds is not None:
673
- raise ValueError(
674
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
675
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
676
- )
677
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
678
- raise ValueError(
679
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
680
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
681
- )
682
-
683
- if prompt_embeds is not None and negative_prompt_embeds is not None:
684
- if prompt_embeds.shape != negative_prompt_embeds.shape:
685
- raise ValueError(
686
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
687
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
688
- f" {negative_prompt_embeds.shape}."
689
- )
690
-
691
- if prompt_embeds is not None and pooled_prompt_embeds is None:
692
- raise ValueError(
693
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
694
- )
695
-
696
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
697
- raise ValueError(
698
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
699
- )
700
-
701
- # `prompt` needs more sophisticated handling when there are multiple
702
- # conditionings.
703
- if isinstance(self.controlnet, MultiControlNetModel):
704
- if isinstance(prompt, list):
705
- logger.warning(
706
- f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
707
- " prompts. The conditionings will be fixed across the prompts."
708
- )
709
-
710
- # Check `image`
711
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
712
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
713
- )
714
- if (
715
- isinstance(self.controlnet, ControlNetModel)
716
- or is_compiled
717
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
718
- ):
719
- self.check_image(image, prompt, prompt_embeds)
720
- elif (
721
- isinstance(self.controlnet, MultiControlNetModel)
722
- or is_compiled
723
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
724
- ):
725
- if not isinstance(image, list):
726
- raise TypeError("For multiple controlnets: `image` must be type `list`")
727
-
728
- # When `image` is a nested list:
729
- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
730
- elif any(isinstance(i, list) for i in image):
731
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
732
- elif len(image) != len(self.controlnet.nets):
733
- raise ValueError(
734
- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
735
- )
736
-
737
- for image_ in image:
738
- self.check_image(image_, prompt, prompt_embeds)
739
- else:
740
- assert False
741
-
742
- # Check `controlnet_conditioning_scale`
743
- if (
744
- isinstance(self.controlnet, ControlNetModel)
745
- or is_compiled
746
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
747
- ):
748
- if not isinstance(controlnet_conditioning_scale, float):
749
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
750
- elif (
751
- isinstance(self.controlnet, MultiControlNetModel)
752
- or is_compiled
753
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
754
- ):
755
- if isinstance(controlnet_conditioning_scale, list):
756
- if any(isinstance(i, list) for i in controlnet_conditioning_scale):
757
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
758
- elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
759
- self.controlnet.nets
760
- ):
761
- raise ValueError(
762
- "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
763
- " the same length as the number of controlnets"
764
- )
765
- else:
766
- assert False
767
-
768
- if not isinstance(control_guidance_start, (tuple, list)):
769
- control_guidance_start = [control_guidance_start]
770
-
771
- if not isinstance(control_guidance_end, (tuple, list)):
772
- control_guidance_end = [control_guidance_end]
773
-
774
- if len(control_guidance_start) != len(control_guidance_end):
775
- raise ValueError(
776
- f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
777
- )
778
-
779
- if isinstance(self.controlnet, MultiControlNetModel):
780
- if len(control_guidance_start) != len(self.controlnet.nets):
781
- raise ValueError(
782
- f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
783
- )
784
-
785
- for start, end in zip(control_guidance_start, control_guidance_end):
786
- if start >= end:
787
- raise ValueError(
788
- f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
789
- )
790
- if start < 0.0:
791
- raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
792
- if end > 1.0:
793
- raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
794
-
795
- if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
796
- raise ValueError(
797
- "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
798
- )
799
-
800
- if ip_adapter_image_embeds is not None:
801
- if not isinstance(ip_adapter_image_embeds, list):
802
- raise ValueError(
803
- f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
804
- )
805
- elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
806
- raise ValueError(
807
- f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
808
- )
809
-
810
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
811
- def check_image(self, image, prompt, prompt_embeds):
812
- image_is_pil = isinstance(image, PIL.Image.Image)
813
- image_is_tensor = isinstance(image, torch.Tensor)
814
- image_is_np = isinstance(image, np.ndarray)
815
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
816
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
817
- image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
818
-
819
- if (
820
- not image_is_pil
821
- and not image_is_tensor
822
- and not image_is_np
823
- and not image_is_pil_list
824
- and not image_is_tensor_list
825
- and not image_is_np_list
826
- ):
827
- raise TypeError(
828
- f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
829
- )
830
-
831
- if image_is_pil:
832
- image_batch_size = 1
833
- else:
834
- image_batch_size = len(image)
835
-
836
- if prompt is not None and isinstance(prompt, str):
837
- prompt_batch_size = 1
838
- elif prompt is not None and isinstance(prompt, list):
839
- prompt_batch_size = len(prompt)
840
- elif prompt_embeds is not None:
841
- prompt_batch_size = prompt_embeds.shape[0]
842
-
843
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
844
- raise ValueError(
845
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
846
- )
847
-
848
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
849
- def prepare_image(
850
- self,
851
- image,
852
- width,
853
- height,
854
- batch_size,
855
- num_images_per_prompt,
856
- device,
857
- dtype,
858
- do_classifier_free_guidance=False,
859
- guess_mode=False,
860
- ):
861
- image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
862
- image_batch_size = image.shape[0]
863
-
864
- if image_batch_size == 1:
865
- repeat_by = batch_size
866
- else:
867
- # image batch size is the same as prompt batch size
868
- repeat_by = num_images_per_prompt
869
-
870
- image = image.repeat_interleave(repeat_by, dim=0)
871
-
872
- image = image.to(device=device, dtype=dtype)
873
-
874
- if do_classifier_free_guidance and not guess_mode:
875
- image = torch.cat([image] * 2)
876
-
877
- return image
878
-
879
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
880
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
881
- shape = (
882
- batch_size,
883
- num_channels_latents,
884
- int(height) // self.vae_scale_factor,
885
- int(width) // self.vae_scale_factor,
886
- )
887
- if isinstance(generator, list) and len(generator) != batch_size:
888
- raise ValueError(
889
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
890
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
891
- )
892
-
893
- if latents is None:
894
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
895
- else:
896
- latents = latents.to(device)
897
-
898
- # scale the initial noise by the standard deviation required by the scheduler
899
- latents = latents * self.scheduler.init_noise_sigma
900
- return latents
901
-
902
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
903
- def _get_add_time_ids(
904
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
905
- ):
906
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
907
-
908
- passed_add_embed_dim = (
909
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
910
- )
911
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
912
-
913
- if expected_add_embed_dim != passed_add_embed_dim:
914
- raise ValueError(
915
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
916
- )
917
-
918
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
919
- return add_time_ids
920
-
921
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
922
- def upcast_vae(self):
923
- dtype = self.vae.dtype
924
- self.vae.to(dtype=torch.float32)
925
- use_torch_2_0_or_xformers = isinstance(
926
- self.vae.decoder.mid_block.attentions[0].processor,
927
- (
928
- AttnProcessor2_0,
929
- XFormersAttnProcessor,
930
- ),
931
- )
932
- # if xformers or torch_2_0 is used attention block does not need
933
- # to be in float32 which can save lots of memory
934
- if use_torch_2_0_or_xformers:
935
- self.vae.post_quant_conv.to(dtype)
936
- self.vae.decoder.conv_in.to(dtype)
937
- self.vae.decoder.mid_block.to(dtype)
938
-
939
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
940
- def get_guidance_scale_embedding(
941
- self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
942
- ) -> torch.Tensor:
943
- """
944
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
945
-
946
- Args:
947
- w (`torch.Tensor`):
948
- Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
949
- embedding_dim (`int`, *optional*, defaults to 512):
950
- Dimension of the embeddings to generate.
951
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
952
- Data type of the generated embeddings.
953
-
954
- Returns:
955
- `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
956
- """
957
- assert len(w.shape) == 1
958
- w = w * 1000.0
959
-
960
- half_dim = embedding_dim // 2
961
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
962
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
963
- emb = w.to(dtype)[:, None] * emb[None, :]
964
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
965
- if embedding_dim % 2 == 1: # zero pad
966
- emb = torch.nn.functional.pad(emb, (0, 1))
967
- assert emb.shape == (w.shape[0], embedding_dim)
968
- return emb
969
-
970
- @property
971
- def guidance_scale(self):
972
- return self._guidance_scale
973
-
974
- @property
975
- def clip_skip(self):
976
- return self._clip_skip
977
-
978
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
979
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
980
- # corresponds to doing no classifier free guidance.
981
- @property
982
- def do_classifier_free_guidance(self):
983
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
984
-
985
- @property
986
- def cross_attention_kwargs(self):
987
- return self._cross_attention_kwargs
988
-
989
- @property
990
- def denoising_end(self):
991
- return self._denoising_end
992
-
993
- @property
994
- def num_timesteps(self):
995
- return self._num_timesteps
996
-
997
- @property
998
- def interrupt(self):
999
- return self._interrupt
1000
-
1001
- @torch.no_grad()
1002
- @replace_example_docstring(EXAMPLE_DOC_STRING)
1003
- def __call__(
1004
- self,
1005
- prompt: Union[str, List[str]] = None,
1006
- prompt_2: Optional[Union[str, List[str]]] = None,
1007
- image: PipelineImageInput = None,
1008
- height: Optional[int] = None,
1009
- width: Optional[int] = None,
1010
- num_inference_steps: int = 50,
1011
- timesteps: List[int] = None,
1012
- sigmas: List[float] = None,
1013
- denoising_end: Optional[float] = None,
1014
- guidance_scale: float = 5.0,
1015
- negative_prompt: Optional[Union[str, List[str]]] = None,
1016
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1017
- num_images_per_prompt: Optional[int] = 1,
1018
- eta: float = 0.0,
1019
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1020
- latents: Optional[torch.Tensor] = None,
1021
- prompt_embeds: Optional[torch.Tensor] = None,
1022
- negative_prompt_embeds: Optional[torch.Tensor] = None,
1023
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
1024
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1025
- ip_adapter_image: Optional[PipelineImageInput] = None,
1026
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1027
- output_type: Optional[str] = "pil",
1028
- return_dict: bool = True,
1029
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1030
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1031
- guess_mode: bool = False,
1032
- control_guidance_start: Union[float, List[float]] = 0.0,
1033
- control_guidance_end: Union[float, List[float]] = 1.0,
1034
- original_size: Tuple[int, int] = None,
1035
- crops_coords_top_left: Tuple[int, int] = (0, 0),
1036
- target_size: Tuple[int, int] = None,
1037
- negative_original_size: Optional[Tuple[int, int]] = None,
1038
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1039
- negative_target_size: Optional[Tuple[int, int]] = None,
1040
- clip_skip: Optional[int] = None,
1041
- callback_on_step_end: Optional[
1042
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1043
- ] = None,
1044
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1045
- **kwargs,
1046
- ):
1047
- r"""
1048
- The call function to the pipeline for generation.
1049
-
1050
- Args:
1051
- prompt (`str` or `List[str]`, *optional*):
1052
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1053
- prompt_2 (`str` or `List[str]`, *optional*):
1054
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1055
- used in both text-encoders.
1056
- image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1057
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1058
- The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1059
- specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1060
- as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1061
- width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1062
- images must be passed as a list such that each element of the list can be correctly batched for input
1063
- to a single ControlNet.
1064
- height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1065
- The height in pixels of the generated image. Anything below 512 pixels won't work well for
1066
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1067
- and checkpoints that are not specifically fine-tuned on low resolutions.
1068
- width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1069
- The width in pixels of the generated image. Anything below 512 pixels won't work well for
1070
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1071
- and checkpoints that are not specifically fine-tuned on low resolutions.
1072
- num_inference_steps (`int`, *optional*, defaults to 50):
1073
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1074
- expense of slower inference.
1075
- timesteps (`List[int]`, *optional*):
1076
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1077
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1078
- passed will be used. Must be in descending order.
1079
- sigmas (`List[float]`, *optional*):
1080
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1081
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1082
- will be used.
1083
- denoising_end (`float`, *optional*):
1084
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1085
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
1086
- still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1087
- scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1088
- "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1089
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1090
- guidance_scale (`float`, *optional*, defaults to 5.0):
1091
- A higher guidance scale value encourages the model to generate images closely linked to the text
1092
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1093
- negative_prompt (`str` or `List[str]`, *optional*):
1094
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1095
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1096
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1097
- The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
1098
- and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
1099
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1100
- The number of images to generate per prompt.
1101
- eta (`float`, *optional*, defaults to 0.0):
1102
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1103
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1104
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1105
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1106
- generation deterministic.
1107
- latents (`torch.Tensor`, *optional*):
1108
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1109
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1110
- tensor is generated by sampling using the supplied random `generator`.
1111
- prompt_embeds (`torch.Tensor`, *optional*):
1112
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1113
- provided, text embeddings are generated from the `prompt` input argument.
1114
- negative_prompt_embeds (`torch.Tensor`, *optional*):
1115
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1116
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1117
- pooled_prompt_embeds (`torch.Tensor`, *optional*):
1118
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1119
- not provided, pooled text embeddings are generated from `prompt` input argument.
1120
- negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1121
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1122
- weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1123
- argument.
1124
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1125
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1126
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1127
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1128
- contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1129
- provided, embeddings are computed from the `ip_adapter_image` input argument.
1130
- output_type (`str`, *optional*, defaults to `"pil"`):
1131
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1132
- return_dict (`bool`, *optional*, defaults to `True`):
1133
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1134
- plain tuple.
1135
- cross_attention_kwargs (`dict`, *optional*):
1136
- A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1137
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1138
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1139
- The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1140
- to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1141
- the corresponding scale as a list.
1142
- guess_mode (`bool`, *optional*, defaults to `False`):
1143
- The ControlNet encoder tries to recognize the content of the input image even if you remove all
1144
- prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1145
- control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1146
- The percentage of total steps at which the ControlNet starts applying.
1147
- control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1148
- The percentage of total steps at which the ControlNet stops applying.
1149
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1150
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1151
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1152
- explained in section 2.2 of
1153
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1154
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1155
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1156
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1157
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1158
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1159
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1160
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
1161
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1162
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1163
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1164
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1165
- micro-conditioning as explained in section 2.2 of
1166
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1167
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1168
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1169
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1170
- micro-conditioning as explained in section 2.2 of
1171
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1172
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1173
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1174
- To negatively condition the generation process based on a target image resolution. It should be as same
1175
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1176
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1177
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1178
- clip_skip (`int`, *optional*):
1179
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1180
- the output of the pre-final layer will be used for computing the prompt embeddings.
1181
- callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1182
- A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1183
- each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1184
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1185
- list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1186
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1187
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1188
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1189
- `._callback_tensor_inputs` attribute of your pipeline class.
1190
-
1191
- Examples:
1192
-
1193
- Returns:
1194
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1195
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1196
- otherwise a `tuple` is returned containing the output images.
1197
- """
1198
-
1199
- callback = kwargs.pop("callback", None)
1200
- callback_steps = kwargs.pop("callback_steps", None)
1201
-
1202
- if callback is not None:
1203
- deprecate(
1204
- "callback",
1205
- "1.0.0",
1206
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1207
- )
1208
- if callback_steps is not None:
1209
- deprecate(
1210
- "callback_steps",
1211
- "1.0.0",
1212
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1213
- )
1214
-
1215
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1216
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1217
-
1218
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1219
-
1220
- # align format for control guidance
1221
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1222
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1223
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1224
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1225
- elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1226
- mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1227
- control_guidance_start, control_guidance_end = (
1228
- mult * [control_guidance_start],
1229
- mult * [control_guidance_end],
1230
- )
1231
-
1232
- # 1. Check inputs. Raise error if not correct
1233
- self.check_inputs(
1234
- prompt,
1235
- prompt_2,
1236
- image,
1237
- callback_steps,
1238
- negative_prompt,
1239
- negative_prompt_2,
1240
- prompt_embeds,
1241
- negative_prompt_embeds,
1242
- pooled_prompt_embeds,
1243
- ip_adapter_image,
1244
- ip_adapter_image_embeds,
1245
- negative_pooled_prompt_embeds,
1246
- controlnet_conditioning_scale,
1247
- control_guidance_start,
1248
- control_guidance_end,
1249
- callback_on_step_end_tensor_inputs,
1250
- )
1251
-
1252
- self._guidance_scale = guidance_scale
1253
- self._clip_skip = clip_skip
1254
- self._cross_attention_kwargs = cross_attention_kwargs
1255
- self._denoising_end = denoising_end
1256
- self._interrupt = False
1257
-
1258
- # 2. Define call parameters
1259
- if prompt is not None and isinstance(prompt, str):
1260
- batch_size = 1
1261
- elif prompt is not None and isinstance(prompt, list):
1262
- batch_size = len(prompt)
1263
- else:
1264
- batch_size = prompt_embeds.shape[0]
1265
-
1266
- device = self._execution_device
1267
-
1268
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1269
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1270
-
1271
- global_pool_conditions = (
1272
- controlnet.config.global_pool_conditions
1273
- if isinstance(controlnet, ControlNetModel)
1274
- else controlnet.nets[0].config.global_pool_conditions
1275
- )
1276
- guess_mode = guess_mode or global_pool_conditions
1277
-
1278
- # 3.1 Encode input prompt
1279
- text_encoder_lora_scale = (
1280
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1281
- )
1282
- (
1283
- prompt_embeds,
1284
- negative_prompt_embeds,
1285
- pooled_prompt_embeds,
1286
- negative_pooled_prompt_embeds,
1287
- ) = self.encode_prompt(
1288
- prompt,
1289
- prompt_2,
1290
- device,
1291
- num_images_per_prompt,
1292
- self.do_classifier_free_guidance,
1293
- negative_prompt,
1294
- negative_prompt_2,
1295
- prompt_embeds=prompt_embeds,
1296
- negative_prompt_embeds=negative_prompt_embeds,
1297
- pooled_prompt_embeds=pooled_prompt_embeds,
1298
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1299
- lora_scale=text_encoder_lora_scale,
1300
- clip_skip=self.clip_skip,
1301
- )
1302
-
1303
- # 3.2 Encode ip_adapter_image
1304
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1305
- image_embeds = self.prepare_ip_adapter_image_embeds(
1306
- ip_adapter_image,
1307
- ip_adapter_image_embeds,
1308
- device,
1309
- batch_size * num_images_per_prompt,
1310
- self.do_classifier_free_guidance,
1311
- )
1312
-
1313
- # 4. Prepare image
1314
- if isinstance(controlnet, ControlNetModel):
1315
- image = self.prepare_image(
1316
- image=image,
1317
- width=width,
1318
- height=height,
1319
- batch_size=batch_size * num_images_per_prompt,
1320
- num_images_per_prompt=num_images_per_prompt,
1321
- device=device,
1322
- dtype=controlnet.dtype,
1323
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1324
- guess_mode=guess_mode,
1325
- )
1326
- height, width = image.shape[-2:]
1327
- height, width = height*self.vae_scale_factor, width*self.vae_scale_factor # Bria: update for vae controlnet
1328
- elif isinstance(controlnet, MultiControlNetModel):
1329
- images = []
1330
-
1331
- for image_ in image:
1332
- image_ = self.prepare_image(
1333
- image=image_,
1334
- width=width,
1335
- height=height,
1336
- batch_size=batch_size * num_images_per_prompt,
1337
- num_images_per_prompt=num_images_per_prompt,
1338
- device=device,
1339
- dtype=controlnet.dtype,
1340
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1341
- guess_mode=guess_mode,
1342
- )
1343
-
1344
- images.append(image_)
1345
-
1346
- image = images
1347
- height, width = image[0].shape[-2:]
1348
- else:
1349
- assert False
1350
-
1351
- # 5. Prepare timesteps
1352
- timesteps, num_inference_steps = retrieve_timesteps(
1353
- self.scheduler, num_inference_steps, device, timesteps, sigmas
1354
- )
1355
- self._num_timesteps = len(timesteps)
1356
-
1357
- # 6. Prepare latent variables
1358
- num_channels_latents = self.unet.config.in_channels
1359
- latents = self.prepare_latents(
1360
- batch_size * num_images_per_prompt,
1361
- num_channels_latents,
1362
- height,
1363
- width,
1364
- prompt_embeds.dtype,
1365
- device,
1366
- generator,
1367
- latents,
1368
- )
1369
-
1370
- # 6.5 Optionally get Guidance Scale Embedding
1371
- timestep_cond = None
1372
- if self.unet.config.time_cond_proj_dim is not None:
1373
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1374
- timestep_cond = self.get_guidance_scale_embedding(
1375
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1376
- ).to(device=device, dtype=latents.dtype)
1377
-
1378
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1379
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1380
-
1381
- # 7.1 Create tensor stating which controlnets to keep
1382
- controlnet_keep = []
1383
- for i in range(len(timesteps)):
1384
- keeps = [
1385
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1386
- for s, e in zip(control_guidance_start, control_guidance_end)
1387
- ]
1388
- controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1389
-
1390
- # 7.2 Prepare added time ids & embeddings
1391
- if isinstance(image, list):
1392
- original_size = original_size or image[0].shape[-2:]
1393
- else:
1394
- original_size = original_size or image.shape[-2:]
1395
- target_size = target_size or (height, width)
1396
-
1397
- add_text_embeds = pooled_prompt_embeds
1398
- if self.text_encoder_2 is None:
1399
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1400
- else:
1401
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1402
-
1403
- add_time_ids = self._get_add_time_ids(
1404
- original_size,
1405
- crops_coords_top_left,
1406
- target_size,
1407
- dtype=prompt_embeds.dtype,
1408
- text_encoder_projection_dim=text_encoder_projection_dim,
1409
- )
1410
-
1411
- if negative_original_size is not None and negative_target_size is not None:
1412
- negative_add_time_ids = self._get_add_time_ids(
1413
- negative_original_size,
1414
- negative_crops_coords_top_left,
1415
- negative_target_size,
1416
- dtype=prompt_embeds.dtype,
1417
- text_encoder_projection_dim=text_encoder_projection_dim,
1418
- )
1419
- else:
1420
- negative_add_time_ids = add_time_ids
1421
-
1422
- if self.do_classifier_free_guidance:
1423
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1424
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1425
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1426
-
1427
- prompt_embeds = prompt_embeds.to(device)
1428
- add_text_embeds = add_text_embeds.to(device)
1429
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1430
-
1431
- # 8. Denoising loop
1432
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1433
-
1434
- # 8.1 Apply denoising_end
1435
- if (
1436
- self.denoising_end is not None
1437
- and isinstance(self.denoising_end, float)
1438
- and self.denoising_end > 0
1439
- and self.denoising_end < 1
1440
- ):
1441
- discrete_timestep_cutoff = int(
1442
- round(
1443
- self.scheduler.config.num_train_timesteps
1444
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1445
- )
1446
- )
1447
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1448
- timesteps = timesteps[:num_inference_steps]
1449
-
1450
- is_unet_compiled = is_compiled_module(self.unet)
1451
- is_controlnet_compiled = is_compiled_module(self.controlnet)
1452
- is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1453
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1454
- for i, t in enumerate(timesteps):
1455
- if self.interrupt:
1456
- continue
1457
-
1458
- # Relevant thread:
1459
- # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1460
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1461
- torch._inductor.cudagraph_mark_step_begin()
1462
- # expand the latents if we are doing classifier free guidance
1463
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1464
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1465
-
1466
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1467
-
1468
- # controlnet(s) inference
1469
- if guess_mode and self.do_classifier_free_guidance:
1470
- # Infer ControlNet only for the conditional batch.
1471
- control_model_input = latents
1472
- control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1473
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1474
- controlnet_added_cond_kwargs = {
1475
- "text_embeds": add_text_embeds.chunk(2)[1],
1476
- "time_ids": add_time_ids.chunk(2)[1],
1477
- }
1478
- else:
1479
- control_model_input = latent_model_input
1480
- controlnet_prompt_embeds = prompt_embeds
1481
- controlnet_added_cond_kwargs = added_cond_kwargs
1482
-
1483
- if isinstance(controlnet_keep[i], list):
1484
- cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1485
- else:
1486
- controlnet_cond_scale = controlnet_conditioning_scale
1487
- if isinstance(controlnet_cond_scale, list):
1488
- controlnet_cond_scale = controlnet_cond_scale[0]
1489
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
1490
-
1491
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1492
- control_model_input,
1493
- t,
1494
- encoder_hidden_states=controlnet_prompt_embeds,
1495
- controlnet_cond=image,
1496
- conditioning_scale=cond_scale,
1497
- guess_mode=guess_mode,
1498
- added_cond_kwargs=controlnet_added_cond_kwargs,
1499
- return_dict=False,
1500
- )
1501
-
1502
- if guess_mode and self.do_classifier_free_guidance:
1503
- # Inferred ControlNet only for the conditional batch.
1504
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1505
- # add 0 to the unconditional batch to keep it unchanged.
1506
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1507
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1508
-
1509
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1510
- added_cond_kwargs["image_embeds"] = image_embeds
1511
-
1512
- # predict the noise residual
1513
- noise_pred = self.unet(
1514
- latent_model_input,
1515
- t,
1516
- encoder_hidden_states=prompt_embeds,
1517
- timestep_cond=timestep_cond,
1518
- cross_attention_kwargs=self.cross_attention_kwargs,
1519
- down_block_additional_residuals=down_block_res_samples,
1520
- mid_block_additional_residual=mid_block_res_sample,
1521
- added_cond_kwargs=added_cond_kwargs,
1522
- return_dict=False,
1523
- )[0]
1524
-
1525
- # perform guidance
1526
- if self.do_classifier_free_guidance:
1527
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1528
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1529
-
1530
- # compute the previous noisy sample x_t -> x_t-1
1531
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1532
-
1533
- if callback_on_step_end is not None:
1534
- callback_kwargs = {}
1535
- for k in callback_on_step_end_tensor_inputs:
1536
- callback_kwargs[k] = locals()[k]
1537
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1538
-
1539
- latents = callback_outputs.pop("latents", latents)
1540
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1541
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1542
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1543
- negative_pooled_prompt_embeds = callback_outputs.pop(
1544
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1545
- )
1546
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1547
- negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1548
- image = callback_outputs.pop("image", image)
1549
-
1550
- # call the callback, if provided
1551
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1552
- progress_bar.update()
1553
- if callback is not None and i % callback_steps == 0:
1554
- step_idx = i // getattr(self.scheduler, "order", 1)
1555
- callback(step_idx, t, latents)
1556
-
1557
- if not output_type == "latent":
1558
- # make sure the VAE is in float32 mode, as it overflows in float16
1559
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1560
-
1561
- if needs_upcasting:
1562
- self.upcast_vae()
1563
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1564
-
1565
- # unscale/denormalize the latents
1566
- # denormalize with the mean and std if available and not None
1567
- has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1568
- has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1569
- if has_latents_mean and has_latents_std:
1570
- latents_mean = (
1571
- torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1572
- )
1573
- latents_std = (
1574
- torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1575
- )
1576
- latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1577
- else:
1578
- latents = latents / self.vae.config.scaling_factor
1579
-
1580
- image = self.vae.decode(latents, return_dict=False)[0]
1581
-
1582
- # cast back to fp16 if needed
1583
- if needs_upcasting:
1584
- self.vae.to(dtype=torch.float16)
1585
- else:
1586
- image = latents
1587
-
1588
- if not output_type == "latent":
1589
- # apply watermark if available
1590
- if self.watermark is not None:
1591
- image = self.watermark.apply_watermark(image)
1592
-
1593
- image = self.image_processor.postprocess(image, output_type=output_type)
1594
-
1595
- # Offload all models
1596
- self.maybe_free_model_hooks()
1597
-
1598
- if not return_dict:
1599
- return (image,)
1600
-
1601
- return StableDiffusionXLPipelineOutput(images=image)