gokaygokay commited on
Commit
181bfb3
·
verified ·
1 Parent(s): b8eac85

Create flux_8bit_lora.py

Browse files
Files changed (1) hide show
  1. flux_8bit_lora.py +655 -0
flux_8bit_lora.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7
+
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.loaders import FluxLoraLoaderMixin
10
+ from diffusers.models.autoencoders import AutoencoderKL
11
+ from diffusers.models.transformers import FluxTransformer2DModel
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
13
+ from diffusers.utils import (
14
+ USE_PEFT_BACKEND,
15
+ make_image_grid,
16
+ scale_lora_layers,
17
+ unscale_lora_layers,
18
+ )
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
21
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
22
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
23
+
24
+ from optimum.quanto import quantize, qfloat8, freeze
25
+
26
+
27
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
28
+ r"""
29
+ The Flux pipeline for text-to-image generation.
30
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
31
+ Args:
32
+ transformer ([`FluxTransformer2DModel`]):
33
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
34
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
35
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
36
+ vae ([`AutoencoderKL`]):
37
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
38
+ text_encoder ([`CLIPTextModel`]):
39
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
40
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
41
+ text_encoder_2 ([`T5EncoderModel`]):
42
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
43
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
44
+ tokenizer (`CLIPTokenizer`):
45
+ Tokenizer of class
46
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
47
+ tokenizer_2 (`T5TokenizerFast`):
48
+ Second Tokenizer of class
49
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
50
+ """
51
+
52
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
53
+ _optional_components = []
54
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
55
+
56
+ def __init__(
57
+ self,
58
+ scheduler: FlowMatchEulerDiscreteScheduler,
59
+ vae: AutoencoderKL,
60
+ text_encoder: CLIPTextModel,
61
+ tokenizer: CLIPTokenizer,
62
+ text_encoder_2: T5EncoderModel,
63
+ tokenizer_2: T5TokenizerFast,
64
+ transformer: FluxTransformer2DModel,
65
+ ):
66
+ super().__init__()
67
+
68
+ self.register_modules(
69
+ vae=vae,
70
+ text_encoder=text_encoder,
71
+ text_encoder_2=text_encoder_2,
72
+ tokenizer=tokenizer,
73
+ tokenizer_2=tokenizer_2,
74
+ transformer=transformer,
75
+ scheduler=scheduler,
76
+ )
77
+ self.vae_scale_factor = (
78
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
79
+ )
80
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
81
+ self.tokenizer_max_length = (
82
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
83
+ )
84
+ self.default_sample_size = 64
85
+
86
+ def _get_t5_prompt_embeds(
87
+ self,
88
+ prompt: Union[str, List[str]] = None,
89
+ num_images_per_prompt: int = 1,
90
+ max_sequence_length: int = 512,
91
+ device: Optional[torch.device] = None,
92
+ dtype: Optional[torch.dtype] = None,
93
+ ):
94
+ device = device or self._execution_device
95
+ dtype = dtype or self.text_encoder.dtype
96
+
97
+ prompt = [prompt] if isinstance(prompt, str) else prompt
98
+ batch_size = len(prompt)
99
+
100
+ text_inputs = self.tokenizer_2(
101
+ prompt,
102
+ padding="max_length",
103
+ max_length=max_sequence_length,
104
+ truncation=True,
105
+ return_length=False,
106
+ return_overflowing_tokens=False,
107
+ return_tensors="pt",
108
+ )
109
+ text_input_ids = text_inputs.input_ids
110
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
111
+
112
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
113
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
114
+
115
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
116
+
117
+ dtype = self.text_encoder_2.dtype
118
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
119
+
120
+ _, seq_len, _ = prompt_embeds.shape
121
+
122
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
123
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
124
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
125
+
126
+ return prompt_embeds
127
+
128
+ def _get_clip_prompt_embeds(
129
+ self,
130
+ prompt: Union[str, List[str]],
131
+ num_images_per_prompt: int = 1,
132
+ device: Optional[torch.device] = None,
133
+ ):
134
+ device = device or self._execution_device
135
+
136
+ prompt = [prompt] if isinstance(prompt, str) else prompt
137
+ batch_size = len(prompt)
138
+
139
+ text_inputs = self.tokenizer(
140
+ prompt,
141
+ padding="max_length",
142
+ max_length=self.tokenizer_max_length,
143
+ truncation=True,
144
+ return_overflowing_tokens=False,
145
+ return_length=False,
146
+ return_tensors="pt",
147
+ )
148
+
149
+ text_input_ids = text_inputs.input_ids
150
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
151
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
152
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
153
+
154
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
155
+
156
+ # Use pooled output of CLIPTextModel
157
+ prompt_embeds = prompt_embeds.pooler_output
158
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
159
+
160
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
161
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
162
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
163
+
164
+ return prompt_embeds
165
+
166
+ def encode_prompt(
167
+ self,
168
+ prompt: Union[str, List[str]],
169
+ prompt_2: Union[str, List[str]],
170
+ device: Optional[torch.device] = None,
171
+ num_images_per_prompt: int = 1,
172
+ prompt_embeds: Optional[torch.FloatTensor] = None,
173
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
174
+ max_sequence_length: int = 512,
175
+ lora_scale: Optional[float] = None,
176
+ ):
177
+ r"""
178
+ Args:
179
+ prompt (`str` or `List[str]`, *optional*):
180
+ prompt to be encoded
181
+ prompt_2 (`str` or `List[str]`, *optional*):
182
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
183
+ used in all text-encoders
184
+ device: (`torch.device`):
185
+ torch device
186
+ num_images_per_prompt (`int`):
187
+ number of images that should be generated per prompt
188
+ prompt_embeds (`torch.FloatTensor`, *optional*):
189
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
190
+ provided, text embeddings will be generated from `prompt` input argument.
191
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
192
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
193
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
194
+ lora_scale (`float`, *optional*):
195
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
196
+ """
197
+ device = device or self._execution_device
198
+
199
+ # set lora scale so that monkey patched LoRA
200
+ # function of text encoder can correctly access it
201
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
202
+ self._lora_scale = lora_scale
203
+
204
+ # dynamically adjust the LoRA scale
205
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
206
+ scale_lora_layers(self.text_encoder, lora_scale)
207
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
208
+ scale_lora_layers(self.text_encoder_2, lora_scale)
209
+
210
+ prompt = [prompt] if isinstance(prompt, str) else prompt
211
+ if prompt is not None:
212
+ batch_size = len(prompt)
213
+ else:
214
+ batch_size = prompt_embeds.shape[0]
215
+
216
+ if prompt_embeds is None:
217
+ prompt_2 = prompt_2 or prompt
218
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
219
+
220
+ # We only use the pooled prompt output from the CLIPTextModel
221
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
222
+ prompt=prompt,
223
+ device=device,
224
+ num_images_per_prompt=num_images_per_prompt,
225
+ )
226
+ prompt_embeds = self._get_t5_prompt_embeds(
227
+ prompt=prompt_2,
228
+ num_images_per_prompt=num_images_per_prompt,
229
+ max_sequence_length=max_sequence_length,
230
+ device=device,
231
+ )
232
+
233
+ if self.text_encoder is not None:
234
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
235
+ # Retrieve the original scale by scaling back the LoRA layers
236
+ unscale_lora_layers(self.text_encoder, lora_scale)
237
+
238
+ if self.text_encoder_2 is not None:
239
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
240
+ # Retrieve the original scale by scaling back the LoRA layers
241
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
242
+
243
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
244
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
245
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
246
+
247
+ return prompt_embeds, pooled_prompt_embeds, text_ids
248
+
249
+ def check_inputs(
250
+ self,
251
+ prompt,
252
+ prompt_2,
253
+ height,
254
+ width,
255
+ prompt_embeds=None,
256
+ pooled_prompt_embeds=None,
257
+ callback_on_step_end_tensor_inputs=None,
258
+ max_sequence_length=None,
259
+ ):
260
+ if height % 8 != 0 or width % 8 != 0:
261
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
262
+
263
+ if callback_on_step_end_tensor_inputs is not None and not all(
264
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
265
+ ):
266
+ raise ValueError(
267
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
268
+ )
269
+
270
+ if prompt is not None and prompt_embeds is not None:
271
+ raise ValueError(
272
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
273
+ " only forward one of the two."
274
+ )
275
+ elif prompt_2 is not None and prompt_embeds is not None:
276
+ raise ValueError(
277
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
278
+ " only forward one of the two."
279
+ )
280
+ elif prompt is None and prompt_embeds is None:
281
+ raise ValueError(
282
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
283
+ )
284
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
285
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
286
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
287
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
288
+
289
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
290
+ raise ValueError(
291
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
292
+ )
293
+
294
+ if max_sequence_length is not None and max_sequence_length > 512:
295
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
296
+
297
+ @staticmethod
298
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
299
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
300
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
301
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
302
+
303
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
304
+
305
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
306
+ latent_image_ids = latent_image_ids.reshape(
307
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
308
+ )
309
+
310
+ return latent_image_ids.to(device=device, dtype=dtype)
311
+
312
+ @staticmethod
313
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
314
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
315
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
316
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
317
+
318
+ return latents
319
+
320
+ @staticmethod
321
+ def _unpack_latents(latents, height, width, vae_scale_factor):
322
+ batch_size, num_patches, channels = latents.shape
323
+
324
+ height = height // vae_scale_factor
325
+ width = width // vae_scale_factor
326
+
327
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
328
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
329
+
330
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
331
+
332
+ return latents
333
+
334
+ def prepare_latents(
335
+ self,
336
+ batch_size,
337
+ num_channels_latents,
338
+ height,
339
+ width,
340
+ dtype,
341
+ device,
342
+ generator,
343
+ latents=None,
344
+ ):
345
+ height = 2 * (int(height) // self.vae_scale_factor)
346
+ width = 2 * (int(width) // self.vae_scale_factor)
347
+
348
+ shape = (batch_size, num_channels_latents, height, width)
349
+
350
+ if latents is not None:
351
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
352
+ return latents.to(device=device, dtype=dtype), latent_image_ids
353
+
354
+ if isinstance(generator, list) and len(generator) != batch_size:
355
+ raise ValueError(
356
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
357
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
358
+ )
359
+
360
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
361
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
362
+
363
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
364
+
365
+ return latents, latent_image_ids
366
+
367
+ @property
368
+ def guidance_scale(self):
369
+ return self._guidance_scale
370
+
371
+ @property
372
+ def joint_attention_kwargs(self):
373
+ return self._joint_attention_kwargs
374
+
375
+ @property
376
+ def num_timesteps(self):
377
+ return self._num_timesteps
378
+
379
+ @property
380
+ def interrupt(self):
381
+ return self._interrupt
382
+
383
+ @torch.no_grad()
384
+ def __call__(
385
+ self,
386
+ prompt: Union[str, List[str]] = None,
387
+ prompt_2: Optional[Union[str, List[str]]] = None,
388
+ height: Optional[int] = None,
389
+ width: Optional[int] = None,
390
+ num_inference_steps: int = 28,
391
+ timesteps: List[int] = None,
392
+ guidance_scale: float = 3.5,
393
+ num_images_per_prompt: Optional[int] = 1,
394
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
395
+ latents: Optional[torch.FloatTensor] = None,
396
+ prompt_embeds: Optional[torch.FloatTensor] = None,
397
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ output_type: Optional[str] = "pil",
399
+ return_dict: bool = True,
400
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
401
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
402
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
403
+ max_sequence_length: int = 512,
404
+ fake_guidance_scale: float = 3.5,
405
+ timestep_to_start_cfg: int = 0,
406
+ ):
407
+ r"""
408
+ Function invoked when calling the pipeline for generation.
409
+ Args:
410
+ prompt (`str` or `List[str]`, *optional*):
411
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
412
+ instead.
413
+ prompt_2 (`str` or `List[str]`, *optional*):
414
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
415
+ will be used instead
416
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
417
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
418
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
419
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
420
+ num_inference_steps (`int`, *optional*, defaults to 50):
421
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
422
+ expense of slower inference.
423
+ timesteps (`List[int]`, *optional*):
424
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
425
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
426
+ passed will be used. Must be in descending order.
427
+ guidance_scale (`float`, *optional*, defaults to 7.0):
428
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
429
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
430
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
431
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
432
+ usually at the expense of lower image quality.
433
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
434
+ The number of images to generate per prompt.
435
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
436
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
437
+ to make generation deterministic.
438
+ latents (`torch.FloatTensor`, *optional*):
439
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
440
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
441
+ tensor will ge generated by sampling using the supplied random `generator`.
442
+ prompt_embeds (`torch.FloatTensor`, *optional*):
443
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
444
+ provided, text embeddings will be generated from `prompt` input argument.
445
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
446
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
447
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
448
+ output_type (`str`, *optional*, defaults to `"pil"`):
449
+ The output format of the generate image. Choose between
450
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
451
+ return_dict (`bool`, *optional*, defaults to `True`):
452
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
453
+ joint_attention_kwargs (`dict`, *optional*):
454
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
455
+ `self.processor` in
456
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
457
+ callback_on_step_end (`Callable`, *optional*):
458
+ A function that calls at the end of each denoising steps during the inference. The function is called
459
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
460
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
461
+ `callback_on_step_end_tensor_inputs`.
462
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
463
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
464
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
465
+ `._callback_tensor_inputs` attribute of your pipeline class.
466
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
467
+ Examples:
468
+ Returns:
469
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
470
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
471
+ images.
472
+ """
473
+
474
+ height = height or self.default_sample_size * self.vae_scale_factor
475
+ width = width or self.default_sample_size * self.vae_scale_factor
476
+
477
+ # 1. Check inputs. Raise error if not correct
478
+ self.check_inputs(
479
+ prompt,
480
+ prompt_2,
481
+ height,
482
+ width,
483
+ prompt_embeds=prompt_embeds,
484
+ pooled_prompt_embeds=pooled_prompt_embeds,
485
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
486
+ max_sequence_length=max_sequence_length,
487
+ )
488
+
489
+ self._guidance_scale = guidance_scale
490
+ self._joint_attention_kwargs = joint_attention_kwargs
491
+ self._interrupt = False
492
+
493
+ # 2. Define call parameters
494
+ if prompt is not None and isinstance(prompt, str):
495
+ batch_size = 1
496
+ elif prompt is not None and isinstance(prompt, list):
497
+ batch_size = len(prompt)
498
+ else:
499
+ batch_size = prompt_embeds.shape[0]
500
+
501
+ device = self._execution_device
502
+
503
+ lora_scale = (
504
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
505
+ )
506
+ (
507
+ prompt_embeds,
508
+ pooled_prompt_embeds,
509
+ text_ids,
510
+ ) = self.encode_prompt(
511
+ prompt=prompt,
512
+ prompt_2=prompt_2,
513
+ prompt_embeds=prompt_embeds,
514
+ pooled_prompt_embeds=pooled_prompt_embeds,
515
+ device=device,
516
+ num_images_per_prompt=num_images_per_prompt,
517
+ max_sequence_length=max_sequence_length,
518
+ lora_scale=lora_scale,
519
+ )
520
+
521
+ negative_prompt_embeds = None
522
+ negative_pooled_prompt_embeds = None
523
+ negative_text_ids = None
524
+ (
525
+ negative_prompt_embeds,
526
+ negative_pooled_prompt_embeds,
527
+ negative_text_ids,
528
+ ) = self.encode_prompt(
529
+ prompt="",
530
+ prompt_2="",
531
+ prompt_embeds=None,
532
+ pooled_prompt_embeds=None,
533
+ device=device,
534
+ num_images_per_prompt=num_images_per_prompt,
535
+ max_sequence_length=max_sequence_length,
536
+ lora_scale=lora_scale,
537
+ )
538
+
539
+ # 4. Prepare latent variables
540
+ num_channels_latents = self.transformer.config.in_channels // 4
541
+ latents, latent_image_ids = self.prepare_latents(
542
+ batch_size * num_images_per_prompt,
543
+ num_channels_latents,
544
+ height,
545
+ width,
546
+ prompt_embeds.dtype,
547
+ device,
548
+ generator,
549
+ latents,
550
+ )
551
+
552
+ # 5. Prepare timesteps
553
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
554
+ image_seq_len = latents.shape[1]
555
+ mu = calculate_shift(
556
+ image_seq_len,
557
+ self.scheduler.config.base_image_seq_len,
558
+ self.scheduler.config.max_image_seq_len,
559
+ self.scheduler.config.base_shift,
560
+ self.scheduler.config.max_shift,
561
+ )
562
+ timesteps, num_inference_steps = retrieve_timesteps(
563
+ self.scheduler,
564
+ num_inference_steps,
565
+ device,
566
+ timesteps,
567
+ sigmas,
568
+ mu=mu,
569
+ )
570
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
571
+ self._num_timesteps = len(timesteps)
572
+
573
+ # 6. Denoising loop
574
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
575
+ for i, t in enumerate(timesteps):
576
+ if self.interrupt:
577
+ continue
578
+
579
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
580
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
581
+
582
+ # handle guidance
583
+ if self.transformer.config.guidance_embeds:
584
+ guidance = torch.tensor([fake_guidance_scale], device=device)
585
+ guidance = guidance.expand(latents.shape[0])
586
+ else:
587
+ guidance = None
588
+
589
+ noise_pred = self.transformer(
590
+ hidden_states=latents,
591
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
592
+ timestep=timestep / 1000,
593
+ guidance=guidance,
594
+ pooled_projections=pooled_prompt_embeds,
595
+ encoder_hidden_states=prompt_embeds,
596
+ txt_ids=text_ids,
597
+ img_ids=latent_image_ids,
598
+ joint_attention_kwargs=self.joint_attention_kwargs,
599
+ return_dict=False,
600
+ )[0]
601
+
602
+ if i >= timestep_to_start_cfg:
603
+ noise_pred_uncond = self.transformer(
604
+ hidden_states=latents,
605
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
606
+ timestep=timestep / 1000,
607
+ guidance=guidance,
608
+ pooled_projections=negative_pooled_prompt_embeds,
609
+ encoder_hidden_states=negative_prompt_embeds,
610
+ txt_ids=negative_text_ids,
611
+ img_ids=latent_image_ids,
612
+ joint_attention_kwargs=self.joint_attention_kwargs,
613
+ return_dict=False,
614
+ )[0]
615
+
616
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
617
+
618
+ # compute the previous noisy sample x_t -> x_t-1
619
+ latents_dtype = latents.dtype
620
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
621
+
622
+ if latents.dtype != latents_dtype:
623
+ if torch.backends.mps.is_available():
624
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
625
+ latents = latents.to(latents_dtype)
626
+
627
+ if callback_on_step_end is not None:
628
+ callback_kwargs = {}
629
+ for k in callback_on_step_end_tensor_inputs:
630
+ callback_kwargs[k] = locals()[k]
631
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
632
+
633
+ latents = callback_outputs.pop("latents", latents)
634
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
635
+
636
+ # call the callback, if provided
637
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
638
+ progress_bar.update()
639
+
640
+ if output_type == "latent":
641
+ image = latents
642
+
643
+ else:
644
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
645
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
646
+ image = self.vae.decode(latents, return_dict=False)[0]
647
+ image = self.image_processor.postprocess(image, output_type=output_type)
648
+
649
+ # Offload all models
650
+ self.maybe_free_model_hooks()
651
+
652
+ if not return_dict:
653
+ return (image,)
654
+
655
+ return FluxPipelineOutput(images=image)