Diffusers
TalHach61 commited on
Commit
a845211
·
verified ·
1 Parent(s): e3e33ef

Delete pipeline_bria.py

Browse files
Files changed (1) hide show
  1. pipeline_bria.py +0 -576
pipeline_bria.py DELETED
@@ -1,576 +0,0 @@
1
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift
2
- from typing import Any, Callable, Dict, List, Optional, Union
3
-
4
- import torch
5
-
6
- from transformers import (
7
- T5EncoderModel,
8
- T5TokenizerFast,
9
- )
10
-
11
- from diffusers.image_processor import VaeImageProcessor
12
- from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler
13
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
- from diffusers.schedulers import KarrasDiffusionSchedulers
15
- from diffusers.loaders import FluxLoraLoaderMixin
16
- from diffusers.utils import (
17
- USE_PEFT_BACKEND,
18
- is_torch_xla_available,
19
- logging,
20
- replace_example_docstring,
21
- scale_lora_layers,
22
- unscale_lora_layers,
23
- )
24
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
- from transformer_bria import BriaTransformer2DModel
27
- from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
- import numpy as np
29
-
30
- if is_torch_xla_available():
31
- import torch_xla.core.xla_model as xm
32
-
33
- XLA_AVAILABLE = True
34
- else:
35
- XLA_AVAILABLE = False
36
-
37
-
38
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
-
40
- EXAMPLE_DOC_STRING = """
41
- Examples:
42
- ```py
43
- >>> import torch
44
- >>> from diffusers import StableDiffusion3Pipeline
45
-
46
- >>> pipe = StableDiffusion3Pipeline.from_pretrained(
47
- ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
48
- ... )
49
- >>> pipe.to("cuda")
50
- >>> prompt = "A cat holding a sign that says hello world"
51
- >>> image = pipe(prompt).images[0]
52
- >>> image.save("sd3.png")
53
- ```
54
- """
55
-
56
- T5_PRECISION = torch.float16
57
-
58
- """
59
- Based on FluxPipeline with several changes:
60
- - no pooled embeddings
61
- - We use zero padding for prompts
62
- - No guidance embedding since this is not a distilled version
63
- """
64
- class BriaPipeline(FluxPipeline):
65
- r"""
66
- Args:
67
- transformer ([`SD3Transformer2DModel`]):
68
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
69
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
71
- vae ([`AutoencoderKL`]):
72
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
73
- text_encoder ([`T5EncoderModel`]):
74
- Frozen text-encoder. Stable Diffusion 3 uses
75
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
76
- [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
77
- tokenizer (`T5TokenizerFast`):
78
- Tokenizer of class
79
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
80
- """
81
-
82
- def __init__(
83
- self,
84
- transformer: BriaTransformer2DModel,
85
- scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
86
- vae: AutoencoderKL,
87
- text_encoder: T5EncoderModel,
88
- tokenizer: T5TokenizerFast
89
- ):
90
- self.register_modules(
91
- vae=vae,
92
- text_encoder=text_encoder,
93
- tokenizer=tokenizer,
94
- transformer=transformer,
95
- scheduler=scheduler,
96
- )
97
-
98
- # TODO - why different than offical flux (-1)
99
- self.vae_scale_factor = (
100
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
101
- )
102
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
103
- self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
104
-
105
- # T5 is senstive to precision so we use the precision used for precompute and cast as needed
106
- self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
107
- for block in self.text_encoder.encoder.block:
108
- block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
109
-
110
- def encode_prompt(
111
- self,
112
- prompt: Union[str, List[str]],
113
- device: Optional[torch.device] = None,
114
- num_images_per_prompt: int = 1,
115
- do_classifier_free_guidance: bool = True,
116
- negative_prompt: Optional[Union[str, List[str]]] = None,
117
- prompt_embeds: Optional[torch.FloatTensor] = None,
118
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
119
- max_sequence_length: int = 128,
120
- lora_scale: Optional[float] = None,
121
- ):
122
- r"""
123
-
124
- Args:
125
- prompt (`str` or `List[str]`, *optional*):
126
- prompt to be encoded
127
- device: (`torch.device`):
128
- torch device
129
- num_images_per_prompt (`int`):
130
- number of images that should be generated per prompt
131
- do_classifier_free_guidance (`bool`):
132
- whether to use classifier free guidance or not
133
- negative_prompt (`str` or `List[str]`, *optional*):
134
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
135
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
136
- less than `1`).
137
- prompt_embeds (`torch.FloatTensor`, *optional*):
138
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
139
- provided, text embeddings will be generated from `prompt` input argument.
140
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
141
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
142
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
143
- argument.
144
- """
145
- device = device or self._execution_device
146
-
147
- # set lora scale so that monkey patched LoRA
148
- # function of text encoder can correctly access it
149
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
150
- self._lora_scale = lora_scale
151
-
152
- # dynamically adjust the LoRA scale
153
- if self.text_encoder is not None and USE_PEFT_BACKEND:
154
- scale_lora_layers(self.text_encoder, lora_scale)
155
-
156
- prompt = [prompt] if isinstance(prompt, str) else prompt
157
- if prompt is not None:
158
- batch_size = len(prompt)
159
- else:
160
- batch_size = prompt_embeds.shape[0]
161
-
162
- if prompt_embeds is None:
163
- prompt_embeds = get_t5_prompt_embeds(
164
- self.tokenizer,
165
- self.text_encoder,
166
- prompt=prompt,
167
- num_images_per_prompt=num_images_per_prompt,
168
- max_sequence_length=max_sequence_length,
169
- device=device,
170
- ).to(dtype=self.transformer.dtype)
171
-
172
- if do_classifier_free_guidance and negative_prompt_embeds is None:
173
- if not is_ng_none(negative_prompt):
174
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
175
-
176
- if prompt is not None and type(prompt) is not type(negative_prompt):
177
- raise TypeError(
178
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
179
- f" {type(prompt)}."
180
- )
181
- elif batch_size != len(negative_prompt):
182
- raise ValueError(
183
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
184
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
185
- " the batch size of `prompt`."
186
- )
187
-
188
- negative_prompt_embeds = get_t5_prompt_embeds(
189
- self.tokenizer,
190
- self.text_encoder,
191
- prompt=negative_prompt,
192
- num_images_per_prompt=num_images_per_prompt,
193
- max_sequence_length=max_sequence_length,
194
- device=device,
195
- ).to(dtype=self.transformer.dtype)
196
- else:
197
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
198
-
199
- if self.text_encoder is not None:
200
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
201
- # Retrieve the original scale by scaling back the LoRA layers
202
- unscale_lora_layers(self.text_encoder, lora_scale)
203
-
204
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
205
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
206
- text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
207
-
208
- return prompt_embeds, negative_prompt_embeds, text_ids
209
-
210
- @property
211
- def guidance_scale(self):
212
- return self._guidance_scale
213
-
214
-
215
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
216
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
217
- # corresponds to doing no classifier free guidance.
218
- @property
219
- def do_classifier_free_guidance(self):
220
- return self._guidance_scale > 1
221
-
222
- @property
223
- def joint_attention_kwargs(self):
224
- return self._joint_attention_kwargs
225
-
226
- @property
227
- def num_timesteps(self):
228
- return self._num_timesteps
229
-
230
- @property
231
- def interrupt(self):
232
- return self._interrupt
233
-
234
- @torch.no_grad()
235
- @replace_example_docstring(EXAMPLE_DOC_STRING)
236
- def __call__(
237
- self,
238
- prompt: Union[str, List[str]] = None,
239
- height: Optional[int] = None,
240
- width: Optional[int] = None,
241
- num_inference_steps: int = 30,
242
- timesteps: List[int] = None,
243
- guidance_scale: float = 5,
244
- negative_prompt: Optional[Union[str, List[str]]] = None,
245
- num_images_per_prompt: Optional[int] = 1,
246
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
247
- latents: Optional[torch.FloatTensor] = None,
248
- prompt_embeds: Optional[torch.FloatTensor] = None,
249
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
250
- output_type: Optional[str] = "pil",
251
- return_dict: bool = True,
252
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
253
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
254
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
255
- max_sequence_length: int = 128,
256
- clip_value:Union[None,float] = None,
257
- normalize:bool = False,
258
- ):
259
- r"""
260
- Function invoked when calling the pipeline for generation.
261
-
262
- Args:
263
- prompt (`str` or `List[str]`, *optional*):
264
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
265
- instead.
266
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
267
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
268
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
269
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
270
- num_inference_steps (`int`, *optional*, defaults to 50):
271
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
272
- expense of slower inference.
273
- timesteps (`List[int]`, *optional*):
274
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
275
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
276
- passed will be used. Must be in descending order.
277
- guidance_scale (`float`, *optional*, defaults to 5.0):
278
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
279
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
280
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
281
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
282
- usually at the expense of lower image quality.
283
- negative_prompt (`str` or `List[str]`, *optional*):
284
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
285
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
286
- less than `1`).
287
- num_images_per_prompt (`int`, *optional*, defaults to 1):
288
- The number of images to generate per prompt.
289
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
290
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
291
- to make generation deterministic.
292
- latents (`torch.FloatTensor`, *optional*):
293
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
294
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
295
- tensor will ge generated by sampling using the supplied random `generator`.
296
- prompt_embeds (`torch.FloatTensor`, *optional*):
297
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
298
- provided, text embeddings will be generated from `prompt` input argument.
299
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
300
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
301
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302
- argument.
303
- output_type (`str`, *optional*, defaults to `"pil"`):
304
- The output format of the generate image. Choose between
305
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
306
- return_dict (`bool`, *optional*, defaults to `True`):
307
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
308
- of a plain tuple.
309
- joint_attention_kwargs (`dict`, *optional*):
310
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
311
- `self.processor` in
312
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
313
- callback_on_step_end (`Callable`, *optional*):
314
- A function that calls at the end of each denoising steps during the inference. The function is called
315
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
316
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
317
- `callback_on_step_end_tensor_inputs`.
318
- callback_on_step_end_tensor_inputs (`List`, *optional*):
319
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
320
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
321
- `._callback_tensor_inputs` attribute of your pipeline class.
322
- max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
323
-
324
- Examples:
325
-
326
- Returns:
327
- [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
328
- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
329
- images.
330
- """
331
-
332
- height = height or self.default_sample_size * self.vae_scale_factor
333
- width = width or self.default_sample_size * self.vae_scale_factor
334
-
335
- # 1. Check inputs. Raise error if not correct
336
- self.check_inputs(
337
- prompt=prompt,
338
- height=height,
339
- width=width,
340
- prompt_embeds=prompt_embeds,
341
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
342
- max_sequence_length=max_sequence_length,
343
- )
344
-
345
- self._guidance_scale = guidance_scale
346
- self._joint_attention_kwargs = joint_attention_kwargs
347
- self._interrupt = False
348
-
349
- # 2. Define call parameters
350
- if prompt is not None and isinstance(prompt, str):
351
- batch_size = 1
352
- elif prompt is not None and isinstance(prompt, list):
353
- batch_size = len(prompt)
354
- else:
355
- batch_size = prompt_embeds.shape[0]
356
-
357
- device = self._execution_device
358
-
359
- lora_scale = (
360
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
361
- )
362
-
363
- (
364
- prompt_embeds,
365
- negative_prompt_embeds,
366
- text_ids
367
- ) = self.encode_prompt(
368
- prompt=prompt,
369
- negative_prompt=negative_prompt,
370
- do_classifier_free_guidance=self.do_classifier_free_guidance,
371
- prompt_embeds=prompt_embeds,
372
- negative_prompt_embeds=negative_prompt_embeds,
373
- device=device,
374
- num_images_per_prompt=num_images_per_prompt,
375
- max_sequence_length=max_sequence_length,
376
- lora_scale=lora_scale,
377
- )
378
-
379
- if self.do_classifier_free_guidance:
380
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
381
-
382
-
383
-
384
- # 5. Prepare latent variables
385
- num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
386
- latents, latent_image_ids = self.prepare_latents(
387
- batch_size * num_images_per_prompt,
388
- num_channels_latents,
389
- height,
390
- width,
391
- prompt_embeds.dtype,
392
- device,
393
- generator,
394
- latents,
395
- )
396
-
397
- if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
398
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
399
- image_seq_len = latents.shape[1] # Shift by height - Why just height?
400
- print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
401
-
402
- mu = calculate_shift(
403
- image_seq_len,
404
- self.scheduler.config.base_image_seq_len,
405
- self.scheduler.config.max_image_seq_len,
406
- self.scheduler.config.base_shift,
407
- self.scheduler.config.max_shift,
408
- )
409
- timesteps, num_inference_steps = retrieve_timesteps(
410
- self.scheduler,
411
- num_inference_steps,
412
- device,
413
- timesteps,
414
- sigmas,
415
- mu=mu,
416
- )
417
- else:
418
- # 4. Prepare timesteps
419
- # Sample from training sigmas
420
- if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
421
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
422
- else:
423
- sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
424
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
425
-
426
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
427
- self._num_timesteps = len(timesteps)
428
-
429
- # Supprot different diffusers versions
430
- if len(latent_image_ids.shape)==2:
431
- text_ids=text_ids.squeeze()
432
-
433
- # 6. Denoising loop
434
- with self.progress_bar(total=num_inference_steps) as progress_bar:
435
- for i, t in enumerate(timesteps):
436
- if self.interrupt:
437
- continue
438
-
439
- # expand the latents if we are doing classifier free guidance
440
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
441
- if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler:
442
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
443
-
444
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
445
- timestep = t.expand(latent_model_input.shape[0])
446
-
447
- # This is predicts "v" from flow-matching or eps from diffusion
448
- noise_pred = self.transformer(
449
- hidden_states=latent_model_input,
450
- timestep=timestep,
451
- encoder_hidden_states=prompt_embeds,
452
- joint_attention_kwargs=self.joint_attention_kwargs,
453
- return_dict=False,
454
- txt_ids=text_ids,
455
- img_ids=latent_image_ids,
456
- )[0]
457
-
458
- # perform guidance
459
- if self.do_classifier_free_guidance:
460
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
461
- cfg_noise_pred_text = noise_pred_text.std()
462
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
463
-
464
- if normalize:
465
- noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
466
-
467
- if clip_value:
468
- assert clip_value>0
469
- noise_pred = noise_pred.clip(-clip_value,clip_value)
470
-
471
- # compute the previous noisy sample x_t -> x_t-1
472
- latents_dtype = latents.dtype
473
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
474
-
475
- if latents.dtype != latents_dtype:
476
- if torch.backends.mps.is_available():
477
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
478
- latents = latents.to(latents_dtype)
479
-
480
- if callback_on_step_end is not None:
481
- callback_kwargs = {}
482
- for k in callback_on_step_end_tensor_inputs:
483
- callback_kwargs[k] = locals()[k]
484
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
485
-
486
- latents = callback_outputs.pop("latents", latents)
487
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
488
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
489
-
490
- # call the callback, if provided
491
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
492
- progress_bar.update()
493
-
494
- if XLA_AVAILABLE:
495
- xm.mark_step()
496
-
497
- if output_type == "latent":
498
- image = latents
499
-
500
- else:
501
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
502
- latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
503
- image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
504
- image = self.image_processor.postprocess(image, output_type=output_type)
505
-
506
- # Offload all models
507
- self.maybe_free_model_hooks()
508
-
509
- if not return_dict:
510
- return (image,)
511
-
512
- return FluxPipelineOutput(images=image)
513
-
514
- def check_inputs(
515
- self,
516
- prompt,
517
- height,
518
- width,
519
- negative_prompt=None,
520
- prompt_embeds=None,
521
- negative_prompt_embeds=None,
522
- callback_on_step_end_tensor_inputs=None,
523
- max_sequence_length=None,
524
- ):
525
- if height % 8 != 0 or width % 8 != 0:
526
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
527
-
528
- if callback_on_step_end_tensor_inputs is not None and not all(
529
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
530
- ):
531
- raise ValueError(
532
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
533
- )
534
-
535
- if prompt is not None and prompt_embeds is not None:
536
- raise ValueError(
537
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
538
- " only forward one of the two."
539
- )
540
- elif prompt is None and prompt_embeds is None:
541
- raise ValueError(
542
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
543
- )
544
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
545
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
546
-
547
- if negative_prompt is not None and negative_prompt_embeds is not None:
548
- raise ValueError(
549
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
550
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
551
- )
552
-
553
-
554
- if prompt_embeds is not None and negative_prompt_embeds is not None:
555
- if prompt_embeds.shape != negative_prompt_embeds.shape:
556
- raise ValueError(
557
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
558
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
559
- f" {negative_prompt_embeds.shape}."
560
- )
561
-
562
- if max_sequence_length is not None and max_sequence_length > 512:
563
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
564
-
565
- def to(self, *args, **kwargs):
566
- DiffusionPipeline.to(self, *args, **kwargs)
567
- # T5 is senstive to precision so we use the precision used for precompute and cast as needed
568
- self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
569
- for block in self.text_encoder.encoder.block:
570
- block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
571
-
572
- return self
573
-
574
-
575
-
576
-