elismasilva commited on
Commit
a5c3b38
·
1 Parent(s): 296d048

small fixes pipeline

Browse files
Files changed (1) hide show
  1. mixture_tiling_sdxl.py +60 -73
mixture_tiling_sdxl.py CHANGED
@@ -12,23 +12,20 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from enum import Enum
16
  import inspect
 
17
  from typing import Any, Dict, List, Optional, Tuple, Union
18
 
19
  import torch
20
  from transformers import (
21
- CLIPImageProcessor,
22
  CLIPTextModel,
23
  CLIPTextModelWithProjection,
24
  CLIPTokenizer,
25
- CLIPVisionModelWithProjection,
26
  )
27
 
28
  from diffusers.image_processor import VaeImageProcessor
29
  from diffusers.loaders import (
30
  FromSingleFileMixin,
31
- IPAdapterMixin,
32
  StableDiffusionXLLoraLoaderMixin,
33
  TextualInversionLoaderMixin,
34
  )
@@ -39,6 +36,8 @@ from diffusers.models.attention_processor import (
39
  XFormersAttnProcessor,
40
  )
41
  from diffusers.models.lora import adjust_lora_scale_text_encoder
 
 
42
  from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
43
  from diffusers.utils import (
44
  USE_PEFT_BACKEND,
@@ -50,11 +49,10 @@ from diffusers.utils import (
50
  unscale_lora_layers,
51
  )
52
  from diffusers.utils.torch_utils import randn_tensor
53
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
54
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
55
 
56
  try:
57
- from ligo.segments import segment
58
  except ImportError:
59
  raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
60
 
@@ -87,6 +85,7 @@ EXAMPLE_DOC_STRING = """
87
  ```
88
  """
89
 
 
90
  def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
91
  """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
92
 
@@ -151,6 +150,7 @@ def _tile2latent_exclusive_indices(
151
  # return row_init, row_end, col_init, col_end
152
  return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
153
 
 
154
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
155
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
156
  r"""
@@ -210,7 +210,7 @@ def retrieve_timesteps(
210
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
211
  second element is the number of inference steps.
212
  """
213
-
214
  if timesteps is not None and sigmas is not None:
215
  raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
216
  if timesteps is not None:
@@ -245,7 +245,6 @@ class StableDiffusionXLTilingPipeline(
245
  FromSingleFileMixin,
246
  StableDiffusionXLLoraLoaderMixin,
247
  TextualInversionLoaderMixin,
248
- IPAdapterMixin,
249
  ):
250
  r"""
251
  Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -258,7 +257,6 @@ class StableDiffusionXLTilingPipeline(
258
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
259
  - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
260
  - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
261
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
262
 
263
  Args:
264
  vae ([`AutoencoderKL`]):
@@ -298,10 +296,8 @@ class StableDiffusionXLTilingPipeline(
298
  "tokenizer_2",
299
  "text_encoder",
300
  "text_encoder_2",
301
- "image_encoder",
302
- "feature_extractor",
303
  ]
304
-
305
  def __init__(
306
  self,
307
  vae: AutoencoderKL,
@@ -311,8 +307,6 @@ class StableDiffusionXLTilingPipeline(
311
  tokenizer_2: CLIPTokenizer,
312
  unet: UNet2DConditionModel,
313
  scheduler: KarrasDiffusionSchedulers,
314
- image_encoder: CLIPVisionModelWithProjection = None,
315
- feature_extractor: CLIPImageProcessor = None,
316
  force_zeros_for_empty_prompt: bool = True,
317
  add_watermarker: Optional[bool] = None,
318
  ):
@@ -326,8 +320,6 @@ class StableDiffusionXLTilingPipeline(
326
  tokenizer_2=tokenizer_2,
327
  unet=unet,
328
  scheduler=scheduler,
329
- image_encoder=image_encoder,
330
- feature_extractor=feature_extractor,
331
  )
332
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
333
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
@@ -351,7 +343,7 @@ class StableDiffusionXLTilingPipeline(
351
 
352
  FULL = "full"
353
  EXCLUSIVE = "exclusive"
354
-
355
  def encode_prompt(
356
  self,
357
  prompt: str,
@@ -589,14 +581,14 @@ class StableDiffusionXLTilingPipeline(
589
  unscale_lora_layers(self.text_encoder_2, lora_scale)
590
 
591
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
592
-
593
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
594
  def prepare_extra_step_kwargs(self, generator, eta):
595
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
596
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
597
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
598
  # and should be between [0, 1]
599
-
600
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
601
  extra_step_kwargs = {}
602
  if accepts_eta:
@@ -608,15 +600,7 @@ class StableDiffusionXLTilingPipeline(
608
  extra_step_kwargs["generator"] = generator
609
  return extra_step_kwargs
610
 
611
- def check_inputs(
612
- self,
613
- prompt,
614
- height,
615
- width,
616
- grid_cols,
617
- seed_tiles_mode,
618
- tiles_mode
619
- ):
620
  if height % 8 != 0 or width % 8 != 0:
621
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
622
 
@@ -625,18 +609,18 @@ class StableDiffusionXLTilingPipeline(
625
 
626
  if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
627
  raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
628
-
629
  if not all(len(row) == grid_cols for row in prompt):
630
  raise ValueError("All prompt rows must have the same number of prompt columns")
631
-
632
  if not isinstance(seed_tiles_mode, str) and (
633
  not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
634
  ):
635
  raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
636
-
637
  if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
638
  raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
639
-
640
  def _get_add_time_ids(
641
  self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
642
  ):
@@ -678,8 +662,8 @@ class StableDiffusionXLTilingPipeline(
678
  weights_np = np.outer(y_probs, x_probs)
679
  weights_torch = torch.tensor(weights_np, device=device)
680
  weights_torch = weights_torch.to(dtype)
681
- return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
682
-
683
  def upcast_vae(self):
684
  dtype = self.vae.dtype
685
  self.vae.to(dtype=torch.float32)
@@ -760,25 +744,25 @@ class StableDiffusionXLTilingPipeline(
760
  @replace_example_docstring(EXAMPLE_DOC_STRING)
761
  def __call__(
762
  self,
763
- prompt: Union[str, List[str]] = None,
764
  height: Optional[int] = None,
765
  width: Optional[int] = None,
766
- num_inference_steps: int = 50,
767
  guidance_scale: float = 5.0,
768
- negative_prompt: Optional[Union[str, List[str]]] = None,
769
  num_images_per_prompt: Optional[int] = 1,
770
  eta: float = 0.0,
771
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
772
  output_type: Optional[str] = "pil",
773
  return_dict: bool = True,
774
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
775
  original_size: Optional[Tuple[int, int]] = None,
776
  crops_coords_top_left: Tuple[int, int] = (0, 0),
777
  target_size: Optional[Tuple[int, int]] = None,
778
  negative_original_size: Optional[Tuple[int, int]] = None,
779
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
780
  negative_target_size: Optional[Tuple[int, int]] = None,
781
- clip_skip: Optional[int] = None,
782
  tile_height: Optional[int] = 1024,
783
  tile_width: Optional[int] = 1024,
784
  tile_row_overlap: Optional[int] = 128,
@@ -786,7 +770,7 @@ class StableDiffusionXLTilingPipeline(
786
  guidance_scale_tiles: Optional[List[List[float]]] = None,
787
  seed_tiles: Optional[List[List[int]]] = None,
788
  seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
789
- seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
790
  **kwargs,
791
  ):
792
  r"""
@@ -795,7 +779,7 @@ class StableDiffusionXLTilingPipeline(
795
  Args:
796
  prompt (`str` or `List[str]`, *optional*):
797
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
798
- instead.
799
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
800
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
801
  Anything below 512 pixels won't work well for
@@ -808,7 +792,7 @@ class StableDiffusionXLTilingPipeline(
808
  and checkpoints that are not specifically fine-tuned on low resolutions.
809
  num_inference_steps (`int`, *optional*, defaults to 50):
810
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
811
- expense of slower inference.
812
  guidance_scale (`float`, *optional*, defaults to 5.0):
813
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
814
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -818,7 +802,7 @@ class StableDiffusionXLTilingPipeline(
818
  negative_prompt (`str` or `List[str]`, *optional*):
819
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
820
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
821
- less than `1`).
822
  num_images_per_prompt (`int`, *optional*, defaults to 1):
823
  The number of images to generate per prompt.
824
  eta (`float`, *optional*, defaults to 0.0):
@@ -826,7 +810,7 @@ class StableDiffusionXLTilingPipeline(
826
  [`schedulers.DDIMScheduler`], will be ignored for others.
827
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
828
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
829
- to make generation deterministic.
830
  output_type (`str`, *optional*, defaults to `"pil"`):
831
  The output format of the generate image. Choose between
832
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -836,7 +820,7 @@ class StableDiffusionXLTilingPipeline(
836
  cross_attention_kwargs (`dict`, *optional*):
837
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
838
  `self.processor` in
839
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
840
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
841
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
842
  `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -881,10 +865,10 @@ class StableDiffusionXLTilingPipeline(
881
  seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
882
  Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
883
  seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
884
- A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
885
  **kwargs (`Dict[str, Any]`, *optional*):
886
  Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
887
-
888
  Examples:
889
 
890
  Returns:
@@ -902,7 +886,7 @@ class StableDiffusionXLTilingPipeline(
902
 
903
  self._guidance_scale = guidance_scale
904
  self._clip_skip = clip_skip
905
- self._cross_attention_kwargs = cross_attention_kwargs
906
  self._interrupt = False
907
 
908
  grid_rows = len(prompt)
@@ -912,12 +896,12 @@ class StableDiffusionXLTilingPipeline(
912
 
913
  if isinstance(seed_tiles_mode, str):
914
  seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
915
-
916
  # 1. Check inputs. Raise error if not correct
917
  self.check_inputs(
918
- prompt,
919
  height,
920
- width,
921
  grid_cols,
922
  seed_tiles_mode,
923
  tiles_mode,
@@ -933,7 +917,7 @@ class StableDiffusionXLTilingPipeline(
933
  # update height and width tile size and tile overlap size
934
  height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
935
  width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
936
-
937
  # 3. Encode input prompt
938
  lora_scale = (
939
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
@@ -941,11 +925,11 @@ class StableDiffusionXLTilingPipeline(
941
  text_embeddings = [
942
  [
943
  self.encode_prompt(
944
- prompt=col,
945
  device=device,
946
  num_images_per_prompt=num_images_per_prompt,
947
  do_classifier_free_guidance=self.do_classifier_free_guidance,
948
- negative_prompt=negative_prompt,
949
  prompt_embeds=None,
950
  negative_prompt_embeds=None,
951
  pooled_prompt_embeds=None,
@@ -956,10 +940,10 @@ class StableDiffusionXLTilingPipeline(
956
  for col in row
957
  ]
958
  for row in prompt
959
- ]
960
 
961
  # 3. Prepare latents
962
- latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
963
  dtype = text_embeddings[0][0][0].dtype
964
  latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
965
 
@@ -1008,7 +992,7 @@ class StableDiffusionXLTilingPipeline(
1008
  extra_set_kwargs["offset"] = 1
1009
  timesteps, num_inference_steps = retrieve_timesteps(
1010
  self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
1011
- )
1012
 
1013
  # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
1014
  if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -1023,7 +1007,7 @@ class StableDiffusionXLTilingPipeline(
1023
  for row in range(grid_rows):
1024
  addition_embed_type_row = []
1025
  for col in range(grid_cols):
1026
- #extract generated values
1027
  prompt_embeds = text_embeddings[row][col][0]
1028
  negative_prompt_embeds = text_embeddings[row][col][1]
1029
  pooled_prompt_embeds = text_embeddings[row][col][2]
@@ -1051,7 +1035,7 @@ class StableDiffusionXLTilingPipeline(
1051
  )
1052
  else:
1053
  negative_add_time_ids = add_time_ids
1054
-
1055
  if self.do_classifier_free_guidance:
1056
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1057
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
@@ -1062,14 +1046,14 @@ class StableDiffusionXLTilingPipeline(
1062
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1063
  addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
1064
  embeddings_and_added_time.append(addition_embed_type_row)
1065
-
1066
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1067
 
1068
  # 7. Mask for tile weights strength
1069
  tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
1070
 
1071
  # 8. Denoising loop
1072
- self._num_timesteps = len(timesteps)
1073
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1074
  for i, t in enumerate(timesteps):
1075
  # Diffuse each tile
@@ -1084,16 +1068,21 @@ class StableDiffusionXLTilingPipeline(
1084
  )
1085
  tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
1086
  # expand the latents if we are doing classifier free guidance
1087
- latent_model_input = torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else latents
 
 
1088
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1089
 
1090
  # predict the noise residual
1091
- added_cond_kwargs = {"text_embeds": embeddings_and_added_time[row][col][1], "time_ids": embeddings_and_added_time[row][col][2]}
1092
- with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype!=self.unet.dtype):
 
 
 
1093
  noise_pred = self.unet(
1094
  latent_model_input,
1095
  t,
1096
- encoder_hidden_states=embeddings_and_added_time[row][col][0],
1097
  cross_attention_kwargs=self.cross_attention_kwargs,
1098
  added_cond_kwargs=added_cond_kwargs,
1099
  return_dict=False,
@@ -1110,7 +1099,7 @@ class StableDiffusionXLTilingPipeline(
1110
  noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
1111
  noise_preds_row.append(noise_pred_tile)
1112
  noise_preds.append(noise_preds_row)
1113
-
1114
  # Stitch noise predictions for all tiles
1115
  noise_pred = torch.zeros(latents.shape, device=device)
1116
  contributors = torch.zeros(latents.shape, device=device)
@@ -1140,7 +1129,7 @@ class StableDiffusionXLTilingPipeline(
1140
 
1141
  # update progress bar
1142
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1143
- progress_bar.update()
1144
 
1145
  if XLA_AVAILABLE:
1146
  xm.mark_step()
@@ -1173,7 +1162,7 @@ class StableDiffusionXLTilingPipeline(
1173
  latents = latents / self.vae.config.scaling_factor
1174
 
1175
  image = self.vae.decode(latents, return_dict=False)[0]
1176
-
1177
  # cast back to fp16 if needed
1178
  if needs_upcasting:
1179
  self.vae.to(dtype=torch.float16)
@@ -1184,7 +1173,7 @@ class StableDiffusionXLTilingPipeline(
1184
  # apply watermark if available
1185
  if self.watermark is not None:
1186
  image = self.watermark.apply_watermark(image)
1187
-
1188
  image = self.image_processor.postprocess(image, output_type=output_type)
1189
 
1190
  # Offload all models
@@ -1194,5 +1183,3 @@ class StableDiffusionXLTilingPipeline(
1194
  return (image,)
1195
 
1196
  return StableDiffusionXLPipelineOutput(images=image)
1197
-
1198
-
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import inspect
16
+ from enum import Enum
17
  from typing import Any, Dict, List, Optional, Tuple, Union
18
 
19
  import torch
20
  from transformers import (
 
21
  CLIPTextModel,
22
  CLIPTextModelWithProjection,
23
  CLIPTokenizer,
 
24
  )
25
 
26
  from diffusers.image_processor import VaeImageProcessor
27
  from diffusers.loaders import (
28
  FromSingleFileMixin,
 
29
  StableDiffusionXLLoraLoaderMixin,
30
  TextualInversionLoaderMixin,
31
  )
 
36
  XFormersAttnProcessor,
37
  )
38
  from diffusers.models.lora import adjust_lora_scale_text_encoder
39
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
41
  from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
42
  from diffusers.utils import (
43
  USE_PEFT_BACKEND,
 
49
  unscale_lora_layers,
50
  )
51
  from diffusers.utils.torch_utils import randn_tensor
52
+
 
53
 
54
  try:
55
+ from ligo.segments import segment
56
  except ImportError:
57
  raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
58
 
 
85
  ```
86
  """
87
 
88
+
89
  def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
90
  """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
91
 
 
150
  # return row_init, row_end, col_init, col_end
151
  return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
152
 
153
+
154
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
155
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
156
  r"""
 
210
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
211
  second element is the number of inference steps.
212
  """
213
+
214
  if timesteps is not None and sigmas is not None:
215
  raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
216
  if timesteps is not None:
 
245
  FromSingleFileMixin,
246
  StableDiffusionXLLoraLoaderMixin,
247
  TextualInversionLoaderMixin,
 
248
  ):
249
  r"""
250
  Pipeline for text-to-image generation using Stable Diffusion XL.
 
257
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
258
  - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
259
  - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
 
260
 
261
  Args:
262
  vae ([`AutoencoderKL`]):
 
296
  "tokenizer_2",
297
  "text_encoder",
298
  "text_encoder_2",
 
 
299
  ]
300
+
301
  def __init__(
302
  self,
303
  vae: AutoencoderKL,
 
307
  tokenizer_2: CLIPTokenizer,
308
  unet: UNet2DConditionModel,
309
  scheduler: KarrasDiffusionSchedulers,
 
 
310
  force_zeros_for_empty_prompt: bool = True,
311
  add_watermarker: Optional[bool] = None,
312
  ):
 
320
  tokenizer_2=tokenizer_2,
321
  unet=unet,
322
  scheduler=scheduler,
 
 
323
  )
324
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
325
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
 
343
 
344
  FULL = "full"
345
  EXCLUSIVE = "exclusive"
346
+
347
  def encode_prompt(
348
  self,
349
  prompt: str,
 
581
  unscale_lora_layers(self.text_encoder_2, lora_scale)
582
 
583
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
584
+
585
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
586
  def prepare_extra_step_kwargs(self, generator, eta):
587
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
588
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
589
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
590
  # and should be between [0, 1]
591
+
592
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
593
  extra_step_kwargs = {}
594
  if accepts_eta:
 
600
  extra_step_kwargs["generator"] = generator
601
  return extra_step_kwargs
602
 
603
+ def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):
 
 
 
 
 
 
 
 
604
  if height % 8 != 0 or width % 8 != 0:
605
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
606
 
 
609
 
610
  if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
611
  raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
612
+
613
  if not all(len(row) == grid_cols for row in prompt):
614
  raise ValueError("All prompt rows must have the same number of prompt columns")
615
+
616
  if not isinstance(seed_tiles_mode, str) and (
617
  not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
618
  ):
619
  raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
620
+
621
  if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
622
  raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
623
+
624
  def _get_add_time_ids(
625
  self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
626
  ):
 
662
  weights_np = np.outer(y_probs, x_probs)
663
  weights_torch = torch.tensor(weights_np, device=device)
664
  weights_torch = weights_torch.to(dtype)
665
+ return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
666
+
667
  def upcast_vae(self):
668
  dtype = self.vae.dtype
669
  self.vae.to(dtype=torch.float32)
 
744
  @replace_example_docstring(EXAMPLE_DOC_STRING)
745
  def __call__(
746
  self,
747
+ prompt: Union[str, List[str]] = None,
748
  height: Optional[int] = None,
749
  width: Optional[int] = None,
750
+ num_inference_steps: int = 50,
751
  guidance_scale: float = 5.0,
752
+ negative_prompt: Optional[Union[str, List[str]]] = None,
753
  num_images_per_prompt: Optional[int] = 1,
754
  eta: float = 0.0,
755
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
756
  output_type: Optional[str] = "pil",
757
  return_dict: bool = True,
758
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
759
  original_size: Optional[Tuple[int, int]] = None,
760
  crops_coords_top_left: Tuple[int, int] = (0, 0),
761
  target_size: Optional[Tuple[int, int]] = None,
762
  negative_original_size: Optional[Tuple[int, int]] = None,
763
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
764
  negative_target_size: Optional[Tuple[int, int]] = None,
765
+ clip_skip: Optional[int] = None,
766
  tile_height: Optional[int] = 1024,
767
  tile_width: Optional[int] = 1024,
768
  tile_row_overlap: Optional[int] = 128,
 
770
  guidance_scale_tiles: Optional[List[List[float]]] = None,
771
  seed_tiles: Optional[List[List[int]]] = None,
772
  seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
773
+ seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
774
  **kwargs,
775
  ):
776
  r"""
 
779
  Args:
780
  prompt (`str` or `List[str]`, *optional*):
781
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
782
+ instead.
783
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
784
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
785
  Anything below 512 pixels won't work well for
 
792
  and checkpoints that are not specifically fine-tuned on low resolutions.
793
  num_inference_steps (`int`, *optional*, defaults to 50):
794
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
795
+ expense of slower inference.
796
  guidance_scale (`float`, *optional*, defaults to 5.0):
797
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
798
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
 
802
  negative_prompt (`str` or `List[str]`, *optional*):
803
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
804
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
805
+ less than `1`).
806
  num_images_per_prompt (`int`, *optional*, defaults to 1):
807
  The number of images to generate per prompt.
808
  eta (`float`, *optional*, defaults to 0.0):
 
810
  [`schedulers.DDIMScheduler`], will be ignored for others.
811
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
812
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
813
+ to make generation deterministic.
814
  output_type (`str`, *optional*, defaults to `"pil"`):
815
  The output format of the generate image. Choose between
816
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
 
820
  cross_attention_kwargs (`dict`, *optional*):
821
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
822
  `self.processor` in
823
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
824
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
825
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
826
  `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
 
865
  seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
866
  Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
867
  seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
868
+ A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
869
  **kwargs (`Dict[str, Any]`, *optional*):
870
  Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
871
+
872
  Examples:
873
 
874
  Returns:
 
886
 
887
  self._guidance_scale = guidance_scale
888
  self._clip_skip = clip_skip
889
+ self._cross_attention_kwargs = cross_attention_kwargs
890
  self._interrupt = False
891
 
892
  grid_rows = len(prompt)
 
896
 
897
  if isinstance(seed_tiles_mode, str):
898
  seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
899
+
900
  # 1. Check inputs. Raise error if not correct
901
  self.check_inputs(
902
+ prompt,
903
  height,
904
+ width,
905
  grid_cols,
906
  seed_tiles_mode,
907
  tiles_mode,
 
917
  # update height and width tile size and tile overlap size
918
  height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
919
  width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
920
+
921
  # 3. Encode input prompt
922
  lora_scale = (
923
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
 
925
  text_embeddings = [
926
  [
927
  self.encode_prompt(
928
+ prompt=col,
929
  device=device,
930
  num_images_per_prompt=num_images_per_prompt,
931
  do_classifier_free_guidance=self.do_classifier_free_guidance,
932
+ negative_prompt=negative_prompt,
933
  prompt_embeds=None,
934
  negative_prompt_embeds=None,
935
  pooled_prompt_embeds=None,
 
940
  for col in row
941
  ]
942
  for row in prompt
943
+ ]
944
 
945
  # 3. Prepare latents
946
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
947
  dtype = text_embeddings[0][0][0].dtype
948
  latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
949
 
 
992
  extra_set_kwargs["offset"] = 1
993
  timesteps, num_inference_steps = retrieve_timesteps(
994
  self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
995
+ )
996
 
997
  # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
998
  if isinstance(self.scheduler, LMSDiscreteScheduler):
 
1007
  for row in range(grid_rows):
1008
  addition_embed_type_row = []
1009
  for col in range(grid_cols):
1010
+ # extract generated values
1011
  prompt_embeds = text_embeddings[row][col][0]
1012
  negative_prompt_embeds = text_embeddings[row][col][1]
1013
  pooled_prompt_embeds = text_embeddings[row][col][2]
 
1035
  )
1036
  else:
1037
  negative_add_time_ids = add_time_ids
1038
+
1039
  if self.do_classifier_free_guidance:
1040
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1041
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
 
1046
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1047
  addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
1048
  embeddings_and_added_time.append(addition_embed_type_row)
1049
+
1050
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1051
 
1052
  # 7. Mask for tile weights strength
1053
  tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
1054
 
1055
  # 8. Denoising loop
1056
+ self._num_timesteps = len(timesteps)
1057
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1058
  for i, t in enumerate(timesteps):
1059
  # Diffuse each tile
 
1068
  )
1069
  tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
1070
  # expand the latents if we are doing classifier free guidance
1071
+ latent_model_input = (
1072
+ torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
1073
+ )
1074
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1075
 
1076
  # predict the noise residual
1077
+ added_cond_kwargs = {
1078
+ "text_embeds": embeddings_and_added_time[row][col][1],
1079
+ "time_ids": embeddings_and_added_time[row][col][2],
1080
+ }
1081
+ with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
1082
  noise_pred = self.unet(
1083
  latent_model_input,
1084
  t,
1085
+ encoder_hidden_states=embeddings_and_added_time[row][col][0],
1086
  cross_attention_kwargs=self.cross_attention_kwargs,
1087
  added_cond_kwargs=added_cond_kwargs,
1088
  return_dict=False,
 
1099
  noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
1100
  noise_preds_row.append(noise_pred_tile)
1101
  noise_preds.append(noise_preds_row)
1102
+
1103
  # Stitch noise predictions for all tiles
1104
  noise_pred = torch.zeros(latents.shape, device=device)
1105
  contributors = torch.zeros(latents.shape, device=device)
 
1129
 
1130
  # update progress bar
1131
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1132
+ progress_bar.update()
1133
 
1134
  if XLA_AVAILABLE:
1135
  xm.mark_step()
 
1162
  latents = latents / self.vae.config.scaling_factor
1163
 
1164
  image = self.vae.decode(latents, return_dict=False)[0]
1165
+
1166
  # cast back to fp16 if needed
1167
  if needs_upcasting:
1168
  self.vae.to(dtype=torch.float16)
 
1173
  # apply watermark if available
1174
  if self.watermark is not None:
1175
  image = self.watermark.apply_watermark(image)
1176
+
1177
  image = self.image_processor.postprocess(image, output_type=output_type)
1178
 
1179
  # Offload all models
 
1183
  return (image,)
1184
 
1185
  return StableDiffusionXLPipelineOutput(images=image)