Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a5c3b38
1
Parent(s):
296d048
small fixes pipeline
Browse files- mixture_tiling_sdxl.py +60 -73
mixture_tiling_sdxl.py
CHANGED
@@ -12,23 +12,20 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
-
from enum import Enum
|
16 |
import inspect
|
|
|
17 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
|
19 |
import torch
|
20 |
from transformers import (
|
21 |
-
CLIPImageProcessor,
|
22 |
CLIPTextModel,
|
23 |
CLIPTextModelWithProjection,
|
24 |
CLIPTokenizer,
|
25 |
-
CLIPVisionModelWithProjection,
|
26 |
)
|
27 |
|
28 |
from diffusers.image_processor import VaeImageProcessor
|
29 |
from diffusers.loaders import (
|
30 |
FromSingleFileMixin,
|
31 |
-
IPAdapterMixin,
|
32 |
StableDiffusionXLLoraLoaderMixin,
|
33 |
TextualInversionLoaderMixin,
|
34 |
)
|
@@ -39,6 +36,8 @@ from diffusers.models.attention_processor import (
|
|
39 |
XFormersAttnProcessor,
|
40 |
)
|
41 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
|
|
|
42 |
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
43 |
from diffusers.utils import (
|
44 |
USE_PEFT_BACKEND,
|
@@ -50,11 +49,10 @@ from diffusers.utils import (
|
|
50 |
unscale_lora_layers,
|
51 |
)
|
52 |
from diffusers.utils.torch_utils import randn_tensor
|
53 |
-
|
54 |
-
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
55 |
|
56 |
try:
|
57 |
-
from ligo.segments import segment
|
58 |
except ImportError:
|
59 |
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
60 |
|
@@ -87,6 +85,7 @@ EXAMPLE_DOC_STRING = """
|
|
87 |
```
|
88 |
"""
|
89 |
|
|
|
90 |
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
91 |
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
|
92 |
|
@@ -151,6 +150,7 @@ def _tile2latent_exclusive_indices(
|
|
151 |
# return row_init, row_end, col_init, col_end
|
152 |
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
|
153 |
|
|
|
154 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
155 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
156 |
r"""
|
@@ -210,7 +210,7 @@ def retrieve_timesteps(
|
|
210 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
211 |
second element is the number of inference steps.
|
212 |
"""
|
213 |
-
|
214 |
if timesteps is not None and sigmas is not None:
|
215 |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
216 |
if timesteps is not None:
|
@@ -245,7 +245,6 @@ class StableDiffusionXLTilingPipeline(
|
|
245 |
FromSingleFileMixin,
|
246 |
StableDiffusionXLLoraLoaderMixin,
|
247 |
TextualInversionLoaderMixin,
|
248 |
-
IPAdapterMixin,
|
249 |
):
|
250 |
r"""
|
251 |
Pipeline for text-to-image generation using Stable Diffusion XL.
|
@@ -258,7 +257,6 @@ class StableDiffusionXLTilingPipeline(
|
|
258 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
259 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
260 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
261 |
-
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
262 |
|
263 |
Args:
|
264 |
vae ([`AutoencoderKL`]):
|
@@ -298,10 +296,8 @@ class StableDiffusionXLTilingPipeline(
|
|
298 |
"tokenizer_2",
|
299 |
"text_encoder",
|
300 |
"text_encoder_2",
|
301 |
-
"image_encoder",
|
302 |
-
"feature_extractor",
|
303 |
]
|
304 |
-
|
305 |
def __init__(
|
306 |
self,
|
307 |
vae: AutoencoderKL,
|
@@ -311,8 +307,6 @@ class StableDiffusionXLTilingPipeline(
|
|
311 |
tokenizer_2: CLIPTokenizer,
|
312 |
unet: UNet2DConditionModel,
|
313 |
scheduler: KarrasDiffusionSchedulers,
|
314 |
-
image_encoder: CLIPVisionModelWithProjection = None,
|
315 |
-
feature_extractor: CLIPImageProcessor = None,
|
316 |
force_zeros_for_empty_prompt: bool = True,
|
317 |
add_watermarker: Optional[bool] = None,
|
318 |
):
|
@@ -326,8 +320,6 @@ class StableDiffusionXLTilingPipeline(
|
|
326 |
tokenizer_2=tokenizer_2,
|
327 |
unet=unet,
|
328 |
scheduler=scheduler,
|
329 |
-
image_encoder=image_encoder,
|
330 |
-
feature_extractor=feature_extractor,
|
331 |
)
|
332 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
333 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
@@ -351,7 +343,7 @@ class StableDiffusionXLTilingPipeline(
|
|
351 |
|
352 |
FULL = "full"
|
353 |
EXCLUSIVE = "exclusive"
|
354 |
-
|
355 |
def encode_prompt(
|
356 |
self,
|
357 |
prompt: str,
|
@@ -589,14 +581,14 @@ class StableDiffusionXLTilingPipeline(
|
|
589 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
590 |
|
591 |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
592 |
-
|
593 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
594 |
def prepare_extra_step_kwargs(self, generator, eta):
|
595 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
596 |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
597 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
598 |
# and should be between [0, 1]
|
599 |
-
|
600 |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
601 |
extra_step_kwargs = {}
|
602 |
if accepts_eta:
|
@@ -608,15 +600,7 @@ class StableDiffusionXLTilingPipeline(
|
|
608 |
extra_step_kwargs["generator"] = generator
|
609 |
return extra_step_kwargs
|
610 |
|
611 |
-
def check_inputs(
|
612 |
-
self,
|
613 |
-
prompt,
|
614 |
-
height,
|
615 |
-
width,
|
616 |
-
grid_cols,
|
617 |
-
seed_tiles_mode,
|
618 |
-
tiles_mode
|
619 |
-
):
|
620 |
if height % 8 != 0 or width % 8 != 0:
|
621 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
622 |
|
@@ -625,18 +609,18 @@ class StableDiffusionXLTilingPipeline(
|
|
625 |
|
626 |
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
|
627 |
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
|
628 |
-
|
629 |
if not all(len(row) == grid_cols for row in prompt):
|
630 |
raise ValueError("All prompt rows must have the same number of prompt columns")
|
631 |
-
|
632 |
if not isinstance(seed_tiles_mode, str) and (
|
633 |
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
|
634 |
):
|
635 |
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
|
636 |
-
|
637 |
if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
|
638 |
raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
|
639 |
-
|
640 |
def _get_add_time_ids(
|
641 |
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
642 |
):
|
@@ -678,8 +662,8 @@ class StableDiffusionXLTilingPipeline(
|
|
678 |
weights_np = np.outer(y_probs, x_probs)
|
679 |
weights_torch = torch.tensor(weights_np, device=device)
|
680 |
weights_torch = weights_torch.to(dtype)
|
681 |
-
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
682 |
-
|
683 |
def upcast_vae(self):
|
684 |
dtype = self.vae.dtype
|
685 |
self.vae.to(dtype=torch.float32)
|
@@ -760,25 +744,25 @@ class StableDiffusionXLTilingPipeline(
|
|
760 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
761 |
def __call__(
|
762 |
self,
|
763 |
-
prompt: Union[str, List[str]] = None,
|
764 |
height: Optional[int] = None,
|
765 |
width: Optional[int] = None,
|
766 |
-
num_inference_steps: int = 50,
|
767 |
guidance_scale: float = 5.0,
|
768 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
769 |
num_images_per_prompt: Optional[int] = 1,
|
770 |
eta: float = 0.0,
|
771 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
772 |
output_type: Optional[str] = "pil",
|
773 |
return_dict: bool = True,
|
774 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
775 |
original_size: Optional[Tuple[int, int]] = None,
|
776 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
777 |
target_size: Optional[Tuple[int, int]] = None,
|
778 |
negative_original_size: Optional[Tuple[int, int]] = None,
|
779 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
780 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
781 |
-
clip_skip: Optional[int] = None,
|
782 |
tile_height: Optional[int] = 1024,
|
783 |
tile_width: Optional[int] = 1024,
|
784 |
tile_row_overlap: Optional[int] = 128,
|
@@ -786,7 +770,7 @@ class StableDiffusionXLTilingPipeline(
|
|
786 |
guidance_scale_tiles: Optional[List[List[float]]] = None,
|
787 |
seed_tiles: Optional[List[List[int]]] = None,
|
788 |
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
|
789 |
-
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
|
790 |
**kwargs,
|
791 |
):
|
792 |
r"""
|
@@ -795,7 +779,7 @@ class StableDiffusionXLTilingPipeline(
|
|
795 |
Args:
|
796 |
prompt (`str` or `List[str]`, *optional*):
|
797 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
798 |
-
instead.
|
799 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
800 |
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
801 |
Anything below 512 pixels won't work well for
|
@@ -808,7 +792,7 @@ class StableDiffusionXLTilingPipeline(
|
|
808 |
and checkpoints that are not specifically fine-tuned on low resolutions.
|
809 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
810 |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
811 |
-
expense of slower inference.
|
812 |
guidance_scale (`float`, *optional*, defaults to 5.0):
|
813 |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
814 |
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
@@ -818,7 +802,7 @@ class StableDiffusionXLTilingPipeline(
|
|
818 |
negative_prompt (`str` or `List[str]`, *optional*):
|
819 |
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
820 |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
821 |
-
less than `1`).
|
822 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
823 |
The number of images to generate per prompt.
|
824 |
eta (`float`, *optional*, defaults to 0.0):
|
@@ -826,7 +810,7 @@ class StableDiffusionXLTilingPipeline(
|
|
826 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
827 |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
828 |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
829 |
-
to make generation deterministic.
|
830 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
831 |
The output format of the generate image. Choose between
|
832 |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
@@ -836,7 +820,7 @@ class StableDiffusionXLTilingPipeline(
|
|
836 |
cross_attention_kwargs (`dict`, *optional*):
|
837 |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
838 |
`self.processor` in
|
839 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
840 |
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
841 |
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
842 |
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
@@ -881,10 +865,10 @@ class StableDiffusionXLTilingPipeline(
|
|
881 |
seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
|
882 |
Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
|
883 |
seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
|
884 |
-
A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
|
885 |
**kwargs (`Dict[str, Any]`, *optional*):
|
886 |
Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
|
887 |
-
|
888 |
Examples:
|
889 |
|
890 |
Returns:
|
@@ -902,7 +886,7 @@ class StableDiffusionXLTilingPipeline(
|
|
902 |
|
903 |
self._guidance_scale = guidance_scale
|
904 |
self._clip_skip = clip_skip
|
905 |
-
self._cross_attention_kwargs = cross_attention_kwargs
|
906 |
self._interrupt = False
|
907 |
|
908 |
grid_rows = len(prompt)
|
@@ -912,12 +896,12 @@ class StableDiffusionXLTilingPipeline(
|
|
912 |
|
913 |
if isinstance(seed_tiles_mode, str):
|
914 |
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
|
915 |
-
|
916 |
# 1. Check inputs. Raise error if not correct
|
917 |
self.check_inputs(
|
918 |
-
prompt,
|
919 |
height,
|
920 |
-
width,
|
921 |
grid_cols,
|
922 |
seed_tiles_mode,
|
923 |
tiles_mode,
|
@@ -933,7 +917,7 @@ class StableDiffusionXLTilingPipeline(
|
|
933 |
# update height and width tile size and tile overlap size
|
934 |
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
|
935 |
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
|
936 |
-
|
937 |
# 3. Encode input prompt
|
938 |
lora_scale = (
|
939 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
@@ -941,11 +925,11 @@ class StableDiffusionXLTilingPipeline(
|
|
941 |
text_embeddings = [
|
942 |
[
|
943 |
self.encode_prompt(
|
944 |
-
prompt=col,
|
945 |
device=device,
|
946 |
num_images_per_prompt=num_images_per_prompt,
|
947 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
948 |
-
negative_prompt=negative_prompt,
|
949 |
prompt_embeds=None,
|
950 |
negative_prompt_embeds=None,
|
951 |
pooled_prompt_embeds=None,
|
@@ -956,10 +940,10 @@ class StableDiffusionXLTilingPipeline(
|
|
956 |
for col in row
|
957 |
]
|
958 |
for row in prompt
|
959 |
-
]
|
960 |
|
961 |
# 3. Prepare latents
|
962 |
-
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
963 |
dtype = text_embeddings[0][0][0].dtype
|
964 |
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
965 |
|
@@ -1008,7 +992,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1008 |
extra_set_kwargs["offset"] = 1
|
1009 |
timesteps, num_inference_steps = retrieve_timesteps(
|
1010 |
self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
|
1011 |
-
)
|
1012 |
|
1013 |
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
1014 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
@@ -1023,7 +1007,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1023 |
for row in range(grid_rows):
|
1024 |
addition_embed_type_row = []
|
1025 |
for col in range(grid_cols):
|
1026 |
-
#extract generated values
|
1027 |
prompt_embeds = text_embeddings[row][col][0]
|
1028 |
negative_prompt_embeds = text_embeddings[row][col][1]
|
1029 |
pooled_prompt_embeds = text_embeddings[row][col][2]
|
@@ -1051,7 +1035,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1051 |
)
|
1052 |
else:
|
1053 |
negative_add_time_ids = add_time_ids
|
1054 |
-
|
1055 |
if self.do_classifier_free_guidance:
|
1056 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1057 |
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
@@ -1062,14 +1046,14 @@ class StableDiffusionXLTilingPipeline(
|
|
1062 |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1063 |
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
1064 |
embeddings_and_added_time.append(addition_embed_type_row)
|
1065 |
-
|
1066 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1067 |
|
1068 |
# 7. Mask for tile weights strength
|
1069 |
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
|
1070 |
|
1071 |
# 8. Denoising loop
|
1072 |
-
self._num_timesteps = len(timesteps)
|
1073 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1074 |
for i, t in enumerate(timesteps):
|
1075 |
# Diffuse each tile
|
@@ -1084,16 +1068,21 @@ class StableDiffusionXLTilingPipeline(
|
|
1084 |
)
|
1085 |
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
|
1086 |
# expand the latents if we are doing classifier free guidance
|
1087 |
-
latent_model_input =
|
|
|
|
|
1088 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1089 |
|
1090 |
# predict the noise residual
|
1091 |
-
added_cond_kwargs = {
|
1092 |
-
|
|
|
|
|
|
|
1093 |
noise_pred = self.unet(
|
1094 |
latent_model_input,
|
1095 |
t,
|
1096 |
-
encoder_hidden_states=embeddings_and_added_time[row][col][0],
|
1097 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
1098 |
added_cond_kwargs=added_cond_kwargs,
|
1099 |
return_dict=False,
|
@@ -1110,7 +1099,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1110 |
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
1111 |
noise_preds_row.append(noise_pred_tile)
|
1112 |
noise_preds.append(noise_preds_row)
|
1113 |
-
|
1114 |
# Stitch noise predictions for all tiles
|
1115 |
noise_pred = torch.zeros(latents.shape, device=device)
|
1116 |
contributors = torch.zeros(latents.shape, device=device)
|
@@ -1140,7 +1129,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1140 |
|
1141 |
# update progress bar
|
1142 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1143 |
-
progress_bar.update()
|
1144 |
|
1145 |
if XLA_AVAILABLE:
|
1146 |
xm.mark_step()
|
@@ -1173,7 +1162,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1173 |
latents = latents / self.vae.config.scaling_factor
|
1174 |
|
1175 |
image = self.vae.decode(latents, return_dict=False)[0]
|
1176 |
-
|
1177 |
# cast back to fp16 if needed
|
1178 |
if needs_upcasting:
|
1179 |
self.vae.to(dtype=torch.float16)
|
@@ -1184,7 +1173,7 @@ class StableDiffusionXLTilingPipeline(
|
|
1184 |
# apply watermark if available
|
1185 |
if self.watermark is not None:
|
1186 |
image = self.watermark.apply_watermark(image)
|
1187 |
-
|
1188 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1189 |
|
1190 |
# Offload all models
|
@@ -1194,5 +1183,3 @@ class StableDiffusionXLTilingPipeline(
|
|
1194 |
return (image,)
|
1195 |
|
1196 |
return StableDiffusionXLPipelineOutput(images=image)
|
1197 |
-
|
1198 |
-
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
|
|
15 |
import inspect
|
16 |
+
from enum import Enum
|
17 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
|
19 |
import torch
|
20 |
from transformers import (
|
|
|
21 |
CLIPTextModel,
|
22 |
CLIPTextModelWithProjection,
|
23 |
CLIPTokenizer,
|
|
|
24 |
)
|
25 |
|
26 |
from diffusers.image_processor import VaeImageProcessor
|
27 |
from diffusers.loaders import (
|
28 |
FromSingleFileMixin,
|
|
|
29 |
StableDiffusionXLLoraLoaderMixin,
|
30 |
TextualInversionLoaderMixin,
|
31 |
)
|
|
|
36 |
XFormersAttnProcessor,
|
37 |
)
|
38 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
40 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
41 |
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
42 |
from diffusers.utils import (
|
43 |
USE_PEFT_BACKEND,
|
|
|
49 |
unscale_lora_layers,
|
50 |
)
|
51 |
from diffusers.utils.torch_utils import randn_tensor
|
52 |
+
|
|
|
53 |
|
54 |
try:
|
55 |
+
from ligo.segments import segment
|
56 |
except ImportError:
|
57 |
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
58 |
|
|
|
85 |
```
|
86 |
"""
|
87 |
|
88 |
+
|
89 |
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
90 |
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
|
91 |
|
|
|
150 |
# return row_init, row_end, col_init, col_end
|
151 |
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
|
152 |
|
153 |
+
|
154 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
155 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
156 |
r"""
|
|
|
210 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
211 |
second element is the number of inference steps.
|
212 |
"""
|
213 |
+
|
214 |
if timesteps is not None and sigmas is not None:
|
215 |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
216 |
if timesteps is not None:
|
|
|
245 |
FromSingleFileMixin,
|
246 |
StableDiffusionXLLoraLoaderMixin,
|
247 |
TextualInversionLoaderMixin,
|
|
|
248 |
):
|
249 |
r"""
|
250 |
Pipeline for text-to-image generation using Stable Diffusion XL.
|
|
|
257 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
258 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
259 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
|
|
260 |
|
261 |
Args:
|
262 |
vae ([`AutoencoderKL`]):
|
|
|
296 |
"tokenizer_2",
|
297 |
"text_encoder",
|
298 |
"text_encoder_2",
|
|
|
|
|
299 |
]
|
300 |
+
|
301 |
def __init__(
|
302 |
self,
|
303 |
vae: AutoencoderKL,
|
|
|
307 |
tokenizer_2: CLIPTokenizer,
|
308 |
unet: UNet2DConditionModel,
|
309 |
scheduler: KarrasDiffusionSchedulers,
|
|
|
|
|
310 |
force_zeros_for_empty_prompt: bool = True,
|
311 |
add_watermarker: Optional[bool] = None,
|
312 |
):
|
|
|
320 |
tokenizer_2=tokenizer_2,
|
321 |
unet=unet,
|
322 |
scheduler=scheduler,
|
|
|
|
|
323 |
)
|
324 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
325 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
|
343 |
|
344 |
FULL = "full"
|
345 |
EXCLUSIVE = "exclusive"
|
346 |
+
|
347 |
def encode_prompt(
|
348 |
self,
|
349 |
prompt: str,
|
|
|
581 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
582 |
|
583 |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
584 |
+
|
585 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
586 |
def prepare_extra_step_kwargs(self, generator, eta):
|
587 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
588 |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
589 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
590 |
# and should be between [0, 1]
|
591 |
+
|
592 |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
593 |
extra_step_kwargs = {}
|
594 |
if accepts_eta:
|
|
|
600 |
extra_step_kwargs["generator"] = generator
|
601 |
return extra_step_kwargs
|
602 |
|
603 |
+
def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
if height % 8 != 0 or width % 8 != 0:
|
605 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
606 |
|
|
|
609 |
|
610 |
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
|
611 |
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
|
612 |
+
|
613 |
if not all(len(row) == grid_cols for row in prompt):
|
614 |
raise ValueError("All prompt rows must have the same number of prompt columns")
|
615 |
+
|
616 |
if not isinstance(seed_tiles_mode, str) and (
|
617 |
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
|
618 |
):
|
619 |
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
|
620 |
+
|
621 |
if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
|
622 |
raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
|
623 |
+
|
624 |
def _get_add_time_ids(
|
625 |
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
626 |
):
|
|
|
662 |
weights_np = np.outer(y_probs, x_probs)
|
663 |
weights_torch = torch.tensor(weights_np, device=device)
|
664 |
weights_torch = weights_torch.to(dtype)
|
665 |
+
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
666 |
+
|
667 |
def upcast_vae(self):
|
668 |
dtype = self.vae.dtype
|
669 |
self.vae.to(dtype=torch.float32)
|
|
|
744 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
745 |
def __call__(
|
746 |
self,
|
747 |
+
prompt: Union[str, List[str]] = None,
|
748 |
height: Optional[int] = None,
|
749 |
width: Optional[int] = None,
|
750 |
+
num_inference_steps: int = 50,
|
751 |
guidance_scale: float = 5.0,
|
752 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
753 |
num_images_per_prompt: Optional[int] = 1,
|
754 |
eta: float = 0.0,
|
755 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
756 |
output_type: Optional[str] = "pil",
|
757 |
return_dict: bool = True,
|
758 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
759 |
original_size: Optional[Tuple[int, int]] = None,
|
760 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
761 |
target_size: Optional[Tuple[int, int]] = None,
|
762 |
negative_original_size: Optional[Tuple[int, int]] = None,
|
763 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
764 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
765 |
+
clip_skip: Optional[int] = None,
|
766 |
tile_height: Optional[int] = 1024,
|
767 |
tile_width: Optional[int] = 1024,
|
768 |
tile_row_overlap: Optional[int] = 128,
|
|
|
770 |
guidance_scale_tiles: Optional[List[List[float]]] = None,
|
771 |
seed_tiles: Optional[List[List[int]]] = None,
|
772 |
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
|
773 |
+
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
|
774 |
**kwargs,
|
775 |
):
|
776 |
r"""
|
|
|
779 |
Args:
|
780 |
prompt (`str` or `List[str]`, *optional*):
|
781 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
782 |
+
instead.
|
783 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
784 |
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
785 |
Anything below 512 pixels won't work well for
|
|
|
792 |
and checkpoints that are not specifically fine-tuned on low resolutions.
|
793 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
794 |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
795 |
+
expense of slower inference.
|
796 |
guidance_scale (`float`, *optional*, defaults to 5.0):
|
797 |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
798 |
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
|
802 |
negative_prompt (`str` or `List[str]`, *optional*):
|
803 |
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
804 |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
805 |
+
less than `1`).
|
806 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
807 |
The number of images to generate per prompt.
|
808 |
eta (`float`, *optional*, defaults to 0.0):
|
|
|
810 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
811 |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
812 |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
813 |
+
to make generation deterministic.
|
814 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
815 |
The output format of the generate image. Choose between
|
816 |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
|
820 |
cross_attention_kwargs (`dict`, *optional*):
|
821 |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
822 |
`self.processor` in
|
823 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
824 |
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
825 |
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
826 |
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
|
|
865 |
seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
|
866 |
Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
|
867 |
seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
|
868 |
+
A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
|
869 |
**kwargs (`Dict[str, Any]`, *optional*):
|
870 |
Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
|
871 |
+
|
872 |
Examples:
|
873 |
|
874 |
Returns:
|
|
|
886 |
|
887 |
self._guidance_scale = guidance_scale
|
888 |
self._clip_skip = clip_skip
|
889 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
890 |
self._interrupt = False
|
891 |
|
892 |
grid_rows = len(prompt)
|
|
|
896 |
|
897 |
if isinstance(seed_tiles_mode, str):
|
898 |
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
|
899 |
+
|
900 |
# 1. Check inputs. Raise error if not correct
|
901 |
self.check_inputs(
|
902 |
+
prompt,
|
903 |
height,
|
904 |
+
width,
|
905 |
grid_cols,
|
906 |
seed_tiles_mode,
|
907 |
tiles_mode,
|
|
|
917 |
# update height and width tile size and tile overlap size
|
918 |
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
|
919 |
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
|
920 |
+
|
921 |
# 3. Encode input prompt
|
922 |
lora_scale = (
|
923 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
|
925 |
text_embeddings = [
|
926 |
[
|
927 |
self.encode_prompt(
|
928 |
+
prompt=col,
|
929 |
device=device,
|
930 |
num_images_per_prompt=num_images_per_prompt,
|
931 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
932 |
+
negative_prompt=negative_prompt,
|
933 |
prompt_embeds=None,
|
934 |
negative_prompt_embeds=None,
|
935 |
pooled_prompt_embeds=None,
|
|
|
940 |
for col in row
|
941 |
]
|
942 |
for row in prompt
|
943 |
+
]
|
944 |
|
945 |
# 3. Prepare latents
|
946 |
+
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
947 |
dtype = text_embeddings[0][0][0].dtype
|
948 |
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
949 |
|
|
|
992 |
extra_set_kwargs["offset"] = 1
|
993 |
timesteps, num_inference_steps = retrieve_timesteps(
|
994 |
self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
|
995 |
+
)
|
996 |
|
997 |
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
998 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
|
1007 |
for row in range(grid_rows):
|
1008 |
addition_embed_type_row = []
|
1009 |
for col in range(grid_cols):
|
1010 |
+
# extract generated values
|
1011 |
prompt_embeds = text_embeddings[row][col][0]
|
1012 |
negative_prompt_embeds = text_embeddings[row][col][1]
|
1013 |
pooled_prompt_embeds = text_embeddings[row][col][2]
|
|
|
1035 |
)
|
1036 |
else:
|
1037 |
negative_add_time_ids = add_time_ids
|
1038 |
+
|
1039 |
if self.do_classifier_free_guidance:
|
1040 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1041 |
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
|
|
1046 |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1047 |
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
1048 |
embeddings_and_added_time.append(addition_embed_type_row)
|
1049 |
+
|
1050 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1051 |
|
1052 |
# 7. Mask for tile weights strength
|
1053 |
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
|
1054 |
|
1055 |
# 8. Denoising loop
|
1056 |
+
self._num_timesteps = len(timesteps)
|
1057 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1058 |
for i, t in enumerate(timesteps):
|
1059 |
# Diffuse each tile
|
|
|
1068 |
)
|
1069 |
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
|
1070 |
# expand the latents if we are doing classifier free guidance
|
1071 |
+
latent_model_input = (
|
1072 |
+
torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
|
1073 |
+
)
|
1074 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1075 |
|
1076 |
# predict the noise residual
|
1077 |
+
added_cond_kwargs = {
|
1078 |
+
"text_embeds": embeddings_and_added_time[row][col][1],
|
1079 |
+
"time_ids": embeddings_and_added_time[row][col][2],
|
1080 |
+
}
|
1081 |
+
with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
|
1082 |
noise_pred = self.unet(
|
1083 |
latent_model_input,
|
1084 |
t,
|
1085 |
+
encoder_hidden_states=embeddings_and_added_time[row][col][0],
|
1086 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
1087 |
added_cond_kwargs=added_cond_kwargs,
|
1088 |
return_dict=False,
|
|
|
1099 |
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
1100 |
noise_preds_row.append(noise_pred_tile)
|
1101 |
noise_preds.append(noise_preds_row)
|
1102 |
+
|
1103 |
# Stitch noise predictions for all tiles
|
1104 |
noise_pred = torch.zeros(latents.shape, device=device)
|
1105 |
contributors = torch.zeros(latents.shape, device=device)
|
|
|
1129 |
|
1130 |
# update progress bar
|
1131 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1132 |
+
progress_bar.update()
|
1133 |
|
1134 |
if XLA_AVAILABLE:
|
1135 |
xm.mark_step()
|
|
|
1162 |
latents = latents / self.vae.config.scaling_factor
|
1163 |
|
1164 |
image = self.vae.decode(latents, return_dict=False)[0]
|
1165 |
+
|
1166 |
# cast back to fp16 if needed
|
1167 |
if needs_upcasting:
|
1168 |
self.vae.to(dtype=torch.float16)
|
|
|
1173 |
# apply watermark if available
|
1174 |
if self.watermark is not None:
|
1175 |
image = self.watermark.apply_watermark(image)
|
1176 |
+
|
1177 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1178 |
|
1179 |
# Offload all models
|
|
|
1183 |
return (image,)
|
1184 |
|
1185 |
return StableDiffusionXLPipelineOutput(images=image)
|
|
|
|