rashidvyro commited on
Commit
3e442e7
·
1 Parent(s): 20542a8

Upload stable_diffusion_custom_v4_1.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_custom_v4_1.py +795 -0
stable_diffusion_custom_v4_1.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from diffusers import StableDiffusionPipeline
3
+ # from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
4
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput, AutoencoderKL, CLIPTextModel, CLIPTokenizer, UNet2DConditionModel, KarrasDiffusionSchedulers, StableDiffusionSafetyChecker, CLIPImageProcessor
5
+ from compel import Compel
6
+ from onediff.utils.tokenizer import TextualInversionLoaderMixin, MultiTokenCLIPTokenizer
7
+ import torch
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+ from dynamicprompts.generators import RandomPromptGenerator
10
+ import time
11
+ from compel import Compel
12
+ from onediff.utils.prompt_parser import ScheduledPromptConditioning
13
+ from onediff.utils.prompt_parser import get_learned_conditioning_prompt_schedules
14
+ from dynamicprompts.generators import RandomPromptGenerator
15
+ import tqdm
16
+ from cachetools import LRUCache
17
+ from onediff.utils.image_processor import VaeImageProcessor
18
+
19
+
20
+ class CustomStableDiffusionPipeline4_1(TextualInversionLoaderMixin, StableDiffusionPipeline):
21
+ def __init__(
22
+ self,
23
+ vae: AutoencoderKL,
24
+ text_encoder: CLIPTextModel,
25
+ tokenizer: CLIPTokenizer,
26
+ unet: UNet2DConditionModel,
27
+ scheduler: KarrasDiffusionSchedulers,
28
+ safety_checker: StableDiffusionSafetyChecker,
29
+ feature_extractor: CLIPImageProcessor,
30
+ requires_safety_checker: bool = True,
31
+ prompt_cache_size: int = 1024,
32
+ prompt_cache_ttl: int = 60 * 2,
33
+ ) -> None:
34
+ super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler,
35
+ safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker)
36
+
37
+ self.vae_scale_factor = 2 ** (
38
+ len(self.vae.config.block_out_channels) - 1)
39
+ self.image_processor = VaeImageProcessor(
40
+ vae_scale_factor=self.vae_scale_factor)
41
+ self.register_to_config(
42
+ requires_safety_checker=requires_safety_checker)
43
+
44
+ self.compel = Compel(tokenizer=self.tokenizer,
45
+ text_encoder=self.text_encoder, truncate_long_prompts=False)
46
+ self.cache = LRUCache(maxsize=prompt_cache_size)
47
+
48
+ self.cached_uc = [None, None]
49
+ self.cached_c = [None, None]
50
+
51
+ self.prompt_handler = None
52
+
53
+ def build_scheduled_cond(self, prompt, steps, key):
54
+ prompt_schedule = get_learned_conditioning_prompt_schedules([prompt], steps)[
55
+ 0]
56
+
57
+ cached = self.cache.get(key, None)
58
+ if cached is not None:
59
+ return cached
60
+
61
+ texts = [x[1] for x in prompt_schedule]
62
+ conds = [self.compel.build_conditioning_tensor(
63
+ text).to('cpu') for text in texts]
64
+
65
+ cond_schedule = []
66
+ for i, s in enumerate(prompt_schedule):
67
+ cond_schedule.append(ScheduledPromptConditioning(s[0], conds[i]))
68
+
69
+ self.cache[key] = cond_schedule
70
+ return cond_schedule
71
+
72
+ def initialize_magic_prompt_cache(self, pos_prompt_template: str, plain_prompt_template: str, neg_prompt_template: str, num_to_generate: int, steps: int):
73
+ r"""
74
+ Initializes the magic prompt cache for the forward pass.
75
+ Must be called immedaitely after Compel is loaded and embeds are initalized.
76
+ """
77
+ rpg = RandomPromptGenerator(ignore_whitespace=True, seed=555)
78
+ positive_prompts = rpg.generate(
79
+ template=pos_prompt_template, num_images=num_to_generate)
80
+ scheduled_conds = []
81
+ with torch.no_grad():
82
+ cache = {}
83
+ for i in tqdm.tqdm(range(len(positive_prompts))):
84
+ scheduled_conds.append(self.build_scheduled_cond(
85
+ positive_prompts[i], steps, cache))
86
+
87
+ plain_scheduled_cond = self.build_scheduled_cond(
88
+ plain_prompt_template, steps, cache)
89
+
90
+ scheduled_uncond = self.build_scheduled_cond(
91
+ neg_prompt_template, steps, cache)
92
+
93
+ self.scheduled_conds = scheduled_conds
94
+ self.plain_scheduled_cond = plain_scheduled_cond
95
+ self.scheduled_uncond = scheduled_uncond
96
+
97
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
98
+ r"""
99
+ Encodes the prompt into text encoder hidden states.
100
+
101
+ Args:
102
+ prompt (`str` or `list(int)`):
103
+ prompt to be encoded
104
+ device: (`torch.device`):
105
+ torch device
106
+ num_images_per_prompt (`int`):
107
+ number of images that should be generated per prompt
108
+ do_classifier_free_guidance (`bool`):
109
+ whether to use classifier free guidance or not
110
+ negative_prompt (`str` or `List[str]`):
111
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
112
+ if `guidance_scale` is less than `1`).
113
+ """
114
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
115
+
116
+ text_inputs = self.tokenizer(
117
+ prompt,
118
+ padding="max_length",
119
+ max_length=self.tokenizer.model_max_length,
120
+ truncation=True,
121
+ return_tensors="np",
122
+ )
123
+ text_input_ids = text_inputs.input_ids
124
+ text_input_ids = torch.from_numpy(text_input_ids)
125
+ untruncated_ids = self.tokenizer(
126
+ prompt, padding="max_length", return_tensors="np").input_ids
127
+ untruncated_ids = torch.from_numpy(untruncated_ids)
128
+
129
+ if (
130
+ text_input_ids.shape == untruncated_ids.shape
131
+ and text_input_ids.numel() == untruncated_ids.numel()
132
+ and not torch.equal(text_input_ids, untruncated_ids)
133
+ ):
134
+ removed_text = self.tokenizer.batch_decode(
135
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
136
+ logger.warning(
137
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
138
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
139
+ )
140
+
141
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
142
+ attention_mask = text_inputs.attention_mask.to(device)
143
+ else:
144
+ attention_mask = None
145
+
146
+ text_embeddings = self.text_encoder(
147
+ text_input_ids.to(device), attention_mask=attention_mask)
148
+ text_embeddings = text_embeddings[0]
149
+
150
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
151
+ bs_embed, seq_len, _ = text_embeddings.shape
152
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
153
+ text_embeddings = text_embeddings.view(
154
+ bs_embed * num_images_per_prompt, seq_len, -1)
155
+
156
+ # get unconditional embeddings for classifier free guidance
157
+ if do_classifier_free_guidance:
158
+ uncond_tokens: List[str]
159
+ if negative_prompt is None:
160
+ uncond_tokens = [""] * batch_size
161
+ elif type(prompt) is not type(negative_prompt):
162
+ raise TypeError(
163
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
164
+ f" {type(prompt)}."
165
+ )
166
+ elif isinstance(negative_prompt, str):
167
+ uncond_tokens = [negative_prompt]
168
+ elif batch_size != len(negative_prompt):
169
+ raise ValueError(
170
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
171
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
172
+ " the batch size of `prompt`."
173
+ )
174
+ else:
175
+ uncond_tokens = negative_prompt
176
+
177
+ max_length = text_input_ids.shape[-1]
178
+ uncond_input = self.tokenizer(
179
+ uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np",
180
+ )
181
+
182
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
183
+ attention_mask = torch.from_numpy(
184
+ uncond_input.attention_mask).to(device)
185
+ else:
186
+ attention_mask = None
187
+
188
+ uncond_embeddings = self.text_encoder(
189
+ torch.from_numpy(uncond_input.input_ids).to(device), attention_mask=attention_mask,
190
+ )
191
+ uncond_embeddings = uncond_embeddings[0]
192
+
193
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
194
+ seq_len = uncond_embeddings.shape[1]
195
+ uncond_embeddings = uncond_embeddings.repeat(
196
+ 1, num_images_per_prompt, 1)
197
+ uncond_embeddings = uncond_embeddings.view(
198
+ batch_size * num_images_per_prompt, seq_len, -1)
199
+
200
+ # For classifier free guidance, we need to do two forward passes.
201
+ # Here we concatenate the unconditional and text embeddings into a single batch
202
+ # to avoid doing two forward passes
203
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
204
+
205
+ return text_embeddings
206
+
207
+ def _encode_promptv2(
208
+ self,
209
+ prompt,
210
+ device,
211
+ num_images_per_prompt,
212
+ do_classifier_free_guidance,
213
+ negative_prompt=None,
214
+ prompt_embeds: Optional[torch.FloatTensor] = None,
215
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
216
+ ):
217
+
218
+ if prompt is not None and isinstance(prompt, str):
219
+ batch_size = 1
220
+ elif prompt is not None and isinstance(prompt, list):
221
+ batch_size = len(prompt)
222
+ else:
223
+ batch_size = prompt_embeds.shape[0]
224
+
225
+ if prompt_embeds is None:
226
+ text_inputs = self.tokenizer(
227
+ prompt,
228
+ padding="max_length",
229
+ max_length=self.tokenizer.model_max_length,
230
+ truncation=True,
231
+ return_tensors="pt",
232
+ )
233
+ text_input_ids = text_inputs.input_ids
234
+ untruncated_ids = self.tokenizer(
235
+ prompt, padding="longest", return_tensors="pt").input_ids
236
+
237
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
238
+ text_input_ids, untruncated_ids
239
+ ):
240
+ removed_text = self.tokenizer.batch_decode(
241
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
242
+ )
243
+
244
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
245
+ attention_mask = text_inputs.attention_mask.to(device)
246
+ else:
247
+ attention_mask = None
248
+
249
+ prompt_embeds = self.text_encoder(
250
+ text_input_ids.to(device),
251
+ attention_mask=attention_mask,
252
+ )
253
+ prompt_embeds = prompt_embeds[0]
254
+
255
+ prompt_embeds = prompt_embeds.to(
256
+ dtype=self.text_encoder.dtype, device=device)
257
+
258
+ bs_embed, seq_len, _ = prompt_embeds.shape
259
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
260
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
261
+ prompt_embeds = prompt_embeds.view(
262
+ bs_embed * num_images_per_prompt, seq_len, -1)
263
+
264
+ # get unconditional embeddings for classifier free guidance
265
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
266
+ uncond_tokens: List[str]
267
+ if negative_prompt is None:
268
+ uncond_tokens = [""] * batch_size
269
+ elif type(prompt) is not type(negative_prompt):
270
+ raise TypeError(
271
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
272
+ f" {type(prompt)}."
273
+ )
274
+ elif isinstance(negative_prompt, str):
275
+ uncond_tokens = [negative_prompt]
276
+ elif batch_size != len(negative_prompt):
277
+ raise ValueError(
278
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
279
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
280
+ " the batch size of `prompt`."
281
+ )
282
+ else:
283
+ uncond_tokens = negative_prompt
284
+
285
+ max_length = prompt_embeds.shape[1]
286
+ uncond_input = self.tokenizer(
287
+ uncond_tokens,
288
+ padding="max_length",
289
+ max_length=max_length,
290
+ truncation=True,
291
+ return_tensors="pt",
292
+ )
293
+
294
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
295
+ attention_mask = uncond_input.attention_mask.to(device)
296
+ else:
297
+ attention_mask = None
298
+
299
+ negative_prompt_embeds = self.text_encoder(
300
+ uncond_input.input_ids.to(device),
301
+ attention_mask=attention_mask,
302
+ )
303
+ negative_prompt_embeds = negative_prompt_embeds[0]
304
+
305
+ if do_classifier_free_guidance:
306
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
307
+ seq_len = negative_prompt_embeds.shape[1]
308
+
309
+ negative_prompt_embeds = negative_prompt_embeds.to(
310
+ dtype=self.text_encoder.dtype, device=device)
311
+
312
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
313
+ 1, num_images_per_prompt, 1)
314
+ negative_prompt_embeds = negative_prompt_embeds.view(
315
+ batch_size * num_images_per_prompt, seq_len, -1)
316
+
317
+ negative_prompt_embeds, prompt_embeds = self.compel.pad_conditioning_tensors_to_same_length(
318
+ [negative_prompt_embeds, prompt_embeds])
319
+ # For classifier free guidance, we need to do two forward passes.
320
+ # Here we concatenate the unconditional and text embeddings into a single batch
321
+ # to avoid doing two forward passes
322
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
323
+
324
+ return prompt_embeds
325
+
326
+ def _pyramid_noise_like(self, noise, device, seed, iterations=6, discount=0.4):
327
+ gen = torch.manual_seed(seed)
328
+ # EDIT: w and h get over-written, rename for a different variant!
329
+ b, c, w, h = noise.shape
330
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
331
+ for i in range(iterations):
332
+ r = random.random() * 2 + 2 # Rather than always going 2x,
333
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
334
+ noise += u(torch.randn(b, c, wn, hn,
335
+ generator=gen).to(device)) * discount**i
336
+ if wn == 1 or hn == 1:
337
+ break # Lowest resolution is 1x1
338
+ return noise / noise.std() # Scaled back to roughly unit variance
339
+
340
+ @torch.no_grad()
341
+ def inferV4(
342
+ self,
343
+ prompt: Union[str, List[str]],
344
+ height: Optional[int] = None,
345
+ width: Optional[int] = None,
346
+ num_inference_steps: int = 50,
347
+ guidance_scale: float = 7.5,
348
+ negative_prompt: Optional[Union[str, List[str]]] = None,
349
+ num_images_per_prompt: Optional[int] = 1,
350
+ eta: float = 0.0,
351
+ generator: Optional[torch.Generator] = None,
352
+ latents: Optional[torch.FloatTensor] = None,
353
+ output_type: Optional[str] = "pil",
354
+ return_dict: bool = True,
355
+ callback: Optional[Callable[[
356
+ int, int, torch.FloatTensor], None]] = None,
357
+ callback_steps: Optional[int] = 1,
358
+ compile_unet: bool = True,
359
+ compile_vae: bool = True,
360
+ compile_tenc: bool = True,
361
+ max_tokens=0,
362
+ seed=-1,
363
+ flags=[],
364
+ og_prompt=None,
365
+ og_neg_prompt=None,
366
+ disc=0.4,
367
+ iter=6,
368
+ pyramid=0, # disabled by default unless specified
369
+ ):
370
+ r"""
371
+ Function invoked when calling the pipeline for generation.
372
+
373
+ Args:
374
+ prompt (`str` or `List[str]`):
375
+ The prompt or prompts to guide the image generation.
376
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
377
+ The height in pixels of the generated image.
378
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
379
+ The width in pixels of the generated image.
380
+ num_inference_steps (`int`, *optional*, defaults to 50):
381
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
382
+ expense of slower inference.
383
+ guidance_scale (`float`, *optional*, defaults to 7.5):
384
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
385
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
386
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
387
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
388
+ usually at the expense of lower image quality.
389
+ negative_prompt (`str` or `List[str]`, *optional*):
390
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
391
+ if `guidance_scale` is less than `1`).
392
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
393
+ The number of images to generate per prompt.
394
+ eta (`float`, *optional*, defaults to 0.0):
395
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
396
+ [`schedulers.DDIMScheduler`], will be ignored for others.
397
+ generator (`torch.Generator`, *optional*):
398
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
399
+ deterministic.
400
+ latents (`torch.FloatTensor`, *optional*):
401
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
402
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
403
+ tensor will ge generated by sampling using the supplied random `generator`.
404
+ output_type (`str`, *optional*, defaults to `"pil"`):
405
+ The output format of the generate image. Choose between
406
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
407
+ return_dict (`bool`, *optional*, defaults to `True`):
408
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
409
+ plain tuple.
410
+ callback (`Callable`, *optional*):
411
+ A function that will be called every `callback_steps` steps during inference. The function will be
412
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
413
+ callback_steps (`int`, *optional*, defaults to 1):
414
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
415
+ called at every step.
416
+
417
+ Returns:
418
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
419
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
420
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
421
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
422
+ (nsfw) content, according to the `safety_checker`.
423
+ """
424
+ # 0. Default height and width to unet
425
+
426
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
427
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
428
+
429
+ self.check_inputs(prompt, height, width, callback_steps)
430
+ if negative_prompt == None:
431
+ negative_prompt = ['']
432
+ # 2. Define call parameters
433
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
434
+ device = self._execution_device
435
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
436
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
437
+ # corresponds to doing no classifier free guidance.
438
+ do_classifier_free_guidance = guidance_scale > 1.0
439
+
440
+ # # 3. Encode input prompt
441
+
442
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
443
+ timesteps = self.scheduler.timesteps
444
+
445
+ # Cache key for flags
446
+ plain = "plain" in flags
447
+ flair = None
448
+ for flag in flags:
449
+ if "flair" in flag:
450
+ flair = flag
451
+ break
452
+
453
+ with torch.no_grad():
454
+ c_time = time.time()
455
+ user_cond = self.build_scheduled_cond(
456
+ prompt[0], num_inference_steps, ('pos', og_prompt, seed, plain, flair))
457
+ c_time = time.time()
458
+ user_uncond = self.build_scheduled_cond(
459
+ negative_prompt[0], num_inference_steps, ('neg', negative_prompt[0], 0))
460
+
461
+ c = []
462
+ c.extend(user_cond)
463
+ uc = []
464
+ uc.extend(user_uncond)
465
+ max_token_count = 0
466
+
467
+ for cond in uc:
468
+ if cond.cond.shape[1] > max_token_count:
469
+ max_token_count = cond.cond.shape[1]
470
+ for cond in c:
471
+ if cond.cond.shape[1] > max_token_count:
472
+ max_token_count = cond.cond.shape[1]
473
+
474
+ def pad_tensor(conditionings: List[ScheduledPromptConditioning], max_token_count: int) -> List[ScheduledPromptConditioning]:
475
+
476
+ c0_shape = conditionings[0].cond.shape
477
+ if not all([len(c.cond.shape) == len(c0_shape) for c in conditionings]):
478
+ raise ValueError(
479
+ "Conditioning tensors must all have either 2 dimensions (unbatched) or 3 dimensions (batched)")
480
+
481
+ if len(c0_shape) == 2:
482
+ # need to be unsqueezed
483
+ for c in conditionings:
484
+ c.cond = c.cond.unsqueeze(0)
485
+ c0_shape = conditionings[0].cond.shape
486
+ if len(c0_shape) != 3:
487
+ raise ValueError(
488
+ f"All conditioning tensors must have the same number of dimensions (2 or 3)")
489
+
490
+ if not all([c.cond.shape[0] == c0_shape[0] and c.cond.shape[2] == c0_shape[2] for c in conditionings]):
491
+ raise ValueError(
492
+ f"All conditioning tensors must have the same batch size ({c0_shape[0]}) and number of embeddings per token ({c0_shape[1]}")
493
+
494
+ # if necessary, pad shorter tensors out with an emptystring tensor
495
+ empty_z = torch.cat(
496
+ [self.compel.build_conditioning_tensor("")] * c0_shape[0])
497
+ for i, c in enumerate(conditionings):
498
+ cond = c.cond.to(self.device)
499
+ while cond.shape[1] < max_token_count:
500
+ cond = torch.cat([cond, empty_z], dim=1)
501
+ conditionings[i] = ScheduledPromptConditioning(
502
+ c.end_at_step, cond)
503
+ return conditionings
504
+
505
+ uc = pad_tensor(uc, max_token_count)
506
+ c = pad_tensor(c, max_token_count)
507
+
508
+ next_uc = uc.pop(0)
509
+ next_c = c.pop(0)
510
+ prompt_embeds = None
511
+ new_embeds = True
512
+ embed_per_step = []
513
+ for i in range(len(timesteps)):
514
+ if i > next_uc.end_at_step:
515
+ next_uc = uc.pop(0)
516
+ new_embeds = True
517
+ if i > next_c.end_at_step:
518
+ next_c = c.pop(0)
519
+ new_embeds = True
520
+
521
+ if new_embeds:
522
+ negative_prompt_embeds, prompt_embeds = self.compel.pad_conditioning_tensors_to_same_length([
523
+ next_uc.cond, next_c.cond])
524
+ prompt_embeds = torch.cat(
525
+ [negative_prompt_embeds, prompt_embeds])
526
+ new_embeds = False
527
+
528
+ embed_per_step.append(prompt_embeds)
529
+
530
+ # 5. Prepare latent variables
531
+ num_channels_latents = self.unet.in_channels
532
+ latents = self.prepare_latents(
533
+ batch_size * num_images_per_prompt,
534
+ num_channels_latents,
535
+ height,
536
+ width,
537
+ prompt_embeds.dtype,
538
+ device,
539
+ generator,
540
+ latents,
541
+ )
542
+
543
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
544
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
545
+
546
+ # 7. Denoising loop
547
+ num_warmup_steps = len(timesteps) - \
548
+ num_inference_steps * self.scheduler.order
549
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
550
+ for i, t in enumerate(timesteps):
551
+ # expand the latents if we are doing classifier free guidance
552
+ latent_model_input = torch.cat(
553
+ [latents] * 2) if do_classifier_free_guidance else latents
554
+ latent_model_input = self.scheduler.scale_model_input(
555
+ latent_model_input, t)
556
+
557
+ prompt_embeds = embed_per_step[i]
558
+ # predict the noise residual
559
+
560
+ noise_pred = self.unet(
561
+ latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
562
+
563
+ # perform guidance
564
+ if do_classifier_free_guidance:
565
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
566
+ noise_pred = noise_pred_uncond + guidance_scale * \
567
+ (noise_pred_text - noise_pred_uncond)
568
+
569
+ if (i < pyramid*num_inference_steps):
570
+ noise_pred = self._pyramid_noise_like(
571
+ noise_pred, device, seed, iterations=iter, discount=disc)
572
+
573
+ # compute the previous noisy sample x_t -> x_t-1
574
+ latents = self.scheduler.step(
575
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
576
+
577
+ # call the callback, if provided
578
+ if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
579
+ progress_bar.update()
580
+ if callback is not None and i % callback_steps == 0:
581
+ callback(i, t, latents)
582
+
583
+ if not output_type == "latent":
584
+ image = self.vae.decode(
585
+ latents / self.vae.config.scaling_factor, return_dict=False)[0]
586
+ image, has_nsfw_concept = self.run_safety_checker(
587
+ image, device, prompt_embeds.dtype)
588
+ else:
589
+ image = latents
590
+ has_nsfw_concept = None
591
+
592
+ if has_nsfw_concept is None:
593
+ do_denormalize = [True] * image.shape[0]
594
+ else:
595
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
596
+
597
+ image = self.image_processor.postprocess(
598
+ image, output_type=output_type, do_denormalize=do_denormalize)
599
+
600
+ # Offload last model to CPU
601
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
602
+ self.final_offload_hook.offload()
603
+
604
+ if not return_dict:
605
+ return (image, has_nsfw_concept)
606
+
607
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
608
+
609
+ @torch.no_grad()
610
+ def inferPipe(
611
+ self,
612
+ prompt: Union[str, List[str]] = None,
613
+ height: Optional[int] = None,
614
+ width: Optional[int] = None,
615
+ num_inference_steps: int = 50,
616
+ guidance_scale: float = 7.5,
617
+ negative_prompt: Optional[Union[str, List[str]]] = None,
618
+ num_images_per_prompt: Optional[int] = 1,
619
+ eta: float = 0.0,
620
+ generator: Optional[Union[torch.Generator,
621
+ List[torch.Generator]]] = None,
622
+ latents: Optional[torch.FloatTensor] = None,
623
+ prompt_embeds: Optional[torch.FloatTensor] = None,
624
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
625
+ output_type: Optional[str] = "pil",
626
+ return_dict: bool = True,
627
+ callback: Optional[Callable[[
628
+ int, int, torch.FloatTensor], None]] = None,
629
+ callback_steps: int = 1,
630
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
631
+ ):
632
+ r"""
633
+ Function invoked when calling the pipeline for generation.
634
+
635
+ Args:
636
+ prompt (`str` or `List[str]`, *optional*):
637
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
638
+ instead.
639
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
640
+ The height in pixels of the generated image.
641
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
642
+ The width in pixels of the generated image.
643
+ num_inference_steps (`int`, *optional*, defaults to 50):
644
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
645
+ expense of slower inference.
646
+ guidance_scale (`float`, *optional*, defaults to 7.5):
647
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
648
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
649
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
650
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
651
+ usually at the expense of lower image quality.
652
+ negative_prompt (`str` or `List[str]`, *optional*):
653
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
654
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
655
+ less than `1`).
656
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
657
+ The number of images to generate per prompt.
658
+ eta (`float`, *optional*, defaults to 0.0):
659
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
660
+ [`schedulers.DDIMScheduler`], will be ignored for others.
661
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
662
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
663
+ to make generation deterministic.
664
+ latents (`torch.FloatTensor`, *optional*):
665
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
666
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
667
+ tensor will ge generated by sampling using the supplied random `generator`.
668
+ prompt_embeds (`torch.FloatTensor`, *optional*):
669
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
670
+ provided, text embeddings will be generated from `prompt` input argument.
671
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
672
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
673
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
674
+ argument.
675
+ output_type (`str`, *optional*, defaults to `"pil"`):
676
+ The output format of the generate image. Choose between
677
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
678
+ return_dict (`bool`, *optional*, defaults to `True`):
679
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
680
+ plain tuple.
681
+ callback (`Callable`, *optional*):
682
+ A function that will be called every `callback_steps` steps during inference. The function will be
683
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
684
+ callback_steps (`int`, *optional*, defaults to 1):
685
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
686
+ called at every step.
687
+ cross_attention_kwargs (`dict`, *optional*):
688
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
689
+ `self.processor` in
690
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
691
+
692
+ Examples:
693
+
694
+ Returns:
695
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
696
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
697
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
698
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
699
+ (nsfw) content, according to the `safety_checker`.
700
+ """
701
+ # 0. Default height and width to unet
702
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
703
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
704
+
705
+ # 1. Check inputs. Raise error if not correct
706
+ self.check_inputs(prompt, height, width, callback_steps)
707
+
708
+ # 2. Define call parameters
709
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
710
+ device = self._execution_device
711
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
712
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
713
+ # corresponds to doing no classifier free guidance.
714
+ do_classifier_free_guidance = guidance_scale > 1.0
715
+
716
+ # 3. Encode input prompt
717
+ text_embeddings = self._encode_prompt(
718
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
719
+ )
720
+
721
+ # 4. Prepare timesteps
722
+ self.scheduler.set_timesteps(num_inference_steps)
723
+ timesteps = self.scheduler.timesteps
724
+
725
+ # 5. Prepare latent variables
726
+ num_channels_latents = self.unet.in_channels
727
+ latents = self.prepare_latents(
728
+ batch_size * num_images_per_prompt,
729
+ num_channels_latents,
730
+ height,
731
+ width,
732
+ text_embeddings.dtype,
733
+ device,
734
+ generator,
735
+ latents,
736
+ )
737
+
738
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
739
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
740
+
741
+ # 7. Denoising loop
742
+ num_warmup_steps = len(timesteps) - \
743
+ num_inference_steps * self.scheduler.order
744
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
745
+ for i, t in enumerate(timesteps):
746
+ # expand the latents if we are doing classifier free guidance
747
+ latent_model_input = torch.cat(
748
+ [latents] * 2) if do_classifier_free_guidance else latents
749
+ latent_model_input = self.scheduler.scale_model_input(
750
+ latent_model_input, t)
751
+
752
+ noise_pred = self.unet(
753
+ latent_model_input, t, encoder_hidden_states=text_embeddings).sample
754
+
755
+ # perform guidance
756
+ if do_classifier_free_guidance:
757
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
758
+ noise_pred = noise_pred_uncond + guidance_scale * \
759
+ (noise_pred_text - noise_pred_uncond)
760
+
761
+ # compute the previous noisy sample x_t -> x_t-1
762
+ latents = self.scheduler.step(
763
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
764
+
765
+ # call the callback, if provided
766
+ if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
767
+ progress_bar.update()
768
+ if callback is not None and i % callback_steps == 0:
769
+ callback(i, t, latents)
770
+
771
+ if not output_type == "latent":
772
+ image = self.vae.decode(
773
+ latents / self.vae.config.scaling_factor, return_dict=False)[0]
774
+ image, has_nsfw_concept = self.run_safety_checker(
775
+ image, device, text_embeddings.dtype)
776
+ else:
777
+ image = latents
778
+ has_nsfw_concept = None
779
+
780
+ if has_nsfw_concept is None:
781
+ do_denormalize = [True] * image.shape[0]
782
+ else:
783
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
784
+
785
+ image = self.image_processor.postprocess(
786
+ image, output_type=output_type, do_denormalize=do_denormalize)
787
+
788
+ # Offload last model to CPU
789
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
790
+ self.final_offload_hook.offload()
791
+
792
+ if not return_dict:
793
+ return (image, has_nsfw_concept)
794
+
795
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)