Spaces:
Runtime error
Runtime error
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111). | |
| # Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly | |
| # Code: https://github.com/apple/ml-mdm with MIT license | |
| # | |
| # Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). | |
| import gc | |
| import inspect | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from packaging import version | |
| from PIL import Image | |
| from torch import nn | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast | |
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
| from diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config | |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
| from diffusers.loaders import ( | |
| FromSingleFileMixin, | |
| IPAdapterMixin, | |
| PeftAdapterMixin, | |
| StableDiffusionLoraLoaderMixin, | |
| TextualInversionLoaderMixin, | |
| UNet2DConditionLoadersMixin, | |
| ) | |
| from diffusers.loaders.single_file_model import FromOriginalModelMixin | |
| from diffusers.models.activations import GELU, get_activation | |
| from diffusers.models.attention_processor import ( | |
| ADDED_KV_ATTENTION_PROCESSORS, | |
| CROSS_ATTENTION_PROCESSORS, | |
| Attention, | |
| AttentionProcessor, | |
| AttnAddedKVProcessor, | |
| AttnProcessor, | |
| FusedAttnProcessor2_0, | |
| ) | |
| from diffusers.models.downsampling import Downsample2D | |
| from diffusers.models.embeddings import ( | |
| GaussianFourierProjection, | |
| GLIGENTextBoundingboxProjection, | |
| ImageHintTimeEmbedding, | |
| ImageProjection, | |
| ImageTimeEmbedding, | |
| TextImageProjection, | |
| TextImageTimeEmbedding, | |
| TextTimeEmbedding, | |
| TimestepEmbedding, | |
| Timesteps, | |
| ) | |
| from diffusers.models.lora import adjust_lora_scale_text_encoder | |
| from diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin | |
| from diffusers.models.resnet import ResnetBlock2D | |
| from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D | |
| from diffusers.models.upsampling import Upsample2D | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| BaseOutput, | |
| deprecate, | |
| is_torch_xla_available, | |
| logging, | |
| replace_example_docstring, | |
| scale_lora_layers, | |
| unscale_lora_layers, | |
| ) | |
| from diffusers.utils.torch_utils import apply_freeu, randn_tensor | |
| if is_torch_xla_available(): | |
| import torch_xla.core.xla_model as xm # type: ignore | |
| XLA_AVAILABLE = True | |
| else: | |
| XLA_AVAILABLE = False | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```py | |
| >>> from diffusers import DiffusionPipeline | |
| >>> from diffusers.utils import make_image_grid | |
| >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 | |
| >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", | |
| ... nesting_level=0, | |
| ... trust_remote_code=False, # One needs to give permission for this code to run | |
| ... ).to("cuda") | |
| >>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" | |
| >>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" | |
| >>> image = pipe(prompt, num_inference_steps=50).images | |
| >>> make_image_grid(image, rows=1, cols=len(image)) | |
| >>> # pipe.change_nesting_level(<int>) # 0, 1, or 2 | |
| >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. | |
| ``` | |
| """ | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| # Copied from diffusers.models.attention._chunked_feed_forward | |
| def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): | |
| # "feed_forward_chunk_size" can be used to save memory | |
| if hidden_states.shape[chunk_dim] % chunk_size != 0: | |
| raise ValueError( | |
| f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
| ) | |
| num_chunks = hidden_states.shape[chunk_dim] // chunk_size | |
| ff_output = torch.cat( | |
| [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], | |
| dim=chunk_dim, | |
| ) | |
| return ff_output | |
| class MatryoshkaDDIMSchedulerOutput(BaseOutput): | |
| """ | |
| Output class for the scheduler's `step` function output. | |
| Args: | |
| prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | |
| Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
| denoising loop. | |
| pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | |
| The predicted denoised sample `(x_{0})` based on the model output from the current timestep. | |
| `pred_original_sample` can be used to preview progress or for guidance. | |
| """ | |
| prev_sample: Union[torch.Tensor, List[torch.Tensor]] | |
| pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None | |
| # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar | |
| def betas_for_alpha_bar( | |
| num_diffusion_timesteps, | |
| max_beta=0.999, | |
| alpha_transform_type="cosine", | |
| ): | |
| """ | |
| Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | |
| (1-beta) over time from t = [0,1]. | |
| Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | |
| to that part of the diffusion process. | |
| Args: | |
| num_diffusion_timesteps (`int`): the number of betas to produce. | |
| max_beta (`float`): the maximum beta to use; use values lower than 1 to | |
| prevent singularities. | |
| alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. | |
| Choose from `cosine` or `exp` | |
| Returns: | |
| betas (`np.ndarray`): the betas used by the scheduler to step the model outputs | |
| """ | |
| if alpha_transform_type == "cosine": | |
| def alpha_bar_fn(t): | |
| return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 | |
| elif alpha_transform_type == "exp": | |
| def alpha_bar_fn(t): | |
| return math.exp(t * -12.0) | |
| else: | |
| raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") | |
| betas = [] | |
| for i in range(num_diffusion_timesteps): | |
| t1 = i / num_diffusion_timesteps | |
| t2 = (i + 1) / num_diffusion_timesteps | |
| betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) | |
| return torch.tensor(betas, dtype=torch.float32) | |
| # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr | |
| def rescale_zero_terminal_snr(betas): | |
| """ | |
| Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) | |
| Args: | |
| betas (`torch.Tensor`): | |
| the betas that the scheduler is being initialized with. | |
| Returns: | |
| `torch.Tensor`: rescaled betas with zero terminal SNR | |
| """ | |
| # Convert betas to alphas_bar_sqrt | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_bar_sqrt = alphas_cumprod.sqrt() | |
| # Store old values. | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| # Shift so the last timestep is zero. | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| # Scale so the first timestep is back to the old value. | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
| # Convert alphas_bar_sqrt to betas | |
| alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| return betas | |
| class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): | |
| """ | |
| `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with | |
| non-Markovian guidance. | |
| This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic | |
| methods the library implements for all schedulers such as loading and saving. | |
| Args: | |
| num_train_timesteps (`int`, defaults to 1000): | |
| The number of diffusion steps to train the model. | |
| beta_start (`float`, defaults to 0.0001): | |
| The starting `beta` value of inference. | |
| beta_end (`float`, defaults to 0.02): | |
| The final `beta` value. | |
| beta_schedule (`str`, defaults to `"linear"`): | |
| The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | |
| `linear`, `scaled_linear`, or `squaredcos_cap_v2`. | |
| trained_betas (`np.ndarray`, *optional*): | |
| Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. | |
| clip_sample (`bool`, defaults to `True`): | |
| Clip the predicted sample for numerical stability. | |
| clip_sample_range (`float`, defaults to 1.0): | |
| The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. | |
| set_alpha_to_one (`bool`, defaults to `True`): | |
| Each diffusion step uses the alphas product value at that step and at the previous one. For the final step | |
| there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, | |
| otherwise it uses the alpha value at step 0. | |
| steps_offset (`int`, defaults to 0): | |
| An offset added to the inference steps, as required by some model families. | |
| prediction_type (`str`, defaults to `epsilon`, *optional*): | |
| Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), | |
| `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen | |
| Video](https://imagen.research.google/video/paper.pdf) paper). | |
| thresholding (`bool`, defaults to `False`): | |
| Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such | |
| as Stable Diffusion. | |
| dynamic_thresholding_ratio (`float`, defaults to 0.995): | |
| The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. | |
| sample_max_value (`float`, defaults to 1.0): | |
| The threshold value for dynamic thresholding. Valid only when `thresholding=True`. | |
| timestep_spacing (`str`, defaults to `"leading"`): | |
| The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. | |
| rescale_betas_zero_snr (`bool`, defaults to `False`): | |
| Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and | |
| dark samples instead of limiting it to samples with medium brightness. Loosely related to | |
| [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). | |
| """ | |
| order = 1 | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| beta_start: float = 0.0001, | |
| beta_end: float = 0.02, | |
| beta_schedule: str = "linear", | |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
| clip_sample: bool = True, | |
| set_alpha_to_one: bool = True, | |
| steps_offset: int = 0, | |
| prediction_type: str = "epsilon", | |
| thresholding: bool = False, | |
| dynamic_thresholding_ratio: float = 0.995, | |
| clip_sample_range: float = 1.0, | |
| sample_max_value: float = 1.0, | |
| timestep_spacing: str = "leading", | |
| rescale_betas_zero_snr: bool = False, | |
| ): | |
| if trained_betas is not None: | |
| self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
| elif beta_schedule == "linear": | |
| self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
| elif beta_schedule == "scaled_linear": | |
| # this schedule is very specific to the latent diffusion model. | |
| self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
| elif beta_schedule == "squaredcos_cap_v2": | |
| if self.config.timestep_spacing == "matryoshka_style": | |
| self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps))) | |
| else: | |
| # Glide cosine schedule | |
| self.betas = betas_for_alpha_bar(num_train_timesteps) | |
| else: | |
| raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") | |
| # Rescale for zero SNR | |
| if rescale_betas_zero_snr: | |
| self.betas = rescale_zero_terminal_snr(self.betas) | |
| self.alphas = 1.0 - self.betas | |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
| # At every step in ddim, we are looking into the previous alphas_cumprod | |
| # For the final step, there is no previous alphas_cumprod because we are already at 0 | |
| # `set_alpha_to_one` decides whether we set this parameter simply to one or | |
| # whether we use the final alpha of the "non-previous" one. | |
| self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | |
| # standard deviation of the initial noise distribution | |
| self.init_noise_sigma = 1.0 | |
| # setable values | |
| self.num_inference_steps = None | |
| self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) | |
| self.scales = None | |
| self.schedule_shifted_power = 1.0 | |
| def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: | |
| """ | |
| Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | |
| current timestep. | |
| Args: | |
| sample (`torch.Tensor`): | |
| The input sample. | |
| timestep (`int`, *optional*): | |
| The current timestep in the diffusion chain. | |
| Returns: | |
| `torch.Tensor`: | |
| A scaled input sample. | |
| """ | |
| return sample | |
| def _get_variance(self, timestep, prev_timestep): | |
| alpha_prod_t = self.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
| return variance | |
| # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample | |
| def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: | |
| """ | |
| "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the | |
| prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by | |
| s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing | |
| pixels from saturation at each step. We find that dynamic thresholding results in significantly better | |
| photorealism as well as better image-text alignment, especially when using very large guidance weights." | |
| https://huggingface.co/papers/2205.11487 | |
| """ | |
| dtype = sample.dtype | |
| batch_size, channels, *remaining_dims = sample.shape | |
| if dtype not in (torch.float32, torch.float64): | |
| sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half | |
| # Flatten sample for doing quantile calculation along each image | |
| sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) | |
| abs_sample = sample.abs() # "a certain percentile absolute pixel value" | |
| s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) | |
| s = torch.clamp( | |
| s, min=1, max=self.config.sample_max_value | |
| ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] | |
| s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 | |
| sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" | |
| sample = sample.reshape(batch_size, channels, *remaining_dims) | |
| sample = sample.to(dtype) | |
| return sample | |
| def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | |
| """ | |
| Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
| Args: | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. | |
| """ | |
| if num_inference_steps > self.config.num_train_timesteps: | |
| raise ValueError( | |
| f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" | |
| f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" | |
| f" maximal {self.config.num_train_timesteps} timesteps." | |
| ) | |
| self.num_inference_steps = num_inference_steps | |
| # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 | |
| if self.config.timestep_spacing == "linspace": | |
| timesteps = ( | |
| np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) | |
| .round()[::-1] | |
| .copy() | |
| .astype(np.int64) | |
| ) | |
| elif self.config.timestep_spacing == "leading": | |
| step_ratio = self.config.num_train_timesteps // self.num_inference_steps | |
| # creates integer timesteps by multiplying by ratio | |
| # casting to int to avoid issues when num_inference_step is power of 3 | |
| timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | |
| timesteps += self.config.steps_offset | |
| elif self.config.timestep_spacing == "trailing": | |
| step_ratio = self.config.num_train_timesteps / self.num_inference_steps | |
| # creates integer timesteps by multiplying by ratio | |
| # casting to int to avoid issues when num_inference_step is power of 3 | |
| timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) | |
| timesteps -= 1 | |
| elif self.config.timestep_spacing == "matryoshka_style": | |
| step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1) | |
| timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64) | |
| else: | |
| raise ValueError( | |
| f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." | |
| ) | |
| self.timesteps = torch.from_numpy(timesteps).to(device) | |
| def get_schedule_shifted(self, alpha_prod, scale_factor=None): | |
| if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule | |
| scale_factor = scale_factor**self.schedule_shifted_power | |
| snr = alpha_prod / (1 - alpha_prod) | |
| scaled_snr = snr / scale_factor | |
| alpha_prod = 1 / (1 + 1 / scaled_snr) | |
| return alpha_prod | |
| def step( | |
| self, | |
| model_output: torch.Tensor, | |
| timestep: int, | |
| sample: torch.Tensor, | |
| eta: float = 0.0, | |
| use_clipped_model_output: bool = False, | |
| generator=None, | |
| variance_noise: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| Args: | |
| model_output (`torch.Tensor`): | |
| The direct output from learned diffusion model. | |
| timestep (`float`): | |
| The current discrete timestep in the diffusion chain. | |
| sample (`torch.Tensor`): | |
| A current instance of a sample created by the diffusion process. | |
| eta (`float`): | |
| The weight of noise for added noise in diffusion step. | |
| use_clipped_model_output (`bool`, defaults to `False`): | |
| If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary | |
| because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no | |
| clipping has happened, "corrected" `model_output` would coincide with the one provided as input and | |
| `use_clipped_model_output` has no effect. | |
| generator (`torch.Generator`, *optional*): | |
| A random number generator. | |
| variance_noise (`torch.Tensor`): | |
| Alternative to generating noise with `generator` by directly providing the noise for the variance | |
| itself. Useful for methods such as [`CycleDiffusion`]. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. | |
| Returns: | |
| [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: | |
| If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a | |
| tuple is returned where the first element is the sample tensor. | |
| """ | |
| if self.num_inference_steps is None: | |
| raise ValueError( | |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
| ) | |
| # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502 | |
| # Ideally, read DDIM paper in-detail understanding | |
| # Notation (<variable name> -> <name in paper> | |
| # - pred_noise_t -> e_theta(x_t, t) | |
| # - pred_original_sample -> f_theta(x_t, t) or x_0 | |
| # - std_dev_t -> sigma_t | |
| # - eta -> η | |
| # - pred_sample_direction -> "direction pointing to x_t" | |
| # - pred_prev_sample -> "x_t-1" | |
| # 1. get previous step value (=t-1) | |
| if self.config.timestep_spacing != "matryoshka_style": | |
| prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | |
| else: | |
| prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1] | |
| # 2. compute alphas, betas | |
| alpha_prod_t = self.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
| if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: | |
| alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales]) | |
| alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales]) | |
| beta_prod_t = 1 - alpha_prod_t | |
| # 3. compute predicted original sample from predicted noise also called | |
| # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 | |
| if self.config.prediction_type == "epsilon": | |
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
| pred_epsilon = model_output | |
| elif self.config.prediction_type == "sample": | |
| pred_original_sample = model_output | |
| pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
| elif self.config.prediction_type == "v_prediction": | |
| if len(model_output) > 1: | |
| pred_original_sample = [] | |
| pred_epsilon = [] | |
| for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): | |
| pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o) | |
| pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s) | |
| else: | |
| pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | |
| pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | |
| else: | |
| raise ValueError( | |
| f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | |
| " `v_prediction`" | |
| ) | |
| # 4. Clip or threshold "predicted x_0" | |
| if self.config.thresholding: | |
| if len(model_output) > 1: | |
| pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] | |
| else: | |
| pred_original_sample = self._threshold_sample(pred_original_sample) | |
| elif self.config.clip_sample: | |
| if len(model_output) > 1: | |
| pred_original_sample = [ | |
| p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) | |
| for p_o_s in pred_original_sample | |
| ] | |
| else: | |
| pred_original_sample = pred_original_sample.clamp( | |
| -self.config.clip_sample_range, self.config.clip_sample_range | |
| ) | |
| # 5. compute variance: "sigma_t(η)" -> see formula (16) | |
| # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
| variance = self._get_variance(timestep, prev_timestep) | |
| std_dev_t = eta * variance ** (0.5) | |
| if use_clipped_model_output: | |
| # the pred_epsilon is always re-derived from the clipped x_0 in Glide | |
| if len(model_output) > 1: | |
| pred_epsilon = [] | |
| for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t): | |
| pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5)) | |
| else: | |
| pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
| # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 | |
| if len(model_output) > 1: | |
| pred_sample_direction = [] | |
| for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): | |
| pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) | |
| else: | |
| pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | |
| # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502 | |
| if len(model_output) > 1: | |
| prev_sample = [] | |
| for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): | |
| prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) | |
| else: | |
| prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
| if eta > 0: | |
| if variance_noise is not None and generator is not None: | |
| raise ValueError( | |
| "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" | |
| " `variance_noise` stays `None`." | |
| ) | |
| if variance_noise is None: | |
| if len(model_output) > 1: | |
| variance_noise = [] | |
| for m_o in model_output: | |
| variance_noise.append( | |
| randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype) | |
| ) | |
| else: | |
| variance_noise = randn_tensor( | |
| model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | |
| ) | |
| if len(model_output) > 1: | |
| prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)] | |
| else: | |
| variance = std_dev_t * variance_noise | |
| prev_sample = prev_sample + variance | |
| if not return_dict: | |
| return (prev_sample,) | |
| return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
| # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise | |
| def add_noise( | |
| self, | |
| original_samples: torch.Tensor, | |
| noise: torch.Tensor, | |
| timesteps: torch.IntTensor, | |
| ) -> torch.Tensor: | |
| # Make sure alphas_cumprod and timestep have same device and dtype as original_samples | |
| # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement | |
| # for the subsequent add_noise calls | |
| self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) | |
| alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) | |
| timesteps = timesteps.to(original_samples.device) | |
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
| while len(sqrt_alpha_prod.shape) < len(original_samples.shape): | |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
| while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
| noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
| return noisy_samples | |
| # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity | |
| def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: | |
| # Make sure alphas_cumprod and timestep have same device and dtype as sample | |
| self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) | |
| alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) | |
| timesteps = timesteps.to(sample.device) | |
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
| while len(sqrt_alpha_prod.shape) < len(sample.shape): | |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
| while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
| velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample | |
| return velocity | |
| def __len__(self): | |
| return self.config.num_train_timesteps | |
| class CrossAttnDownBlock2D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| norm_type: str = "layer_norm", | |
| num_attention_heads: int = 1, | |
| cross_attention_dim: int = 1280, | |
| cross_attention_norm: Optional[str] = None, | |
| output_scale_factor: float = 1.0, | |
| downsample_padding: int = 1, | |
| add_downsample: bool = True, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| attention_bias: bool = False, | |
| use_attention_ffn: bool = True, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| attentions = [] | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| if isinstance(transformer_layers_per_block, int): | |
| transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| attentions.append( | |
| MatryoshkaTransformer2DModel( | |
| num_attention_heads, | |
| out_channels // num_attention_heads, | |
| in_channels=out_channels, | |
| num_layers=transformer_layers_per_block[i], | |
| cross_attention_dim=cross_attention_dim, | |
| upcast_attention=upcast_attention, | |
| use_attention_ffn=use_attention_ffn, | |
| ) | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList( | |
| [ | |
| Downsample2D( | |
| out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.downsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| temb: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| additional_residuals: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: | |
| if cross_attention_kwargs is not None: | |
| if cross_attention_kwargs.get("scale", None) is not None: | |
| logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
| output_states = () | |
| blocks = list(zip(self.resnets, self.attentions)) | |
| for i, (resnet, attn) in enumerate(blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| # apply additional residuals to the output of the last pair of resnet and attention blocks | |
| if i == len(blocks) - 1 and additional_residuals is not None: | |
| hidden_states = hidden_states + additional_residuals | |
| output_states = output_states + (hidden_states,) | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| output_states = output_states + (hidden_states,) | |
| return hidden_states, output_states | |
| class UNetMidBlock2DCrossAttn(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| temb_channels: int, | |
| out_channels: Optional[int] = None, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_groups_out: Optional[int] = None, | |
| resnet_pre_norm: bool = True, | |
| norm_type: str = "layer_norm", | |
| num_attention_heads: int = 1, | |
| output_scale_factor: float = 1.0, | |
| cross_attention_dim: int = 1280, | |
| cross_attention_norm: Optional[str] = None, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| upcast_attention: bool = False, | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| attention_bias: bool = False, | |
| use_attention_ffn: bool = True, | |
| ): | |
| super().__init__() | |
| out_channels = out_channels or in_channels | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
| # support for variable transformer layers per block | |
| if isinstance(transformer_layers_per_block, int): | |
| transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
| resnet_groups_out = resnet_groups_out or resnet_groups | |
| # there is always at least one resnet | |
| resnets = [ | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| groups_out=resnet_groups_out, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ] | |
| attentions = [] | |
| for i in range(num_layers): | |
| attentions.append( | |
| MatryoshkaTransformer2DModel( | |
| num_attention_heads, | |
| out_channels // num_attention_heads, | |
| in_channels=out_channels, | |
| num_layers=transformer_layers_per_block[i], | |
| cross_attention_dim=cross_attention_dim, | |
| upcast_attention=upcast_attention, | |
| use_attention_ffn=use_attention_ffn, | |
| ) | |
| ) | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups_out, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| temb: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if cross_attention_kwargs is not None: | |
| if cross_attention_kwargs.get("scale", None) is not None: | |
| logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
| hidden_states = self.resnets[0](hidden_states, temb) | |
| for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) | |
| else: | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| hidden_states = resnet(hidden_states, temb) | |
| return hidden_states | |
| class CrossAttnUpBlock2D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| prev_output_channel: int, | |
| temb_channels: int, | |
| resolution_idx: Optional[int] = None, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| norm_type: str = "layer_norm", | |
| num_attention_heads: int = 1, | |
| cross_attention_dim: int = 1280, | |
| cross_attention_norm: Optional[str] = None, | |
| output_scale_factor: float = 1.0, | |
| add_upsample: bool = True, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| attention_bias: bool = False, | |
| use_attention_ffn: bool = True, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| attentions = [] | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| if isinstance(transformer_layers_per_block, int): | |
| transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
| for i in range(num_layers): | |
| res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=resnet_in_channels + res_skip_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| attentions.append( | |
| MatryoshkaTransformer2DModel( | |
| num_attention_heads, | |
| out_channels // num_attention_heads, | |
| in_channels=out_channels, | |
| num_layers=transformer_layers_per_block[i], | |
| cross_attention_dim=cross_attention_dim, | |
| upcast_attention=upcast_attention, | |
| use_attention_ffn=use_attention_ffn, | |
| ) | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| if add_upsample: | |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
| else: | |
| self.upsamplers = None | |
| self.gradient_checkpointing = False | |
| self.resolution_idx = resolution_idx | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| res_hidden_states_tuple: Tuple[torch.Tensor, ...], | |
| temb: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| upsample_size: Optional[int] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if cross_attention_kwargs is not None: | |
| if cross_attention_kwargs.get("scale", None) is not None: | |
| logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
| is_freeu_enabled = ( | |
| getattr(self, "s1", None) | |
| and getattr(self, "s2", None) | |
| and getattr(self, "b1", None) | |
| and getattr(self, "b2", None) | |
| ) | |
| for resnet, attn in zip(self.resnets, self.attentions): | |
| # pop res hidden states | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| # FreeU: Only operate on the first two stages | |
| if is_freeu_enabled: | |
| hidden_states, res_hidden_states = apply_freeu( | |
| self.resolution_idx, | |
| hidden_states, | |
| res_hidden_states, | |
| s1=self.s1, | |
| s2=self.s2, | |
| b1=self.b1, | |
| b2=self.b2, | |
| ) | |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |
| class MatryoshkaTransformer2DModelOutput(BaseOutput): | |
| """ | |
| The output of [`MatryoshkaTransformer2DModel`]. | |
| Args: | |
| sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete): | |
| The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability | |
| distributions for the unnoised latent pixels. | |
| """ | |
| sample: "torch.Tensor" # noqa: F821 | |
| class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): | |
| _supports_gradient_checkpointing = True | |
| _no_split_modules = ["MatryoshkaTransformerBlock"] | |
| def __init__( | |
| self, | |
| num_attention_heads: int = 16, | |
| attention_head_dim: int = 88, | |
| in_channels: Optional[int] = None, | |
| num_layers: int = 1, | |
| cross_attention_dim: Optional[int] = None, | |
| upcast_attention: bool = False, | |
| use_attention_ffn: bool = True, | |
| ): | |
| super().__init__() | |
| self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim | |
| self.gradient_checkpointing = False | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| MatryoshkaTransformerBlock( | |
| self.in_channels, | |
| self.config.num_attention_heads, | |
| self.config.attention_head_dim, | |
| cross_attention_dim=self.config.cross_attention_dim, | |
| upcast_attention=self.config.upcast_attention, | |
| use_attention_ffn=self.config.use_attention_ffn, | |
| ) | |
| for _ in range(self.config.num_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| added_cond_kwargs: Dict[str, torch.Tensor] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ): | |
| """ | |
| The [`MatryoshkaTransformer2DModel`] forward method. | |
| Args: | |
| hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): | |
| Input `hidden_states`. | |
| encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
| Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
| self-attention. | |
| timestep ( `torch.LongTensor`, *optional*): | |
| Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. | |
| class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | |
| Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in | |
| `AdaLayerZeroNorm`. | |
| cross_attention_kwargs ( `Dict[str, Any]`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| attention_mask ( `torch.Tensor`, *optional*): | |
| An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
| is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
| negative values to the attention scores corresponding to "discard" tokens. | |
| encoder_attention_mask ( `torch.Tensor`, *optional*): | |
| Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: | |
| * Mask `(batch, sequence_length)` True = keep, False = discard. | |
| * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. | |
| If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format | |
| above. This bias will be added to the cross-attention scores. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned, | |
| otherwise a `tuple` where the first element is the sample tensor. | |
| """ | |
| if cross_attention_kwargs is not None: | |
| if cross_attention_kwargs.get("scale", None) is not None: | |
| logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
| # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. | |
| # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. | |
| # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. | |
| # expects mask of shape: | |
| # [batch, key_tokens] | |
| # adds singleton query_tokens dimension: | |
| # [batch, 1, key_tokens] | |
| # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
| # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
| # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
| if attention_mask is not None and attention_mask.ndim == 2: | |
| # assume that mask is expressed as: | |
| # (1 = keep, 0 = discard) | |
| # convert mask into a bias that can be added to attention scores: | |
| # (keep = +0, discard = -10000.0) | |
| attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # convert encoder_attention_mask to a bias the same way we do for attention_mask | |
| if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
| encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
| # Blocks | |
| for block in self.transformer_blocks: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, | |
| hidden_states, | |
| attention_mask, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| timestep, | |
| cross_attention_kwargs, | |
| class_labels, | |
| ) | |
| else: | |
| hidden_states = block( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| timestep=timestep, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| class_labels=class_labels, | |
| ) | |
| # Output | |
| output = hidden_states | |
| if not return_dict: | |
| return (output,) | |
| return MatryoshkaTransformer2DModelOutput(sample=output) | |
| class MatryoshkaTransformerBlock(nn.Module): | |
| r""" | |
| Matryoshka Transformer block. | |
| Parameters: | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| cross_attention_dim: Optional[int] = None, | |
| upcast_attention: bool = False, | |
| use_attention_ffn: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_head_dim = attention_head_dim | |
| self.cross_attention_dim = cross_attention_dim | |
| # Define 3 blocks. | |
| # 1. Self-Attn | |
| self.attn1 = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| norm_num_groups=32, | |
| bias=True, | |
| upcast_attention=upcast_attention, | |
| pre_only=True, | |
| processor=MatryoshkaFusedAttnProcessor2_0(), | |
| ) | |
| self.attn1.fuse_projections() | |
| del self.attn1.to_q | |
| del self.attn1.to_k | |
| del self.attn1.to_v | |
| # 2. Cross-Attn | |
| if cross_attention_dim is not None and cross_attention_dim > 0: | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=cross_attention_dim, | |
| cross_attention_norm="layer_norm", | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| bias=True, | |
| upcast_attention=upcast_attention, | |
| pre_only=True, | |
| processor=MatryoshkaFusedAttnProcessor2_0(), | |
| ) | |
| self.attn2.fuse_projections() | |
| del self.attn2.to_q | |
| del self.attn2.to_k | |
| del self.attn2.to_v | |
| self.proj_out = nn.Linear(dim, dim) | |
| if use_attention_ffn: | |
| # 3. Feed-forward | |
| self.ff = MatryoshkaFeedForward(dim) | |
| else: | |
| self.ff = None | |
| # let chunk size default to None | |
| self._chunk_size = None | |
| self._chunk_dim = 0 | |
| # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward | |
| def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): | |
| # Sets chunk feed-forward | |
| self._chunk_size = chunk_size | |
| self._chunk_dim = dim | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| if cross_attention_kwargs is not None: | |
| if cross_attention_kwargs.get("scale", None) is not None: | |
| logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
| # 1. Self-Attention | |
| batch_size, channels, *spatial_dims = hidden_states.shape | |
| attn_output, query = self.attn1( | |
| hidden_states, | |
| # **cross_attention_kwargs, | |
| ) | |
| # 2. Cross-Attention | |
| if self.cross_attention_dim is not None and self.cross_attention_dim > 0: | |
| attn_output_cond = self.attn2( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| self_attention_output=attn_output, | |
| self_attention_query=query, | |
| # **cross_attention_kwargs, | |
| ) | |
| attn_output_cond = self.proj_out(attn_output_cond) | |
| attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) | |
| hidden_states = hidden_states + attn_output_cond | |
| if self.ff is not None: | |
| # 3. Feed-forward | |
| if self._chunk_size is not None: | |
| # "feed_forward_chunk_size" can be used to save memory | |
| ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) | |
| else: | |
| ff_output = self.ff(hidden_states) | |
| hidden_states = ff_output + hidden_states | |
| return hidden_states | |
| class MatryoshkaFusedAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses | |
| fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. | |
| For cross-attention modules, key and value projection matrices are fused. | |
| <Tip warning={true}> | |
| This API is currently 🧪 experimental in nature and can change in future. | |
| </Tip> | |
| """ | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| self_attention_query: Optional[torch.Tensor] = None, | |
| self_attention_output: Optional[torch.Tensor] = None, | |
| *args, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| if len(args) > 0 or kwargs.get("scale", None) is not None: | |
| deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | |
| deprecate("scale", "1.0.0", deprecation_message) | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states) | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous() | |
| if encoder_hidden_states is None: | |
| qkv = attn.to_qkv(hidden_states) | |
| split_size = qkv.shape[-1] // 3 | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| if self_attention_query is not None: | |
| query = self_attention_query | |
| else: | |
| query = attn.to_q(hidden_states) | |
| kv = attn.to_kv(encoder_hidden_states) | |
| split_size = kv.shape[-1] // 2 | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| if self_attention_output is None: | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if self_attention_output is not None: | |
| hidden_states = hidden_states + self_attention_output | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states if self_attention_output is not None else (hidden_states, query) | |
| class MatryoshkaFeedForward(nn.Module): | |
| r""" | |
| A feed-forward layer for the Matryoshka models. | |
| Parameters:""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| ): | |
| super().__init__() | |
| self.group_norm = nn.GroupNorm(32, dim) | |
| self.linear_gelu = GELU(dim, dim * 4) | |
| self.linear_out = nn.Linear(dim * 4, dim) | |
| def forward(self, x): | |
| batch_size, channels, *spatial_dims = x.shape | |
| x = self.group_norm(x) | |
| x = x.view(batch_size, channels, -1).permute(0, 2, 1) | |
| x = self.linear_out(self.linear_gelu(x)) | |
| x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) | |
| return x | |
| def get_down_block( | |
| down_block_type: str, | |
| num_layers: int, | |
| in_channels: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| add_downsample: bool, | |
| resnet_eps: float, | |
| resnet_act_fn: str, | |
| norm_type: str = "layer_norm", | |
| transformer_layers_per_block: int = 1, | |
| num_attention_heads: Optional[int] = None, | |
| resnet_groups: Optional[int] = None, | |
| cross_attention_dim: Optional[int] = None, | |
| downsample_padding: Optional[int] = None, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| resnet_time_scale_shift: str = "default", | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| resnet_skip_time_act: bool = False, | |
| resnet_out_scale_factor: float = 1.0, | |
| cross_attention_norm: Optional[str] = None, | |
| attention_head_dim: Optional[int] = None, | |
| use_attention_ffn: bool = True, | |
| downsample_type: Optional[str] = None, | |
| dropout: float = 0.0, | |
| ): | |
| # If attn head dim is not defined, we default it to the number of heads | |
| if attention_head_dim is None: | |
| logger.warning( | |
| f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." | |
| ) | |
| attention_head_dim = num_attention_heads | |
| down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type | |
| if down_block_type == "DownBlock2D": | |
| return DownBlock2D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| add_downsample=add_downsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| downsample_padding=downsample_padding, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| ) | |
| elif down_block_type == "CrossAttnDownBlock2D": | |
| if cross_attention_dim is None: | |
| raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") | |
| return CrossAttnDownBlock2D( | |
| num_layers=num_layers, | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| add_downsample=add_downsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| norm_type=norm_type, | |
| resnet_groups=resnet_groups, | |
| downsample_padding=downsample_padding, | |
| cross_attention_dim=cross_attention_dim, | |
| cross_attention_norm=cross_attention_norm, | |
| num_attention_heads=num_attention_heads, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| use_attention_ffn=use_attention_ffn, | |
| ) | |
| def get_mid_block( | |
| mid_block_type: str, | |
| temb_channels: int, | |
| in_channels: int, | |
| resnet_eps: float, | |
| resnet_act_fn: str, | |
| resnet_groups: int, | |
| norm_type: str = "layer_norm", | |
| output_scale_factor: float = 1.0, | |
| transformer_layers_per_block: int = 1, | |
| num_attention_heads: Optional[int] = None, | |
| cross_attention_dim: Optional[int] = None, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| mid_block_only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| resnet_time_scale_shift: str = "default", | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| resnet_skip_time_act: bool = False, | |
| cross_attention_norm: Optional[str] = None, | |
| attention_head_dim: Optional[int] = 1, | |
| dropout: float = 0.0, | |
| ): | |
| if mid_block_type == "UNetMidBlock2DCrossAttn": | |
| return UNetMidBlock2DCrossAttn( | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| in_channels=in_channels, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| norm_type=norm_type, | |
| output_scale_factor=output_scale_factor, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| cross_attention_dim=cross_attention_dim, | |
| cross_attention_norm=cross_attention_norm, | |
| num_attention_heads=num_attention_heads, | |
| resnet_groups=resnet_groups, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| upcast_attention=upcast_attention, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| ) | |
| def get_up_block( | |
| up_block_type: str, | |
| num_layers: int, | |
| in_channels: int, | |
| out_channels: int, | |
| prev_output_channel: int, | |
| temb_channels: int, | |
| add_upsample: bool, | |
| resnet_eps: float, | |
| resnet_act_fn: str, | |
| norm_type: str = "layer_norm", | |
| resolution_idx: Optional[int] = None, | |
| transformer_layers_per_block: int = 1, | |
| num_attention_heads: Optional[int] = None, | |
| resnet_groups: Optional[int] = None, | |
| cross_attention_dim: Optional[int] = None, | |
| dual_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| resnet_time_scale_shift: str = "default", | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| resnet_skip_time_act: bool = False, | |
| resnet_out_scale_factor: float = 1.0, | |
| cross_attention_norm: Optional[str] = None, | |
| attention_head_dim: Optional[int] = None, | |
| use_attention_ffn: bool = True, | |
| upsample_type: Optional[str] = None, | |
| dropout: float = 0.0, | |
| ) -> nn.Module: | |
| # If attn head dim is not defined, we default it to the number of heads | |
| if attention_head_dim is None: | |
| logger.warning( | |
| f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." | |
| ) | |
| attention_head_dim = num_attention_heads | |
| up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type | |
| if up_block_type == "UpBlock2D": | |
| return UpBlock2D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=temb_channels, | |
| resolution_idx=resolution_idx, | |
| dropout=dropout, | |
| add_upsample=add_upsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| ) | |
| elif up_block_type == "CrossAttnUpBlock2D": | |
| if cross_attention_dim is None: | |
| raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") | |
| return CrossAttnUpBlock2D( | |
| num_layers=num_layers, | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=temb_channels, | |
| resolution_idx=resolution_idx, | |
| dropout=dropout, | |
| add_upsample=add_upsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| norm_type=norm_type, | |
| resnet_groups=resnet_groups, | |
| cross_attention_dim=cross_attention_dim, | |
| cross_attention_norm=cross_attention_norm, | |
| num_attention_heads=num_attention_heads, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| use_attention_ffn=use_attention_ffn, | |
| ) | |
| class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): | |
| def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type): | |
| super().__init__() | |
| if type == "unet": | |
| self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) | |
| elif type == "nested_unet": | |
| self.cond_emb = None | |
| self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) | |
| self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) | |
| def forward(self, emb, encoder_hidden_states, added_cond_kwargs): | |
| conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) | |
| masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) | |
| if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): | |
| if conditioning_mask is None: | |
| y = encoder_hidden_states.mean(dim=1) | |
| else: | |
| y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( | |
| dim=1, keepdim=True | |
| ) | |
| cond_emb = self.cond_emb(y) | |
| else: | |
| cond_emb = None | |
| if not masked_cross_attention: | |
| conditioning_mask = None | |
| micro = added_cond_kwargs.get("micro_conditioning_scale", None) | |
| if micro is not None: | |
| temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) | |
| temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) | |
| # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): | |
| return temb_micro_conditioning, conditioning_mask, cond_emb | |
| return None, conditioning_mask, cond_emb | |
| class MatryoshkaUNet2DConditionOutput(BaseOutput): | |
| """ | |
| The output of [`MatryoshkaUNet2DConditionOutput`]. | |
| Args: | |
| sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): | |
| The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
| """ | |
| sample: torch.Tensor = None | |
| sample_inner: torch.Tensor = None | |
| class MatryoshkaUNet2DConditionModel( | |
| ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin | |
| ): | |
| r""" | |
| A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample | |
| shaped output. | |
| This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
| for all models (such as downloading or saving). | |
| Parameters: | |
| sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
| Height and width of input/output sample. | |
| in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. | |
| out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
| center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. | |
| flip_sin_to_cos (`bool`, *optional*, defaults to `True`): | |
| Whether to flip the sin to cos in the time embedding. | |
| freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. | |
| down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): | |
| The tuple of downsample blocks to use. | |
| mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): | |
| Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or | |
| `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. | |
| up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): | |
| The tuple of upsample blocks to use. | |
| only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): | |
| Whether to include self-attention in the basic transformer blocks, see | |
| [`~models.attention.BasicTransformerBlock`]. | |
| block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
| The tuple of output channels for each block. | |
| layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
| downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. | |
| mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
| norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. | |
| If `None`, normalization and activation layers is skipped in post-processing. | |
| norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. | |
| cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
| The dimension of the cross attention features. | |
| transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
| The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
| [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], | |
| [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
| reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): | |
| The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling | |
| blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for | |
| [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], | |
| [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
| encoder_hid_dim (`int`, *optional*, defaults to None): | |
| If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` | |
| dimension to `cross_attention_dim`. | |
| encoder_hid_dim_type (`str`, *optional*, defaults to `None`): | |
| If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text | |
| embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. | |
| attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. | |
| num_attention_heads (`int`, *optional*): | |
| The number of attention heads. If not defined, defaults to `attention_head_dim` | |
| resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config | |
| for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. | |
| class_embed_type (`str`, *optional*, defaults to `None`): | |
| The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, | |
| `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. | |
| addition_embed_type (`str`, *optional*, defaults to `None`): | |
| Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or | |
| "text". "text" will use the `TextTimeEmbedding` layer. | |
| addition_time_embed_dim: (`int`, *optional*, defaults to `None`): | |
| Dimension for the timestep embeddings. | |
| num_class_embeds (`int`, *optional*, defaults to `None`): | |
| Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing | |
| class conditioning with `class_embed_type` equal to `None`. | |
| time_embedding_type (`str`, *optional*, defaults to `positional`): | |
| The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. | |
| time_embedding_dim (`int`, *optional*, defaults to `None`): | |
| An optional override for the dimension of the projected time embedding. | |
| time_embedding_act_fn (`str`, *optional*, defaults to `None`): | |
| Optional activation function to use only once on the time embeddings before they are passed to the rest of | |
| the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. | |
| timestep_post_act (`str`, *optional*, defaults to `None`): | |
| The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. | |
| time_cond_proj_dim (`int`, *optional*, defaults to `None`): | |
| The dimension of `cond_proj` layer in the timestep embedding. | |
| conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. | |
| conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. | |
| projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when | |
| `class_embed_type="projection"`. Required when `class_embed_type="projection"`. | |
| class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time | |
| embeddings with the class embeddings. | |
| mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): | |
| Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If | |
| `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the | |
| `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` | |
| otherwise. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| _no_split_modules = ["MatryoshkaTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] | |
| def __init__( | |
| self, | |
| sample_size: Optional[int] = None, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| center_input_sample: bool = False, | |
| flip_sin_to_cos: bool = True, | |
| freq_shift: int = 0, | |
| down_block_types: Tuple[str] = ( | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", | |
| up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), | |
| only_cross_attention: Union[bool, Tuple[bool]] = False, | |
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
| layers_per_block: Union[int, Tuple[int]] = 2, | |
| downsample_padding: int = 1, | |
| mid_block_scale_factor: float = 1, | |
| dropout: float = 0.0, | |
| act_fn: str = "silu", | |
| norm_type: str = "layer_norm", | |
| norm_num_groups: Optional[int] = 32, | |
| norm_eps: float = 1e-5, | |
| cross_attention_dim: Union[int, Tuple[int]] = 1280, | |
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
| reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, | |
| encoder_hid_dim: Optional[int] = None, | |
| encoder_hid_dim_type: Optional[str] = None, | |
| attention_head_dim: Union[int, Tuple[int]] = 8, | |
| num_attention_heads: Optional[Union[int, Tuple[int]]] = None, | |
| dual_cross_attention: bool = False, | |
| use_attention_ffn: bool = True, | |
| use_linear_projection: bool = False, | |
| class_embed_type: Optional[str] = None, | |
| addition_embed_type: Optional[str] = None, | |
| addition_time_embed_dim: Optional[int] = None, | |
| num_class_embeds: Optional[int] = None, | |
| upcast_attention: bool = False, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_skip_time_act: bool = False, | |
| resnet_out_scale_factor: float = 1.0, | |
| time_embedding_type: str = "positional", | |
| time_embedding_dim: Optional[int] = None, | |
| time_embedding_act_fn: Optional[str] = None, | |
| timestep_post_act: Optional[str] = None, | |
| time_cond_proj_dim: Optional[int] = None, | |
| conv_in_kernel: int = 3, | |
| conv_out_kernel: int = 3, | |
| projection_class_embeddings_input_dim: Optional[int] = None, | |
| attention_type: str = "default", | |
| attention_pre_only: bool = False, | |
| masked_cross_attention: bool = False, | |
| micro_conditioning_scale: int = None, | |
| class_embeddings_concat: bool = False, | |
| mid_block_only_cross_attention: Optional[bool] = None, | |
| cross_attention_norm: Optional[str] = None, | |
| addition_embed_type_num_heads: int = 64, | |
| temporal_mode: bool = False, | |
| temporal_spatial_ds: bool = False, | |
| skip_cond_emb: bool = False, | |
| nesting: Optional[int] = False, | |
| ): | |
| super().__init__() | |
| self.sample_size = sample_size | |
| if num_attention_heads is not None: | |
| raise ValueError( | |
| "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." | |
| ) | |
| # If `num_attention_heads` is not defined (which is the case for most models) | |
| # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. | |
| # The reason for this behavior is to correct for incorrectly named variables that were introduced | |
| # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | |
| # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking | |
| # which is why we correct for the naming here. | |
| num_attention_heads = num_attention_heads or attention_head_dim | |
| # Check inputs | |
| self._check_config( | |
| down_block_types=down_block_types, | |
| up_block_types=up_block_types, | |
| only_cross_attention=only_cross_attention, | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| cross_attention_dim=cross_attention_dim, | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, | |
| attention_head_dim=attention_head_dim, | |
| num_attention_heads=num_attention_heads, | |
| ) | |
| # input | |
| conv_in_padding = (conv_in_kernel - 1) // 2 | |
| self.conv_in = nn.Conv2d( | |
| in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding | |
| ) | |
| # time | |
| time_embed_dim, timestep_input_dim = self._set_time_proj( | |
| time_embedding_type, | |
| block_out_channels=block_out_channels, | |
| flip_sin_to_cos=flip_sin_to_cos, | |
| freq_shift=freq_shift, | |
| time_embedding_dim=time_embedding_dim, | |
| ) | |
| self.time_embedding = TimestepEmbedding( | |
| time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim, | |
| time_embed_dim, | |
| act_fn=act_fn, | |
| post_act_fn=timestep_post_act, | |
| cond_proj_dim=time_cond_proj_dim, | |
| ) | |
| self._set_encoder_hid_proj( | |
| encoder_hid_dim_type, | |
| cross_attention_dim=cross_attention_dim, | |
| encoder_hid_dim=encoder_hid_dim, | |
| ) | |
| # class embedding | |
| self._set_class_embedding( | |
| class_embed_type, | |
| act_fn=act_fn, | |
| num_class_embeds=num_class_embeds, | |
| projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
| time_embed_dim=time_embed_dim, | |
| timestep_input_dim=timestep_input_dim, | |
| ) | |
| self._set_add_embedding( | |
| addition_embed_type, | |
| addition_embed_type_num_heads=addition_embed_type_num_heads, | |
| addition_time_embed_dim=timestep_input_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| encoder_hid_dim=encoder_hid_dim, | |
| flip_sin_to_cos=flip_sin_to_cos, | |
| freq_shift=freq_shift, | |
| projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
| time_embed_dim=time_embed_dim, | |
| ) | |
| if time_embedding_act_fn is None: | |
| self.time_embed_act = None | |
| else: | |
| self.time_embed_act = get_activation(time_embedding_act_fn) | |
| self.down_blocks = nn.ModuleList([]) | |
| self.up_blocks = nn.ModuleList([]) | |
| if isinstance(only_cross_attention, bool): | |
| if mid_block_only_cross_attention is None: | |
| mid_block_only_cross_attention = only_cross_attention | |
| only_cross_attention = [only_cross_attention] * len(down_block_types) | |
| if mid_block_only_cross_attention is None: | |
| mid_block_only_cross_attention = False | |
| if isinstance(num_attention_heads, int): | |
| num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
| if isinstance(attention_head_dim, int): | |
| attention_head_dim = (attention_head_dim,) * len(down_block_types) | |
| if isinstance(cross_attention_dim, int): | |
| cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | |
| if isinstance(layers_per_block, int): | |
| layers_per_block = [layers_per_block] * len(down_block_types) | |
| if isinstance(transformer_layers_per_block, int): | |
| transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
| if class_embeddings_concat: | |
| # The time embeddings are concatenated with the class embeddings. The dimension of the | |
| # time embeddings passed to the down, middle, and up blocks is twice the dimension of the | |
| # regular time embeddings | |
| blocks_time_embed_dim = time_embed_dim * 2 | |
| else: | |
| blocks_time_embed_dim = time_embed_dim | |
| # down | |
| output_channel = block_out_channels[0] | |
| for i, down_block_type in enumerate(down_block_types): | |
| input_channel = output_channel | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| down_block = get_down_block( | |
| down_block_type, | |
| num_layers=layers_per_block[i], | |
| transformer_layers_per_block=transformer_layers_per_block[i], | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| temb_channels=blocks_time_embed_dim, | |
| add_downsample=not is_final_block, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| norm_type=norm_type, | |
| resnet_groups=norm_num_groups, | |
| cross_attention_dim=cross_attention_dim[i], | |
| num_attention_heads=num_attention_heads[i], | |
| downsample_padding=downsample_padding, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention[i], | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| resnet_skip_time_act=resnet_skip_time_act, | |
| resnet_out_scale_factor=resnet_out_scale_factor, | |
| cross_attention_norm=cross_attention_norm, | |
| use_attention_ffn=use_attention_ffn, | |
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
| dropout=dropout, | |
| ) | |
| self.down_blocks.append(down_block) | |
| # mid | |
| self.mid_block = get_mid_block( | |
| mid_block_type, | |
| temb_channels=blocks_time_embed_dim, | |
| in_channels=block_out_channels[-1], | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| norm_type=norm_type, | |
| resnet_groups=norm_num_groups, | |
| output_scale_factor=mid_block_scale_factor, | |
| transformer_layers_per_block=1, | |
| num_attention_heads=num_attention_heads[-1], | |
| cross_attention_dim=cross_attention_dim[-1], | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| mid_block_only_cross_attention=mid_block_only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| resnet_skip_time_act=resnet_skip_time_act, | |
| cross_attention_norm=cross_attention_norm, | |
| attention_head_dim=attention_head_dim[-1], | |
| dropout=dropout, | |
| ) | |
| # count how many layers upsample the images | |
| self.num_upsamplers = 0 | |
| # up | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
| reversed_layers_per_block = list(reversed(layers_per_block)) | |
| reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | |
| reversed_transformer_layers_per_block = ( | |
| list(reversed(transformer_layers_per_block)) | |
| if reverse_transformer_layers_per_block is None | |
| else reverse_transformer_layers_per_block | |
| ) | |
| only_cross_attention = list(reversed(only_cross_attention)) | |
| output_channel = reversed_block_out_channels[0] | |
| for i, up_block_type in enumerate(up_block_types): | |
| is_final_block = i == len(block_out_channels) - 1 | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
| # add upsample block for all BUT final layer | |
| if not is_final_block: | |
| add_upsample = True | |
| self.num_upsamplers += 1 | |
| else: | |
| add_upsample = False | |
| up_block = get_up_block( | |
| up_block_type, | |
| num_layers=reversed_layers_per_block[i] + 1, | |
| transformer_layers_per_block=reversed_transformer_layers_per_block[i], | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=blocks_time_embed_dim, | |
| add_upsample=add_upsample, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| norm_type=norm_type, | |
| resolution_idx=i, | |
| resnet_groups=norm_num_groups, | |
| cross_attention_dim=reversed_cross_attention_dim[i], | |
| num_attention_heads=reversed_num_attention_heads[i], | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention[i], | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| attention_type=attention_type, | |
| attention_pre_only=attention_pre_only, | |
| resnet_skip_time_act=resnet_skip_time_act, | |
| resnet_out_scale_factor=resnet_out_scale_factor, | |
| cross_attention_norm=cross_attention_norm, | |
| use_attention_ffn=use_attention_ffn, | |
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
| dropout=dropout, | |
| ) | |
| self.up_blocks.append(up_block) | |
| # out | |
| if norm_num_groups is not None: | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
| ) | |
| self.conv_act = get_activation(act_fn) | |
| else: | |
| self.conv_norm_out = None | |
| self.conv_act = None | |
| conv_out_padding = (conv_out_kernel - 1) // 2 | |
| self.conv_out = nn.Conv2d( | |
| block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding | |
| ) | |
| self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) | |
| self.is_temporal = [] | |
| def _check_config( | |
| self, | |
| down_block_types: Tuple[str], | |
| up_block_types: Tuple[str], | |
| only_cross_attention: Union[bool, Tuple[bool]], | |
| block_out_channels: Tuple[int], | |
| layers_per_block: Union[int, Tuple[int]], | |
| cross_attention_dim: Union[int, Tuple[int]], | |
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], | |
| reverse_transformer_layers_per_block: bool, | |
| attention_head_dim: int, | |
| num_attention_heads: Optional[Union[int, Tuple[int]]], | |
| ): | |
| if len(down_block_types) != len(up_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
| ) | |
| if len(block_out_channels) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
| ) | |
| if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." | |
| ) | |
| if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
| ) | |
| if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." | |
| ) | |
| if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | |
| ) | |
| if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | |
| raise ValueError( | |
| f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | |
| ) | |
| if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: | |
| for layer_number_per_block in transformer_layers_per_block: | |
| if isinstance(layer_number_per_block, list): | |
| raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") | |
| def _set_time_proj( | |
| self, | |
| time_embedding_type: str, | |
| block_out_channels: int, | |
| flip_sin_to_cos: bool, | |
| freq_shift: float, | |
| time_embedding_dim: int, | |
| ) -> Tuple[int, int]: | |
| if time_embedding_type == "fourier": | |
| time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 | |
| if time_embed_dim % 2 != 0: | |
| raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") | |
| self.time_proj = GaussianFourierProjection( | |
| time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
| ) | |
| timestep_input_dim = time_embed_dim | |
| elif time_embedding_type == "positional": | |
| time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 | |
| if self.model_type == "unet": | |
| self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
| elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 256: | |
| self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) | |
| elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 1024: | |
| self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift) | |
| timestep_input_dim = block_out_channels[0] | |
| else: | |
| raise ValueError( | |
| f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." | |
| ) | |
| return time_embed_dim, timestep_input_dim | |
| def _set_encoder_hid_proj( | |
| self, | |
| encoder_hid_dim_type: Optional[str], | |
| cross_attention_dim: Union[int, Tuple[int]], | |
| encoder_hid_dim: Optional[int], | |
| ): | |
| if encoder_hid_dim_type is None and encoder_hid_dim is not None: | |
| encoder_hid_dim_type = "text_proj" | |
| self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) | |
| logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") | |
| if encoder_hid_dim is None and encoder_hid_dim_type is not None: | |
| raise ValueError( | |
| f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." | |
| ) | |
| if encoder_hid_dim_type == "text_proj": | |
| self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) | |
| elif encoder_hid_dim_type == "text_image_proj": | |
| # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
| # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
| # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` | |
| self.encoder_hid_proj = TextImageProjection( | |
| text_embed_dim=encoder_hid_dim, | |
| image_embed_dim=cross_attention_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| ) | |
| elif encoder_hid_dim_type == "image_proj": | |
| # Kandinsky 2.2 | |
| self.encoder_hid_proj = ImageProjection( | |
| image_embed_dim=encoder_hid_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| ) | |
| elif encoder_hid_dim_type is not None: | |
| raise ValueError( | |
| f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." | |
| ) | |
| else: | |
| self.encoder_hid_proj = None | |
| def _set_class_embedding( | |
| self, | |
| class_embed_type: Optional[str], | |
| act_fn: str, | |
| num_class_embeds: Optional[int], | |
| projection_class_embeddings_input_dim: Optional[int], | |
| time_embed_dim: int, | |
| timestep_input_dim: int, | |
| ): | |
| if class_embed_type is None and num_class_embeds is not None: | |
| self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) | |
| elif class_embed_type == "timestep": | |
| self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) | |
| elif class_embed_type == "identity": | |
| self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) | |
| elif class_embed_type == "projection": | |
| if projection_class_embeddings_input_dim is None: | |
| raise ValueError( | |
| "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" | |
| ) | |
| # The projection `class_embed_type` is the same as the timestep `class_embed_type` except | |
| # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings | |
| # 2. it projects from an arbitrary input dimension. | |
| # | |
| # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. | |
| # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. | |
| # As a result, `TimestepEmbedding` can be passed arbitrary vectors. | |
| self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
| elif class_embed_type == "simple_projection": | |
| if projection_class_embeddings_input_dim is None: | |
| raise ValueError( | |
| "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" | |
| ) | |
| self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) | |
| else: | |
| self.class_embedding = None | |
| def _set_add_embedding( | |
| self, | |
| addition_embed_type: str, | |
| addition_embed_type_num_heads: int, | |
| addition_time_embed_dim: Optional[int], | |
| flip_sin_to_cos: bool, | |
| freq_shift: float, | |
| cross_attention_dim: Optional[int], | |
| encoder_hid_dim: Optional[int], | |
| projection_class_embeddings_input_dim: Optional[int], | |
| time_embed_dim: int, | |
| ): | |
| if addition_embed_type == "text": | |
| if encoder_hid_dim is not None: | |
| text_time_embedding_from_dim = encoder_hid_dim | |
| else: | |
| text_time_embedding_from_dim = cross_attention_dim | |
| self.add_embedding = TextTimeEmbedding( | |
| text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads | |
| ) | |
| elif addition_embed_type == "matryoshka": | |
| self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( | |
| self.config.time_embedding_dim // 4 | |
| if self.config.time_embedding_dim is not None | |
| else addition_time_embed_dim, | |
| cross_attention_dim, | |
| time_embed_dim, | |
| self.model_type, # if not self.config.nesting else "inner_" + self.model_type, | |
| ) | |
| elif addition_embed_type == "text_image": | |
| # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
| # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
| # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` | |
| self.add_embedding = TextImageTimeEmbedding( | |
| text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim | |
| ) | |
| elif addition_embed_type == "text_time": | |
| self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) | |
| self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
| elif addition_embed_type == "image": | |
| # Kandinsky 2.2 | |
| self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
| elif addition_embed_type == "image_hint": | |
| # Kandinsky 2.2 ControlNet | |
| self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
| elif addition_embed_type is not None: | |
| raise ValueError( | |
| f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." | |
| ) | |
| def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): | |
| if attention_type in ["gated", "gated-text-image"]: | |
| positive_len = 768 | |
| if isinstance(cross_attention_dim, int): | |
| positive_len = cross_attention_dim | |
| elif isinstance(cross_attention_dim, (list, tuple)): | |
| positive_len = cross_attention_dim[0] | |
| feature_type = "text-only" if attention_type == "gated" else "text-image" | |
| self.position_net = GLIGENTextBoundingboxProjection( | |
| positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type | |
| ) | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| # set recursively | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = module.get_processor() | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| def set_default_attn_processor(self): | |
| """ | |
| Disables custom attention processors and sets the default attention implementation. | |
| """ | |
| if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnAddedKVProcessor() | |
| elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnProcessor() | |
| else: | |
| raise ValueError( | |
| f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
| ) | |
| self.set_attn_processor(processor) | |
| def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): | |
| r""" | |
| Enable sliced attention computation. | |
| When this option is enabled, the attention module splits the input tensor in slices to compute attention in | |
| several steps. This is useful for saving some memory in exchange for a small decrease in speed. | |
| Args: | |
| slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | |
| When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | |
| `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | |
| provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | |
| must be a multiple of `slice_size`. | |
| """ | |
| sliceable_head_dims = [] | |
| def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): | |
| if hasattr(module, "set_attention_slice"): | |
| sliceable_head_dims.append(module.sliceable_head_dim) | |
| for child in module.children(): | |
| fn_recursive_retrieve_sliceable_dims(child) | |
| # retrieve number of attention layers | |
| for module in self.children(): | |
| fn_recursive_retrieve_sliceable_dims(module) | |
| num_sliceable_layers = len(sliceable_head_dims) | |
| if slice_size == "auto": | |
| # half the attention head size is usually a good trade-off between | |
| # speed and memory | |
| slice_size = [dim // 2 for dim in sliceable_head_dims] | |
| elif slice_size == "max": | |
| # make smallest slice possible | |
| slice_size = num_sliceable_layers * [1] | |
| slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size | |
| if len(slice_size) != len(sliceable_head_dims): | |
| raise ValueError( | |
| f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" | |
| f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." | |
| ) | |
| for i in range(len(slice_size)): | |
| size = slice_size[i] | |
| dim = sliceable_head_dims[i] | |
| if size is not None and size > dim: | |
| raise ValueError(f"size {size} has to be smaller or equal to {dim}.") | |
| # Recursively walk through all the children. | |
| # Any children which exposes the set_attention_slice method | |
| # gets the message | |
| def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): | |
| if hasattr(module, "set_attention_slice"): | |
| module.set_attention_slice(slice_size.pop()) | |
| for child in module.children(): | |
| fn_recursive_set_attention_slice(child, slice_size) | |
| reversed_slice_size = list(reversed(slice_size)) | |
| for module in self.children(): | |
| fn_recursive_set_attention_slice(module, reversed_slice_size) | |
| def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): | |
| r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. | |
| The suffixes after the scaling factors represent the stage blocks where they are being applied. | |
| Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that | |
| are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. | |
| Args: | |
| s1 (`float`): | |
| Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to | |
| mitigate the "oversmoothing effect" in the enhanced denoising process. | |
| s2 (`float`): | |
| Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to | |
| mitigate the "oversmoothing effect" in the enhanced denoising process. | |
| b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. | |
| b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. | |
| """ | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| setattr(upsample_block, "s1", s1) | |
| setattr(upsample_block, "s2", s2) | |
| setattr(upsample_block, "b1", b1) | |
| setattr(upsample_block, "b2", b2) | |
| def disable_freeu(self): | |
| """Disables the FreeU mechanism.""" | |
| freeu_keys = {"s1", "s2", "b1", "b2"} | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| for k in freeu_keys: | |
| if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |
| setattr(upsample_block, k, None) | |
| def fuse_qkv_projections(self): | |
| """ | |
| Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
| are fused. For cross-attention modules, key and value projection matrices are fused. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| """ | |
| self.original_attn_processors = None | |
| for _, attn_processor in self.attn_processors.items(): | |
| if "Added" in str(attn_processor.__class__.__name__): | |
| raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
| self.original_attn_processors = self.attn_processors | |
| for module in self.modules(): | |
| if isinstance(module, Attention): | |
| module.fuse_projections(fuse=True) | |
| self.set_attn_processor(FusedAttnProcessor2_0()) | |
| def unfuse_qkv_projections(self): | |
| """Disables the fused QKV projection if enabled. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| """ | |
| if self.original_attn_processors is not None: | |
| self.set_attn_processor(self.original_attn_processors) | |
| def get_time_embed( | |
| self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] | |
| ) -> Optional[torch.Tensor]: | |
| timesteps = timestep | |
| if not torch.is_tensor(timesteps): | |
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # This would be a good case for the `match` statement (Python 3.10+) | |
| is_mps = sample.device.type == "mps" | |
| is_npu = sample.device.type == "npu" | |
| if isinstance(timestep, float): | |
| dtype = torch.float32 if (is_mps or is_npu) else torch.float64 | |
| else: | |
| dtype = torch.int32 if (is_mps or is_npu) else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| elif len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(sample.device) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(sample.shape[0]) | |
| t_emb = self.time_proj(timesteps) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=sample.dtype) | |
| return t_emb | |
| def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
| class_emb = None | |
| if self.class_embedding is not None: | |
| if class_labels is None: | |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
| if self.config.class_embed_type == "timestep": | |
| class_labels = self.time_proj(class_labels) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # there might be better ways to encapsulate this. | |
| class_labels = class_labels.to(dtype=sample.dtype) | |
| class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) | |
| return class_emb | |
| def get_aug_embed( | |
| self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
| ) -> Optional[torch.Tensor]: | |
| aug_emb = None | |
| if self.config.addition_embed_type == "text": | |
| aug_emb = self.add_embedding(encoder_hidden_states) | |
| elif self.config.addition_embed_type == "matryoshka": | |
| aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) | |
| elif self.config.addition_embed_type == "text_image": | |
| # Kandinsky 2.1 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| image_embs = added_cond_kwargs.get("image_embeds") | |
| text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) | |
| aug_emb = self.add_embedding(text_embs, image_embs) | |
| elif self.config.addition_embed_type == "text_time": | |
| # SDXL - style | |
| if "text_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| text_embeds = added_cond_kwargs.get("text_embeds") | |
| if "time_ids" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
| ) | |
| time_ids = added_cond_kwargs.get("time_ids") | |
| time_embeds = self.add_time_proj(time_ids.flatten()) | |
| time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
| add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
| add_embeds = add_embeds.to(emb.dtype) | |
| aug_emb = self.add_embedding(add_embeds) | |
| elif self.config.addition_embed_type == "image": | |
| # Kandinsky 2.2 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| image_embs = added_cond_kwargs.get("image_embeds") | |
| aug_emb = self.add_embedding(image_embs) | |
| elif self.config.addition_embed_type == "image_hint": | |
| # Kandinsky 2.2 ControlNet - style | |
| if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" | |
| ) | |
| image_embs = added_cond_kwargs.get("image_embeds") | |
| hint = added_cond_kwargs.get("hint") | |
| aug_emb = self.add_embedding(image_embs, hint) | |
| return aug_emb | |
| def process_encoder_hidden_states( | |
| self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
| ) -> torch.Tensor: | |
| if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": | |
| # Kandinsky 2.1 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": | |
| # Kandinsky 2.2 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| encoder_hidden_states = self.encoder_hid_proj(image_embeds) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: | |
| encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| image_embeds = self.encoder_hid_proj(image_embeds) | |
| encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
| return encoder_hidden_states | |
| def model_type(self) -> str: | |
| return "unet" | |
| def forward( | |
| self, | |
| sample: torch.Tensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| cond_emb: Optional[torch.Tensor] = None, | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, | |
| down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| from_nested: bool = False, | |
| ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: | |
| r""" | |
| The [`NestedUNet2DConditionModel`] forward method. | |
| Args: | |
| sample (`torch.Tensor`): | |
| The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
| timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. | |
| encoder_hidden_states (`torch.Tensor`): | |
| The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
| class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
| Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
| timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
| Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
| through the `self.time_embedding` layer to obtain the timestep embeddings. | |
| attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
| An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
| is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
| negative values to the attention scores corresponding to "discard" tokens. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| added_cond_kwargs: (`dict`, *optional*): | |
| A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that | |
| are passed along to the UNet blocks. | |
| down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
| A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
| mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
| A tensor that if specified is added to the residual of the middle unet block. | |
| down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
| additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | |
| encoder_attention_mask (`torch.Tensor`): | |
| A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | |
| `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | |
| which adds large negative values to the attention scores corresponding to "discard" tokens. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| [`~NestedUNet2DConditionOutput`] or `tuple`: | |
| If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, | |
| otherwise a `tuple` is returned where the first element is the sample tensor. | |
| """ | |
| # By default samples have to be AT least a multiple of the overall upsampling factor. | |
| # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
| # However, the upsampling interpolation output size can be forced to fit any upsampling size | |
| # on the fly if necessary. | |
| default_overall_up_factor = 2**self.num_upsamplers | |
| # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
| forward_upsample_size = False | |
| upsample_size = None | |
| if self.config.nesting: | |
| sample, sample_feat = sample | |
| if isinstance(sample, list) and len(sample) == 1: | |
| sample = sample[0] | |
| for dim in sample.shape[-2:]: | |
| if dim % default_overall_up_factor != 0: | |
| # Forward upsample size to force interpolation output size. | |
| forward_upsample_size = True | |
| break | |
| # ensure attention_mask is a bias, and give it a singleton query_tokens dimension | |
| # expects mask of shape: | |
| # [batch, key_tokens] | |
| # adds singleton query_tokens dimension: | |
| # [batch, 1, key_tokens] | |
| # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
| # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
| # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
| if attention_mask is not None: | |
| # assume that mask is expressed as: | |
| # (1 = keep, 0 = discard) | |
| # convert mask into a bias that can be added to attention scores: | |
| # (keep = +0, discard = -10000.0) | |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # 0. center input if necessary | |
| if self.config.center_input_sample: | |
| sample = 2 * sample - 1.0 | |
| # 1. time | |
| t_emb = self.get_time_embed(sample=sample, timestep=timestep) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) | |
| if class_emb is not None: | |
| if self.config.class_embeddings_concat: | |
| emb = torch.cat([emb, class_emb], dim=-1) | |
| else: | |
| emb = emb + class_emb | |
| added_cond_kwargs = added_cond_kwargs or {} | |
| added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention | |
| added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
| added_cond_kwargs["from_nested"] = from_nested | |
| added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
| if not from_nested: | |
| encoder_hidden_states = self.process_encoder_hidden_states( | |
| encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| else: | |
| aug_emb, encoder_attention_mask, _ = self.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| # convert encoder_attention_mask to a bias the same way we do for attention_mask | |
| if encoder_attention_mask is not None: | |
| encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 | |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
| if self.config.addition_embed_type == "image_hint": | |
| aug_emb, hint = aug_emb | |
| sample = torch.cat([sample, hint], dim=1) | |
| emb = emb + aug_emb + cond_emb if aug_emb is not None else emb | |
| if self.time_embed_act is not None: | |
| emb = self.time_embed_act(emb) | |
| # 2. pre-process | |
| sample = self.conv_in(sample) | |
| if self.config.nesting: | |
| sample = sample + sample_feat | |
| # 2.5 GLIGEN position net | |
| if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: | |
| cross_attention_kwargs = cross_attention_kwargs.copy() | |
| gligen_args = cross_attention_kwargs.pop("gligen") | |
| cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} | |
| # 3. down | |
| # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | |
| # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | |
| if cross_attention_kwargs is not None: | |
| cross_attention_kwargs = cross_attention_kwargs.copy() | |
| lora_scale = cross_attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None | |
| # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets | |
| is_adapter = down_intrablock_additional_residuals is not None | |
| # maintain backward compatibility for legacy usage, where | |
| # T2I-Adapter and ControlNet both use down_block_additional_residuals arg | |
| # but can only use one or the other | |
| if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: | |
| deprecate( | |
| "T2I should not use down_block_additional_residuals", | |
| "1.3.0", | |
| "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ | |
| and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ | |
| for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", | |
| standard_warn=False, | |
| ) | |
| down_intrablock_additional_residuals = down_block_additional_residuals | |
| is_adapter = True | |
| down_block_res_samples = (sample,) | |
| for downsample_block in self.down_blocks: | |
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
| # For t2i-adapter CrossAttnDownBlock2D | |
| additional_residuals = {} | |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
| additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) | |
| sample, res_samples = downsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| encoder_attention_mask=encoder_attention_mask, | |
| **additional_residuals, | |
| ) | |
| else: | |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
| sample += down_intrablock_additional_residuals.pop(0) | |
| down_block_res_samples += res_samples | |
| if is_controlnet: | |
| new_down_block_res_samples = () | |
| for down_block_res_sample, down_block_additional_residual in zip( | |
| down_block_res_samples, down_block_additional_residuals | |
| ): | |
| down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
| new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) | |
| down_block_res_samples = new_down_block_res_samples | |
| # 4. mid | |
| if self.mid_block is not None: | |
| if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: | |
| sample = self.mid_block( | |
| sample, | |
| emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| else: | |
| sample = self.mid_block(sample, emb) | |
| # To support T2I-Adapter-XL | |
| if ( | |
| is_adapter | |
| and len(down_intrablock_additional_residuals) > 0 | |
| and sample.shape == down_intrablock_additional_residuals[0].shape | |
| ): | |
| sample += down_intrablock_additional_residuals.pop(0) | |
| if is_controlnet: | |
| sample = sample + mid_block_additional_residual | |
| # 5. up | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| is_final_block = i == len(self.up_blocks) - 1 | |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
| # if we have not reached the final block and need to forward the | |
| # upsample size, we do it here | |
| if not is_final_block and forward_upsample_size: | |
| upsample_size = down_block_res_samples[-1].shape[2:] | |
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| res_hidden_states_tuple=res_samples, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| upsample_size=upsample_size, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| else: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| res_hidden_states_tuple=res_samples, | |
| upsample_size=upsample_size, | |
| ) | |
| sample_inner = sample | |
| # 6. post-process | |
| if self.conv_norm_out: | |
| sample = self.conv_norm_out(sample_inner) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (sample,) | |
| if self.config.nesting: | |
| return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner) | |
| return MatryoshkaUNet2DConditionOutput(sample=sample) | |
| class NestedUNet2DConditionOutput(BaseOutput): | |
| """ | |
| Output type for the [`NestedUNet2DConditionModel`] model. | |
| """ | |
| sample: list = None | |
| sample_inner: torch.Tensor = None | |
| class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): | |
| """ | |
| Nested UNet model with condition for image denoising. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| out_channels=3, | |
| block_out_channels=(64, 128, 256), | |
| cross_attention_dim=2048, | |
| resnet_time_scale_shift="scale_shift", | |
| down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"), | |
| up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"), | |
| mid_block_type=None, | |
| nesting=False, | |
| flip_sin_to_cos=False, | |
| transformer_layers_per_block=[0, 0, 0], | |
| layers_per_block=[2, 2, 1], | |
| masked_cross_attention=True, | |
| micro_conditioning_scale=256, | |
| addition_embed_type="matryoshka", | |
| skip_normalization=True, | |
| time_embedding_dim=1024, | |
| skip_inner_unet_input=False, | |
| temporal_mode=False, | |
| temporal_spatial_ds=False, | |
| initialize_inner_with_pretrained=None, | |
| use_attention_ffn=False, | |
| act_fn="silu", | |
| addition_embed_type_num_heads=64, | |
| addition_time_embed_dim=None, | |
| attention_head_dim=8, | |
| attention_pre_only=False, | |
| attention_type="default", | |
| center_input_sample=False, | |
| class_embed_type=None, | |
| class_embeddings_concat=False, | |
| conv_in_kernel=3, | |
| conv_out_kernel=3, | |
| cross_attention_norm=None, | |
| downsample_padding=1, | |
| dropout=0.0, | |
| dual_cross_attention=False, | |
| encoder_hid_dim=None, | |
| encoder_hid_dim_type=None, | |
| freq_shift=0, | |
| mid_block_only_cross_attention=None, | |
| mid_block_scale_factor=1, | |
| norm_eps=1e-05, | |
| norm_num_groups=32, | |
| norm_type="layer_norm", | |
| num_attention_heads=None, | |
| num_class_embeds=None, | |
| only_cross_attention=False, | |
| projection_class_embeddings_input_dim=None, | |
| resnet_out_scale_factor=1.0, | |
| resnet_skip_time_act=False, | |
| reverse_transformer_layers_per_block=None, | |
| sample_size=None, | |
| skip_cond_emb=False, | |
| time_cond_proj_dim=None, | |
| time_embedding_act_fn=None, | |
| time_embedding_type="positional", | |
| timestep_post_act=None, | |
| upcast_attention=False, | |
| use_linear_projection=False, | |
| is_temporal=None, | |
| inner_config={}, | |
| ): | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| block_out_channels=block_out_channels, | |
| cross_attention_dim=cross_attention_dim, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| down_block_types=down_block_types, | |
| up_block_types=up_block_types, | |
| mid_block_type=mid_block_type, | |
| nesting=nesting, | |
| flip_sin_to_cos=flip_sin_to_cos, | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| layers_per_block=layers_per_block, | |
| masked_cross_attention=masked_cross_attention, | |
| micro_conditioning_scale=micro_conditioning_scale, | |
| addition_embed_type=addition_embed_type, | |
| time_embedding_dim=time_embedding_dim, | |
| temporal_mode=temporal_mode, | |
| temporal_spatial_ds=temporal_spatial_ds, | |
| use_attention_ffn=use_attention_ffn, | |
| sample_size=sample_size, | |
| ) | |
| # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim | |
| if "inner_config" not in self.config.inner_config: | |
| self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) | |
| else: | |
| self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) | |
| if not self.config.skip_inner_unet_input: | |
| self.in_adapter = nn.Conv2d( | |
| self.config.block_out_channels[-1], | |
| self.config.inner_config["block_out_channels"][0], | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| else: | |
| self.in_adapter = None | |
| self.out_adapter = nn.Conv2d( | |
| self.config.inner_config["block_out_channels"][0], | |
| self.config.block_out_channels[-1], | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)] | |
| if hasattr(self.inner_unet, "is_temporal"): | |
| self.is_temporal = self.is_temporal + self.inner_unet.is_temporal | |
| nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) | |
| if self.is_temporal[0]: | |
| nest_ratio = int(np.sqrt(nest_ratio)) | |
| if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": | |
| self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio | |
| else: | |
| self.nest_ratio = [nest_ratio] | |
| # self.register_modules(inner_unet=self.inner_unet) | |
| def model_type(self): | |
| return "nested_unet" | |
| def forward( | |
| self, | |
| sample: torch.Tensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| cond_emb: Optional[torch.Tensor] = None, | |
| from_nested: bool = False, | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, | |
| down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: | |
| r""" | |
| The [`NestedUNet2DConditionModel`] forward method. | |
| Args: | |
| sample (`torch.Tensor`): | |
| The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
| timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. | |
| encoder_hidden_states (`torch.Tensor`): | |
| The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
| class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
| Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
| timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
| Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
| through the `self.time_embedding` layer to obtain the timestep embeddings. | |
| attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
| An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
| is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
| negative values to the attention scores corresponding to "discard" tokens. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| added_cond_kwargs: (`dict`, *optional*): | |
| A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that | |
| are passed along to the UNet blocks. | |
| down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
| A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
| mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
| A tensor that if specified is added to the residual of the middle unet block. | |
| down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
| additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | |
| encoder_attention_mask (`torch.Tensor`): | |
| A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | |
| `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | |
| which adds large negative values to the attention scores corresponding to "discard" tokens. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| [`~NestedUNet2DConditionOutput`] or `tuple`: | |
| If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, | |
| otherwise a `tuple` is returned where the first element is the sample tensor. | |
| """ | |
| # By default samples have to be AT least a multiple of the overall upsampling factor. | |
| # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
| # However, the upsampling interpolation output size can be forced to fit any upsampling size | |
| # on the fly if necessary. | |
| default_overall_up_factor = 2**self.num_upsamplers | |
| # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
| forward_upsample_size = False | |
| upsample_size = None | |
| if self.config.nesting: | |
| sample, sample_feat = sample | |
| if isinstance(sample, list) and len(sample) == 1: | |
| sample = sample[0] | |
| # 2. input layer (normalize the input) | |
| bsz = [x.size(0) for x in sample] | |
| bh, bl = bsz[0], bsz[1] | |
| x_t_low, sample = sample[1:], sample[0] | |
| for dim in sample.shape[-2:]: | |
| if dim % default_overall_up_factor != 0: | |
| # Forward upsample size to force interpolation output size. | |
| forward_upsample_size = True | |
| break | |
| # ensure attention_mask is a bias, and give it a singleton query_tokens dimension | |
| # expects mask of shape: | |
| # [batch, key_tokens] | |
| # adds singleton query_tokens dimension: | |
| # [batch, 1, key_tokens] | |
| # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
| # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
| # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
| if attention_mask is not None: | |
| # assume that mask is expressed as: | |
| # (1 = keep, 0 = discard) | |
| # convert mask into a bias that can be added to attention scores: | |
| # (keep = +0, discard = -10000.0) | |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # 0. center input if necessary | |
| if self.config.center_input_sample: | |
| sample = 2 * sample - 1.0 | |
| # 1. time | |
| t_emb = self.get_time_embed(sample=sample, timestep=timestep) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) | |
| if class_emb is not None: | |
| if self.config.class_embeddings_concat: | |
| emb = torch.cat([emb, class_emb], dim=-1) | |
| else: | |
| emb = emb + class_emb | |
| if self.inner_unet.model_type == "unet": | |
| added_cond_kwargs = added_cond_kwargs or {} | |
| added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention | |
| added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
| added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
| if not self.config.nesting: | |
| encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( | |
| encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention | |
| aug_emb, __, _ = self.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| else: | |
| aug_emb, cond_mask, _ = self.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| elif self.inner_unet.model_type == "nested_unet": | |
| added_cond_kwargs = added_cond_kwargs or {} | |
| added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention | |
| added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
| added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
| encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( | |
| encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| aug_emb, __, _ = self.get_aug_embed( | |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
| ) | |
| # convert encoder_attention_mask to a bias the same way we do for attention_mask | |
| if encoder_attention_mask is not None: | |
| encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 | |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
| if self.config.addition_embed_type == "image_hint": | |
| aug_emb, hint = aug_emb | |
| sample = torch.cat([sample, hint], dim=1) | |
| emb = emb + aug_emb + cond_emb if aug_emb is not None else emb | |
| if self.time_embed_act is not None: | |
| emb = self.time_embed_act(emb) | |
| if not self.config.skip_normalization: | |
| sample = sample / sample.std((1, 2, 3), keepdims=True) | |
| if isinstance(sample, list) and len(sample) == 1: | |
| sample = sample[0] | |
| sample = self.conv_in(sample) | |
| if self.config.nesting: | |
| sample = sample + sample_feat | |
| # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | |
| # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | |
| if cross_attention_kwargs is not None: | |
| cross_attention_kwargs = cross_attention_kwargs.copy() | |
| lora_scale = cross_attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets | |
| is_adapter = down_intrablock_additional_residuals is not None | |
| # maintain backward compatibility for legacy usage, where | |
| # T2I-Adapter and ControlNet both use down_block_additional_residuals arg | |
| # but can only use one or the other | |
| if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: | |
| deprecate( | |
| "T2I should not use down_block_additional_residuals", | |
| "1.3.0", | |
| "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ | |
| and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ | |
| for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", | |
| standard_warn=False, | |
| ) | |
| down_intrablock_additional_residuals = down_block_additional_residuals | |
| is_adapter = True | |
| # 3. downsample blocks in the outer layers | |
| down_block_res_samples = (sample,) | |
| for downsample_block in self.down_blocks: | |
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
| # For t2i-adapter CrossAttnDownBlock2D | |
| additional_residuals = {} | |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
| additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) | |
| sample, res_samples = downsample_block( | |
| hidden_states=sample, | |
| temb=emb[:bh], | |
| encoder_hidden_states=encoder_hidden_states[:bh], | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, | |
| **additional_residuals, | |
| ) | |
| else: | |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
| sample += down_intrablock_additional_residuals.pop(0) | |
| down_block_res_samples += res_samples | |
| # 4. run inner unet | |
| x_inner = self.in_adapter(sample) if self.in_adapter is not None else None | |
| x_inner = ( | |
| torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner | |
| ) # pad zeros for low-resolutions | |
| inner_unet_output = self.inner_unet( | |
| (x_t_low, x_inner), | |
| timestep, | |
| cond_emb=cond_emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=cond_mask, | |
| from_nested=True, | |
| ) | |
| x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner | |
| x_inner = self.out_adapter(x_inner) | |
| sample = sample + x_inner[:bh] if bh < bl else sample + x_inner | |
| # 5. upsample blocks in the outer layers | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| is_final_block = i == len(self.up_blocks) - 1 | |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
| # if we have not reached the final block and need to forward the | |
| # upsample size, we do it here | |
| if not is_final_block and forward_upsample_size: | |
| upsample_size = down_block_res_samples[-1].shape[2:] | |
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb[:bh], | |
| res_hidden_states_tuple=res_samples, | |
| encoder_hidden_states=encoder_hidden_states[:bh], | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| upsample_size=upsample_size, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, | |
| ) | |
| else: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| res_hidden_states_tuple=res_samples, | |
| upsample_size=upsample_size, | |
| ) | |
| # 6. post-process | |
| if self.conv_norm_out: | |
| sample_out = self.conv_norm_out(sample) | |
| sample_out = self.conv_act(sample_out) | |
| sample_out = self.conv_out(sample_out) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| # 7. output both low and high-res output | |
| if isinstance(x_low, list): | |
| out = [sample_out] + x_low | |
| else: | |
| out = [sample_out, x_low] | |
| if self.config.nesting: | |
| return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) | |
| if not return_dict: | |
| return (out,) | |
| else: | |
| return NestedUNet2DConditionOutput(sample=out) | |
| class MatryoshkaPipelineOutput(BaseOutput): | |
| """ | |
| Output class for Matryoshka pipelines. | |
| Args: | |
| images (`List[PIL.Image.Image]` or `np.ndarray`) | |
| List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, | |
| num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. | |
| """ | |
| images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] | |
| class MatryoshkaPipeline( | |
| DiffusionPipeline, | |
| StableDiffusionMixin, | |
| TextualInversionLoaderMixin, | |
| StableDiffusionLoraLoaderMixin, | |
| IPAdapterMixin, | |
| FromSingleFileMixin, | |
| ): | |
| r""" | |
| Pipeline for text-to-image generation using Matryoshka Diffusion Models. | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
| implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
| The pipeline also inherits the following loading methods: | |
| - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings | |
| - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights | |
| - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights | |
| - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files | |
| - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters | |
| Args: | |
| text_encoder ([`~transformers.T5EncoderModel`]): | |
| Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)). | |
| tokenizer ([`~transformers.T5Tokenizer`]): | |
| A `T5Tokenizer` to tokenize text. | |
| unet ([`MatryoshkaUNet2DConditionModel`]): | |
| A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. | |
| scheduler ([`SchedulerMixin`]): | |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
| [`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md. | |
| feature_extractor ([`~transformers.<AnImageProcessor>`]): | |
| A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. | |
| """ | |
| model_cpu_offload_seq = "text_encoder->image_encoder->unet" | |
| _optional_components = ["unet", "feature_extractor", "image_encoder"] | |
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] | |
| def __init__( | |
| self, | |
| text_encoder: T5EncoderModel, | |
| tokenizer: T5TokenizerFast, | |
| scheduler: MatryoshkaDDIMScheduler, | |
| unet: MatryoshkaUNet2DConditionModel = None, | |
| feature_extractor: CLIPImageProcessor = None, | |
| image_encoder: CLIPVisionModelWithProjection = None, | |
| trust_remote_code: bool = False, | |
| nesting_level: int = 0, | |
| ): | |
| super().__init__() | |
| if nesting_level == 0: | |
| unet = MatryoshkaUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" | |
| ) | |
| elif nesting_level == 1: | |
| unet = NestedUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" | |
| ) | |
| elif nesting_level == 2: | |
| unet = NestedUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" | |
| ) | |
| else: | |
| raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") | |
| if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: | |
| deprecation_message = ( | |
| f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | |
| f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | |
| "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | |
| " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | |
| " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | |
| " file" | |
| ) | |
| deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) | |
| new_config = dict(scheduler.config) | |
| new_config["steps_offset"] = 1 | |
| scheduler._internal_dict = FrozenDict(new_config) | |
| # if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: | |
| # deprecation_message = ( | |
| # f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." | |
| # " `clip_sample` should be set to False in the configuration file. Please make sure to update the" | |
| # " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" | |
| # " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" | |
| # " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" | |
| # ) | |
| # deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) | |
| # new_config = dict(scheduler.config) | |
| # new_config["clip_sample"] = False | |
| # scheduler._internal_dict = FrozenDict(new_config) | |
| is_unet_version_less_0_9_0 = ( | |
| unet is not None | |
| and hasattr(unet.config, "_diffusers_version") | |
| and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") | |
| ) | |
| is_unet_sample_size_less_64 = ( | |
| unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 | |
| ) | |
| if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: | |
| deprecation_message = ( | |
| "The configuration file of the unet has set the default `sample_size` to smaller than" | |
| " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" | |
| " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" | |
| " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" | |
| " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" | |
| " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" | |
| " in the config might lead to incorrect results in future versions. If you have downloaded this" | |
| " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" | |
| " the `unet/config.json` file" | |
| ) | |
| deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) | |
| new_config = dict(unet.config) | |
| new_config["sample_size"] = 64 | |
| unet._internal_dict = FrozenDict(new_config) | |
| if hasattr(unet, "nest_ratio"): | |
| scheduler.scales = unet.nest_ratio + [1] | |
| if nesting_level == 2: | |
| scheduler.schedule_shifted_power = 2.0 | |
| self.register_modules( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| feature_extractor=feature_extractor, | |
| image_encoder=image_encoder, | |
| ) | |
| self.register_to_config(nesting_level=nesting_level) | |
| self.image_processor = VaeImageProcessor(do_resize=False) | |
| def change_nesting_level(self, nesting_level: int): | |
| if nesting_level == 0: | |
| if hasattr(self.unet, "nest_ratio"): | |
| self.scheduler.scales = None | |
| self.unet = MatryoshkaUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" | |
| ).to(self.device) | |
| self.config.nesting_level = 0 | |
| elif nesting_level == 1: | |
| self.unet = NestedUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" | |
| ).to(self.device) | |
| self.config.nesting_level = 1 | |
| self.scheduler.scales = self.unet.nest_ratio + [1] | |
| self.scheduler.schedule_shifted_power = 1.0 | |
| elif nesting_level == 2: | |
| self.unet = NestedUNet2DConditionModel.from_pretrained( | |
| "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" | |
| ).to(self.device) | |
| self.config.nesting_level = 2 | |
| self.scheduler.scales = self.unet.nest_ratio + [1] | |
| self.scheduler.schedule_shifted_power = 2.0 | |
| else: | |
| raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def encode_prompt( | |
| self, | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| do_classifier_free_guidance, | |
| negative_prompt=None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| lora_scale: Optional[float] = None, | |
| clip_skip: Optional[int] = None, | |
| ): | |
| r""" | |
| Encodes the prompt into text encoder hidden states. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| prompt to be encoded | |
| device: (`torch.device`): | |
| torch device | |
| num_images_per_prompt (`int`): | |
| number of images that should be generated per prompt | |
| do_classifier_free_guidance (`bool`): | |
| whether to use classifier free guidance or not | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
| less than `1`). | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
| argument. | |
| lora_scale (`float`, *optional*): | |
| A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
| clip_skip (`int`, *optional*): | |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
| the output of the pre-final layer will be used for computing the prompt embeddings. | |
| """ | |
| # set lora scale so that monkey patched LoRA | |
| # function of text encoder can correctly access it | |
| if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): | |
| self._lora_scale = lora_scale | |
| # dynamically adjust the LoRA scale | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
| else: | |
| scale_lora_layers(self.text_encoder, lora_scale) | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| if prompt_embeds is None: | |
| # textual inversion: process multi-vector tokens if necessary | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
| text_input_ids, untruncated_ids | |
| ): | |
| removed_text = self.tokenizer.batch_decode( | |
| untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] | |
| ) | |
| logger.warning( | |
| "The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to" | |
| f" {self.tokenizer.model_max_length} tokens: {removed_text}" | |
| ) | |
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| prompt_attention_mask = text_inputs.attention_mask.to(device) | |
| else: | |
| prompt_attention_mask = None | |
| if self.text_encoder is not None: | |
| prompt_embeds_dtype = self.text_encoder.dtype | |
| elif self.unet is not None: | |
| prompt_embeds_dtype = self.unet.dtype | |
| else: | |
| prompt_embeds_dtype = prompt_embeds.dtype | |
| # get unconditional embeddings for classifier free guidance | |
| if do_classifier_free_guidance and negative_prompt_embeds is None: | |
| uncond_tokens: List[str] | |
| if negative_prompt is None: | |
| uncond_tokens = [""] * batch_size | |
| elif prompt is not None and type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif isinstance(negative_prompt, str): | |
| uncond_tokens = [negative_prompt] | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| else: | |
| uncond_tokens = negative_prompt | |
| # textual inversion: process multi-vector tokens if necessary | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) | |
| uncond_input = self.tokenizer( | |
| uncond_tokens, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_input.input_ids | |
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| negative_prompt_attention_mask = uncond_input.attention_mask.to(device) | |
| else: | |
| negative_prompt_attention_mask = None | |
| if not do_classifier_free_guidance: | |
| if clip_skip is None: | |
| prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) | |
| prompt_embeds = prompt_embeds[0] | |
| else: | |
| prompt_embeds = self.text_encoder( | |
| text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True | |
| ) | |
| # Access the `hidden_states` first, that contains a tuple of | |
| # all the hidden states from the encoder layers. Then index into | |
| # the tuple to access the hidden states from the desired layer. | |
| prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] | |
| # We also need to apply the final LayerNorm here to not mess with the | |
| # representations. The `last_hidden_states` that we typically use for | |
| # obtaining the final prompt representations passes through the LayerNorm | |
| # layer. | |
| prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) | |
| else: | |
| max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) | |
| if len(text_input_ids[0]) < max_len: | |
| text_input_ids = torch.cat( | |
| [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], | |
| dim=1, | |
| ) | |
| prompt_attention_mask = torch.cat( | |
| [ | |
| prompt_attention_mask, | |
| torch.zeros( | |
| batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| elif len(uncond_input_ids[0]) < max_len: | |
| uncond_input_ids = torch.cat( | |
| [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)], | |
| dim=1, | |
| ) | |
| negative_prompt_attention_mask = torch.cat( | |
| [ | |
| negative_prompt_attention_mask, | |
| torch.zeros( | |
| batch_size, | |
| max_len - len(negative_prompt_attention_mask[0]), | |
| dtype=torch.long, | |
| device=device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0) | |
| cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
| prompt_embeds = self.text_encoder( | |
| cfg_input_ids.to(device), | |
| attention_mask=cfg_attention_mask, | |
| ) | |
| prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | |
| if self.text_encoder is not None: | |
| if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder, lora_scale) | |
| if not do_classifier_free_guidance: | |
| return prompt_embeds, None, prompt_attention_mask, None | |
| return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask | |
| def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): | |
| dtype = next(self.image_encoder.parameters()).dtype | |
| if not isinstance(image, torch.Tensor): | |
| image = self.feature_extractor(image, return_tensors="pt").pixel_values | |
| image = image.to(device=device, dtype=dtype) | |
| if output_hidden_states: | |
| image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | |
| image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | |
| uncond_image_enc_hidden_states = self.image_encoder( | |
| torch.zeros_like(image), output_hidden_states=True | |
| ).hidden_states[-2] | |
| uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( | |
| num_images_per_prompt, dim=0 | |
| ) | |
| return image_enc_hidden_states, uncond_image_enc_hidden_states | |
| else: | |
| image_embeds = self.image_encoder(image).image_embeds | |
| image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
| uncond_image_embeds = torch.zeros_like(image_embeds) | |
| return image_embeds, uncond_image_embeds | |
| def prepare_ip_adapter_image_embeds( | |
| self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance | |
| ): | |
| image_embeds = [] | |
| if do_classifier_free_guidance: | |
| negative_image_embeds = [] | |
| if ip_adapter_image_embeds is None: | |
| if not isinstance(ip_adapter_image, list): | |
| ip_adapter_image = [ip_adapter_image] | |
| if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): | |
| raise ValueError( | |
| f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." | |
| ) | |
| for single_ip_adapter_image, image_proj_layer in zip( | |
| ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers | |
| ): | |
| output_hidden_state = not isinstance(image_proj_layer, ImageProjection) | |
| single_image_embeds, single_negative_image_embeds = self.encode_image( | |
| single_ip_adapter_image, device, 1, output_hidden_state | |
| ) | |
| image_embeds.append(single_image_embeds[None, :]) | |
| if do_classifier_free_guidance: | |
| negative_image_embeds.append(single_negative_image_embeds[None, :]) | |
| else: | |
| for single_image_embeds in ip_adapter_image_embeds: | |
| if do_classifier_free_guidance: | |
| single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) | |
| negative_image_embeds.append(single_negative_image_embeds) | |
| image_embeds.append(single_image_embeds) | |
| ip_adapter_image_embeds = [] | |
| for i, single_image_embeds in enumerate(image_embeds): | |
| single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) | |
| if do_classifier_free_guidance: | |
| single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) | |
| single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) | |
| single_image_embeds = single_image_embeds.to(device=device) | |
| ip_adapter_image_embeds.append(single_image_embeds) | |
| return ip_adapter_image_embeds | |
| def prepare_extra_step_kwargs(self, generator, eta): | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| extra_step_kwargs = {} | |
| if accepts_eta: | |
| extra_step_kwargs["eta"] = eta | |
| # check if the scheduler accepts generator | |
| accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| if accepts_generator: | |
| extra_step_kwargs["generator"] = generator | |
| return extra_step_kwargs | |
| def check_inputs( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt=None, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| ip_adapter_image=None, | |
| ip_adapter_image_embeds=None, | |
| callback_on_step_end_tensor_inputs=None, | |
| ): | |
| if height % 8 != 0 or width % 8 != 0: | |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
| if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): | |
| raise ValueError( | |
| f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | |
| f" {type(callback_steps)}." | |
| ) | |
| if callback_on_step_end_tensor_inputs is not None and not all( | |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
| ): | |
| raise ValueError( | |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
| ) | |
| if prompt is not None and prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
| " only forward one of the two." | |
| ) | |
| elif prompt is None and prompt_embeds is None: | |
| raise ValueError( | |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
| ) | |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
| if negative_prompt is not None and negative_prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
| ) | |
| if prompt_embeds is not None and negative_prompt_embeds is not None: | |
| if prompt_embeds.shape != negative_prompt_embeds.shape: | |
| raise ValueError( | |
| "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
| f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
| f" {negative_prompt_embeds.shape}." | |
| ) | |
| if ip_adapter_image is not None and ip_adapter_image_embeds is not None: | |
| raise ValueError( | |
| "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." | |
| ) | |
| if ip_adapter_image_embeds is not None: | |
| if not isinstance(ip_adapter_image_embeds, list): | |
| raise ValueError( | |
| f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" | |
| ) | |
| elif ip_adapter_image_embeds[0].ndim not in [3, 4]: | |
| raise ValueError( | |
| f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" | |
| ) | |
| def prepare_latents( | |
| self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None | |
| ): | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| int(height), | |
| int(width), | |
| ) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| if scales is not None: | |
| out = [latents] | |
| for s in scales[1:]: | |
| ratio = scales[0] // s | |
| sample_low = F.avg_pool2d(latents, ratio) * ratio | |
| sample_low = sample_low.normal_(generator=generator) | |
| out += [sample_low] | |
| latents = out | |
| else: | |
| if scales is not None: | |
| latents = [latent.to(device=device) for latent in latents] | |
| else: | |
| latents = latents.to(device) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| if scales is not None: | |
| latents = [latent * self.scheduler.init_noise_sigma for latent in latents] | |
| else: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| return latents | |
| # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding | |
| def get_guidance_scale_embedding( | |
| self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 | |
| ) -> torch.Tensor: | |
| """ | |
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
| Args: | |
| w (`torch.Tensor`): | |
| Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. | |
| embedding_dim (`int`, *optional*, defaults to 512): | |
| Dimension of the embeddings to generate. | |
| dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): | |
| Data type of the generated embeddings. | |
| Returns: | |
| `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. | |
| """ | |
| assert len(w.shape) == 1 | |
| w = w * 1000.0 | |
| half_dim = embedding_dim // 2 | |
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
| emb = w.to(dtype)[:, None] * emb[None, :] | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
| if embedding_dim % 2 == 1: # zero pad | |
| emb = torch.nn.functional.pad(emb, (0, 1)) | |
| assert emb.shape == (w.shape[0], embedding_dim) | |
| return emb | |
| def guidance_scale(self): | |
| return self._guidance_scale | |
| def guidance_rescale(self): | |
| return self._guidance_rescale | |
| def clip_skip(self): | |
| return self._clip_skip | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | |
| def cross_attention_kwargs(self): | |
| return self._cross_attention_kwargs | |
| def num_timesteps(self): | |
| return self._num_timesteps | |
| def interrupt(self): | |
| return self._interrupt | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| timesteps: List[int] = None, | |
| sigmas: List[float] = None, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| ip_adapter_image: Optional[PipelineImageInput] = None, | |
| ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guidance_rescale: float = 0.0, | |
| clip_skip: Optional[int] = None, | |
| callback_on_step_end: Optional[ | |
| Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] | |
| ] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| **kwargs, | |
| ): | |
| r""" | |
| The call function to the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
| height (`int`, *optional*, defaults to `self.unet.config.sample_size`): | |
| The height in pixels of the generated image. | |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size`): | |
| The width in pixels of the generated image. | |
| num_inference_steps (`int`, *optional*, defaults to 50): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument | |
| in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is | |
| passed will be used. Must be in descending order. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in | |
| their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed | |
| will be used. | |
| guidance_scale (`float`, *optional*, defaults to 7.5): | |
| A higher guidance scale value encourages the model to generate images closely linked to the text | |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| eta (`float`, *optional*, defaults to 0.0): | |
| Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies | |
| to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
| generation deterministic. | |
| latents (`torch.Tensor`, *optional*): | |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor is generated by sampling using the supplied random `generator`. | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
| provided, text embeddings are generated from the `prompt` input argument. | |
| negative_prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
| not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
| ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. | |
| ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): | |
| Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of | |
| IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should | |
| contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not | |
| provided, embeddings are computed from the `ip_adapter_image` input argument. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
| plain tuple. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| guidance_rescale (`float`, *optional*, defaults to 0.0): | |
| Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are | |
| Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when | |
| using zero terminal SNR. | |
| clip_skip (`int`, *optional*): | |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
| the output of the pre-final layer will be used for computing the prompt embeddings. | |
| callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): | |
| A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of | |
| each denoising step during the inference. with the following arguments: `callback_on_step_end(self: | |
| DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a | |
| list of all tensors as specified by `callback_on_step_end_tensor_inputs`. | |
| callback_on_step_end_tensor_inputs (`List`, *optional*): | |
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
| `._callback_tensor_inputs` attribute of your pipeline class. | |
| Examples: | |
| Returns: | |
| [`~MatryoshkaPipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned, | |
| otherwise a `tuple` is returned where the first element is a list with the generated images and the | |
| second element is a list of `bool`s indicating whether the corresponding generated image contains | |
| "not-safe-for-work" (nsfw) content. | |
| """ | |
| callback = kwargs.pop("callback", None) | |
| callback_steps = kwargs.pop("callback_steps", None) | |
| if callback is not None: | |
| deprecate( | |
| "callback", | |
| "1.0.0", | |
| "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | |
| ) | |
| if callback_steps is not None: | |
| deprecate( | |
| "callback_steps", | |
| "1.0.0", | |
| "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | |
| ) | |
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size | |
| width = width or self.unet.config.sample_size | |
| # to deal with lora scaling and other possible forward hooks | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| ip_adapter_image, | |
| ip_adapter_image_embeds, | |
| callback_on_step_end_tensor_inputs, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._guidance_rescale = guidance_rescale | |
| self._clip_skip = clip_skip | |
| self._cross_attention_kwargs = cross_attention_kwargs | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # 3. Encode input prompt | |
| lora_scale = ( | |
| self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | |
| ) | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| prompt_attention_mask, | |
| negative_prompt_attention_mask, | |
| ) = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| lora_scale=lora_scale, | |
| clip_skip=self.clip_skip, | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)]) | |
| attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) | |
| else: | |
| attention_masks = prompt_attention_mask | |
| prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) | |
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
| image_embeds = self.prepare_ip_adapter_image_embeds( | |
| ip_adapter_image, | |
| ip_adapter_image_embeds, | |
| device, | |
| batch_size * num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| ) | |
| # 4. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps, sigmas | |
| ) | |
| timesteps = timesteps[:-1] | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| self.scheduler.scales, | |
| latents, | |
| ) | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| extra_step_kwargs |= {"use_clipped_model_output": True} | |
| # 6.1 Add image embeds for IP-Adapter | |
| added_cond_kwargs = ( | |
| {"image_embeds": image_embeds} | |
| if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) | |
| else None | |
| ) | |
| # 6.2 Optionally get Guidance Scale Embedding | |
| timestep_cond = None | |
| if self.unet.config.time_cond_proj_dim is not None: | |
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
| timestep_cond = self.get_guidance_scale_embedding( | |
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
| ).to(device=device, dtype=latents.dtype) | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| # expand the latents if we are doing classifier free guidance | |
| if self.do_classifier_free_guidance and isinstance(latents, list): | |
| latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents] | |
| elif self.do_classifier_free_guidance: | |
| latent_model_input = latents.repeat(2, 1, 1, 1) | |
| else: | |
| latent_model_input = latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t - 1, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep_cond=timestep_cond, | |
| cross_attention_kwargs=self.cross_attention_kwargs, | |
| added_cond_kwargs=added_cond_kwargs, | |
| encoder_attention_mask=attention_masks, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if isinstance(noise_pred, list) and self.do_classifier_free_guidance: | |
| for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred): | |
| noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| elif self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://huggingface.co/papers/2305.08891 | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, t, latents) | |
| if XLA_AVAILABLE: | |
| xm.mark_step() | |
| image = latents | |
| if self.scheduler.scales is not None: | |
| for i, img in enumerate(image): | |
| image[i] = self.image_processor.postprocess(img, output_type=output_type)[0] | |
| else: | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return MatryoshkaPipelineOutput(images=image) | |