gokaygokay commited on
Commit
1d20a91
·
1 Parent(s): 3880b98
Files changed (2) hide show
  1. app.py +35 -30
  2. pipelines.py +1417 -0
app.py CHANGED
@@ -3,37 +3,41 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from huggingface_hub import hf_hub_download
10
  from optimum.quanto import freeze, qfloat8, quantize
 
11
  import os
12
 
 
 
 
13
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
14
  dtype = torch.bfloat16
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
17
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
18
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype, token=huggingface_token).to(device)
19
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1, token=huggingface_token).to(device)
20
- torch.cuda.empty_cache()
21
-
22
- MAX_SEED = np.iinfo(np.int32).max
23
- MAX_IMAGE_SIZE = 2048
24
 
25
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
26
 
27
- # Load base model first (before quantization)
28
- pipe = DiffusionPipeline.from_pretrained(
29
- "black-forest-labs/FLUX.1-dev",
30
- torch_dtype=dtype,
31
- vae=taef1,
32
  token=huggingface_token
33
- )
34
 
35
  # Load and fuse LoRA BEFORE quantizing
36
- print('Loading and fusing lora, please wait...')
37
  lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
38
  pipe.load_lora_weights(lora_path)
39
  pipe.fuse_lora(lora_scale=0.125)
@@ -43,12 +47,14 @@ pipe.unload_lora_weights()
43
  print("Quantizing transformer")
44
  quantize(pipe.transformer, weights=qfloat8)
45
  freeze(pipe.transformer)
46
- pipe.transformer.to(device)
47
 
48
- # Quantize T5 encoder
49
- print("Quantizing T5")
50
  quantize(pipe.text_encoder_2, weights=qfloat8)
51
  freeze(pipe.text_encoder_2)
 
 
 
52
  pipe.text_encoder_2.to(device)
53
 
54
  # Move other components to device
@@ -72,14 +78,14 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
72
  good_vae=good_vae,
73
  ):
74
  yield img, seed
75
-
76
  examples = [
77
  "wbgmsst, a cat, white background",
78
  "wbgmsst, a warrior, white background",
79
  "wbgmsst, an anime girl, white background",
80
  ]
81
 
82
- css="""
83
  #col-container {
84
  margin: 0 auto;
85
  max-width: 520px;
@@ -139,7 +145,6 @@ with gr.Blocks(css=css) as demo:
139
  )
140
 
141
  with gr.Row():
142
-
143
  guidance_scale = gr.Slider(
144
  label="Guidance Scale",
145
  minimum=1,
@@ -157,18 +162,18 @@ with gr.Blocks(css=css) as demo:
157
  )
158
 
159
  gr.Examples(
160
- examples = examples,
161
- fn = infer,
162
- inputs = [prompt],
163
- outputs = [result, seed],
164
  cache_examples="lazy"
165
  )
166
 
167
  gr.on(
168
  triggers=[run_button.click, prompt.submit],
169
- fn = infer,
170
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
171
- outputs = [result, seed]
172
  )
173
 
174
  demo.launch()
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
 
8
  from huggingface_hub import hf_hub_download
9
  from optimum.quanto import freeze, qfloat8, quantize
10
+ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
11
  import os
12
 
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 2048
15
+ # Set up environment variables and device
16
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
17
  dtype = torch.bfloat16
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # Load VAE models
21
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
22
+ good_vae = AutoencoderKL.from_pretrained(
23
+ "black-forest-labs/FLUX.1-dev",
24
+ subfolder="vae",
25
+ torch_dtype=dtype,
26
+ token=huggingface_token
27
+ ).to(device)
28
 
29
+ # Initialize FluxPipeline instead of DiffusionPipeline
30
+ from pipelines import FluxPipeline
31
 
32
+ pipe = FluxPipeline.from_pretrained(
33
+ "black-forest-labs/FLUX.1-dev",
34
+ torch_dtype=torch.float32, # Load in full precision initially
35
+ vae=taef1,
 
36
  token=huggingface_token
37
+ ).to(device)
38
 
39
  # Load and fuse LoRA BEFORE quantizing
40
+ print('Loading and fusing LoRA, please wait...')
41
  lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
42
  pipe.load_lora_weights(lora_path)
43
  pipe.fuse_lora(lora_scale=0.125)
 
47
  print("Quantizing transformer")
48
  quantize(pipe.transformer, weights=qfloat8)
49
  freeze(pipe.transformer)
 
50
 
51
+ # Quantize the T5 text encoder
52
+ print("Quantizing T5 text encoder")
53
  quantize(pipe.text_encoder_2, weights=qfloat8)
54
  freeze(pipe.text_encoder_2)
55
+
56
+ # Move quantized components to device (if not already)
57
+ pipe.transformer.to(device)
58
  pipe.text_encoder_2.to(device)
59
 
60
  # Move other components to device
 
78
  good_vae=good_vae,
79
  ):
80
  yield img, seed
81
+
82
  examples = [
83
  "wbgmsst, a cat, white background",
84
  "wbgmsst, a warrior, white background",
85
  "wbgmsst, an anime girl, white background",
86
  ]
87
 
88
+ css = """
89
  #col-container {
90
  margin: 0 auto;
91
  max-width: 520px;
 
145
  )
146
 
147
  with gr.Row():
 
148
  guidance_scale = gr.Slider(
149
  label="Guidance Scale",
150
  minimum=1,
 
162
  )
163
 
164
  gr.Examples(
165
+ examples=examples,
166
+ fn=infer,
167
+ inputs=[prompt],
168
+ outputs=[result, seed],
169
  cache_examples="lazy"
170
  )
171
 
172
  gr.on(
173
  triggers=[run_button.click, prompt.submit],
174
+ fn=infer,
175
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
176
+ outputs=[result, seed]
177
  )
178
 
179
  demo.launch()
