diff --git "a/v0.33.0/pipeline_faithdiff_stable_diffusion_xl.py" "b/v0.33.0/pipeline_faithdiff_stable_diffusion_xl.py" new file mode 100644--- /dev/null +++ "b/v0.33.0/pipeline_faithdiff_stable_diffusion_xl.py" @@ -0,0 +1,2269 @@ +# Copyright 2025 Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab Team +# and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + PeftAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + UNet2DConditionLoadersMixin, +) +from diffusers.models import AutoencoderKL +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_version, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.outputs import BaseOutput +from diffusers.utils.torch_utils import randn_tensor + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import random + >>> import numpy as np + >>> import torch + >>> from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler + >>> from huggingface_hub import hf_hub_download + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> + >>> device = "cuda" + >>> dtype = torch.float16 + >>> MAX_SEED = np.iinfo(np.int32).max + >>> + >>> # Download weights for additional unet layers + >>> model_file = hf_hub_download( + ... "jychen9811/FaithDiff", + ... filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False + ... ) + >>> + >>> # Initialize the models and pipeline + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype) + >>> + >>> model_id = "SG161222/RealVisXL_V4.0" + >>> pipe = DiffusionPipeline.from_pretrained( + ... model_id, + ... torch_dtype=dtype, + ... vae=vae, + ... unet=None, #<- Do not load with original model. + ... custom_pipeline="mixture_tiling_sdxl", + ... use_safetensors=True, + ... variant="fp16", + ... ).to(device) + >>> + >>> # Here we need use pipeline internal unet model + >>> pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True) + >>> + >>> # Load aditional layers to the model + >>> pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype) + >>> + >>> # Enable vae tiling + >>> pipe.set_encoder_tile_settings() + >>> pipe.enable_vae_tiling() + >>> + >>> # Optimization + >>> pipe.enable_model_cpu_offload() + >>> + >>> # Set selected scheduler + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> + >>> #input params + >>> prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. " + >>> upscale = 2 # scale here + >>> start_point = "lr" # or "noise" + >>> latent_tiled_overlap = 0.5 + >>> latent_tiled_size = 1024 + >>> + >>> # Load image + >>> lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png") + >>> original_height = lq_image.height + >>> original_width = lq_image.width + >>> print(f"Current resolution: H:{original_height} x W:{original_width}") + >>> + >>> width = original_width * int(upscale) + >>> height = original_height * int(upscale) + >>> print(f"Final resolution: H:{height} x W:{width}") + >>> + >>> # Restoration + >>> image = lq_image.resize((width, height), Image.LANCZOS) + >>> input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image) + >>> + >>> generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED)) + >>> gen_image = pipe(lr_img=input_image, + ... prompt = prompt, + ... num_inference_steps=20, + ... guidance_scale=5, + ... generator=generator, + ... start_point=start_point, + ... height = height_now, + ... width=width_now, + ... overlap=latent_tiled_overlap, + ... target_size=(latent_tiled_size, latent_tiled_size) + ... ).images[0] + >>> + >>> cropped_image = gen_image.crop((0, 0, width_init, height_init)) + >>> cropped_image.save("data/result.png") + ``` +""" + + +def zero_module(module): + """Zero out the parameters of a module and return it.""" + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class Encoder(nn.Module): + """Encoder layer of a variational autoencoder that encodes input into a latent representation.""" + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.use_rgb = False + self.down_block_type = down_block_types + self.block_out_channels = block_out_channels + + self.tile_sample_min_size = 1024 + self.tile_latent_min_size = int(self.tile_sample_min_size / 8) + self.tile_overlap_factor = 0.25 + self.use_tiling = False + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + self.gradient_checkpointing = False + + def to_rgb_init(self): + """Initialize layers to convert features to RGB.""" + self.to_rgbs = nn.ModuleList([]) + self.use_rgb = True + for i, down_block_type in enumerate(self.down_block_type): + output_channel = self.block_out_channels[i] + self.to_rgbs.append(nn.Conv2d(output_channel, 3, kernel_size=3, padding=1)) + + def enable_tiling(self): + """Enable tiling for large inputs.""" + self.use_tiling = True + + def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """Encode the input tensor into a latent representation.""" + sample = self.conv_in(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + return sample + else: + for down_block in self.down_blocks: + sample = down_block(sample) + sample = self.mid_block(sample) + return sample + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """Blend two tensors vertically with a smooth transition.""" + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """Blend two tensors horizontally with a smooth transition.""" + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor: + """Encode the input tensor using tiling for large inputs.""" + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encode(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + return moments + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """Forward pass of the encoder, using tiling if enabled for large inputs.""" + if self.use_tiling and ( + sample.shape[-1] > self.tile_latent_min_size or sample.shape[-2] > self.tile_latent_min_size + ): + return self.tiled_encode(sample) + return self.encode(sample) + + +class ControlNetConditioningEmbedding(nn.Module): + """A small network to preprocess conditioning inputs, inspired by ControlNet.""" + + def __init__(self, conditioning_embedding_channels: int, conditioning_channels: int = 4): + super().__init__() + self.conv_in = nn.Conv2d(conditioning_channels, conditioning_channels, kernel_size=3, padding=1) + self.norm_in = nn.GroupNorm(num_channels=conditioning_channels, num_groups=32, eps=1e-6) + self.conv_out = zero_module( + nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + """Process the conditioning input through the network.""" + conditioning = self.norm_in(conditioning) + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + embedding = self.conv_out(embedding) + return embedding + + +class QuickGELU(nn.Module): + """A fast approximation of the GELU activation function.""" + + def forward(self, x: torch.Tensor): + """Apply the QuickGELU activation to the input tensor.""" + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + """Apply LayerNorm and preserve the input dtype.""" + orig_type = x.dtype + ret = super().forward(x) + return ret.type(orig_type) + + +class ResidualAttentionBlock(nn.Module): + """A transformer-style block with self-attention and an MLP.""" + + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 2)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 2, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + """Apply self-attention to the input tensor.""" + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + """Forward pass through the residual attention block.""" + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """The output of UnifiedUNet2DConditionModel.""" + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + """A unified 2D UNet model extending OriginalUNet2DConditionModel with custom functionality.""" + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + """Initialize the UnifiedUNet2DConditionModel.""" + super().__init__( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + center_input_sample=center_input_sample, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + down_block_types=down_block_types, + mid_block_type=mid_block_type, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + downsample_padding=downsample_padding, + mid_block_scale_factor=mid_block_scale_factor, + dropout=dropout, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + class_embed_type=class_embed_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + num_class_embeds=num_class_embeds, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + time_embedding_type=time_embedding_type, + time_embedding_dim=time_embedding_dim, + time_embedding_act_fn=time_embedding_act_fn, + timestep_post_act=timestep_post_act, + time_cond_proj_dim=time_cond_proj_dim, + conv_in_kernel=conv_in_kernel, + conv_out_kernel=conv_out_kernel, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + attention_type=attention_type, + class_embeddings_concat=class_embeddings_concat, + mid_block_only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + addition_embed_type_num_heads=addition_embed_type_num_heads, + ) + + # Additional attributes + self.denoise_encoder = None + self.information_transformer_layes = None + self.condition_embedding = None + self.agg_net = None + self.spatial_ch_projs = None + + def init_vae_encoder(self, dtype): + self.denoise_encoder = Encoder() + if dtype is not None: + self.denoise_encoder.dtype = dtype + + def init_information_transformer_layes(self): + num_trans_channel = 640 + num_trans_head = 8 + num_trans_layer = 2 + num_proj_channel = 320 + self.information_transformer_layes = nn.Sequential( + *[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)] + ) + self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel)) + + def init_ControlNetConditioningEmbedding(self, channel=512): + self.condition_embedding = ControlNetConditioningEmbedding(320, channel) + + def init_extra_weights(self): + self.agg_net = nn.ModuleList() + + def load_additional_layers( + self, dtype: Optional[torch.dtype] = torch.float16, channel: int = 512, weight_path: Optional[str] = None + ): + """Load additional layers and weights from a file. + + Args: + weight_path (str): Path to the weight file. + dtype (torch.dtype, optional): Data type for the loaded weights. Defaults to torch.float16. + channel (int): Conditioning embedding channel out size. Defaults 512. + """ + if self.denoise_encoder is None: + self.init_vae_encoder(dtype) + + if self.information_transformer_layes is None: + self.init_information_transformer_layes() + + if self.condition_embedding is None: + self.init_ControlNetConditioningEmbedding(channel) + + if self.agg_net is None: + self.init_extra_weights() + + # Load weights if provided + if weight_path is not None: + state_dict = torch.load(weight_path, weights_only=False) + self.load_state_dict(state_dict, strict=True) + + # Move all modules to the same device and dtype as the model + device = next(self.parameters()).device + if dtype is not None or device is not None: + self.to(device=device, dtype=dtype or next(self.parameters()).dtype) + + def to(self, *args, **kwargs): + """Override to() to move all additional modules to the same device and dtype.""" + super().to(*args, **kwargs) + for module in [ + self.denoise_encoder, + self.information_transformer_layes, + self.condition_embedding, + self.agg_net, + self.spatial_ch_projs, + ]: + if module is not None: + module.to(*args, **kwargs) + return self + + def load_state_dict(self, state_dict, strict=True): + """Load state dictionary into the model. + + Args: + state_dict (dict): State dictionary to load. + strict (bool, optional): Whether to strictly enforce that all keys match. Defaults to True. + """ + core_dict = {} + additional_dicts = { + "denoise_encoder": {}, + "information_transformer_layes": {}, + "condition_embedding": {}, + "agg_net": {}, + "spatial_ch_projs": {}, + } + + for key, value in state_dict.items(): + if key.startswith("denoise_encoder."): + additional_dicts["denoise_encoder"][key[len("denoise_encoder.") :]] = value + elif key.startswith("information_transformer_layes."): + additional_dicts["information_transformer_layes"][key[len("information_transformer_layes.") :]] = value + elif key.startswith("condition_embedding."): + additional_dicts["condition_embedding"][key[len("condition_embedding.") :]] = value + elif key.startswith("agg_net."): + additional_dicts["agg_net"][key[len("agg_net.") :]] = value + elif key.startswith("spatial_ch_projs."): + additional_dicts["spatial_ch_projs"][key[len("spatial_ch_projs.") :]] = value + else: + core_dict[key] = value + + super().load_state_dict(core_dict, strict=False) + for module_name, module_dict in additional_dicts.items(): + module = getattr(self, module_name, None) + if module is not None and module_dict: + module.load_state_dict(module_dict, strict=strict) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + input_embedding: Optional[torch.Tensor] = None, + add_sample: bool = True, + return_dict: bool = True, + use_condition_embedding: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """Forward pass prioritizing the original modified implementation. + + Args: + sample (torch.FloatTensor): The noisy input tensor with shape `(batch, channel, height, width)`. + timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input. + encoder_hidden_states (torch.Tensor): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (torch.Tensor, optional): Optional class labels for conditioning. + timestep_cond (torch.Tensor, optional): Conditional embeddings for timestep. + attention_mask (torch.Tensor, optional): An attention mask of shape `(batch, key_tokens)`. + cross_attention_kwargs (Dict[str, Any], optional): A kwargs dictionary for the AttentionProcessor. + added_cond_kwargs (Dict[str, torch.Tensor], optional): Additional embeddings to add to the UNet blocks. + down_block_additional_residuals (Tuple[torch.Tensor], optional): Residuals for down UNet blocks. + mid_block_additional_residual (torch.Tensor, optional): Residual for the middle UNet block. + down_intrablock_additional_residuals (Tuple[torch.Tensor], optional): Additional residuals within down blocks. + encoder_attention_mask (torch.Tensor, optional): A cross-attention mask of shape `(batch, sequence_length)`. + input_embedding (torch.Tensor, optional): Additional input embedding for preprocessing. + add_sample (bool): Whether to add the sample to the processed embedding. Defaults to True. + return_dict (bool): Whether to return a UNet2DConditionOutput. Defaults to True. + use_condition_embedding (bool): Whether to use the condition embedding. Defaults to True. + + Returns: + Union[UNet2DConditionOutput, Tuple]: The processed sample tensor, either as a UNet2DConditionOutput or tuple. + """ + default_overall_up_factor = 2**self.num_upsamplers + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + forward_upsample_size = True + break + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process (following the original modified logic) + sample = self.conv_in(sample) # [B, 4, H, W] -> [B, 320, H, W] + if ( + input_embedding is not None + and self.condition_embedding is not None + and self.information_transformer_layes is not None + ): + if use_condition_embedding: + input_embedding = self.condition_embedding(input_embedding) # [B, 320, H, W] + batch_size, channel, height, width = input_embedding.shape + concat_feat = ( + torch.cat([sample, input_embedding], dim=1) + .view(batch_size, 2 * channel, height * width) + .transpose(1, 2) + ) + concat_feat = self.information_transformer_layes(concat_feat) + feat_alpha = self.spatial_ch_projs(concat_feat).transpose(1, 2).view(batch_size, channel, height, width) + sample = sample + feat_alpha if add_sample else feat_alpha # Update sample as in the original version + + # 2.5 GLIGEN position net (kept from the original version) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down (continues the standard flow) + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = down_intrablock_additional_residuals is not None + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + return UNet2DConditionOutput(sample=sample) + + +class LocalAttention: + """A class to handle local attention by splitting tensors into overlapping grids for processing.""" + + def __init__(self, kernel_size=None, overlap=0.5): + """Initialize the LocalAttention module. + + Args: + kernel_size (tuple[int, int], optional): Size of the grid (height, width). Defaults to None. + overlap (float): Overlap factor between adjacent grids (0.0 to 1.0). Defaults to 0.5. + """ + super().__init__() + self.kernel_size = kernel_size + self.overlap = overlap + + def grids_list(self, x): + """Split the input tensor into a list of non-overlapping grid patches. + + Args: + x (torch.Tensor): Input tensor of shape (batch, channels, height, width). + + Returns: + list[torch.Tensor]: List of tensor patches. + """ + b, c, h, w = x.shape + self.original_size = (b, c, h, w) + assert b == 1 + k1, k2 = self.kernel_size + if h < k1: + k1 = h + if w < k2: + k2 = w + num_row = (h - 1) // k1 + 1 + num_col = (w - 1) // k2 + 1 + self.nr = num_row + self.nc = num_col + + import math + + step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap) + step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap) + parts = [] + idxes = [] + i = 0 + last_i = False + while i < h and not last_i: + j = 0 + if i + k1 >= h: + i = h - k1 + last_i = True + last_j = False + while j < w and not last_j: + if j + k2 >= w: + j = w - k2 + last_j = True + parts.append(x[:, :, i : i + k1, j : j + k2]) + idxes.append({"i": i, "j": j}) + j = j + step_j + i = i + step_i + return parts + + def grids(self, x): + """Split the input tensor into overlapping grid patches and concatenate them. + + Args: + x (torch.Tensor): Input tensor of shape (batch, channels, height, width). + + Returns: + torch.Tensor: Concatenated tensor of all grid patches. + """ + b, c, h, w = x.shape + self.original_size = (b, c, h, w) + assert b == 1 + k1, k2 = self.kernel_size + if h < k1: + k1 = h + if w < k2: + k2 = w + self.tile_weights = self._gaussian_weights(k2, k1) + num_row = (h - 1) // k1 + 1 + num_col = (w - 1) // k2 + 1 + self.nr = num_row + self.nc = num_col + + import math + + step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap) + step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap) + parts = [] + idxes = [] + i = 0 + last_i = False + while i < h and not last_i: + j = 0 + if i + k1 >= h: + i = h - k1 + last_i = True + last_j = False + while j < w and not last_j: + if j + k2 >= w: + j = w - k2 + last_j = True + parts.append(x[:, :, i : i + k1, j : j + k2]) + idxes.append({"i": i, "j": j}) + j = j + step_j + i = i + step_i + self.idxes = idxes + return torch.cat(parts, dim=0) + + def _gaussian_weights(self, tile_width, tile_height): + """Generate a Gaussian weight mask for tile contributions. + + Args: + tile_width (int): Width of the tile. + tile_height (int): Height of the tile. + + Returns: + torch.Tensor: Gaussian weight tensor of shape (channels, height, width). + """ + import numpy as np + from numpy import exp, pi, sqrt + + latent_width = tile_width + latent_height = tile_height + var = 0.01 + midpoint = (latent_width - 1) / 2 + x_probs = [ + exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) + for x in range(latent_width) + ] + midpoint = latent_height / 2 + y_probs = [ + exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) + for y in range(latent_height) + ] + weights = np.outer(y_probs, x_probs) + return torch.tile(torch.tensor(weights, device=torch.device("cuda")), (4, 1, 1)) + + def grids_inverse(self, outs): + """Reconstruct the original tensor from processed grid patches with overlap blending. + + Args: + outs (torch.Tensor): Processed grid patches. + + Returns: + torch.Tensor: Reconstructed tensor of original size. + """ + preds = torch.zeros(self.original_size).to(outs.device) + b, c, h, w = self.original_size + count_mt = torch.zeros((b, 4, h, w)).to(outs.device) + k1, k2 = self.kernel_size + + for cnt, each_idx in enumerate(self.idxes): + i = each_idx["i"] + j = each_idx["j"] + preds[0, :, i : i + k1, j : j + k2] += outs[cnt, :, :, :] * self.tile_weights + count_mt[0, :, i : i + k1, j : j + k2] += self.tile_weights + + del outs + torch.cuda.empty_cache() + return preds / count_mt + + def _pad(self, x): + """Pad the input tensor to align with kernel size. + + Args: + x (torch.Tensor): Input tensor of shape (batch, channels, height, width). + + Returns: + tuple: Padded tensor and padding values. + """ + b, c, h, w = x.shape + k1, k2 = self.kernel_size + mod_pad_h = (k1 - h % k1) % k1 + mod_pad_w = (k2 - w % k2) % k2 + pad = (mod_pad_w // 2, mod_pad_w - mod_pad_w // 2, mod_pad_h // 2, mod_pad_h - mod_pad_h // 2) + x = F.pad(x, pad, "reflect") + return x, pad + + def forward(self, x): + """Apply local attention by splitting into grids and reconstructing. + + Args: + x (torch.Tensor): Input tensor of shape (batch, channels, height, width). + + Returns: + torch.Tensor: Processed tensor of original size. + """ + b, c, h, w = x.shape + qkv = self.grids(x) + out = self.grids_inverse(qkv) + return out + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + + Args: + noise_cfg (torch.Tensor): Noise configuration tensor. + noise_pred_text (torch.Tensor): Predicted noise from text-conditioned model. + guidance_rescale (float): Rescaling factor for guidance. Defaults to 0.0. + + Returns: + torch.Tensor: Rescaled noise configuration. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + """Retrieve latents from an encoder output. + + Args: + encoder_output (torch.Tensor): Output from an encoder (e.g., VAE). + generator (torch.Generator, optional): Random generator for sampling. Defaults to None. + sample_mode (str): Sampling mode ("sample" or "argmax"). Defaults to "sample". + + Returns: + torch.Tensor: Retrieved latent tensor. + """ + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FaithDiffStableDiffusionXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + unet_model = UNet2DConditionModel + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "unet"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: OriginalUNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.DDPMScheduler = DDPMScheduler.from_config(self.scheduler.config, subfolder="scheduler") + self.default_sample_size = self.unet.config.sample_size if unet is not None else 128 + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = "cuda" # device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + dtype = text_encoders[0].dtype + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + text_encoder = text_encoder.to(dtype) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_image_size(self, x, padder_size=8): + # 获取图像的宽高 + width, height = x.size + padder_size = padder_size + # 计算需要填充的高度和宽度 + mod_pad_h = (padder_size - height % padder_size) % padder_size + mod_pad_w = (padder_size - width % padder_size) % padder_size + x_np = np.array(x) + # 使用 ImageOps.expand 进行填充 + x_padded = cv2.copyMakeBorder( + x_np, top=0, bottom=mod_pad_h, left=0, right=mod_pad_w, borderType=cv2.BORDER_REPLICATE + ) + + x = PIL.Image.fromarray(x_padded) + # x = x.resize((width + mod_pad_w, height + mod_pad_h)) + + return x, width, height, width + mod_pad_w, height + mod_pad_h + + def check_inputs( + self, + lr_img, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if lr_img is None: + raise ValueError("`lr_image` must be provided!") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + def set_encoder_tile_settings( + self, + denoise_encoder_tile_sample_min_size=1024, + denoise_encoder_sample_overlap_factor=0.25, + vae_sample_size=1024, + vae_tile_overlap_factor=0.25, + ): + self.unet.denoise_encoder.tile_sample_min_size = denoise_encoder_tile_sample_min_size + self.unet.denoise_encoder.tile_overlap_factor = denoise_encoder_sample_overlap_factor + self.vae.config.sample_size = vae_sample_size + self.vae.tile_overlap_factor = vae_tile_overlap_factor + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + self.unet.denoise_encoder.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + self.unet.denoise_encoder.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + # if needs_upcasting: + # image = image.float() + # self.upcast_vae() + self.unet.denoise_encoder.to(device=image.device, dtype=image.dtype) + image_latents = self.unet.denoise_encoder(image) + self.unet.denoise_encoder.to("cpu") + # cast back to fp16 if needed + # if needs_upcasting: + # self.vae.to(dtype=torch.float16) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + image_latents = image_latents + + if image_latents.dtype != self.vae.dtype: + image_latents = image_latents.to(dtype=self.vae.dtype) + + return image_latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + lr_img: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + start_point: Optional[str] = "noise", + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + overlap: float = 0.5, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + add_sample: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + lr_img (PipelineImageInput, optional): Low-resolution input image for conditioning the generation process. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + start_point (str, *optional*): + The starting point for the generation process. Can be "noise" (random noise) or "lr" (low-resolution image). + Defaults to "noise". + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + overlap (float): + Overlap factor for local attention tiling (between 0.0 and 1.0). Controls the overlap between adjacent + grid patches during processing. Defaults to 0.5. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + add_sample (bool): + Whether to include sample conditioning (e.g., low-resolution image) in the UNet during denoising. + Defaults to True. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + lr_img, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self.tlc_vae_latents = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap) + self.tlc_vae_img = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap) + + # 2. Define call parameters + batch_size = 1 + num_images_per_prompt = 1 + + device = torch.device("cuda") # self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + num_samples = num_images_per_prompt + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + lora_scale=lora_scale, + ) + + lr_img_list = [lr_img] + lr_img = self.image_processor.preprocess(lr_img_list, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + image_latents = self.prepare_image_latents( + lr_img, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, self.do_classifier_free_guidance + ) + + image_latents = self.tlc_vae_img.grids(image_latents) + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if start_point == "lr": + latents_condition_image = self.vae.encode(lr_img * 2 - 1).latent_dist.sample() + latents_condition_image = latents_condition_image * self.vae.config.scaling_factor + start_steps_tensor = torch.randint(999, 999 + 1, (latents.shape[0],), device=latents.device) + start_steps_tensor = start_steps_tensor.long() + latents = self.DDPMScheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor) + + latents = self.tlc_vae_latents.grids(latents) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * image_latents.shape[0] + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + sub_latents_num = latents.shape[0] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if i >= 1: + latents = self.tlc_vae_latents.grids(latents).to(dtype=latents.dtype) + if self.interrupt: + continue + concat_grid = [] + for sub_num in range(sub_latents_num): + self.scheduler.__dict__.update(views_scheduler_status[sub_num]) + sub_latents = latents[sub_num, :, :, :].unsqueeze(0) + img_sub_latents = image_latents[sub_num, :, :, :].unsqueeze(0) + latent_model_input = ( + torch.cat([sub_latents] * 2) if self.do_classifier_free_guidance else sub_latents + ) + img_sub_latents = ( + torch.cat([img_sub_latents] * 2) if self.do_classifier_free_guidance else img_sub_latents + ) + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + pos_height = self.tlc_vae_latents.idxes[sub_num]["i"] + pos_width = self.tlc_vae_latents.idxes[sub_num]["j"] + add_time_ids = [ + torch.tensor([original_size]), + torch.tensor([[pos_height, pos_width]]), + torch.tensor([target_size]), + ] + add_time_ids = torch.cat(add_time_ids, dim=1).to( + img_sub_latents.device, dtype=img_sub_latents.dtype + ) + add_time_ids = add_time_ids.repeat(2, 1).to(dtype=img_sub_latents.dtype) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + with torch.amp.autocast( + device.type, dtype=latents.dtype, enabled=latents.dtype != self.unet.dtype + ): + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + input_embedding=img_sub_latents, + add_sample=add_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = sub_latents.dtype + sub_latents = self.scheduler.step( + noise_pred, t, sub_latents, **extra_step_kwargs, return_dict=False + )[0] + + views_scheduler_status[sub_num] = copy.deepcopy(self.scheduler.__dict__) + concat_grid.append(sub_latents) + if latents.dtype != sub_latents: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + sub_latents = sub_latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self.tlc_vae_latents.grids_inverse(torch.cat(concat_grid, dim=0)).to(sub_latents.dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image)