|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from transformers import (
|
|
CLIPTextModel,
|
|
CLIPTextModelWithProjection,
|
|
CLIPTokenizer,
|
|
)
|
|
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
from diffusers.loaders import (
|
|
FromSingleFileMixin,
|
|
StableDiffusionXLLoraLoaderMixin,
|
|
TextualInversionLoaderMixin,
|
|
)
|
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
from diffusers.models.attention_processor import (
|
|
AttnProcessor2_0,
|
|
FusedAttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
)
|
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
|
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
|
from diffusers.utils import (
|
|
USE_PEFT_BACKEND,
|
|
is_invisible_watermark_available,
|
|
is_torch_xla_available,
|
|
logging,
|
|
replace_example_docstring,
|
|
scale_lora_layers,
|
|
unscale_lora_layers,
|
|
)
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
|
try:
|
|
from ligo.segments import segment
|
|
except ImportError:
|
|
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
|
|
|
if is_invisible_watermark_available():
|
|
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
EXAMPLE_DOC_STRING = """
|
|
Examples:
|
|
```py
|
|
>>> import torch
|
|
>>> from diffusers import StableDiffusionXLPipeline
|
|
|
|
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
|
... )
|
|
>>> pipe = pipe.to("cuda")
|
|
|
|
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
|
>>> image = pipe(prompt).images[0]
|
|
```
|
|
"""
|
|
|
|
|
|
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
|
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
|
|
|
|
Returns a tuple with:
|
|
- Starting coordinates of rows in pixel space
|
|
- Ending coordinates of rows in pixel space
|
|
- Starting coordinates of columns in pixel space
|
|
- Ending coordinates of columns in pixel space
|
|
"""
|
|
px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
|
|
px_row_end = px_row_init + tile_height
|
|
px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
|
|
px_col_end = px_col_init + tile_width
|
|
return px_row_init, px_row_end, px_col_init, px_col_end
|
|
|
|
|
|
def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
|
|
"""Translates coordinates in pixel space to coordinates in latent space"""
|
|
return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
|
|
|
|
|
|
def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
|
"""Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
|
|
|
|
Returns a tuple with:
|
|
- Starting coordinates of rows in latent space
|
|
- Ending coordinates of rows in latent space
|
|
- Starting coordinates of columns in latent space
|
|
- Ending coordinates of columns in latent space
|
|
"""
|
|
px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
|
|
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
|
|
|
|
|
|
def _tile2latent_exclusive_indices(
|
|
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns
|
|
):
|
|
"""Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
|
|
|
|
Returns a tuple with:
|
|
- Starting coordinates of rows in latent space
|
|
- Ending coordinates of rows in latent space
|
|
- Starting coordinates of columns in latent space
|
|
- Ending coordinates of columns in latent space
|
|
"""
|
|
row_init, row_end, col_init, col_end = _tile2latent_indices(
|
|
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
row_segment = segment(row_init, row_end)
|
|
col_segment = segment(col_init, col_end)
|
|
|
|
for row in range(rows):
|
|
for column in range(columns):
|
|
if row != tile_row and column != tile_col:
|
|
clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(
|
|
row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
row_segment = row_segment - segment(clip_row_init, clip_row_end)
|
|
col_segment = col_segment - segment(clip_col_init, clip_col_end)
|
|
|
|
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
|
|
|
|
|
|
def _get_crops_coords_list(num_rows, num_cols, output_width):
|
|
"""
|
|
Generates a list of lists of `crops_coords_top_left` tuples for focusing on
|
|
different horizontal parts of an image, and repeats this list for the specified
|
|
number of rows in the output structure.
|
|
|
|
This function calculates `crops_coords_top_left` tuples to create horizontal
|
|
focus variations (like left, center, right focus) based on `output_width`
|
|
and `num_cols` (which represents the number of horizontal focus points/columns).
|
|
It then repeats the *list* of these horizontal focus tuples `num_rows` times to
|
|
create the final list of lists output structure.
|
|
|
|
Args:
|
|
num_rows (int): The desired number of rows in the output list of lists.
|
|
This determines how many times the list of horizontal
|
|
focus variations will be repeated.
|
|
num_cols (int): The number of horizontal focus points (columns) to generate.
|
|
This determines how many horizontal focus variations are
|
|
created based on dividing the `output_width`.
|
|
output_width (int): The desired width of the output image.
|
|
|
|
Returns:
|
|
list[list[tuple[int, int]]]: A list of lists of tuples. Each inner list
|
|
contains `num_cols` tuples of `(ctop, cleft)`,
|
|
representing horizontal focus points. The outer list
|
|
contains `num_rows` such inner lists.
|
|
"""
|
|
crops_coords_list = []
|
|
if num_cols <= 0:
|
|
crops_coords_list = []
|
|
elif num_cols == 1:
|
|
crops_coords_list = [(0, 0)]
|
|
else:
|
|
section_width = output_width / num_cols
|
|
for i in range(num_cols):
|
|
cleft = int(round(i * section_width))
|
|
crops_coords_list.append((0, cleft))
|
|
|
|
result_list = []
|
|
for _ in range(num_rows):
|
|
result_list.append(list(crops_coords_list))
|
|
|
|
return result_list
|
|
|
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
r"""
|
|
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
|
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
|
|
|
Args:
|
|
noise_cfg (`torch.Tensor`):
|
|
The predicted noise tensor for the guided diffusion process.
|
|
noise_pred_text (`torch.Tensor`):
|
|
The predicted noise tensor for the text-guided diffusion process.
|
|
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
|
A rescale factor applied to the noise predictions.
|
|
|
|
Returns:
|
|
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
|
"""
|
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
|
return noise_cfg
|
|
|
|
|
|
|
|
def retrieve_timesteps(
|
|
scheduler,
|
|
num_inference_steps: Optional[int] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
timesteps: Optional[List[int]] = None,
|
|
sigmas: Optional[List[float]] = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
|
Args:
|
|
scheduler (`SchedulerMixin`):
|
|
The scheduler to get timesteps from.
|
|
num_inference_steps (`int`):
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
|
must be `None`.
|
|
device (`str` or `torch.device`, *optional*):
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
timesteps (`List[int]`, *optional*):
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
|
`num_inference_steps` and `sigmas` must be `None`.
|
|
sigmas (`List[float]`, *optional*):
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
|
`num_inference_steps` and `timesteps` must be `None`.
|
|
|
|
Returns:
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
|
second element is the number of inference steps.
|
|
"""
|
|
|
|
if timesteps is not None and sigmas is not None:
|
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
|
if timesteps is not None:
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accepts_timesteps:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
elif sigmas is not None:
|
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accept_sigmas:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
else:
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
return timesteps, num_inference_steps
|
|
|
|
|
|
class StableDiffusionXLTilingPipeline(
|
|
DiffusionPipeline,
|
|
StableDiffusionMixin,
|
|
FromSingleFileMixin,
|
|
StableDiffusionXLLoraLoaderMixin,
|
|
TextualInversionLoaderMixin,
|
|
):
|
|
r"""
|
|
Pipeline for text-to-image generation using Stable Diffusion XL.
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
|
The pipeline also inherits the following loading methods:
|
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
|
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
|
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
|
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
|
|
|
Args:
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
text_encoder ([`CLIPTextModel`]):
|
|
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
|
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
|
specifically the
|
|
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
|
variant.
|
|
tokenizer (`CLIPTokenizer`):
|
|
Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
tokenizer_2 (`CLIPTokenizer`):
|
|
Second Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
scheduler ([`SchedulerMixin`]):
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
|
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
|
`stabilityai/stable-diffusion-xl-base-1-0`.
|
|
add_watermarker (`bool`, *optional*):
|
|
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
|
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
|
watermarker will be used.
|
|
"""
|
|
|
|
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
|
_optional_components = [
|
|
"tokenizer",
|
|
"tokenizer_2",
|
|
"text_encoder",
|
|
"text_encoder_2",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
vae: AutoencoderKL,
|
|
text_encoder: CLIPTextModel,
|
|
text_encoder_2: CLIPTextModelWithProjection,
|
|
tokenizer: CLIPTokenizer,
|
|
tokenizer_2: CLIPTokenizer,
|
|
unet: UNet2DConditionModel,
|
|
scheduler: KarrasDiffusionSchedulers,
|
|
force_zeros_for_empty_prompt: bool = True,
|
|
add_watermarker: Optional[bool] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
tokenizer=tokenizer,
|
|
tokenizer_2=tokenizer_2,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
)
|
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
|
|
|
self.default_sample_size = (
|
|
self.unet.config.sample_size
|
|
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
|
else 128
|
|
)
|
|
|
|
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
|
|
|
if add_watermarker:
|
|
self.watermark = StableDiffusionXLWatermarker()
|
|
else:
|
|
self.watermark = None
|
|
|
|
class SeedTilesMode(Enum):
|
|
"""Modes in which the latents of a particular tile can be re-seeded"""
|
|
|
|
FULL = "full"
|
|
EXCLUSIVE = "exclusive"
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt: str,
|
|
prompt_2: Optional[str] = None,
|
|
device: Optional[torch.device] = None,
|
|
num_images_per_prompt: int = 1,
|
|
do_classifier_free_guidance: bool = True,
|
|
negative_prompt: Optional[str] = None,
|
|
negative_prompt_2: Optional[str] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
|
lora_scale: Optional[float] = None,
|
|
clip_skip: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Encodes the prompt into text encoder hidden states.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
prompt to be encoded
|
|
prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
|
used in both text-encoders
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
do_classifier_free_guidance (`bool`):
|
|
whether to use classifier free guidance or not
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
less than `1`).
|
|
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
argument.
|
|
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
|
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
|
input argument.
|
|
lora_scale (`float`, *optional*):
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
|
clip_skip (`int`, *optional*):
|
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
"""
|
|
device = device or self._execution_device
|
|
|
|
|
|
|
|
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
|
self._lora_scale = lora_scale
|
|
|
|
|
|
if self.text_encoder is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if prompt is not None:
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
|
|
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
|
text_encoders = (
|
|
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
|
)
|
|
|
|
if prompt_embeds is None:
|
|
prompt_2 = prompt_2 or prompt
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
|
|
|
|
|
prompt_embeds_list = []
|
|
prompts = [prompt, prompt_2]
|
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
|
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
text_input_ids, untruncated_ids
|
|
):
|
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
|
|
|
|
|
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
|
|
if clip_skip is None:
|
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
else:
|
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
|
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
|
|
|
|
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
|
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
negative_prompt = negative_prompt or ""
|
|
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
|
|
|
|
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
|
negative_prompt_2 = (
|
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
|
)
|
|
|
|
uncond_tokens: List[str]
|
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
else:
|
|
uncond_tokens = [negative_prompt, negative_prompt_2]
|
|
|
|
negative_prompt_embeds_list = []
|
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
|
|
|
max_length = prompt_embeds.shape[1]
|
|
uncond_input = tokenizer(
|
|
negative_prompt,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
negative_prompt_embeds = text_encoder(
|
|
uncond_input.input_ids.to(device),
|
|
output_hidden_states=True,
|
|
)
|
|
|
|
|
|
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
|
|
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
|
|
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
seq_len = negative_prompt_embeds.shape[1]
|
|
|
|
if self.text_encoder_2 is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
if do_classifier_free_guidance:
|
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
extra_step_kwargs = {}
|
|
if accepts_eta:
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
if accepts_generator:
|
|
extra_step_kwargs["generator"] = generator
|
|
return extra_step_kwargs
|
|
|
|
def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
|
|
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
|
|
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
|
|
|
|
if not all(len(row) == grid_cols for row in prompt):
|
|
raise ValueError("All prompt rows must have the same number of prompt columns")
|
|
|
|
if not isinstance(seed_tiles_mode, str) and (
|
|
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
|
|
):
|
|
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
|
|
|
|
if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
|
|
raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
|
|
|
|
def _get_add_time_ids(
|
|
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
|
):
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
|
|
passed_add_embed_dim = (
|
|
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
|
)
|
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
|
|
|
if expected_add_embed_dim != passed_add_embed_dim:
|
|
raise ValueError(
|
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
|
)
|
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
|
return add_time_ids
|
|
|
|
def _gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype):
|
|
"""Generates a gaussian mask of weights for tile contributions"""
|
|
import numpy as np
|
|
from numpy import exp, pi, sqrt
|
|
|
|
latent_width = tile_width // 8
|
|
latent_height = tile_height // 8
|
|
|
|
var = 0.01
|
|
midpoint = (latent_width - 1) / 2
|
|
x_probs = [
|
|
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
|
|
for x in range(latent_width)
|
|
]
|
|
midpoint = latent_height / 2
|
|
y_probs = [
|
|
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
|
|
for y in range(latent_height)
|
|
]
|
|
|
|
weights_np = np.outer(y_probs, x_probs)
|
|
weights_torch = torch.tensor(weights_np, device=device)
|
|
weights_torch = weights_torch.to(dtype)
|
|
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
|
|
|
def upcast_vae(self):
|
|
dtype = self.vae.dtype
|
|
self.vae.to(dtype=torch.float32)
|
|
use_torch_2_0_or_xformers = isinstance(
|
|
self.vae.decoder.mid_block.attentions[0].processor,
|
|
(
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
FusedAttnProcessor2_0,
|
|
),
|
|
)
|
|
|
|
|
|
if use_torch_2_0_or_xformers:
|
|
self.vae.post_quant_conv.to(dtype)
|
|
self.vae.decoder.conv_in.to(dtype)
|
|
self.vae.decoder.mid_block.to(dtype)
|
|
|
|
|
|
def get_guidance_scale_embedding(
|
|
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
|
) -> torch.Tensor:
|
|
"""
|
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
|
|
|
Args:
|
|
w (`torch.Tensor`):
|
|
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
|
embedding_dim (`int`, *optional*, defaults to 512):
|
|
Dimension of the embeddings to generate.
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
Data type of the generated embeddings.
|
|
|
|
Returns:
|
|
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
|
"""
|
|
assert len(w.shape) == 1
|
|
w = w * 1000.0
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
|
emb = w.to(dtype)[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1))
|
|
assert emb.shape == (w.shape[0], embedding_dim)
|
|
return emb
|
|
|
|
@property
|
|
def guidance_scale(self):
|
|
return self._guidance_scale
|
|
|
|
@property
|
|
def clip_skip(self):
|
|
return self._clip_skip
|
|
|
|
|
|
|
|
|
|
@property
|
|
def do_classifier_free_guidance(self):
|
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
|
|
|
@property
|
|
def cross_attention_kwargs(self):
|
|
return self._cross_attention_kwargs
|
|
|
|
@property
|
|
def num_timesteps(self):
|
|
return self._num_timesteps
|
|
|
|
@property
|
|
def interrupt(self):
|
|
return self._interrupt
|
|
|
|
@torch.no_grad()
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
def __call__(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 5.0,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
original_size: Optional[Tuple[int, int]] = None,
|
|
crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
|
|
target_size: Optional[Tuple[int, int]] = None,
|
|
negative_original_size: Optional[Tuple[int, int]] = None,
|
|
negative_crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
|
|
negative_target_size: Optional[Tuple[int, int]] = None,
|
|
clip_skip: Optional[int] = None,
|
|
tile_height: Optional[int] = 1024,
|
|
tile_width: Optional[int] = 1024,
|
|
tile_row_overlap: Optional[int] = 128,
|
|
tile_col_overlap: Optional[int] = 128,
|
|
guidance_scale_tiles: Optional[List[List[float]]] = None,
|
|
seed_tiles: Optional[List[List[int]]] = None,
|
|
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
|
|
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Function invoked when calling the pipeline for generation.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
instead.
|
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
Anything below 512 pixels won't work well for
|
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
Anything below 512 pixels won't work well for
|
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
usually at the expense of lower image quality.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
less than `1`).
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
eta (`float`, *optional*, defaults to 0.0):
|
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
|
to make generation deterministic.
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generate image. Choose between
|
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
|
of a plain tuple.
|
|
cross_attention_kwargs (`dict`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
`self.processor` in
|
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
|
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
|
explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
|
crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
|
|
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
|
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
|
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
|
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
|
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
|
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
|
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
|
micro-conditioning as explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
|
negative_crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
|
|
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
|
micro-conditioning as explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
|
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
To negatively condition the generation process based on a target image resolution. It should be as same
|
|
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
|
tile_height (`int`, *optional*, defaults to 1024):
|
|
Height of each grid tile in pixels.
|
|
tile_width (`int`, *optional*, defaults to 1024):
|
|
Width of each grid tile in pixels.
|
|
tile_row_overlap (`int`, *optional*, defaults to 128):
|
|
Number of overlapping pixels between tiles in consecutive rows.
|
|
tile_col_overlap (`int`, *optional*, defaults to 128):
|
|
Number of overlapping pixels between tiles in consecutive columns.
|
|
guidance_scale_tiles (`List[List[float]]`, *optional*):
|
|
Specific weights for classifier-free guidance in each tile. If `None`, the value provided in `guidance_scale` will be used.
|
|
seed_tiles (`List[List[int]]`, *optional*):
|
|
Specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard `generator` parameter.
|
|
seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
|
|
Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
|
|
seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
|
|
A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
|
|
**kwargs (`Dict[str, Any]`, *optional*):
|
|
Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
|
|
|
|
Examples:
|
|
|
|
Returns:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] or `tuple`:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] if `return_dict` is True, otherwise a
|
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
|
"""
|
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor
|
|
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
original_size = original_size or (height, width)
|
|
target_size = target_size or (height, width)
|
|
negative_original_size = negative_original_size or (height, width)
|
|
negative_target_size = negative_target_size or (height, width)
|
|
|
|
self._guidance_scale = guidance_scale
|
|
self._clip_skip = clip_skip
|
|
self._cross_attention_kwargs = cross_attention_kwargs
|
|
self._interrupt = False
|
|
|
|
grid_rows = len(prompt)
|
|
grid_cols = len(prompt[0])
|
|
|
|
tiles_mode = [mode.value for mode in self.SeedTilesMode]
|
|
|
|
if isinstance(seed_tiles_mode, str):
|
|
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
|
|
|
|
|
|
self.check_inputs(
|
|
prompt,
|
|
height,
|
|
width,
|
|
grid_cols,
|
|
seed_tiles_mode,
|
|
tiles_mode,
|
|
)
|
|
|
|
if seed_reroll_regions is None:
|
|
seed_reroll_regions = []
|
|
|
|
batch_size = 1
|
|
|
|
device = self._execution_device
|
|
|
|
|
|
crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)
|
|
if negative_original_size is not None and negative_target_size is not None:
|
|
negative_crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)
|
|
|
|
|
|
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
|
|
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
|
|
|
|
|
|
lora_scale = (
|
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
)
|
|
text_embeddings = [
|
|
[
|
|
self.encode_prompt(
|
|
prompt=col,
|
|
device=device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
negative_prompt=negative_prompt,
|
|
prompt_embeds=None,
|
|
negative_prompt_embeds=None,
|
|
pooled_prompt_embeds=None,
|
|
negative_pooled_prompt_embeds=None,
|
|
lora_scale=lora_scale,
|
|
clip_skip=self.clip_skip,
|
|
)
|
|
for col in row
|
|
]
|
|
for row in prompt
|
|
]
|
|
|
|
|
|
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
|
dtype = text_embeddings[0][0][0].dtype
|
|
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
|
if seed_tiles is not None:
|
|
for row in range(grid_rows):
|
|
for col in range(grid_cols):
|
|
if (seed_tile := seed_tiles[row][col]) is not None:
|
|
mode = seed_tiles_mode[row][col]
|
|
if mode == self.SeedTilesMode.FULL.value:
|
|
row_init, row_end, col_init, col_end = _tile2latent_indices(
|
|
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
else:
|
|
row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(
|
|
row,
|
|
col,
|
|
tile_width,
|
|
tile_height,
|
|
tile_row_overlap,
|
|
tile_col_overlap,
|
|
grid_rows,
|
|
grid_cols,
|
|
)
|
|
tile_generator = torch.Generator(device).manual_seed(seed_tile)
|
|
tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
|
|
latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
|
|
tile_shape, generator=tile_generator, device=device
|
|
)
|
|
|
|
|
|
for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
|
|
row_init, row_end, col_init, col_end = _pixel2latent_indices(
|
|
row_init, row_end, col_init, col_end
|
|
)
|
|
reroll_generator = torch.Generator(device).manual_seed(seed_reroll)
|
|
region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
|
|
latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
|
|
region_shape, generator=reroll_generator, device=device
|
|
)
|
|
|
|
|
|
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
|
extra_set_kwargs = {}
|
|
if accepts_offset:
|
|
extra_set_kwargs["offset"] = 1
|
|
timesteps, num_inference_steps = retrieve_timesteps(
|
|
self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
|
|
)
|
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
latents = latents * self.scheduler.sigmas[0]
|
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
|
|
|
|
embeddings_and_added_time = []
|
|
for row in range(grid_rows):
|
|
addition_embed_type_row = []
|
|
for col in range(grid_cols):
|
|
|
|
prompt_embeds = text_embeddings[row][col][0]
|
|
negative_prompt_embeds = text_embeddings[row][col][1]
|
|
pooled_prompt_embeds = text_embeddings[row][col][2]
|
|
negative_pooled_prompt_embeds = text_embeddings[row][col][3]
|
|
|
|
add_text_embeds = pooled_prompt_embeds
|
|
if self.text_encoder_2 is None:
|
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
|
else:
|
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
|
add_time_ids = self._get_add_time_ids(
|
|
original_size,
|
|
crops_coords_top_left[row][col],
|
|
target_size,
|
|
dtype=prompt_embeds.dtype,
|
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
|
)
|
|
if negative_original_size is not None and negative_target_size is not None:
|
|
negative_add_time_ids = self._get_add_time_ids(
|
|
negative_original_size,
|
|
negative_crops_coords_top_left[row][col],
|
|
negative_target_size,
|
|
dtype=prompt_embeds.dtype,
|
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
|
)
|
|
else:
|
|
negative_add_time_ids = add_time_ids
|
|
|
|
if self.do_classifier_free_guidance:
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
|
|
|
prompt_embeds = prompt_embeds.to(device)
|
|
add_text_embeds = add_text_embeds.to(device)
|
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
|
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
|
embeddings_and_added_time.append(addition_embed_type_row)
|
|
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
|
|
|
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
|
|
|
|
|
|
self._num_timesteps = len(timesteps)
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
|
|
noise_preds = []
|
|
for row in range(grid_rows):
|
|
noise_preds_row = []
|
|
for col in range(grid_cols):
|
|
if self.interrupt:
|
|
continue
|
|
px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
|
|
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
|
|
|
|
latent_model_input = (
|
|
torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
|
|
)
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
|
|
added_cond_kwargs = {
|
|
"text_embeds": embeddings_and_added_time[row][col][1],
|
|
"time_ids": embeddings_and_added_time[row][col][2],
|
|
}
|
|
with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
|
|
noise_pred = self.unet(
|
|
latent_model_input,
|
|
t,
|
|
encoder_hidden_states=embeddings_and_added_time[row][col][0],
|
|
cross_attention_kwargs=self.cross_attention_kwargs,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
|
|
if self.do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
guidance = (
|
|
guidance_scale
|
|
if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None
|
|
else guidance_scale_tiles[row][col]
|
|
)
|
|
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
|
noise_preds_row.append(noise_pred_tile)
|
|
noise_preds.append(noise_preds_row)
|
|
|
|
|
|
noise_pred = torch.zeros(latents.shape, device=device)
|
|
contributors = torch.zeros(latents.shape, device=device)
|
|
|
|
|
|
for row in range(grid_rows):
|
|
for col in range(grid_cols):
|
|
px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
|
|
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
|
|
)
|
|
noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
|
|
noise_preds[row][col] * tile_weights
|
|
)
|
|
contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
|
|
|
|
|
|
noise_pred /= contributors
|
|
noise_pred = noise_pred.to(dtype)
|
|
|
|
|
|
latents_dtype = latents.dtype
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
if latents.dtype != latents_dtype:
|
|
if torch.backends.mps.is_available():
|
|
|
|
latents = latents.to(latents_dtype)
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
if not output_type == "latent":
|
|
|
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
|
|
|
if needs_upcasting:
|
|
self.upcast_vae()
|
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
|
elif latents.dtype != self.vae.dtype:
|
|
if torch.backends.mps.is_available():
|
|
|
|
self.vae = self.vae.to(latents.dtype)
|
|
|
|
|
|
|
|
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
|
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
|
if has_latents_mean and has_latents_std:
|
|
latents_mean = (
|
|
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents_std = (
|
|
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
|
else:
|
|
latents = latents / self.vae.config.scaling_factor
|
|
|
|
image = self.vae.decode(latents, return_dict=False)[0]
|
|
|
|
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float16)
|
|
else:
|
|
image = latents
|
|
|
|
if not output_type == "latent":
|
|
|
|
if self.watermark is not None:
|
|
image = self.watermark.apply_watermark(image)
|
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (image,)
|
|
|
|
return StableDiffusionXLPipelineOutput(images=image)
|
|
|