davidvgilmore commited on
Commit
9445c42
·
verified ·
1 Parent(s): 0d7e482

Upload hy3dgen/texgen/hunyuanpaint/pipeline.py with huggingface_hub

Browse files
hy3dgen/texgen/hunyuanpaint/pipeline.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ from typing import Any, Callable, Dict, List, Optional, Union
26
+
27
+ import numpy
28
+ import numpy as np
29
+ import torch
30
+ import torch.distributed
31
+ import torch.utils.checkpoint
32
+ from PIL import Image
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ DiffusionPipeline,
36
+ ImagePipelineOutput
37
+ )
38
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
39
+ from diffusers.image_processor import PipelineImageInput
40
+ from diffusers.image_processor import VaeImageProcessor
41
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, retrieve_timesteps, \
43
+ rescale_noise_cfg
44
+ from diffusers.schedulers import KarrasDiffusionSchedulers
45
+ from diffusers.utils import deprecate
46
+ from einops import rearrange
47
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
48
+
49
+ from .unet.modules import UNet2p5DConditionModel
50
+
51
+
52
+ def to_rgb_image(maybe_rgba: Image.Image):
53
+ if maybe_rgba.mode == 'RGB':
54
+ return maybe_rgba
55
+ elif maybe_rgba.mode == 'RGBA':
56
+ rgba = maybe_rgba
57
+ img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
58
+ img = Image.fromarray(img, 'RGB')
59
+ img.paste(rgba, mask=rgba.getchannel('A'))
60
+ return img
61
+ else:
62
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
63
+
64
+
65
+ class HunyuanPaintPipeline(StableDiffusionPipeline):
66
+
67
+ def __init__(
68
+ self,
69
+ vae: AutoencoderKL,
70
+ text_encoder: CLIPTextModel,
71
+ tokenizer: CLIPTokenizer,
72
+ unet: UNet2p5DConditionModel,
73
+ scheduler: KarrasDiffusionSchedulers,
74
+ feature_extractor: CLIPImageProcessor,
75
+ safety_checker=None,
76
+ use_torch_compile=False,
77
+ ):
78
+ DiffusionPipeline.__init__(self)
79
+
80
+ safety_checker = None
81
+ self.register_modules(
82
+ vae=torch.compile(vae) if use_torch_compile else vae,
83
+ text_encoder=text_encoder,
84
+ tokenizer=tokenizer,
85
+ unet=unet,
86
+ scheduler=scheduler,
87
+ safety_checker=safety_checker,
88
+ feature_extractor=torch.compile(feature_extractor) if use_torch_compile else feature_extractor,
89
+ )
90
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
91
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
92
+
93
+ @torch.no_grad()
94
+ def encode_images(self, images):
95
+ B = images.shape[0]
96
+ images = rearrange(images, 'b n c h w -> (b n) c h w')
97
+
98
+ dtype = next(self.vae.parameters()).dtype
99
+ images = (images - 0.5) * 2.0
100
+ posterior = self.vae.encode(images.to(dtype)).latent_dist
101
+ latents = posterior.sample() * self.vae.config.scaling_factor
102
+
103
+ latents = rearrange(latents, '(b n) c h w -> b n c h w', b=B)
104
+ return latents
105
+
106
+ @torch.no_grad()
107
+ def __call__(
108
+ self,
109
+ image: Image.Image = None,
110
+ prompt=None,
111
+ negative_prompt='watermark, ugly, deformed, noisy, blurry, low contrast',
112
+ *args,
113
+ num_images_per_prompt: Optional[int] = 1,
114
+ guidance_scale=2.0,
115
+ output_type: Optional[str] = "pil",
116
+ width=512,
117
+ height=512,
118
+ num_inference_steps=28,
119
+ return_dict=True,
120
+ **cached_condition,
121
+ ):
122
+ if image is None:
123
+ raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
124
+ assert not isinstance(image, torch.Tensor)
125
+
126
+ image = to_rgb_image(image)
127
+
128
+ image_vae = torch.tensor(np.array(image) / 255.0)
129
+ image_vae = image_vae.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(0)
130
+ image_vae = image_vae.to(device=self.vae.device, dtype=self.vae.dtype)
131
+
132
+ batch_size = image_vae.shape[0]
133
+ assert batch_size == 1
134
+ assert num_images_per_prompt == 1
135
+
136
+ ref_latents = self.encode_images(image_vae)
137
+
138
+ def convert_pil_list_to_tensor(images):
139
+ bg_c = [1., 1., 1.]
140
+ images_tensor = []
141
+ for batch_imgs in images:
142
+ view_imgs = []
143
+ for pil_img in batch_imgs:
144
+ img = numpy.asarray(pil_img, dtype=numpy.float32) / 255.
145
+ if img.shape[2] > 3:
146
+ alpha = img[:, :, 3:]
147
+ img = img[:, :, :3] * alpha + bg_c * (1 - alpha)
148
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).contiguous().half().to("cuda")
149
+ view_imgs.append(img)
150
+ view_imgs = torch.cat(view_imgs, dim=0)
151
+ images_tensor.append(view_imgs.unsqueeze(0))
152
+
153
+ images_tensor = torch.cat(images_tensor, dim=0)
154
+ return images_tensor
155
+
156
+ if "normal_imgs" in cached_condition:
157
+
158
+ if isinstance(cached_condition["normal_imgs"], List):
159
+ cached_condition["normal_imgs"] = convert_pil_list_to_tensor(cached_condition["normal_imgs"])
160
+
161
+ cached_condition['normal_imgs'] = self.encode_images(cached_condition["normal_imgs"])
162
+
163
+ if "position_imgs" in cached_condition:
164
+
165
+ if isinstance(cached_condition["position_imgs"], List):
166
+ cached_condition["position_imgs"] = convert_pil_list_to_tensor(cached_condition["position_imgs"])
167
+
168
+ cached_condition["position_imgs"] = self.encode_images(cached_condition["position_imgs"])
169
+
170
+ if 'camera_info_gen' in cached_condition:
171
+ camera_info = cached_condition['camera_info_gen'] # B,N
172
+ if isinstance(camera_info, List):
173
+ camera_info = torch.tensor(camera_info)
174
+ camera_info = camera_info.to(image_vae.device).to(torch.int64)
175
+ cached_condition['camera_info_gen'] = camera_info
176
+ if 'camera_info_ref' in cached_condition:
177
+ camera_info = cached_condition['camera_info_ref'] # B,N
178
+ if isinstance(camera_info, List):
179
+ camera_info = torch.tensor(camera_info)
180
+ camera_info = camera_info.to(image_vae.device).to(torch.int64)
181
+ cached_condition['camera_info_ref'] = camera_info
182
+
183
+ cached_condition['ref_latents'] = ref_latents
184
+
185
+ if guidance_scale > 1:
186
+ negative_ref_latents = torch.zeros_like(cached_condition['ref_latents'])
187
+ cached_condition['ref_latents'] = torch.cat([negative_ref_latents, cached_condition['ref_latents']])
188
+ cached_condition['ref_scale'] = torch.as_tensor([0.0, 1.0]).to(cached_condition['ref_latents'])
189
+ if "normal_imgs" in cached_condition:
190
+ cached_condition['normal_imgs'] = torch.cat(
191
+ (cached_condition['normal_imgs'], cached_condition['normal_imgs']))
192
+
193
+ if "position_imgs" in cached_condition:
194
+ cached_condition['position_imgs'] = torch.cat(
195
+ (cached_condition['position_imgs'], cached_condition['position_imgs']))
196
+
197
+ if 'position_maps' in cached_condition:
198
+ cached_condition['position_maps'] = torch.cat(
199
+ (cached_condition['position_maps'], cached_condition['position_maps']))
200
+
201
+ if 'camera_info_gen' in cached_condition:
202
+ cached_condition['camera_info_gen'] = torch.cat(
203
+ (cached_condition['camera_info_gen'], cached_condition['camera_info_gen']))
204
+ if 'camera_info_ref' in cached_condition:
205
+ cached_condition['camera_info_ref'] = torch.cat(
206
+ (cached_condition['camera_info_ref'], cached_condition['camera_info_ref']))
207
+
208
+ prompt_embeds = self.unet.learned_text_clip_gen.repeat(num_images_per_prompt, 1, 1)
209
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
210
+
211
+ latents: torch.Tensor = self.denoise(
212
+ None,
213
+ *args,
214
+ cross_attention_kwargs=None,
215
+ guidance_scale=guidance_scale,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ prompt_embeds=prompt_embeds,
218
+ negative_prompt_embeds=negative_prompt_embeds,
219
+ num_inference_steps=num_inference_steps,
220
+ output_type='latent',
221
+ width=width,
222
+ height=height,
223
+ **cached_condition
224
+ ).images
225
+
226
+ if not output_type == "latent":
227
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
228
+ else:
229
+ image = latents
230
+
231
+ image = self.image_processor.postprocess(image, output_type=output_type)
232
+ if not return_dict:
233
+ return (image,)
234
+
235
+ return ImagePipelineOutput(images=image)
236
+
237
+ def denoise(
238
+ self,
239
+ prompt: Union[str, List[str]] = None,
240
+ height: Optional[int] = None,
241
+ width: Optional[int] = None,
242
+ num_inference_steps: int = 50,
243
+ timesteps: List[int] = None,
244
+ sigmas: List[float] = None,
245
+ guidance_scale: float = 7.5,
246
+ negative_prompt: Optional[Union[str, List[str]]] = None,
247
+ num_images_per_prompt: Optional[int] = 1,
248
+ eta: float = 0.0,
249
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
250
+ latents: Optional[torch.Tensor] = None,
251
+ prompt_embeds: Optional[torch.Tensor] = None,
252
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
253
+ ip_adapter_image: Optional[PipelineImageInput] = None,
254
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
255
+ output_type: Optional[str] = "pil",
256
+ return_dict: bool = True,
257
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
258
+ guidance_rescale: float = 0.0,
259
+ clip_skip: Optional[int] = None,
260
+ callback_on_step_end: Optional[
261
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
262
+ ] = None,
263
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
264
+ **kwargs,
265
+ ):
266
+ r"""
267
+ The call function to the pipeline for generation.
268
+
269
+ Args:
270
+ prompt (`str` or `List[str]`, *optional*):
271
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
272
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
273
+ The height in pixels of the generated image.
274
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
275
+ The width in pixels of the generated image.
276
+ num_inference_steps (`int`, *optional*, defaults to 50):
277
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
278
+ expense of slower inference.
279
+ timesteps (`List[int]`, *optional*):
280
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
281
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
282
+ passed will be used. Must be in descending order.
283
+ sigmas (`List[float]`, *optional*):
284
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
285
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
286
+ will be used.
287
+ guidance_scale (`float`, *optional*, defaults to 7.5):
288
+ A higher guidance scale value encourages the model to generate images closely linked to the text
289
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
290
+ negative_prompt (`str` or `List[str]`, *optional*):
291
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
292
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
293
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
294
+ The number of images to generate per prompt.
295
+ eta (`float`, *optional*, defaults to 0.0):
296
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
297
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
298
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
299
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
300
+ generation deterministic.
301
+ latents (`torch.Tensor`, *optional*):
302
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
303
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
304
+ tensor is generated by sampling using the supplied random `generator`.
305
+ prompt_embeds (`torch.Tensor`, *optional*):
306
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
307
+ provided, text embeddings are generated from the `prompt` input argument.
308
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
309
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
310
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
311
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
312
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
313
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
314
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
315
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
316
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
317
+ output_type (`str`, *optional*, defaults to `"pil"`):
318
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
319
+ return_dict (`bool`, *optional*, defaults to `True`):
320
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
321
+ plain tuple.
322
+ cross_attention_kwargs (`dict`, *optional*):
323
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
324
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
325
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
326
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
327
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
328
+ using zero terminal SNR.
329
+ clip_skip (`int`, *optional*):
330
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
331
+ the output of the pre-final layer will be used for computing the prompt embeddings.
332
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
333
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
334
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
335
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
336
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
337
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
338
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
339
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
340
+ `._callback_tensor_inputs` attribute of your pipeline class.
341
+
342
+ Examples:
343
+
344
+ Returns:
345
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
346
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
347
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
348
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
349
+ "not-safe-for-work" (nsfw) content.
350
+ """
351
+
352
+ callback = kwargs.pop("callback", None)
353
+ callback_steps = kwargs.pop("callback_steps", None)
354
+
355
+ if callback is not None:
356
+ deprecate(
357
+ "callback",
358
+ "1.0.0",
359
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
360
+ )
361
+ if callback_steps is not None:
362
+ deprecate(
363
+ "callback_steps",
364
+ "1.0.0",
365
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
366
+ )
367
+
368
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
369
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
370
+
371
+ # 0. Default height and width to unet
372
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
373
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
374
+ # to deal with lora scaling and other possible forward hooks
375
+
376
+ # 1. Check inputs. Raise error if not correct
377
+ self.check_inputs(
378
+ prompt,
379
+ height,
380
+ width,
381
+ callback_steps,
382
+ negative_prompt,
383
+ prompt_embeds,
384
+ negative_prompt_embeds,
385
+ ip_adapter_image,
386
+ ip_adapter_image_embeds,
387
+ callback_on_step_end_tensor_inputs,
388
+ )
389
+
390
+ self._guidance_scale = guidance_scale
391
+ self._guidance_rescale = guidance_rescale
392
+ self._clip_skip = clip_skip
393
+ self._cross_attention_kwargs = cross_attention_kwargs
394
+ self._interrupt = False
395
+
396
+ # 2. Define call parameters
397
+ if prompt is not None and isinstance(prompt, str):
398
+ batch_size = 1
399
+ elif prompt is not None and isinstance(prompt, list):
400
+ batch_size = len(prompt)
401
+ else:
402
+ batch_size = prompt_embeds.shape[0]
403
+
404
+ device = self._execution_device
405
+
406
+ # 3. Encode input prompt
407
+ lora_scale = (
408
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
409
+ )
410
+
411
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
412
+ prompt,
413
+ device,
414
+ num_images_per_prompt,
415
+ self.do_classifier_free_guidance,
416
+ negative_prompt,
417
+ prompt_embeds=prompt_embeds,
418
+ negative_prompt_embeds=negative_prompt_embeds,
419
+ lora_scale=lora_scale,
420
+ clip_skip=self.clip_skip,
421
+ )
422
+
423
+ # For classifier free guidance, we need to do two forward passes.
424
+ # Here we concatenate the unconditional and text embeddings into a single batch
425
+ # to avoid doing two forward passes
426
+ if self.do_classifier_free_guidance:
427
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
428
+
429
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
430
+ image_embeds = self.prepare_ip_adapter_image_embeds(
431
+ ip_adapter_image,
432
+ ip_adapter_image_embeds,
433
+ device,
434
+ batch_size * num_images_per_prompt,
435
+ self.do_classifier_free_guidance,
436
+ )
437
+
438
+ # 4. Prepare timesteps
439
+ timesteps, num_inference_steps = retrieve_timesteps(
440
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
441
+ )
442
+ assert num_images_per_prompt == 1
443
+ # 5. Prepare latent variables
444
+ num_channels_latents = self.unet.config.in_channels
445
+ latents = self.prepare_latents(
446
+ batch_size * kwargs['num_in_batch'], # num_images_per_prompt,
447
+ num_channels_latents,
448
+ height,
449
+ width,
450
+ prompt_embeds.dtype,
451
+ device,
452
+ generator,
453
+ latents,
454
+ )
455
+
456
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
457
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
458
+
459
+ # 6.1 Add image embeds for IP-Adapter
460
+ added_cond_kwargs = (
461
+ {"image_embeds": image_embeds}
462
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
463
+ else None
464
+ )
465
+
466
+ # 6.2 Optionally get Guidance Scale Embedding
467
+ timestep_cond = None
468
+ if self.unet.config.time_cond_proj_dim is not None:
469
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
470
+ timestep_cond = self.get_guidance_scale_embedding(
471
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
472
+ ).to(device=device, dtype=latents.dtype)
473
+
474
+ # 7. Denoising loop
475
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
476
+ self._num_timesteps = len(timesteps)
477
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
478
+ for i, t in enumerate(timesteps):
479
+ if self.interrupt:
480
+ continue
481
+
482
+ # expand the latents if we are doing classifier free guidance
483
+ latents = rearrange(latents, '(b n) c h w -> b n c h w', n=kwargs['num_in_batch'])
484
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
485
+ latent_model_input = rearrange(latent_model_input, 'b n c h w -> (b n) c h w')
486
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
487
+ latent_model_input = rearrange(latent_model_input, '(b n) c h w ->b n c h w', n=kwargs['num_in_batch'])
488
+
489
+ # predict the noise residual
490
+
491
+ noise_pred = self.unet(
492
+ latent_model_input,
493
+ t,
494
+ encoder_hidden_states=prompt_embeds,
495
+ timestep_cond=timestep_cond,
496
+ cross_attention_kwargs=self.cross_attention_kwargs,
497
+ added_cond_kwargs=added_cond_kwargs,
498
+ return_dict=False, **kwargs
499
+ )[0]
500
+ latents = rearrange(latents, 'b n c h w -> (b n) c h w')
501
+ # perform guidance
502
+ if self.do_classifier_free_guidance:
503
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
504
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
505
+
506
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
507
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
508
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
509
+
510
+ # compute the previous noisy sample x_t -> x_t-1
511
+ latents = \
512
+ self.scheduler.step(noise_pred, t, latents[:, :num_channels_latents, :, :], **extra_step_kwargs,
513
+ return_dict=False)[0]
514
+
515
+ if callback_on_step_end is not None:
516
+ callback_kwargs = {}
517
+ for k in callback_on_step_end_tensor_inputs:
518
+ callback_kwargs[k] = locals()[k]
519
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
520
+
521
+ latents = callback_outputs.pop("latents", latents)
522
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
523
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
524
+
525
+ # call the callback, if provided
526
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
527
+ progress_bar.update()
528
+ if callback is not None and i % callback_steps == 0:
529
+ step_idx = i // getattr(self.scheduler, "order", 1)
530
+ callback(step_idx, t, latents)
531
+
532
+ if not output_type == "latent":
533
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
534
+ 0
535
+ ]
536
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
537
+ else:
538
+ image = latents
539
+ has_nsfw_concept = None
540
+
541
+ if has_nsfw_concept is None:
542
+ do_denormalize = [True] * image.shape[0]
543
+ else:
544
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
545
+
546
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
547
+
548
+ # Offload all models
549
+ self.maybe_free_model_hooks()
550
+
551
+ if not return_dict:
552
+ return (image, has_nsfw_concept)
553
+
554
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)