gokaygokay commited on
Commit
74d41ea
·
1 Parent(s): f469d2f
Files changed (3) hide show
  1. live_preview_helpers.py +0 -166
  2. pipelines.py +0 -1417
  3. stable_diffusion_model.py +0 -0
live_preview_helpers.py DELETED
@@ -1,166 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
-
6
- # Helper functions
7
- def calculate_shift(
8
- image_seq_len,
9
- base_seq_len: int = 256,
10
- max_seq_len: int = 4096,
11
- base_shift: float = 0.5,
12
- max_shift: float = 1.16,
13
- ):
14
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
- b = base_shift - m * base_seq_len
16
- mu = image_seq_len * m + b
17
- return mu
18
-
19
- def retrieve_timesteps(
20
- scheduler,
21
- num_inference_steps: Optional[int] = None,
22
- device: Optional[Union[str, torch.device]] = None,
23
- timesteps: Optional[List[int]] = None,
24
- sigmas: Optional[List[float]] = None,
25
- **kwargs,
26
- ):
27
- if timesteps is not None and sigmas is not None:
28
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
- if timesteps is not None:
30
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
- timesteps = scheduler.timesteps
32
- num_inference_steps = len(timesteps)
33
- elif sigmas is not None:
34
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
- timesteps = scheduler.timesteps
36
- num_inference_steps = len(timesteps)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
- timesteps = scheduler.timesteps
40
- return timesteps, num_inference_steps
41
-
42
- # FLUX pipeline function
43
- @torch.inference_mode()
44
- def flux_pipe_call_that_returns_an_iterable_of_images(
45
- self,
46
- prompt: Union[str, List[str]] = None,
47
- prompt_2: Optional[Union[str, List[str]]] = None,
48
- height: Optional[int] = None,
49
- width: Optional[int] = None,
50
- num_inference_steps: int = 28,
51
- timesteps: List[int] = None,
52
- guidance_scale: float = 3.5,
53
- num_images_per_prompt: Optional[int] = 1,
54
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
- latents: Optional[torch.FloatTensor] = None,
56
- prompt_embeds: Optional[torch.FloatTensor] = None,
57
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
- output_type: Optional[str] = "pil",
59
- return_dict: bool = True,
60
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
- max_sequence_length: int = 512,
62
- good_vae: Optional[Any] = None,
63
- ):
64
- height = height or self.default_sample_size * self.vae_scale_factor
65
- width = width or self.default_sample_size * self.vae_scale_factor
66
-
67
- # 1. Check inputs
68
- self.check_inputs(
69
- prompt,
70
- prompt_2,
71
- height,
72
- width,
73
- prompt_embeds=prompt_embeds,
74
- pooled_prompt_embeds=pooled_prompt_embeds,
75
- max_sequence_length=max_sequence_length,
76
- )
77
-
78
- self._guidance_scale = guidance_scale
79
- self._joint_attention_kwargs = joint_attention_kwargs
80
- self._interrupt = False
81
-
82
- # 2. Define call parameters
83
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
- device = self._execution_device
85
-
86
- # 3. Encode prompt
87
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
- prompt=prompt,
90
- prompt_2=prompt_2,
91
- prompt_embeds=prompt_embeds,
92
- pooled_prompt_embeds=pooled_prompt_embeds,
93
- device=device,
94
- num_images_per_prompt=num_images_per_prompt,
95
- max_sequence_length=max_sequence_length,
96
- lora_scale=lora_scale,
97
- )
98
- # 4. Prepare latent variables
99
- num_channels_latents = self.transformer.config.in_channels // 4
100
- latents, latent_image_ids = self.prepare_latents(
101
- batch_size * num_images_per_prompt,
102
- num_channels_latents,
103
- height,
104
- width,
105
- prompt_embeds.dtype,
106
- device,
107
- generator,
108
- latents,
109
- )
110
- # 5. Prepare timesteps
111
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
- image_seq_len = latents.shape[1]
113
- mu = calculate_shift(
114
- image_seq_len,
115
- self.scheduler.config.base_image_seq_len,
116
- self.scheduler.config.max_image_seq_len,
117
- self.scheduler.config.base_shift,
118
- self.scheduler.config.max_shift,
119
- )
120
- timesteps, num_inference_steps = retrieve_timesteps(
121
- self.scheduler,
122
- num_inference_steps,
123
- device,
124
- timesteps,
125
- sigmas,
126
- mu=mu,
127
- )
128
- self._num_timesteps = len(timesteps)
129
-
130
- # Handle guidance
131
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
-
133
- # 6. Denoising loop
134
- for i, t in enumerate(timesteps):
135
- if self.interrupt:
136
- continue
137
-
138
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
-
140
- noise_pred = self.transformer(
141
- hidden_states=latents,
142
- timestep=timestep / 1000,
143
- guidance=guidance,
144
- pooled_projections=pooled_prompt_embeds,
145
- encoder_hidden_states=prompt_embeds,
146
- txt_ids=text_ids,
147
- img_ids=latent_image_ids,
148
- joint_attention_kwargs=self.joint_attention_kwargs,
149
- return_dict=False,
150
- )[0]
151
- # Yield intermediate result
152
- latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
153
- latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
154
- image = self.vae.decode(latents_for_image, return_dict=False)[0]
155
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
156
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
157
- torch.cuda.empty_cache()
158
-
159
-
160
- # Final image using good_vae
161
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
- latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
- image = good_vae.decode(latents, return_dict=False)[0]
164
- self.maybe_free_model_hooks()
165
- torch.cuda.empty_cache()
166
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipelines.py DELETED
@@ -1,1417 +0,0 @@
1
- import importlib
2
- import inspect
3
- from typing import Union, List, Optional, Dict, Any, Tuple, Callable
4
-
5
- import numpy as np
6
- import torch
7
- from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
8
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
9
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
10
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
11
- # from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
12
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
13
- from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
14
- from diffusers.utils import is_torch_xla_available
15
- from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
16
- from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
17
-
18
-
19
- if is_torch_xla_available():
20
- import torch_xla.core.xla_model as xm
21
-
22
- XLA_AVAILABLE = True
23
- else:
24
- XLA_AVAILABLE = False
25
-
26
- class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
27
-
28
- def __init__(
29
- self,
30
- vae: 'AutoencoderKL',
31
- text_encoder: 'CLIPTextModel',
32
- text_encoder_2: 'CLIPTextModelWithProjection',
33
- tokenizer: 'CLIPTokenizer',
34
- tokenizer_2: 'CLIPTokenizer',
35
- unet: 'UNet2DConditionModel',
36
- scheduler: 'KarrasDiffusionSchedulers',
37
- force_zeros_for_empty_prompt: bool = True,
38
- add_watermarker: Optional[bool] = None,
39
- ):
40
- super().__init__(
41
- vae=vae,
42
- text_encoder=text_encoder,
43
- text_encoder_2=text_encoder_2,
44
- tokenizer=tokenizer,
45
- tokenizer_2=tokenizer_2,
46
- unet=unet,
47
- scheduler=scheduler,
48
- )
49
- raise NotImplementedError("This pipeline is not implemented yet")
50
- # self.sampler = None
51
- # scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
52
- # model = ModelWrapper(unet, scheduler.alphas_cumprod)
53
- # if scheduler.config.prediction_type == "v_prediction":
54
- # self.k_diffusion_model = CompVisVDenoiser(model)
55
- # else:
56
- # self.k_diffusion_model = CompVisDenoiser(model)
57
-
58
- def set_scheduler(self, scheduler_type: str):
59
- library = importlib.import_module("k_diffusion")
60
- sampling = getattr(library, "sampling")
61
- self.sampler = getattr(sampling, scheduler_type)
62
-
63
- @torch.no_grad()
64
- def __call__(
65
- self,
66
- prompt: Union[str, List[str]] = None,
67
- prompt_2: Optional[Union[str, List[str]]] = None,
68
- height: Optional[int] = None,
69
- width: Optional[int] = None,
70
- num_inference_steps: int = 50,
71
- denoising_end: Optional[float] = None,
72
- guidance_scale: float = 5.0,
73
- negative_prompt: Optional[Union[str, List[str]]] = None,
74
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
75
- num_images_per_prompt: Optional[int] = 1,
76
- eta: float = 0.0,
77
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
78
- latents: Optional[torch.FloatTensor] = None,
79
- prompt_embeds: Optional[torch.FloatTensor] = None,
80
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
81
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
82
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
83
- output_type: Optional[str] = "pil",
84
- return_dict: bool = True,
85
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
86
- callback_steps: int = 1,
87
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
88
- guidance_rescale: float = 0.0,
89
- original_size: Optional[Tuple[int, int]] = None,
90
- crops_coords_top_left: Tuple[int, int] = (0, 0),
91
- target_size: Optional[Tuple[int, int]] = None,
92
- use_karras_sigmas: bool = False,
93
- ):
94
-
95
- # 0. Default height and width to unet
96
- height = height or self.default_sample_size * self.vae_scale_factor
97
- width = width or self.default_sample_size * self.vae_scale_factor
98
-
99
- original_size = original_size or (height, width)
100
- target_size = target_size or (height, width)
101
-
102
- # 1. Check inputs. Raise error if not correct
103
- self.check_inputs(
104
- prompt,
105
- prompt_2,
106
- height,
107
- width,
108
- callback_steps,
109
- negative_prompt,
110
- negative_prompt_2,
111
- prompt_embeds,
112
- negative_prompt_embeds,
113
- pooled_prompt_embeds,
114
- negative_pooled_prompt_embeds,
115
- )
116
-
117
- # 2. Define call parameters
118
- if prompt is not None and isinstance(prompt, str):
119
- batch_size = 1
120
- elif prompt is not None and isinstance(prompt, list):
121
- batch_size = len(prompt)
122
- else:
123
- batch_size = prompt_embeds.shape[0]
124
-
125
- device = self._execution_device
126
-
127
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
128
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
129
- # corresponds to doing no classifier free guidance.
130
- do_classifier_free_guidance = guidance_scale > 1.0
131
-
132
- # 3. Encode input prompt
133
- text_encoder_lora_scale = (
134
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
135
- )
136
- (
137
- prompt_embeds,
138
- negative_prompt_embeds,
139
- pooled_prompt_embeds,
140
- negative_pooled_prompt_embeds,
141
- ) = self.encode_prompt(
142
- prompt=prompt,
143
- prompt_2=prompt_2,
144
- device=device,
145
- num_images_per_prompt=num_images_per_prompt,
146
- do_classifier_free_guidance=do_classifier_free_guidance,
147
- negative_prompt=negative_prompt,
148
- negative_prompt_2=negative_prompt_2,
149
- prompt_embeds=prompt_embeds,
150
- negative_prompt_embeds=negative_prompt_embeds,
151
- pooled_prompt_embeds=pooled_prompt_embeds,
152
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
153
- lora_scale=text_encoder_lora_scale,
154
- )
155
-
156
- # 4. Prepare timesteps
157
- self.scheduler.set_timesteps(num_inference_steps, device=device)
158
-
159
- timesteps = self.scheduler.timesteps
160
-
161
- # 5. Prepare latent variables
162
- num_channels_latents = self.unet.config.in_channels
163
- latents = self.prepare_latents(
164
- batch_size * num_images_per_prompt,
165
- num_channels_latents,
166
- height,
167
- width,
168
- prompt_embeds.dtype,
169
- device,
170
- generator,
171
- latents,
172
- )
173
-
174
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
-
177
- # 7. Prepare added time ids & embeddings
178
- add_text_embeds = pooled_prompt_embeds
179
- add_time_ids = self._get_add_time_ids(
180
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
181
- )
182
-
183
- if do_classifier_free_guidance:
184
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
185
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
186
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
187
-
188
- prompt_embeds = prompt_embeds.to(device)
189
- add_text_embeds = add_text_embeds.to(device)
190
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
191
-
192
- # 8. Denoising loop
193
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
194
-
195
- # 7.1 Apply denoising_end
196
- if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
197
- discrete_timestep_cutoff = int(
198
- round(
199
- self.scheduler.config.num_train_timesteps
200
- - (denoising_end * self.scheduler.config.num_train_timesteps)
201
- )
202
- )
203
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
204
- timesteps = timesteps[:num_inference_steps]
205
-
206
- # 5. Prepare sigmas
207
- if use_karras_sigmas:
208
- sigma_min: float = self.k_diffusion_model.sigmas[0].item()
209
- sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
210
- sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
211
- sigmas = sigmas.to(device)
212
- else:
213
- sigmas = self.scheduler.sigmas
214
- sigmas = sigmas.to(prompt_embeds.dtype)
215
-
216
- # 5. Prepare latent variables
217
- num_channels_latents = self.unet.config.in_channels
218
- latents = self.prepare_latents(
219
- batch_size * num_images_per_prompt,
220
- num_channels_latents,
221
- height,
222
- width,
223
- prompt_embeds.dtype,
224
- device,
225
- generator,
226
- latents,
227
- )
228
-
229
- latents = latents * sigmas[0]
230
- self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
231
- self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
232
-
233
- # 7. Define model function
234
- def model_fn(x, t):
235
- latent_model_input = torch.cat([x] * 2)
236
- t = torch.cat([t] * 2)
237
-
238
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
239
- # noise_pred = self.unet(
240
- # latent_model_input,
241
- # t,
242
- # encoder_hidden_states=prompt_embeds,
243
- # cross_attention_kwargs=cross_attention_kwargs,
244
- # added_cond_kwargs=added_cond_kwargs,
245
- # return_dict=False,
246
- # )[0]
247
-
248
- noise_pred = self.k_diffusion_model(
249
- latent_model_input,
250
- t,
251
- encoder_hidden_states=prompt_embeds,
252
- cross_attention_kwargs=cross_attention_kwargs,
253
- added_cond_kwargs=added_cond_kwargs,
254
- return_dict=False,)[0]
255
-
256
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
257
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
258
- return noise_pred
259
-
260
-
261
- # 8. Run k-diffusion solver
262
- sampler_kwargs = {}
263
- # should work without it
264
- noise_sampler_seed = None
265
-
266
-
267
- if "noise_sampler" in inspect.signature(self.sampler).parameters:
268
- min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
269
- noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
270
- sampler_kwargs["noise_sampler"] = noise_sampler
271
-
272
- latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
273
-
274
- if not output_type == "latent":
275
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
276
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
277
- else:
278
- image = latents
279
- has_nsfw_concept = None
280
-
281
- if has_nsfw_concept is None:
282
- do_denormalize = [True] * image.shape[0]
283
- else:
284
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
285
-
286
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
287
-
288
- # Offload last model to CPU
289
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
290
- self.final_offload_hook.offload()
291
-
292
- if not return_dict:
293
- return (image,)
294
-
295
- return StableDiffusionXLPipelineOutput(images=image)
296
-
297
-
298
- class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
299
-
300
- def predict_noise(
301
- self,
302
- prompt: Union[str, List[str]] = None,
303
- prompt_2: Optional[Union[str, List[str]]] = None,
304
- num_inference_steps: int = 50,
305
- guidance_scale: float = 5.0,
306
- negative_prompt: Optional[Union[str, List[str]]] = None,
307
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
308
- num_images_per_prompt: Optional[int] = 1,
309
- eta: float = 0.0,
310
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
311
- latents: Optional[torch.FloatTensor] = None,
312
- prompt_embeds: Optional[torch.FloatTensor] = None,
313
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
314
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
315
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
316
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
317
- guidance_rescale: float = 0.0,
318
- crops_coords_top_left: Tuple[int, int] = (0, 0),
319
- timestep: Optional[int] = None,
320
- ):
321
- r"""
322
- Function invoked when calling the pipeline for generation.
323
-
324
- Args:
325
- prompt (`str` or `List[str]`, *optional*):
326
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
327
- instead.
328
- prompt_2 (`str` or `List[str]`, *optional*):
329
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
330
- used in both text-encoders
331
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
332
- The height in pixels of the generated image.
333
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
334
- The width in pixels of the generated image.
335
- num_inference_steps (`int`, *optional*, defaults to 50):
336
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
337
- expense of slower inference.
338
- denoising_end (`float`, *optional*):
339
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
340
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
341
- still retain a substantial amount of noise as determined by the discrete timesteps selected by the
342
- scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
343
- "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
344
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
345
- guidance_scale (`float`, *optional*, defaults to 7.5):
346
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
347
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
348
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
349
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
350
- usually at the expense of lower image quality.
351
- negative_prompt (`str` or `List[str]`, *optional*):
352
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
353
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
354
- less than `1`).
355
- negative_prompt_2 (`str` or `List[str]`, *optional*):
356
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
357
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
358
- num_images_per_prompt (`int`, *optional*, defaults to 1):
359
- The number of images to generate per prompt.
360
- eta (`float`, *optional*, defaults to 0.0):
361
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
362
- [`schedulers.DDIMScheduler`], will be ignored for others.
363
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
364
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
365
- to make generation deterministic.
366
- latents (`torch.FloatTensor`, *optional*):
367
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
368
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
369
- tensor will ge generated by sampling using the supplied random `generator`.
370
- prompt_embeds (`torch.FloatTensor`, *optional*):
371
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
372
- provided, text embeddings will be generated from `prompt` input argument.
373
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
374
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
375
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
376
- argument.
377
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
378
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
379
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
380
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
382
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
383
- input argument.
384
- output_type (`str`, *optional*, defaults to `"pil"`):
385
- The output format of the generate image. Choose between
386
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
387
- return_dict (`bool`, *optional*, defaults to `True`):
388
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
389
- of a plain tuple.
390
- callback (`Callable`, *optional*):
391
- A function that will be called every `callback_steps` steps during inference. The function will be
392
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
393
- callback_steps (`int`, *optional*, defaults to 1):
394
- The frequency at which the `callback` function will be called. If not specified, the callback will be
395
- called at every step.
396
- cross_attention_kwargs (`dict`, *optional*):
397
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
398
- `self.processor` in
399
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
400
- guidance_rescale (`float`, *optional*, defaults to 0.7):
401
- Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
402
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
403
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
404
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
405
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
406
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
407
- `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
408
- explained in section 2.2 of
409
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
410
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
411
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
412
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
413
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
414
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
415
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
416
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
417
- not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
418
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
419
-
420
- Examples:
421
-
422
- Returns:
423
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
424
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
425
- `tuple`. When returning a tuple, the first element is a list with the generated images.
426
- """
427
- # if not predict_noise:
428
- # # call parent
429
- # return super().__call__(
430
- # prompt=prompt,
431
- # prompt_2=prompt_2,
432
- # height=height,
433
- # width=width,
434
- # num_inference_steps=num_inference_steps,
435
- # denoising_end=denoising_end,
436
- # guidance_scale=guidance_scale,
437
- # negative_prompt=negative_prompt,
438
- # negative_prompt_2=negative_prompt_2,
439
- # num_images_per_prompt=num_images_per_prompt,
440
- # eta=eta,
441
- # generator=generator,
442
- # latents=latents,
443
- # prompt_embeds=prompt_embeds,
444
- # negative_prompt_embeds=negative_prompt_embeds,
445
- # pooled_prompt_embeds=pooled_prompt_embeds,
446
- # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
447
- # output_type=output_type,
448
- # return_dict=return_dict,
449
- # callback=callback,
450
- # callback_steps=callback_steps,
451
- # cross_attention_kwargs=cross_attention_kwargs,
452
- # guidance_rescale=guidance_rescale,
453
- # original_size=original_size,
454
- # crops_coords_top_left=crops_coords_top_left,
455
- # target_size=target_size,
456
- # )
457
-
458
- # 0. Default height and width to unet
459
- height = self.default_sample_size * self.vae_scale_factor
460
- width = self.default_sample_size * self.vae_scale_factor
461
-
462
- original_size = (height, width)
463
- target_size = (height, width)
464
-
465
- # 2. Define call parameters
466
- if prompt is not None and isinstance(prompt, str):
467
- batch_size = 1
468
- elif prompt is not None and isinstance(prompt, list):
469
- batch_size = len(prompt)
470
- else:
471
- batch_size = prompt_embeds.shape[0]
472
-
473
- device = self._execution_device
474
-
475
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
476
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
477
- # corresponds to doing no classifier free guidance.
478
- do_classifier_free_guidance = guidance_scale > 1.0
479
-
480
- # 3. Encode input prompt
481
- text_encoder_lora_scale = (
482
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
483
- )
484
- (
485
- prompt_embeds,
486
- negative_prompt_embeds,
487
- pooled_prompt_embeds,
488
- negative_pooled_prompt_embeds,
489
- ) = self.encode_prompt(
490
- prompt=prompt,
491
- prompt_2=prompt_2,
492
- device=device,
493
- num_images_per_prompt=num_images_per_prompt,
494
- do_classifier_free_guidance=do_classifier_free_guidance,
495
- negative_prompt=negative_prompt,
496
- negative_prompt_2=negative_prompt_2,
497
- prompt_embeds=prompt_embeds,
498
- negative_prompt_embeds=negative_prompt_embeds,
499
- pooled_prompt_embeds=pooled_prompt_embeds,
500
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
501
- lora_scale=text_encoder_lora_scale,
502
- )
503
-
504
- # 4. Prepare timesteps
505
- self.scheduler.set_timesteps(num_inference_steps, device=device)
506
-
507
- # 5. Prepare latent variables
508
- num_channels_latents = self.unet.config.in_channels
509
- latents = self.prepare_latents(
510
- batch_size * num_images_per_prompt,
511
- num_channels_latents,
512
- height,
513
- width,
514
- prompt_embeds.dtype,
515
- device,
516
- generator,
517
- latents,
518
- )
519
-
520
- # 7. Prepare added time ids & embeddings
521
- add_text_embeds = pooled_prompt_embeds
522
- add_time_ids = self._get_add_time_ids(
523
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
524
- ).to(device) # TODO DOES NOT CAST ORIGINALLY
525
-
526
- if do_classifier_free_guidance:
527
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
528
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
529
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
530
-
531
- prompt_embeds = prompt_embeds.to(device)
532
- add_text_embeds = add_text_embeds.to(device)
533
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
534
-
535
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
536
-
537
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
538
-
539
- # predict the noise residual
540
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
541
- noise_pred = self.unet(
542
- latent_model_input,
543
- timestep,
544
- encoder_hidden_states=prompt_embeds,
545
- cross_attention_kwargs=cross_attention_kwargs,
546
- added_cond_kwargs=added_cond_kwargs,
547
- return_dict=False,
548
- )[0]
549
-
550
- # perform guidance
551
- if do_classifier_free_guidance:
552
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
553
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
554
-
555
- if do_classifier_free_guidance and guidance_rescale > 0.0:
556
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
557
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
558
-
559
- return noise_pred
560
-
561
- def enable_model_cpu_offload(self, gpu_id=0):
562
- print('Called cpu offload', gpu_id)
563
- # fuck off
564
- pass
565
-
566
-
567
- class CustomStableDiffusionPipeline(StableDiffusionPipeline):
568
-
569
- # replace the call so it matches SDXL call so we can use the same code and also stop early
570
- def __call__(
571
- self,
572
- prompt: Union[str, List[str]] = None,
573
- prompt_2: Optional[Union[str, List[str]]] = None,
574
- height: Optional[int] = None,
575
- width: Optional[int] = None,
576
- num_inference_steps: int = 50,
577
- denoising_end: Optional[float] = None,
578
- guidance_scale: float = 5.0,
579
- negative_prompt: Optional[Union[str, List[str]]] = None,
580
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
581
- num_images_per_prompt: Optional[int] = 1,
582
- eta: float = 0.0,
583
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
584
- latents: Optional[torch.FloatTensor] = None,
585
- prompt_embeds: Optional[torch.FloatTensor] = None,
586
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
587
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
588
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
589
- output_type: Optional[str] = "pil",
590
- return_dict: bool = True,
591
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
592
- callback_steps: int = 1,
593
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
594
- guidance_rescale: float = 0.0,
595
- original_size: Optional[Tuple[int, int]] = None,
596
- crops_coords_top_left: Tuple[int, int] = (0, 0),
597
- target_size: Optional[Tuple[int, int]] = None,
598
- ):
599
- # 0. Default height and width to unet
600
- height = height or self.unet.config.sample_size * self.vae_scale_factor
601
- width = width or self.unet.config.sample_size * self.vae_scale_factor
602
-
603
- # 1. Check inputs. Raise error if not correct
604
- self.check_inputs(
605
- prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
606
- )
607
-
608
- # 2. Define call parameters
609
- if prompt is not None and isinstance(prompt, str):
610
- batch_size = 1
611
- elif prompt is not None and isinstance(prompt, list):
612
- batch_size = len(prompt)
613
- else:
614
- batch_size = prompt_embeds.shape[0]
615
-
616
- device = self._execution_device
617
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
618
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
619
- # corresponds to doing no classifier free guidance.
620
- do_classifier_free_guidance = guidance_scale > 1.0
621
-
622
- # 3. Encode input prompt
623
- text_encoder_lora_scale = (
624
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
625
- )
626
- prompt_embeds = self._encode_prompt(
627
- prompt,
628
- device,
629
- num_images_per_prompt,
630
- do_classifier_free_guidance,
631
- negative_prompt,
632
- prompt_embeds=prompt_embeds,
633
- negative_prompt_embeds=negative_prompt_embeds,
634
- lora_scale=text_encoder_lora_scale,
635
- )
636
-
637
- # 4. Prepare timesteps
638
- self.scheduler.set_timesteps(num_inference_steps, device=device)
639
- timesteps = self.scheduler.timesteps
640
-
641
- # 5. Prepare latent variables
642
- num_channels_latents = self.unet.config.in_channels
643
- latents = self.prepare_latents(
644
- batch_size * num_images_per_prompt,
645
- num_channels_latents,
646
- height,
647
- width,
648
- prompt_embeds.dtype,
649
- device,
650
- generator,
651
- latents,
652
- )
653
-
654
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
655
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
656
-
657
- # 7. Denoising loop
658
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
659
-
660
- # 7.1 Apply denoising_end
661
- if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
662
- discrete_timestep_cutoff = int(
663
- round(
664
- self.scheduler.config.num_train_timesteps
665
- - (denoising_end * self.scheduler.config.num_train_timesteps)
666
- )
667
- )
668
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
669
- timesteps = timesteps[:num_inference_steps]
670
-
671
- with self.progress_bar(total=num_inference_steps) as progress_bar:
672
- for i, t in enumerate(timesteps):
673
- # expand the latents if we are doing classifier free guidance
674
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
675
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
676
-
677
- # predict the noise residual
678
- noise_pred = self.unet(
679
- latent_model_input,
680
- t,
681
- encoder_hidden_states=prompt_embeds,
682
- cross_attention_kwargs=cross_attention_kwargs,
683
- return_dict=False,
684
- )[0]
685
-
686
- # perform guidance
687
- if do_classifier_free_guidance:
688
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
689
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
690
-
691
- if do_classifier_free_guidance and guidance_rescale > 0.0:
692
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
693
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
694
-
695
- # compute the previous noisy sample x_t -> x_t-1
696
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
697
-
698
- # call the callback, if provided
699
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
700
- progress_bar.update()
701
- if callback is not None and i % callback_steps == 0:
702
- callback(i, t, latents)
703
-
704
- if not output_type == "latent":
705
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
706
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
707
- else:
708
- image = latents
709
- has_nsfw_concept = None
710
-
711
- if has_nsfw_concept is None:
712
- do_denormalize = [True] * image.shape[0]
713
- else:
714
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
715
-
716
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
717
-
718
- # Offload last model to CPU
719
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
720
- self.final_offload_hook.offload()
721
-
722
- if not return_dict:
723
- return (image, has_nsfw_concept)
724
-
725
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
726
-
727
- # some of the inputs are to keep it compatible with sdx
728
- def predict_noise(
729
- self,
730
- prompt: Union[str, List[str]] = None,
731
- prompt_2: Optional[Union[str, List[str]]] = None,
732
- num_inference_steps: int = 50,
733
- guidance_scale: float = 5.0,
734
- negative_prompt: Optional[Union[str, List[str]]] = None,
735
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
736
- num_images_per_prompt: Optional[int] = 1,
737
- eta: float = 0.0,
738
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
739
- latents: Optional[torch.FloatTensor] = None,
740
- prompt_embeds: Optional[torch.FloatTensor] = None,
741
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
742
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
743
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
744
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
745
- guidance_rescale: float = 0.0,
746
- crops_coords_top_left: Tuple[int, int] = (0, 0),
747
- timestep: Optional[int] = None,
748
- ):
749
-
750
- # 0. Default height and width to unet
751
- height = self.unet.config.sample_size * self.vae_scale_factor
752
- width = self.unet.config.sample_size * self.vae_scale_factor
753
-
754
- # 2. Define call parameters
755
- if prompt is not None and isinstance(prompt, str):
756
- batch_size = 1
757
- elif prompt is not None and isinstance(prompt, list):
758
- batch_size = len(prompt)
759
- else:
760
- batch_size = prompt_embeds.shape[0]
761
-
762
- device = self._execution_device
763
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
764
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
765
- # corresponds to doing no classifier free guidance.
766
- do_classifier_free_guidance = guidance_scale > 1.0
767
-
768
- # 3. Encode input prompt
769
- text_encoder_lora_scale = (
770
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
771
- )
772
- prompt_embeds = self._encode_prompt(
773
- prompt,
774
- device,
775
- num_images_per_prompt,
776
- do_classifier_free_guidance,
777
- negative_prompt,
778
- prompt_embeds=prompt_embeds,
779
- negative_prompt_embeds=negative_prompt_embeds,
780
- lora_scale=text_encoder_lora_scale,
781
- )
782
-
783
- # 4. Prepare timesteps
784
- self.scheduler.set_timesteps(num_inference_steps, device=device)
785
-
786
- # 5. Prepare latent variables
787
- num_channels_latents = self.unet.config.in_channels
788
- latents = self.prepare_latents(
789
- batch_size * num_images_per_prompt,
790
- num_channels_latents,
791
- height,
792
- width,
793
- prompt_embeds.dtype,
794
- device,
795
- generator,
796
- latents,
797
- )
798
-
799
- # expand the latents if we are doing classifier free guidance
800
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
801
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
802
-
803
- # predict the noise residual
804
- noise_pred = self.unet(
805
- latent_model_input,
806
- timestep,
807
- encoder_hidden_states=prompt_embeds,
808
- cross_attention_kwargs=cross_attention_kwargs,
809
- return_dict=False,
810
- )[0]
811
-
812
- # perform guidance
813
- if do_classifier_free_guidance:
814
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
815
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
816
-
817
- if do_classifier_free_guidance and guidance_rescale > 0.0:
818
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
819
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
820
-
821
- return noise_pred
822
-
823
-
824
- class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
825
-
826
- @torch.no_grad()
827
- def __call__(
828
- self,
829
- prompt: Union[str, List[str]] = None,
830
- prompt_2: Optional[Union[str, List[str]]] = None,
831
- height: Optional[int] = None,
832
- width: Optional[int] = None,
833
- num_inference_steps: int = 50,
834
- denoising_end: Optional[float] = None,
835
- denoising_start: Optional[float] = None,
836
- guidance_scale: float = 5.0,
837
- negative_prompt: Optional[Union[str, List[str]]] = None,
838
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
839
- num_images_per_prompt: Optional[int] = 1,
840
- eta: float = 0.0,
841
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
842
- latents: Optional[torch.FloatTensor] = None,
843
- prompt_embeds: Optional[torch.FloatTensor] = None,
844
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
845
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
846
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
847
- output_type: Optional[str] = "pil",
848
- return_dict: bool = True,
849
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
850
- callback_steps: int = 1,
851
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
852
- guidance_rescale: float = 0.0,
853
- original_size: Optional[Tuple[int, int]] = None,
854
- crops_coords_top_left: Tuple[int, int] = (0, 0),
855
- target_size: Optional[Tuple[int, int]] = None,
856
- negative_original_size: Optional[Tuple[int, int]] = None,
857
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
858
- negative_target_size: Optional[Tuple[int, int]] = None,
859
- clip_skip: Optional[int] = None,
860
- ):
861
- r"""
862
- Function invoked when calling the pipeline for generation.
863
-
864
- Args:
865
- prompt (`str` or `List[str]`, *optional*):
866
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
867
- instead.
868
- prompt_2 (`str` or `List[str]`, *optional*):
869
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
870
- used in both text-encoders
871
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
872
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
873
- Anything below 512 pixels won't work well for
874
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
875
- and checkpoints that are not specifically fine-tuned on low resolutions.
876
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
877
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
878
- Anything below 512 pixels won't work well for
879
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
880
- and checkpoints that are not specifically fine-tuned on low resolutions.
881
- num_inference_steps (`int`, *optional*, defaults to 50):
882
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
883
- expense of slower inference.
884
- denoising_end (`float`, *optional*):
885
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
886
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
887
- still retain a substantial amount of noise as determined by the discrete timesteps selected by the
888
- scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
889
- "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
890
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
891
- denoising_start (`float`, *optional*):
892
- When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
893
- bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
894
- it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
895
- strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
896
- is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
897
- Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
898
- guidance_scale (`float`, *optional*, defaults to 5.0):
899
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
900
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
901
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
902
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
903
- usually at the expense of lower image quality.
904
- negative_prompt (`str` or `List[str]`, *optional*):
905
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
906
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
907
- less than `1`).
908
- negative_prompt_2 (`str` or `List[str]`, *optional*):
909
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
910
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
911
- num_images_per_prompt (`int`, *optional*, defaults to 1):
912
- The number of images to generate per prompt.
913
- eta (`float`, *optional*, defaults to 0.0):
914
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
915
- [`schedulers.DDIMScheduler`], will be ignored for others.
916
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
917
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
918
- to make generation deterministic.
919
- latents (`torch.FloatTensor`, *optional*):
920
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
921
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
922
- tensor will ge generated by sampling using the supplied random `generator`.
923
- prompt_embeds (`torch.FloatTensor`, *optional*):
924
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
925
- provided, text embeddings will be generated from `prompt` input argument.
926
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
927
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
928
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
929
- argument.
930
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
931
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
932
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
933
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
934
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
935
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
936
- input argument.
937
- output_type (`str`, *optional*, defaults to `"pil"`):
938
- The output format of the generate image. Choose between
939
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
940
- return_dict (`bool`, *optional*, defaults to `True`):
941
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
942
- of a plain tuple.
943
- callback (`Callable`, *optional*):
944
- A function that will be called every `callback_steps` steps during inference. The function will be
945
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
946
- callback_steps (`int`, *optional*, defaults to 1):
947
- The frequency at which the `callback` function will be called. If not specified, the callback will be
948
- called at every step.
949
- cross_attention_kwargs (`dict`, *optional*):
950
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
951
- `self.processor` in
952
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
953
- guidance_rescale (`float`, *optional*, defaults to 0.0):
954
- Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
955
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
956
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
957
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
958
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
959
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
960
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
961
- explained in section 2.2 of
962
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
963
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
964
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
965
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
966
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
967
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
968
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
969
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
970
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
971
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
972
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
973
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
974
- micro-conditioning as explained in section 2.2 of
975
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
976
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
977
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
978
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
979
- micro-conditioning as explained in section 2.2 of
980
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
981
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
982
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
983
- To negatively condition the generation process based on a target image resolution. It should be as same
984
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
985
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
986
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
987
-
988
- Examples:
989
-
990
- Returns:
991
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
992
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
993
- `tuple`. When returning a tuple, the first element is a list with the generated images.
994
- """
995
- # 0. Default height and width to unet
996
- height = height or self.default_sample_size * self.vae_scale_factor
997
- width = width or self.default_sample_size * self.vae_scale_factor
998
-
999
- original_size = original_size or (height, width)
1000
- target_size = target_size or (height, width)
1001
-
1002
- # 1. Check inputs. Raise error if not correct
1003
- self.check_inputs(
1004
- prompt,
1005
- prompt_2,
1006
- height,
1007
- width,
1008
- callback_steps,
1009
- negative_prompt,
1010
- negative_prompt_2,
1011
- prompt_embeds,
1012
- negative_prompt_embeds,
1013
- pooled_prompt_embeds,
1014
- negative_pooled_prompt_embeds,
1015
- )
1016
-
1017
- # 2. Define call parameters
1018
- if prompt is not None and isinstance(prompt, str):
1019
- batch_size = 1
1020
- elif prompt is not None and isinstance(prompt, list):
1021
- batch_size = len(prompt)
1022
- else:
1023
- batch_size = prompt_embeds.shape[0]
1024
-
1025
- device = self._execution_device
1026
-
1027
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1028
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1029
- # corresponds to doing no classifier free guidance.
1030
- do_classifier_free_guidance = guidance_scale > 1.0
1031
-
1032
- # 3. Encode input prompt
1033
- lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1034
-
1035
- (
1036
- prompt_embeds,
1037
- negative_prompt_embeds,
1038
- pooled_prompt_embeds,
1039
- negative_pooled_prompt_embeds,
1040
- ) = self.encode_prompt(
1041
- prompt=prompt,
1042
- prompt_2=prompt_2,
1043
- device=device,
1044
- num_images_per_prompt=num_images_per_prompt,
1045
- do_classifier_free_guidance=do_classifier_free_guidance,
1046
- negative_prompt=negative_prompt,
1047
- negative_prompt_2=negative_prompt_2,
1048
- prompt_embeds=prompt_embeds,
1049
- negative_prompt_embeds=negative_prompt_embeds,
1050
- pooled_prompt_embeds=pooled_prompt_embeds,
1051
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1052
- lora_scale=lora_scale,
1053
- clip_skip=clip_skip,
1054
- )
1055
-
1056
- # 4. Prepare timesteps
1057
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1058
-
1059
- timesteps = self.scheduler.timesteps
1060
-
1061
- # 5. Prepare latent variables
1062
- num_channels_latents = self.unet.config.in_channels
1063
- latents = self.prepare_latents(
1064
- batch_size * num_images_per_prompt,
1065
- num_channels_latents,
1066
- height,
1067
- width,
1068
- prompt_embeds.dtype,
1069
- device,
1070
- generator,
1071
- latents,
1072
- )
1073
-
1074
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1075
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1076
-
1077
- # 7. Prepare added time ids & embeddings
1078
- add_text_embeds = pooled_prompt_embeds
1079
- if self.text_encoder_2 is None:
1080
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1081
- else:
1082
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1083
-
1084
- add_time_ids = self._get_add_time_ids(
1085
- original_size,
1086
- crops_coords_top_left,
1087
- target_size,
1088
- dtype=prompt_embeds.dtype,
1089
- text_encoder_projection_dim=text_encoder_projection_dim,
1090
- )
1091
- if negative_original_size is not None and negative_target_size is not None:
1092
- negative_add_time_ids = self._get_add_time_ids(
1093
- negative_original_size,
1094
- negative_crops_coords_top_left,
1095
- negative_target_size,
1096
- dtype=prompt_embeds.dtype,
1097
- text_encoder_projection_dim=text_encoder_projection_dim,
1098
- )
1099
- else:
1100
- negative_add_time_ids = add_time_ids
1101
-
1102
- if do_classifier_free_guidance:
1103
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1104
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1105
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1106
-
1107
- prompt_embeds = prompt_embeds.to(device)
1108
- add_text_embeds = add_text_embeds.to(device)
1109
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1110
-
1111
- # 8. Denoising loop
1112
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1113
-
1114
- # 8.1 Apply denoising_end
1115
- if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
1116
- discrete_timestep_cutoff = int(
1117
- round(
1118
- self.scheduler.config.num_train_timesteps
1119
- - (denoising_end * self.scheduler.config.num_train_timesteps)
1120
- )
1121
- )
1122
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1123
- timesteps = timesteps[:num_inference_steps]
1124
-
1125
- # 8.2 Determine denoising_start
1126
- denoising_start_index = 0
1127
- if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
1128
- discrete_timestep_start = int(
1129
- round(
1130
- self.scheduler.config.num_train_timesteps
1131
- - (denoising_start * self.scheduler.config.num_train_timesteps)
1132
- )
1133
- )
1134
- denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
1135
-
1136
-
1137
- with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
1138
- for i, t in enumerate(timesteps, start=denoising_start_index):
1139
- # expand the latents if we are doing classifier free guidance
1140
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1141
-
1142
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1143
-
1144
- # predict the noise residual
1145
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1146
- noise_pred = self.unet(
1147
- latent_model_input,
1148
- t,
1149
- encoder_hidden_states=prompt_embeds,
1150
- cross_attention_kwargs=cross_attention_kwargs,
1151
- added_cond_kwargs=added_cond_kwargs,
1152
- return_dict=False,
1153
- )[0]
1154
-
1155
- # perform guidance
1156
- if do_classifier_free_guidance:
1157
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1158
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1159
-
1160
- if do_classifier_free_guidance and guidance_rescale > 0.0:
1161
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1162
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1163
-
1164
- # compute the previous noisy sample x_t -> x_t-1
1165
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1166
-
1167
- # call the callback, if provided
1168
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1169
- progress_bar.update()
1170
- if callback is not None and i % callback_steps == 0:
1171
- step_idx = i // getattr(self.scheduler, "order", 1)
1172
- callback(step_idx, t, latents)
1173
-
1174
- if XLA_AVAILABLE:
1175
- xm.mark_step()
1176
-
1177
- if not output_type == "latent":
1178
- # make sure the VAE is in float32 mode, as it overflows in float16
1179
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1180
-
1181
- if needs_upcasting:
1182
- self.upcast_vae()
1183
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1184
-
1185
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1186
-
1187
- # cast back to fp16 if needed
1188
- if needs_upcasting:
1189
- self.vae.to(dtype=torch.float16)
1190
- else:
1191
- image = latents
1192
-
1193
- if not output_type == "latent":
1194
- # apply watermark if available
1195
- if self.watermark is not None:
1196
- image = self.watermark.apply_watermark(image)
1197
-
1198
- image = self.image_processor.postprocess(image, output_type=output_type)
1199
-
1200
- # Offload all models
1201
- self.maybe_free_model_hooks()
1202
-
1203
- if not return_dict:
1204
- return (image,)
1205
-
1206
- return StableDiffusionXLPipelineOutput(images=image)
1207
-
1208
-
1209
-
1210
-
1211
- # TODO this is rough. Need to properly stack unconditional
1212
- class FluxWithCFGPipeline(FluxPipeline):
1213
- def __call__(
1214
- self,
1215
- prompt: Union[str, List[str]] = None,
1216
- prompt_2: Optional[Union[str, List[str]]] = None,
1217
- negative_prompt: Optional[Union[str, List[str]]] = None,
1218
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1219
- height: Optional[int] = None,
1220
- width: Optional[int] = None,
1221
- num_inference_steps: int = 28,
1222
- timesteps: List[int] = None,
1223
- guidance_scale: float = 7.0,
1224
- num_images_per_prompt: Optional[int] = 1,
1225
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1226
- latents: Optional[torch.FloatTensor] = None,
1227
- prompt_embeds: Optional[torch.FloatTensor] = None,
1228
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1229
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1230
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1231
- output_type: Optional[str] = "pil",
1232
- return_dict: bool = True,
1233
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1234
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1235
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1236
- max_sequence_length: int = 512,
1237
- ):
1238
-
1239
- height = height or self.default_sample_size * self.vae_scale_factor
1240
- width = width or self.default_sample_size * self.vae_scale_factor
1241
-
1242
- # 1. Check inputs. Raise error if not correct
1243
- self.check_inputs(
1244
- prompt,
1245
- prompt_2,
1246
- height,
1247
- width,
1248
- prompt_embeds=prompt_embeds,
1249
- pooled_prompt_embeds=pooled_prompt_embeds,
1250
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1251
- max_sequence_length=max_sequence_length,
1252
- )
1253
-
1254
- self._guidance_scale = guidance_scale
1255
- self._joint_attention_kwargs = joint_attention_kwargs
1256
- self._interrupt = False
1257
-
1258
- # 2. Define call parameters
1259
- if prompt is not None and isinstance(prompt, str):
1260
- batch_size = 1
1261
- elif prompt is not None and isinstance(prompt, list):
1262
- batch_size = len(prompt)
1263
- else:
1264
- batch_size = prompt_embeds.shape[0]
1265
-
1266
- device = self._execution_device
1267
-
1268
- lora_scale = (
1269
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1270
- )
1271
- (
1272
- prompt_embeds,
1273
- pooled_prompt_embeds,
1274
- text_ids,
1275
- ) = self.encode_prompt(
1276
- prompt=prompt,
1277
- prompt_2=prompt_2,
1278
- prompt_embeds=prompt_embeds,
1279
- pooled_prompt_embeds=pooled_prompt_embeds,
1280
- device=device,
1281
- num_images_per_prompt=num_images_per_prompt,
1282
- max_sequence_length=max_sequence_length,
1283
- lora_scale=lora_scale,
1284
- )
1285
- (
1286
- negative_prompt_embeds,
1287
- negative_pooled_prompt_embeds,
1288
- negative_text_ids,
1289
- ) = self.encode_prompt(
1290
- prompt=negative_prompt,
1291
- prompt_2=negative_prompt_2,
1292
- prompt_embeds=negative_prompt_embeds,
1293
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
1294
- device=device,
1295
- num_images_per_prompt=num_images_per_prompt,
1296
- max_sequence_length=max_sequence_length,
1297
- lora_scale=lora_scale,
1298
- )
1299
-
1300
- # 4. Prepare latent variables
1301
- num_channels_latents = self.transformer.config.in_channels // 4
1302
- latents, latent_image_ids = self.prepare_latents(
1303
- batch_size * num_images_per_prompt,
1304
- num_channels_latents,
1305
- height,
1306
- width,
1307
- prompt_embeds.dtype,
1308
- device,
1309
- generator,
1310
- latents,
1311
- )
1312
-
1313
- # 5. Prepare timesteps
1314
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1315
- image_seq_len = latents.shape[1]
1316
- mu = calculate_shift(
1317
- image_seq_len,
1318
- self.scheduler.config.base_image_seq_len,
1319
- self.scheduler.config.max_image_seq_len,
1320
- self.scheduler.config.base_shift,
1321
- self.scheduler.config.max_shift,
1322
- )
1323
- timesteps, num_inference_steps = retrieve_timesteps(
1324
- self.scheduler,
1325
- num_inference_steps,
1326
- device,
1327
- timesteps,
1328
- sigmas,
1329
- mu=mu,
1330
- )
1331
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1332
- self._num_timesteps = len(timesteps)
1333
-
1334
- # 6. Denoising loop
1335
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1336
- for i, t in enumerate(timesteps):
1337
- if self.interrupt:
1338
- continue
1339
-
1340
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1341
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
1342
-
1343
- # handle guidance
1344
- if self.transformer.config.guidance_embeds:
1345
- guidance = torch.tensor([guidance_scale], device=device)
1346
- guidance = guidance.expand(latents.shape[0])
1347
- else:
1348
- guidance = None
1349
-
1350
- noise_pred_text = self.transformer(
1351
- hidden_states=latents,
1352
- timestep=timestep / 1000,
1353
- guidance=guidance,
1354
- pooled_projections=pooled_prompt_embeds,
1355
- encoder_hidden_states=prompt_embeds,
1356
- txt_ids=text_ids,
1357
- img_ids=latent_image_ids,
1358
- joint_attention_kwargs=self.joint_attention_kwargs,
1359
- return_dict=False,
1360
- )[0]
1361
-
1362
- # todo combine these
1363
- noise_pred_uncond = self.transformer(
1364
- hidden_states=latents,
1365
- timestep=timestep / 1000,
1366
- guidance=guidance,
1367
- pooled_projections=negative_pooled_prompt_embeds,
1368
- encoder_hidden_states=negative_prompt_embeds,
1369
- txt_ids=negative_text_ids,
1370
- img_ids=latent_image_ids,
1371
- joint_attention_kwargs=self.joint_attention_kwargs,
1372
- return_dict=False,
1373
- )[0]
1374
-
1375
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1376
-
1377
- # compute the previous noisy sample x_t -> x_t-1
1378
- latents_dtype = latents.dtype
1379
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1380
-
1381
- if latents.dtype != latents_dtype:
1382
- if torch.backends.mps.is_available():
1383
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1384
- latents = latents.to(latents_dtype)
1385
-
1386
- if callback_on_step_end is not None:
1387
- callback_kwargs = {}
1388
- for k in callback_on_step_end_tensor_inputs:
1389
- callback_kwargs[k] = locals()[k]
1390
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1391
-
1392
- latents = callback_outputs.pop("latents", latents)
1393
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1394
-
1395
- # call the callback, if provided
1396
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1397
- progress_bar.update()
1398
-
1399
- if XLA_AVAILABLE:
1400
- xm.mark_step()
1401
-
1402
- if output_type == "latent":
1403
- image = latents
1404
-
1405
- else:
1406
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1407
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1408
- image = self.vae.decode(latents, return_dict=False)[0]
1409
- image = self.image_processor.postprocess(image, output_type=output_type)
1410
-
1411
- # Offload all models
1412
- self.maybe_free_model_hooks()
1413
-
1414
- if not return_dict:
1415
- return (image,)
1416
-
1417
- return FluxPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion_model.py DELETED
The diff for this file is too large to render. See raw diff