pipelines.py ADDED
@@ -0,0 +1,1417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
8
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
9
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
10
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
11
+ # from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
12
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
13
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
14
+ from diffusers.utils import is_torch_xla_available
15
+ from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
16
+ from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
17
+
18
+
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
26
+ class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
27
+
28
+ def __init__(
29
+ self,
30
+ vae: 'AutoencoderKL',
31
+ text_encoder: 'CLIPTextModel',
32
+ text_encoder_2: 'CLIPTextModelWithProjection',
33
+ tokenizer: 'CLIPTokenizer',
34
+ tokenizer_2: 'CLIPTokenizer',
35
+ unet: 'UNet2DConditionModel',
36
+ scheduler: 'KarrasDiffusionSchedulers',
37
+ force_zeros_for_empty_prompt: bool = True,
38
+ add_watermarker: Optional[bool] = None,
39
+ ):
40
+ super().__init__(
41
+ vae=vae,
42
+ text_encoder=text_encoder,
43
+ text_encoder_2=text_encoder_2,
44
+ tokenizer=tokenizer,
45
+ tokenizer_2=tokenizer_2,
46
+ unet=unet,
47
+ scheduler=scheduler,
48
+ )
49
+ raise NotImplementedError("This pipeline is not implemented yet")
50
+ # self.sampler = None
51
+ # scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
52
+ # model = ModelWrapper(unet, scheduler.alphas_cumprod)
53
+ # if scheduler.config.prediction_type == "v_prediction":
54
+ # self.k_diffusion_model = CompVisVDenoiser(model)
55
+ # else:
56
+ # self.k_diffusion_model = CompVisDenoiser(model)
57
+
58
+ def set_scheduler(self, scheduler_type: str):
59
+ library = importlib.import_module("k_diffusion")
60
+ sampling = getattr(library, "sampling")
61
+ self.sampler = getattr(sampling, scheduler_type)
62
+
63
+ @torch.no_grad()
64
+ def __call__(
65
+ self,
66
+ prompt: Union[str, List[str]] = None,
67
+ prompt_2: Optional[Union[str, List[str]]] = None,
68
+ height: Optional[int] = None,
69
+ width: Optional[int] = None,
70
+ num_inference_steps: int = 50,
71
+ denoising_end: Optional[float] = None,
72
+ guidance_scale: float = 5.0,
73
+ negative_prompt: Optional[Union[str, List[str]]] = None,
74
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
75
+ num_images_per_prompt: Optional[int] = 1,
76
+ eta: float = 0.0,
77
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
78
+ latents: Optional[torch.FloatTensor] = None,
79
+ prompt_embeds: Optional[torch.FloatTensor] = None,
80
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
81
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
82
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
83
+ output_type: Optional[str] = "pil",
84
+ return_dict: bool = True,
85
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
86
+ callback_steps: int = 1,
87
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
88
+ guidance_rescale: float = 0.0,
89
+ original_size: Optional[Tuple[int, int]] = None,
90
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
91
+ target_size: Optional[Tuple[int, int]] = None,
92
+ use_karras_sigmas: bool = False,
93
+ ):
94
+
95
+ # 0. Default height and width to unet
96
+ height = height or self.default_sample_size * self.vae_scale_factor
97
+ width = width or self.default_sample_size * self.vae_scale_factor
98
+
99
+ original_size = original_size or (height, width)
100
+ target_size = target_size or (height, width)
101
+
102
+ # 1. Check inputs. Raise error if not correct
103
+ self.check_inputs(
104
+ prompt,
105
+ prompt_2,
106
+ height,
107
+ width,
108
+ callback_steps,
109
+ negative_prompt,
110
+ negative_prompt_2,
111
+ prompt_embeds,
112
+ negative_prompt_embeds,
113
+ pooled_prompt_embeds,
114
+ negative_pooled_prompt_embeds,
115
+ )
116
+
117
+ # 2. Define call parameters
118
+ if prompt is not None and isinstance(prompt, str):
119
+ batch_size = 1
120
+ elif prompt is not None and isinstance(prompt, list):
121
+ batch_size = len(prompt)
122
+ else:
123
+ batch_size = prompt_embeds.shape[0]
124
+
125
+ device = self._execution_device
126
+
127
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
128
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
129
+ # corresponds to doing no classifier free guidance.
130
+ do_classifier_free_guidance = guidance_scale > 1.0
131
+
132
+ # 3. Encode input prompt
133
+ text_encoder_lora_scale = (
134
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
135
+ )
136
+ (
137
+ prompt_embeds,
138
+ negative_prompt_embeds,
139
+ pooled_prompt_embeds,
140
+ negative_pooled_prompt_embeds,
141
+ ) = self.encode_prompt(
142
+ prompt=prompt,
143
+ prompt_2=prompt_2,
144
+ device=device,
145
+ num_images_per_prompt=num_images_per_prompt,
146
+ do_classifier_free_guidance=do_classifier_free_guidance,
147
+ negative_prompt=negative_prompt,
148
+ negative_prompt_2=negative_prompt_2,
149
+ prompt_embeds=prompt_embeds,
150
+ negative_prompt_embeds=negative_prompt_embeds,
151
+ pooled_prompt_embeds=pooled_prompt_embeds,
152
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
153
+ lora_scale=text_encoder_lora_scale,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+
159
+ timesteps = self.scheduler.timesteps
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Prepare added time ids & embeddings
178
+ add_text_embeds = pooled_prompt_embeds
179
+ add_time_ids = self._get_add_time_ids(
180
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
181
+ )
182
+
183
+ if do_classifier_free_guidance:
184
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
185
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
186
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
187
+
188
+ prompt_embeds = prompt_embeds.to(device)
189
+ add_text_embeds = add_text_embeds.to(device)
190
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
191
+
192
+ # 8. Denoising loop
193
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
194
+
195
+ # 7.1 Apply denoising_end
196
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
197
+ discrete_timestep_cutoff = int(
198
+ round(
199
+ self.scheduler.config.num_train_timesteps
200
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
201
+ )
202
+ )
203
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
204
+ timesteps = timesteps[:num_inference_steps]
205
+
206
+ # 5. Prepare sigmas
207
+ if use_karras_sigmas:
208
+ sigma_min: float = self.k_diffusion_model.sigmas[0].item()
209
+ sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
210
+ sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
211
+ sigmas = sigmas.to(device)
212
+ else:
213
+ sigmas = self.scheduler.sigmas
214
+ sigmas = sigmas.to(prompt_embeds.dtype)
215
+
216
+ # 5. Prepare latent variables
217
+ num_channels_latents = self.unet.config.in_channels
218
+ latents = self.prepare_latents(
219
+ batch_size * num_images_per_prompt,
220
+ num_channels_latents,
221
+ height,
222
+ width,
223
+ prompt_embeds.dtype,
224
+ device,
225
+ generator,
226
+ latents,
227
+ )
228
+
229
+ latents = latents * sigmas[0]
230
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
231
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
232
+
233
+ # 7. Define model function
234
+ def model_fn(x, t):
235
+ latent_model_input = torch.cat([x] * 2)
236
+ t = torch.cat([t] * 2)
237
+
238
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
239
+ # noise_pred = self.unet(
240
+ # latent_model_input,
241
+ # t,
242
+ # encoder_hidden_states=prompt_embeds,
243
+ # cross_attention_kwargs=cross_attention_kwargs,
244
+ # added_cond_kwargs=added_cond_kwargs,
245
+ # return_dict=False,
246
+ # )[0]
247
+
248
+ noise_pred = self.k_diffusion_model(
249
+ latent_model_input,
250
+ t,
251
+ encoder_hidden_states=prompt_embeds,
252
+ cross_attention_kwargs=cross_attention_kwargs,
253
+ added_cond_kwargs=added_cond_kwargs,
254
+ return_dict=False,)[0]
255
+
256
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
257
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
258
+ return noise_pred
259
+
260
+
261
+ # 8. Run k-diffusion solver
262
+ sampler_kwargs = {}
263
+ # should work without it
264
+ noise_sampler_seed = None
265
+
266
+
267
+ if "noise_sampler" in inspect.signature(self.sampler).parameters:
268
+ min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
269
+ noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
270
+ sampler_kwargs["noise_sampler"] = noise_sampler
271
+
272
+ latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
273
+
274
+ if not output_type == "latent":
275
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
276
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
277
+ else:
278
+ image = latents
279
+ has_nsfw_concept = None
280
+
281
+ if has_nsfw_concept is None:
282
+ do_denormalize = [True] * image.shape[0]
283
+ else:
284
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
285
+
286
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
287
+
288
+ # Offload last model to CPU
289
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
290
+ self.final_offload_hook.offload()
291
+
292
+ if not return_dict:
293
+ return (image,)
294
+
295
+ return StableDiffusionXLPipelineOutput(images=image)
296
+
297
+
298
+ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
299
+
300
+ def predict_noise(
301
+ self,
302
+ prompt: Union[str, List[str]] = None,
303
+ prompt_2: Optional[Union[str, List[str]]] = None,
304
+ num_inference_steps: int = 50,
305
+ guidance_scale: float = 5.0,
306
+ negative_prompt: Optional[Union[str, List[str]]] = None,
307
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
308
+ num_images_per_prompt: Optional[int] = 1,
309
+ eta: float = 0.0,
310
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
311
+ latents: Optional[torch.FloatTensor] = None,
312
+ prompt_embeds: Optional[torch.FloatTensor] = None,
313
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
314
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
315
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
316
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
317
+ guidance_rescale: float = 0.0,
318
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
319
+ timestep: Optional[int] = None,
320
+ ):
321
+ r"""
322
+ Function invoked when calling the pipeline for generation.
323
+
324
+ Args:
325
+ prompt (`str` or `List[str]`, *optional*):
326
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
327
+ instead.
328
+ prompt_2 (`str` or `List[str]`, *optional*):
329
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
330
+ used in both text-encoders
331
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
332
+ The height in pixels of the generated image.
333
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
334
+ The width in pixels of the generated image.
335
+ num_inference_steps (`int`, *optional*, defaults to 50):
336
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
337
+ expense of slower inference.
338
+ denoising_end (`float`, *optional*):
339
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
340
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
341
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
342
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
343
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
344
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
345
+ guidance_scale (`float`, *optional*, defaults to 7.5):
346
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
347
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
348
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
349
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
350
+ usually at the expense of lower image quality.
351
+ negative_prompt (`str` or `List[str]`, *optional*):
352
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
353
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
354
+ less than `1`).
355
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
356
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
357
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
358
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
359
+ The number of images to generate per prompt.
360
+ eta (`float`, *optional*, defaults to 0.0):
361
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
362
+ [`schedulers.DDIMScheduler`], will be ignored for others.
363
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
364
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
365
+ to make generation deterministic.
366
+ latents (`torch.FloatTensor`, *optional*):
367
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
368
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
369
+ tensor will ge generated by sampling using the supplied random `generator`.
370
+ prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
372
+ provided, text embeddings will be generated from `prompt` input argument.
373
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
375
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
376
+ argument.
377
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
378
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
379
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
380
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
382
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
383
+ input argument.
384
+ output_type (`str`, *optional*, defaults to `"pil"`):
385
+ The output format of the generate image. Choose between
386
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
387
+ return_dict (`bool`, *optional*, defaults to `True`):
388
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
389
+ of a plain tuple.
390
+ callback (`Callable`, *optional*):
391
+ A function that will be called every `callback_steps` steps during inference. The function will be
392
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
393
+ callback_steps (`int`, *optional*, defaults to 1):
394
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
395
+ called at every step.
396
+ cross_attention_kwargs (`dict`, *optional*):
397
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
398
+ `self.processor` in
399
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
400
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
401
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
402
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
403
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
404
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
405
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
406
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
407
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
408
+ explained in section 2.2 of
409
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
410
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
411
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
412
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
413
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
414
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
415
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
416
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
417
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
418
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
419
+
420
+ Examples:
421
+
422
+ Returns:
423
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
424
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
425
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
426
+ """
427
+ # if not predict_noise:
428
+ # # call parent
429
+ # return super().__call__(
430
+ # prompt=prompt,
431
+ # prompt_2=prompt_2,
432
+ # height=height,
433
+ # width=width,
434
+ # num_inference_steps=num_inference_steps,
435
+ # denoising_end=denoising_end,
436
+ # guidance_scale=guidance_scale,
437
+ # negative_prompt=negative_prompt,
438
+ # negative_prompt_2=negative_prompt_2,
439
+ # num_images_per_prompt=num_images_per_prompt,
440
+ # eta=eta,
441
+ # generator=generator,
442
+ # latents=latents,
443
+ # prompt_embeds=prompt_embeds,
444
+ # negative_prompt_embeds=negative_prompt_embeds,
445
+ # pooled_prompt_embeds=pooled_prompt_embeds,
446
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
447
+ # output_type=output_type,
448
+ # return_dict=return_dict,
449
+ # callback=callback,
450
+ # callback_steps=callback_steps,
451
+ # cross_attention_kwargs=cross_attention_kwargs,
452
+ # guidance_rescale=guidance_rescale,
453
+ # original_size=original_size,
454
+ # crops_coords_top_left=crops_coords_top_left,
455
+ # target_size=target_size,
456
+ # )
457
+
458
+ # 0. Default height and width to unet
459
+ height = self.default_sample_size * self.vae_scale_factor
460
+ width = self.default_sample_size * self.vae_scale_factor
461
+
462
+ original_size = (height, width)
463
+ target_size = (height, width)
464
+
465
+ # 2. Define call parameters
466
+ if prompt is not None and isinstance(prompt, str):
467
+ batch_size = 1
468
+ elif prompt is not None and isinstance(prompt, list):
469
+ batch_size = len(prompt)
470
+ else:
471
+ batch_size = prompt_embeds.shape[0]
472
+
473
+ device = self._execution_device
474
+
475
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
476
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
477
+ # corresponds to doing no classifier free guidance.
478
+ do_classifier_free_guidance = guidance_scale > 1.0
479
+
480
+ # 3. Encode input prompt
481
+ text_encoder_lora_scale = (
482
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
483
+ )
484
+ (
485
+ prompt_embeds,
486
+ negative_prompt_embeds,
487
+ pooled_prompt_embeds,
488
+ negative_pooled_prompt_embeds,
489
+ ) = self.encode_prompt(
490
+ prompt=prompt,
491
+ prompt_2=prompt_2,
492
+ device=device,
493
+ num_images_per_prompt=num_images_per_prompt,
494
+ do_classifier_free_guidance=do_classifier_free_guidance,
495
+ negative_prompt=negative_prompt,
496
+ negative_prompt_2=negative_prompt_2,
497
+ prompt_embeds=prompt_embeds,
498
+ negative_prompt_embeds=negative_prompt_embeds,
499
+ pooled_prompt_embeds=pooled_prompt_embeds,
500
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
501
+ lora_scale=text_encoder_lora_scale,
502
+ )
503
+
504
+ # 4. Prepare timesteps
505
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
506
+
507
+ # 5. Prepare latent variables
508
+ num_channels_latents = self.unet.config.in_channels
509
+ latents = self.prepare_latents(
510
+ batch_size * num_images_per_prompt,
511
+ num_channels_latents,
512
+ height,
513
+ width,
514
+ prompt_embeds.dtype,
515
+ device,
516
+ generator,
517
+ latents,
518
+ )
519
+
520
+ # 7. Prepare added time ids & embeddings
521
+ add_text_embeds = pooled_prompt_embeds
522
+ add_time_ids = self._get_add_time_ids(
523
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
524
+ ).to(device) # TODO DOES NOT CAST ORIGINALLY
525
+
526
+ if do_classifier_free_guidance:
527
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
528
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
529
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
530
+
531
+ prompt_embeds = prompt_embeds.to(device)
532
+ add_text_embeds = add_text_embeds.to(device)
533
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
534
+
535
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
536
+
537
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
538
+
539
+ # predict the noise residual
540
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
541
+ noise_pred = self.unet(
542
+ latent_model_input,
543
+ timestep,
544
+ encoder_hidden_states=prompt_embeds,
545
+ cross_attention_kwargs=cross_attention_kwargs,
546
+ added_cond_kwargs=added_cond_kwargs,
547
+ return_dict=False,
548
+ )[0]
549
+
550
+ # perform guidance
551
+ if do_classifier_free_guidance:
552
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
553
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
554
+
555
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
556
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
557
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
558
+
559
+ return noise_pred
560
+
561
+ def enable_model_cpu_offload(self, gpu_id=0):
562
+ print('Called cpu offload', gpu_id)
563
+ # fuck off
564
+ pass
565
+
566
+
567
+ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
568
+
569
+ # replace the call so it matches SDXL call so we can use the same code and also stop early
570
+ def __call__(
571
+ self,
572
+ prompt: Union[str, List[str]] = None,
573
+ prompt_2: Optional[Union[str, List[str]]] = None,
574
+ height: Optional[int] = None,
575
+ width: Optional[int] = None,
576
+ num_inference_steps: int = 50,
577
+ denoising_end: Optional[float] = None,
578
+ guidance_scale: float = 5.0,
579
+ negative_prompt: Optional[Union[str, List[str]]] = None,
580
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
581
+ num_images_per_prompt: Optional[int] = 1,
582
+ eta: float = 0.0,
583
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
584
+ latents: Optional[torch.FloatTensor] = None,
585
+ prompt_embeds: Optional[torch.FloatTensor] = None,
586
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
587
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
588
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
589
+ output_type: Optional[str] = "pil",
590
+ return_dict: bool = True,
591
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
592
+ callback_steps: int = 1,
593
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
594
+ guidance_rescale: float = 0.0,
595
+ original_size: Optional[Tuple[int, int]] = None,
596
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
597
+ target_size: Optional[Tuple[int, int]] = None,
598
+ ):
599
+ # 0. Default height and width to unet
600
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
601
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
602
+
603
+ # 1. Check inputs. Raise error if not correct
604
+ self.check_inputs(
605
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
606
+ )
607
+
608
+ # 2. Define call parameters
609
+ if prompt is not None and isinstance(prompt, str):
610
+ batch_size = 1
611
+ elif prompt is not None and isinstance(prompt, list):
612
+ batch_size = len(prompt)
613
+ else:
614
+ batch_size = prompt_embeds.shape[0]
615
+
616
+ device = self._execution_device
617
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
618
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
619
+ # corresponds to doing no classifier free guidance.
620
+ do_classifier_free_guidance = guidance_scale > 1.0
621
+
622
+ # 3. Encode input prompt
623
+ text_encoder_lora_scale = (
624
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
625
+ )
626
+ prompt_embeds = self._encode_prompt(
627
+ prompt,
628
+ device,
629
+ num_images_per_prompt,
630
+ do_classifier_free_guidance,
631
+ negative_prompt,
632
+ prompt_embeds=prompt_embeds,
633
+ negative_prompt_embeds=negative_prompt_embeds,
634
+ lora_scale=text_encoder_lora_scale,
635
+ )
636
+
637
+ # 4. Prepare timesteps
638
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
639
+ timesteps = self.scheduler.timesteps
640
+
641
+ # 5. Prepare latent variables
642
+ num_channels_latents = self.unet.config.in_channels
643
+ latents = self.prepare_latents(
644
+ batch_size * num_images_per_prompt,
645
+ num_channels_latents,
646
+ height,
647
+ width,
648
+ prompt_embeds.dtype,
649
+ device,
650
+ generator,
651
+ latents,
652
+ )
653
+
654
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
655
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
656
+
657
+ # 7. Denoising loop
658
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
659
+
660
+ # 7.1 Apply denoising_end
661
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
662
+ discrete_timestep_cutoff = int(
663
+ round(
664
+ self.scheduler.config.num_train_timesteps
665
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
666
+ )
667
+ )
668
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
669
+ timesteps = timesteps[:num_inference_steps]
670
+
671
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
672
+ for i, t in enumerate(timesteps):
673
+ # expand the latents if we are doing classifier free guidance
674
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
675
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
676
+
677
+ # predict the noise residual
678
+ noise_pred = self.unet(
679
+ latent_model_input,
680
+ t,
681
+ encoder_hidden_states=prompt_embeds,
682
+ cross_attention_kwargs=cross_attention_kwargs,
683
+ return_dict=False,
684
+ )[0]
685
+
686
+ # perform guidance
687
+ if do_classifier_free_guidance:
688
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
689
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
690
+
691
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
692
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
693
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
694
+
695
+ # compute the previous noisy sample x_t -> x_t-1
696
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
697
+
698
+ # call the callback, if provided
699
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
700
+ progress_bar.update()
701
+ if callback is not None and i % callback_steps == 0:
702
+ callback(i, t, latents)
703
+
704
+ if not output_type == "latent":
705
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
706
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
707
+ else:
708
+ image = latents
709
+ has_nsfw_concept = None
710
+
711
+ if has_nsfw_concept is None:
712
+ do_denormalize = [True] * image.shape[0]
713
+ else:
714
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
715
+
716
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
717
+
718
+ # Offload last model to CPU
719
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
720
+ self.final_offload_hook.offload()
721
+
722
+ if not return_dict:
723
+ return (image, has_nsfw_concept)
724
+
725
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
726
+
727
+ # some of the inputs are to keep it compatible with sdx
728
+ def predict_noise(
729
+ self,
730
+ prompt: Union[str, List[str]] = None,
731
+ prompt_2: Optional[Union[str, List[str]]] = None,
732
+ num_inference_steps: int = 50,
733
+ guidance_scale: float = 5.0,
734
+ negative_prompt: Optional[Union[str, List[str]]] = None,
735
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
736
+ num_images_per_prompt: Optional[int] = 1,
737
+ eta: float = 0.0,
738
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
739
+ latents: Optional[torch.FloatTensor] = None,
740
+ prompt_embeds: Optional[torch.FloatTensor] = None,
741
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
742
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
743
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
744
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
745
+ guidance_rescale: float = 0.0,
746
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
747
+ timestep: Optional[int] = None,
748
+ ):
749
+
750
+ # 0. Default height and width to unet
751
+ height = self.unet.config.sample_size * self.vae_scale_factor
752
+ width = self.unet.config.sample_size * self.vae_scale_factor
753
+
754
+ # 2. Define call parameters
755
+ if prompt is not None and isinstance(prompt, str):
756
+ batch_size = 1
757
+ elif prompt is not None and isinstance(prompt, list):
758
+ batch_size = len(prompt)
759
+ else:
760
+ batch_size = prompt_embeds.shape[0]
761
+
762
+ device = self._execution_device
763
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
764
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
765
+ # corresponds to doing no classifier free guidance.
766
+ do_classifier_free_guidance = guidance_scale > 1.0
767
+
768
+ # 3. Encode input prompt
769
+ text_encoder_lora_scale = (
770
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
771
+ )
772
+ prompt_embeds = self._encode_prompt(
773
+ prompt,
774
+ device,
775
+ num_images_per_prompt,
776
+ do_classifier_free_guidance,
777
+ negative_prompt,
778
+ prompt_embeds=prompt_embeds,
779
+ negative_prompt_embeds=negative_prompt_embeds,
780
+ lora_scale=text_encoder_lora_scale,
781
+ )
782
+
783
+ # 4. Prepare timesteps
784
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
785
+
786
+ # 5. Prepare latent variables
787
+ num_channels_latents = self.unet.config.in_channels
788
+ latents = self.prepare_latents(
789
+ batch_size * num_images_per_prompt,
790
+ num_channels_latents,
791
+ height,
792
+ width,
793
+ prompt_embeds.dtype,
794
+ device,
795
+ generator,
796
+ latents,
797
+ )
798
+
799
+ # expand the latents if we are doing classifier free guidance
800
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
801
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
802
+
803
+ # predict the noise residual
804
+ noise_pred = self.unet(
805
+ latent_model_input,
806
+ timestep,
807
+ encoder_hidden_states=prompt_embeds,
808
+ cross_attention_kwargs=cross_attention_kwargs,
809
+ return_dict=False,
810
+ )[0]
811
+
812
+ # perform guidance
813
+ if do_classifier_free_guidance:
814
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
815
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
816
+
817
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
818
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
819
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
820
+
821
+ return noise_pred
822
+
823
+
824
+ class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
825
+
826
+ @torch.no_grad()
827
+ def __call__(
828
+ self,
829
+ prompt: Union[str, List[str]] = None,
830
+ prompt_2: Optional[Union[str, List[str]]] = None,
831
+ height: Optional[int] = None,
832
+ width: Optional[int] = None,
833
+ num_inference_steps: int = 50,
834
+ denoising_end: Optional[float] = None,
835
+ denoising_start: Optional[float] = None,
836
+ guidance_scale: float = 5.0,
837
+ negative_prompt: Optional[Union[str, List[str]]] = None,
838
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
839
+ num_images_per_prompt: Optional[int] = 1,
840
+ eta: float = 0.0,
841
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
842
+ latents: Optional[torch.FloatTensor] = None,
843
+ prompt_embeds: Optional[torch.FloatTensor] = None,
844
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
845
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
846
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
847
+ output_type: Optional[str] = "pil",
848
+ return_dict: bool = True,
849
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
850
+ callback_steps: int = 1,
851
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
852
+ guidance_rescale: float = 0.0,
853
+ original_size: Optional[Tuple[int, int]] = None,
854
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
855
+ target_size: Optional[Tuple[int, int]] = None,
856
+ negative_original_size: Optional[Tuple[int, int]] = None,
857
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
858
+ negative_target_size: Optional[Tuple[int, int]] = None,
859
+ clip_skip: Optional[int] = None,
860
+ ):
861
+ r"""
862
+ Function invoked when calling the pipeline for generation.
863
+
864
+ Args:
865
+ prompt (`str` or `List[str]`, *optional*):
866
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
867
+ instead.
868
+ prompt_2 (`str` or `List[str]`, *optional*):
869
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
870
+ used in both text-encoders
871
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
872
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
873
+ Anything below 512 pixels won't work well for
874
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
875
+ and checkpoints that are not specifically fine-tuned on low resolutions.
876
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
877
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
878
+ Anything below 512 pixels won't work well for
879
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
880
+ and checkpoints that are not specifically fine-tuned on low resolutions.
881
+ num_inference_steps (`int`, *optional*, defaults to 50):
882
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
883
+ expense of slower inference.
884
+ denoising_end (`float`, *optional*):
885
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
886
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
887
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
888
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
889
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
890
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
891
+ denoising_start (`float`, *optional*):
892
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
893
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
894
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
895
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
896
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
897
+ Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
898
+ guidance_scale (`float`, *optional*, defaults to 5.0):
899
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
900
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
901
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
902
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
903
+ usually at the expense of lower image quality.
904
+ negative_prompt (`str` or `List[str]`, *optional*):
905
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
906
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
907
+ less than `1`).
908
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
909
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
910
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
911
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
912
+ The number of images to generate per prompt.
913
+ eta (`float`, *optional*, defaults to 0.0):
914
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
915
+ [`schedulers.DDIMScheduler`], will be ignored for others.
916
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
917
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
918
+ to make generation deterministic.
919
+ latents (`torch.FloatTensor`, *optional*):
920
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
921
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
922
+ tensor will ge generated by sampling using the supplied random `generator`.
923
+ prompt_embeds (`torch.FloatTensor`, *optional*):
924
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
925
+ provided, text embeddings will be generated from `prompt` input argument.
926
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
927
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
928
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
929
+ argument.
930
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
931
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
932
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
933
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
934
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
935
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
936
+ input argument.
937
+ output_type (`str`, *optional*, defaults to `"pil"`):
938
+ The output format of the generate image. Choose between
939
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
940
+ return_dict (`bool`, *optional*, defaults to `True`):
941
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
942
+ of a plain tuple.
943
+ callback (`Callable`, *optional*):
944
+ A function that will be called every `callback_steps` steps during inference. The function will be
945
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
946
+ callback_steps (`int`, *optional*, defaults to 1):
947
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
948
+ called at every step.
949
+ cross_attention_kwargs (`dict`, *optional*):
950
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
951
+ `self.processor` in
952
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
953
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
954
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
955
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
956
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
957
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
958
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
959
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
960
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
961
+ explained in section 2.2 of
962
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
963
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
964
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
965
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
966
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
967
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
968
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
969
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
970
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
971
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
972
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
973
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
974
+ micro-conditioning as explained in section 2.2 of
975
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
976
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
977
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
978
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
979
+ micro-conditioning as explained in section 2.2 of
980
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
981
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
982
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
983
+ To negatively condition the generation process based on a target image resolution. It should be as same
984
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
985
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
986
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
987
+
988
+ Examples:
989
+
990
+ Returns:
991
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
992
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
993
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
994
+ """
995
+ # 0. Default height and width to unet
996
+ height = height or self.default_sample_size * self.vae_scale_factor
997
+ width = width or self.default_sample_size * self.vae_scale_factor
998
+
999
+ original_size = original_size or (height, width)
1000
+ target_size = target_size or (height, width)
1001
+
1002
+ # 1. Check inputs. Raise error if not correct
1003
+ self.check_inputs(
1004
+ prompt,
1005
+ prompt_2,
1006
+ height,
1007
+ width,
1008
+ callback_steps,
1009
+ negative_prompt,
1010
+ negative_prompt_2,
1011
+ prompt_embeds,
1012
+ negative_prompt_embeds,
1013
+ pooled_prompt_embeds,
1014
+ negative_pooled_prompt_embeds,
1015
+ )
1016
+
1017
+ # 2. Define call parameters
1018
+ if prompt is not None and isinstance(prompt, str):
1019
+ batch_size = 1
1020
+ elif prompt is not None and isinstance(prompt, list):
1021
+ batch_size = len(prompt)
1022
+ else:
1023
+ batch_size = prompt_embeds.shape[0]
1024
+
1025
+ device = self._execution_device
1026
+
1027
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1028
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1029
+ # corresponds to doing no classifier free guidance.
1030
+ do_classifier_free_guidance = guidance_scale > 1.0
1031
+
1032
+ # 3. Encode input prompt
1033
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1034
+
1035
+ (
1036
+ prompt_embeds,
1037
+ negative_prompt_embeds,
1038
+ pooled_prompt_embeds,
1039
+ negative_pooled_prompt_embeds,
1040
+ ) = self.encode_prompt(
1041
+ prompt=prompt,
1042
+ prompt_2=prompt_2,
1043
+ device=device,
1044
+ num_images_per_prompt=num_images_per_prompt,
1045
+ do_classifier_free_guidance=do_classifier_free_guidance,
1046
+ negative_prompt=negative_prompt,
1047
+ negative_prompt_2=negative_prompt_2,
1048
+ prompt_embeds=prompt_embeds,
1049
+ negative_prompt_embeds=negative_prompt_embeds,
1050
+ pooled_prompt_embeds=pooled_prompt_embeds,
1051
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1052
+ lora_scale=lora_scale,
1053
+ clip_skip=clip_skip,
1054
+ )
1055
+
1056
+ # 4. Prepare timesteps
1057
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1058
+
1059
+ timesteps = self.scheduler.timesteps
1060
+
1061
+ # 5. Prepare latent variables
1062
+ num_channels_latents = self.unet.config.in_channels
1063
+ latents = self.prepare_latents(
1064
+ batch_size * num_images_per_prompt,
1065
+ num_channels_latents,
1066
+ height,
1067
+ width,
1068
+ prompt_embeds.dtype,
1069
+ device,
1070
+ generator,
1071
+ latents,
1072
+ )
1073
+
1074
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1075
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1076
+
1077
+ # 7. Prepare added time ids & embeddings
1078
+ add_text_embeds = pooled_prompt_embeds
1079
+ if self.text_encoder_2 is None:
1080
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1081
+ else:
1082
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1083
+
1084
+ add_time_ids = self._get_add_time_ids(
1085
+ original_size,
1086
+ crops_coords_top_left,
1087
+ target_size,
1088
+ dtype=prompt_embeds.dtype,
1089
+ text_encoder_projection_dim=text_encoder_projection_dim,
1090
+ )
1091
+ if negative_original_size is not None and negative_target_size is not None:
1092
+ negative_add_time_ids = self._get_add_time_ids(
1093
+ negative_original_size,
1094
+ negative_crops_coords_top_left,
1095
+ negative_target_size,
1096
+ dtype=prompt_embeds.dtype,
1097
+ text_encoder_projection_dim=text_encoder_projection_dim,
1098
+ )
1099
+ else:
1100
+ negative_add_time_ids = add_time_ids
1101
+
1102
+ if do_classifier_free_guidance:
1103
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1104
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1105
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1106
+
1107
+ prompt_embeds = prompt_embeds.to(device)
1108
+ add_text_embeds = add_text_embeds.to(device)
1109
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1110
+
1111
+ # 8. Denoising loop
1112
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1113
+
1114
+ # 8.1 Apply denoising_end
1115
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
1116
+ discrete_timestep_cutoff = int(
1117
+ round(
1118
+ self.scheduler.config.num_train_timesteps
1119
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
1120
+ )
1121
+ )
1122
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1123
+ timesteps = timesteps[:num_inference_steps]
1124
+
1125
+ # 8.2 Determine denoising_start
1126
+ denoising_start_index = 0
1127
+ if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
1128
+ discrete_timestep_start = int(
1129
+ round(
1130
+ self.scheduler.config.num_train_timesteps
1131
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
1132
+ )
1133
+ )
1134
+ denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
1135
+
1136
+
1137
+ with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
1138
+ for i, t in enumerate(timesteps, start=denoising_start_index):
1139
+ # expand the latents if we are doing classifier free guidance
1140
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1141
+
1142
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1143
+
1144
+ # predict the noise residual
1145
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1146
+ noise_pred = self.unet(
1147
+ latent_model_input,
1148
+ t,
1149
+ encoder_hidden_states=prompt_embeds,
1150
+ cross_attention_kwargs=cross_attention_kwargs,
1151
+ added_cond_kwargs=added_cond_kwargs,
1152
+ return_dict=False,
1153
+ )[0]
1154
+
1155
+ # perform guidance
1156
+ if do_classifier_free_guidance:
1157
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1158
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1159
+
1160
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1161
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1162
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1163
+
1164
+ # compute the previous noisy sample x_t -> x_t-1
1165
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1166
+
1167
+ # call the callback, if provided
1168
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1169
+ progress_bar.update()
1170
+ if callback is not None and i % callback_steps == 0:
1171
+ step_idx = i // getattr(self.scheduler, "order", 1)
1172
+ callback(step_idx, t, latents)
1173
+
1174
+ if XLA_AVAILABLE:
1175
+ xm.mark_step()
1176
+
1177
+ if not output_type == "latent":
1178
+ # make sure the VAE is in float32 mode, as it overflows in float16
1179
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1180
+
1181
+ if needs_upcasting:
1182
+ self.upcast_vae()
1183
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1184
+
1185
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1186
+
1187
+ # cast back to fp16 if needed
1188
+ if needs_upcasting:
1189
+ self.vae.to(dtype=torch.float16)
1190
+ else:
1191
+ image = latents
1192
+
1193
+ if not output_type == "latent":
1194
+ # apply watermark if available
1195
+ if self.watermark is not None:
1196
+ image = self.watermark.apply_watermark(image)
1197
+
1198
+ image = self.image_processor.postprocess(image, output_type=output_type)
1199
+
1200
+ # Offload all models
1201
+ self.maybe_free_model_hooks()
1202
+
1203
+ if not return_dict:
1204
+ return (image,)
1205
+
1206
+ return StableDiffusionXLPipelineOutput(images=image)
1207
+
1208
+
1209
+
1210
+
1211
+ # TODO this is rough. Need to properly stack unconditional
1212
+ class FluxWithCFGPipeline(FluxPipeline):
1213
+ def __call__(
1214
+ self,
1215
+ prompt: Union[str, List[str]] = None,
1216
+ prompt_2: Optional[Union[str, List[str]]] = None,
1217
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1218
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1219
+ height: Optional[int] = None,
1220
+ width: Optional[int] = None,
1221
+ num_inference_steps: int = 28,
1222
+ timesteps: List[int] = None,
1223
+ guidance_scale: float = 7.0,
1224
+ num_images_per_prompt: Optional[int] = 1,
1225
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1226
+ latents: Optional[torch.FloatTensor] = None,
1227
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1228
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1229
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1230
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1231
+ output_type: Optional[str] = "pil",
1232
+ return_dict: bool = True,
1233
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1234
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1235
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1236
+ max_sequence_length: int = 512,
1237
+ ):
1238
+
1239
+ height = height or self.default_sample_size * self.vae_scale_factor
1240
+ width = width or self.default_sample_size * self.vae_scale_factor
1241
+
1242
+ # 1. Check inputs. Raise error if not correct
1243
+ self.check_inputs(
1244
+ prompt,
1245
+ prompt_2,
1246
+ height,
1247
+ width,
1248
+ prompt_embeds=prompt_embeds,
1249
+ pooled_prompt_embeds=pooled_prompt_embeds,
1250
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1251
+ max_sequence_length=max_sequence_length,
1252
+ )
1253
+
1254
+ self._guidance_scale = guidance_scale
1255
+ self._joint_attention_kwargs = joint_attention_kwargs
1256
+ self._interrupt = False
1257
+
1258
+ # 2. Define call parameters
1259
+ if prompt is not None and isinstance(prompt, str):
1260
+ batch_size = 1
1261
+ elif prompt is not None and isinstance(prompt, list):
1262
+ batch_size = len(prompt)
1263
+ else:
1264
+ batch_size = prompt_embeds.shape[0]
1265
+
1266
+ device = self._execution_device
1267
+
1268
+ lora_scale = (
1269
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1270
+ )
1271
+ (
1272
+ prompt_embeds,
1273
+ pooled_prompt_embeds,
1274
+ text_ids,
1275
+ ) = self.encode_prompt(
1276
+ prompt=prompt,
1277
+ prompt_2=prompt_2,
1278
+ prompt_embeds=prompt_embeds,
1279
+ pooled_prompt_embeds=pooled_prompt_embeds,
1280
+ device=device,
1281
+ num_images_per_prompt=num_images_per_prompt,
1282
+ max_sequence_length=max_sequence_length,
1283
+ lora_scale=lora_scale,
1284
+ )
1285
+ (
1286
+ negative_prompt_embeds,
1287
+ negative_pooled_prompt_embeds,
1288
+ negative_text_ids,
1289
+ ) = self.encode_prompt(
1290
+ prompt=negative_prompt,
1291
+ prompt_2=negative_prompt_2,
1292
+ prompt_embeds=negative_prompt_embeds,
1293
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
1294
+ device=device,
1295
+ num_images_per_prompt=num_images_per_prompt,
1296
+ max_sequence_length=max_sequence_length,
1297
+ lora_scale=lora_scale,
1298
+ )
1299
+
1300
+ # 4. Prepare latent variables
1301
+ num_channels_latents = self.transformer.config.in_channels // 4
1302
+ latents, latent_image_ids = self.prepare_latents(
1303
+ batch_size * num_images_per_prompt,
1304
+ num_channels_latents,
1305
+ height,
1306
+ width,
1307
+ prompt_embeds.dtype,
1308
+ device,
1309
+ generator,
1310
+ latents,
1311
+ )
1312
+
1313
+ # 5. Prepare timesteps
1314
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1315
+ image_seq_len = latents.shape[1]
1316
+ mu = calculate_shift(
1317
+ image_seq_len,
1318
+ self.scheduler.config.base_image_seq_len,
1319
+ self.scheduler.config.max_image_seq_len,
1320
+ self.scheduler.config.base_shift,
1321
+ self.scheduler.config.max_shift,
1322
+ )
1323
+ timesteps, num_inference_steps = retrieve_timesteps(
1324
+ self.scheduler,
1325
+ num_inference_steps,
1326
+ device,
1327
+ timesteps,
1328
+ sigmas,
1329
+ mu=mu,
1330
+ )
1331
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1332
+ self._num_timesteps = len(timesteps)
1333
+
1334
+ # 6. Denoising loop
1335
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1336
+ for i, t in enumerate(timesteps):
1337
+ if self.interrupt:
1338
+ continue
1339
+
1340
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1341
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1342
+
1343
+ # handle guidance
1344
+ if self.transformer.config.guidance_embeds:
1345
+ guidance = torch.tensor([guidance_scale], device=device)
1346
+ guidance = guidance.expand(latents.shape[0])
1347
+ else:
1348
+ guidance = None
1349
+
1350
+ noise_pred_text = self.transformer(
1351
+ hidden_states=latents,
1352
+ timestep=timestep / 1000,
1353
+ guidance=guidance,
1354
+ pooled_projections=pooled_prompt_embeds,
1355
+ encoder_hidden_states=prompt_embeds,
1356
+ txt_ids=text_ids,
1357
+ img_ids=latent_image_ids,
1358
+ joint_attention_kwargs=self.joint_attention_kwargs,
1359
+ return_dict=False,
1360
+ )[0]
1361
+
1362
+ # todo combine these
1363
+ noise_pred_uncond = self.transformer(
1364
+ hidden_states=latents,
1365
+ timestep=timestep / 1000,
1366
+ guidance=guidance,
1367
+ pooled_projections=negative_pooled_prompt_embeds,
1368
+ encoder_hidden_states=negative_prompt_embeds,
1369
+ txt_ids=negative_text_ids,
1370
+ img_ids=latent_image_ids,
1371
+ joint_attention_kwargs=self.joint_attention_kwargs,
1372
+ return_dict=False,
1373
+ )[0]
1374
+
1375
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1376
+
1377
+ # compute the previous noisy sample x_t -> x_t-1
1378
+ latents_dtype = latents.dtype
1379
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1380
+
1381
+ if latents.dtype != latents_dtype:
1382
+ if torch.backends.mps.is_available():
1383
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1384
+ latents = latents.to(latents_dtype)
1385
+
1386
+ if callback_on_step_end is not None:
1387
+ callback_kwargs = {}
1388
+ for k in callback_on_step_end_tensor_inputs:
1389
+ callback_kwargs[k] = locals()[k]
1390
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1391
+
1392
+ latents = callback_outputs.pop("latents", latents)
1393
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1394
+
1395
+ # call the callback, if provided
1396
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1397
+ progress_bar.update()
1398
+
1399
+ if XLA_AVAILABLE:
1400
+ xm.mark_step()
1401
+
1402
+ if output_type == "latent":
1403
+ image = latents
1404
+
1405
+ else:
1406
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1407
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1408
+ image = self.vae.decode(latents, return_dict=False)[0]
1409
+ image = self.image_processor.postprocess(image, output_type=output_type)
1410
+
1411
+ # Offload all models
1412
+ self.maybe_free_model_hooks()
1413
+
1414
+ if not return_dict:
1415
+ return (image,)
1416
+
1417
+ return FluxPipelineOutput(images=image)