Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a5c3b38
1
Parent(s):
296d048
small fixes pipeline
Browse files- mixture_tiling_sdxl.py +60 -73
mixture_tiling_sdxl.py
CHANGED
|
@@ -12,23 +12,20 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
-
from enum import Enum
|
| 16 |
import inspect
|
|
|
|
| 17 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
from transformers import (
|
| 21 |
-
CLIPImageProcessor,
|
| 22 |
CLIPTextModel,
|
| 23 |
CLIPTextModelWithProjection,
|
| 24 |
CLIPTokenizer,
|
| 25 |
-
CLIPVisionModelWithProjection,
|
| 26 |
)
|
| 27 |
|
| 28 |
from diffusers.image_processor import VaeImageProcessor
|
| 29 |
from diffusers.loaders import (
|
| 30 |
FromSingleFileMixin,
|
| 31 |
-
IPAdapterMixin,
|
| 32 |
StableDiffusionXLLoraLoaderMixin,
|
| 33 |
TextualInversionLoaderMixin,
|
| 34 |
)
|
|
@@ -39,6 +36,8 @@ from diffusers.models.attention_processor import (
|
|
| 39 |
XFormersAttnProcessor,
|
| 40 |
)
|
| 41 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
|
|
|
|
|
| 42 |
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
| 43 |
from diffusers.utils import (
|
| 44 |
USE_PEFT_BACKEND,
|
|
@@ -50,11 +49,10 @@ from diffusers.utils import (
|
|
| 50 |
unscale_lora_layers,
|
| 51 |
)
|
| 52 |
from diffusers.utils.torch_utils import randn_tensor
|
| 53 |
-
|
| 54 |
-
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 55 |
|
| 56 |
try:
|
| 57 |
-
from ligo.segments import segment
|
| 58 |
except ImportError:
|
| 59 |
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
| 60 |
|
|
@@ -87,6 +85,7 @@ EXAMPLE_DOC_STRING = """
|
|
| 87 |
```
|
| 88 |
"""
|
| 89 |
|
|
|
|
| 90 |
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
| 91 |
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
|
| 92 |
|
|
@@ -151,6 +150,7 @@ def _tile2latent_exclusive_indices(
|
|
| 151 |
# return row_init, row_end, col_init, col_end
|
| 152 |
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
|
| 153 |
|
|
|
|
| 154 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 155 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 156 |
r"""
|
|
@@ -210,7 +210,7 @@ def retrieve_timesteps(
|
|
| 210 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 211 |
second element is the number of inference steps.
|
| 212 |
"""
|
| 213 |
-
|
| 214 |
if timesteps is not None and sigmas is not None:
|
| 215 |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 216 |
if timesteps is not None:
|
|
@@ -245,7 +245,6 @@ class StableDiffusionXLTilingPipeline(
|
|
| 245 |
FromSingleFileMixin,
|
| 246 |
StableDiffusionXLLoraLoaderMixin,
|
| 247 |
TextualInversionLoaderMixin,
|
| 248 |
-
IPAdapterMixin,
|
| 249 |
):
|
| 250 |
r"""
|
| 251 |
Pipeline for text-to-image generation using Stable Diffusion XL.
|
|
@@ -258,7 +257,6 @@ class StableDiffusionXLTilingPipeline(
|
|
| 258 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 259 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 260 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 261 |
-
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 262 |
|
| 263 |
Args:
|
| 264 |
vae ([`AutoencoderKL`]):
|
|
@@ -298,10 +296,8 @@ class StableDiffusionXLTilingPipeline(
|
|
| 298 |
"tokenizer_2",
|
| 299 |
"text_encoder",
|
| 300 |
"text_encoder_2",
|
| 301 |
-
"image_encoder",
|
| 302 |
-
"feature_extractor",
|
| 303 |
]
|
| 304 |
-
|
| 305 |
def __init__(
|
| 306 |
self,
|
| 307 |
vae: AutoencoderKL,
|
|
@@ -311,8 +307,6 @@ class StableDiffusionXLTilingPipeline(
|
|
| 311 |
tokenizer_2: CLIPTokenizer,
|
| 312 |
unet: UNet2DConditionModel,
|
| 313 |
scheduler: KarrasDiffusionSchedulers,
|
| 314 |
-
image_encoder: CLIPVisionModelWithProjection = None,
|
| 315 |
-
feature_extractor: CLIPImageProcessor = None,
|
| 316 |
force_zeros_for_empty_prompt: bool = True,
|
| 317 |
add_watermarker: Optional[bool] = None,
|
| 318 |
):
|
|
@@ -326,8 +320,6 @@ class StableDiffusionXLTilingPipeline(
|
|
| 326 |
tokenizer_2=tokenizer_2,
|
| 327 |
unet=unet,
|
| 328 |
scheduler=scheduler,
|
| 329 |
-
image_encoder=image_encoder,
|
| 330 |
-
feature_extractor=feature_extractor,
|
| 331 |
)
|
| 332 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 333 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
@@ -351,7 +343,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 351 |
|
| 352 |
FULL = "full"
|
| 353 |
EXCLUSIVE = "exclusive"
|
| 354 |
-
|
| 355 |
def encode_prompt(
|
| 356 |
self,
|
| 357 |
prompt: str,
|
|
@@ -589,14 +581,14 @@ class StableDiffusionXLTilingPipeline(
|
|
| 589 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 590 |
|
| 591 |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 592 |
-
|
| 593 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 594 |
def prepare_extra_step_kwargs(self, generator, eta):
|
| 595 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 596 |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 597 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 598 |
# and should be between [0, 1]
|
| 599 |
-
|
| 600 |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 601 |
extra_step_kwargs = {}
|
| 602 |
if accepts_eta:
|
|
@@ -608,15 +600,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 608 |
extra_step_kwargs["generator"] = generator
|
| 609 |
return extra_step_kwargs
|
| 610 |
|
| 611 |
-
def check_inputs(
|
| 612 |
-
self,
|
| 613 |
-
prompt,
|
| 614 |
-
height,
|
| 615 |
-
width,
|
| 616 |
-
grid_cols,
|
| 617 |
-
seed_tiles_mode,
|
| 618 |
-
tiles_mode
|
| 619 |
-
):
|
| 620 |
if height % 8 != 0 or width % 8 != 0:
|
| 621 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 622 |
|
|
@@ -625,18 +609,18 @@ class StableDiffusionXLTilingPipeline(
|
|
| 625 |
|
| 626 |
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
|
| 627 |
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
|
| 628 |
-
|
| 629 |
if not all(len(row) == grid_cols for row in prompt):
|
| 630 |
raise ValueError("All prompt rows must have the same number of prompt columns")
|
| 631 |
-
|
| 632 |
if not isinstance(seed_tiles_mode, str) and (
|
| 633 |
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
|
| 634 |
):
|
| 635 |
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
|
| 636 |
-
|
| 637 |
if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
|
| 638 |
raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
|
| 639 |
-
|
| 640 |
def _get_add_time_ids(
|
| 641 |
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
| 642 |
):
|
|
@@ -678,8 +662,8 @@ class StableDiffusionXLTilingPipeline(
|
|
| 678 |
weights_np = np.outer(y_probs, x_probs)
|
| 679 |
weights_torch = torch.tensor(weights_np, device=device)
|
| 680 |
weights_torch = weights_torch.to(dtype)
|
| 681 |
-
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
| 682 |
-
|
| 683 |
def upcast_vae(self):
|
| 684 |
dtype = self.vae.dtype
|
| 685 |
self.vae.to(dtype=torch.float32)
|
|
@@ -760,25 +744,25 @@ class StableDiffusionXLTilingPipeline(
|
|
| 760 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 761 |
def __call__(
|
| 762 |
self,
|
| 763 |
-
prompt: Union[str, List[str]] = None,
|
| 764 |
height: Optional[int] = None,
|
| 765 |
width: Optional[int] = None,
|
| 766 |
-
num_inference_steps: int = 50,
|
| 767 |
guidance_scale: float = 5.0,
|
| 768 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 769 |
num_images_per_prompt: Optional[int] = 1,
|
| 770 |
eta: float = 0.0,
|
| 771 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 772 |
output_type: Optional[str] = "pil",
|
| 773 |
return_dict: bool = True,
|
| 774 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 775 |
original_size: Optional[Tuple[int, int]] = None,
|
| 776 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 777 |
target_size: Optional[Tuple[int, int]] = None,
|
| 778 |
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 779 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 780 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 781 |
-
clip_skip: Optional[int] = None,
|
| 782 |
tile_height: Optional[int] = 1024,
|
| 783 |
tile_width: Optional[int] = 1024,
|
| 784 |
tile_row_overlap: Optional[int] = 128,
|
|
@@ -786,7 +770,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 786 |
guidance_scale_tiles: Optional[List[List[float]]] = None,
|
| 787 |
seed_tiles: Optional[List[List[int]]] = None,
|
| 788 |
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
|
| 789 |
-
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
|
| 790 |
**kwargs,
|
| 791 |
):
|
| 792 |
r"""
|
|
@@ -795,7 +779,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 795 |
Args:
|
| 796 |
prompt (`str` or `List[str]`, *optional*):
|
| 797 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 798 |
-
instead.
|
| 799 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 800 |
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 801 |
Anything below 512 pixels won't work well for
|
|
@@ -808,7 +792,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 808 |
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 809 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 810 |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 811 |
-
expense of slower inference.
|
| 812 |
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 813 |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 814 |
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
@@ -818,7 +802,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 818 |
negative_prompt (`str` or `List[str]`, *optional*):
|
| 819 |
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 820 |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 821 |
-
less than `1`).
|
| 822 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 823 |
The number of images to generate per prompt.
|
| 824 |
eta (`float`, *optional*, defaults to 0.0):
|
|
@@ -826,7 +810,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 826 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 827 |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 828 |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 829 |
-
to make generation deterministic.
|
| 830 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 831 |
The output format of the generate image. Choose between
|
| 832 |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
@@ -836,7 +820,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 836 |
cross_attention_kwargs (`dict`, *optional*):
|
| 837 |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 838 |
`self.processor` in
|
| 839 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 840 |
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 841 |
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 842 |
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
|
@@ -881,10 +865,10 @@ class StableDiffusionXLTilingPipeline(
|
|
| 881 |
seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
|
| 882 |
Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
|
| 883 |
seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
|
| 884 |
-
A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
|
| 885 |
**kwargs (`Dict[str, Any]`, *optional*):
|
| 886 |
Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
|
| 887 |
-
|
| 888 |
Examples:
|
| 889 |
|
| 890 |
Returns:
|
|
@@ -902,7 +886,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 902 |
|
| 903 |
self._guidance_scale = guidance_scale
|
| 904 |
self._clip_skip = clip_skip
|
| 905 |
-
self._cross_attention_kwargs = cross_attention_kwargs
|
| 906 |
self._interrupt = False
|
| 907 |
|
| 908 |
grid_rows = len(prompt)
|
|
@@ -912,12 +896,12 @@ class StableDiffusionXLTilingPipeline(
|
|
| 912 |
|
| 913 |
if isinstance(seed_tiles_mode, str):
|
| 914 |
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
|
| 915 |
-
|
| 916 |
# 1. Check inputs. Raise error if not correct
|
| 917 |
self.check_inputs(
|
| 918 |
-
prompt,
|
| 919 |
height,
|
| 920 |
-
width,
|
| 921 |
grid_cols,
|
| 922 |
seed_tiles_mode,
|
| 923 |
tiles_mode,
|
|
@@ -933,7 +917,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 933 |
# update height and width tile size and tile overlap size
|
| 934 |
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
|
| 935 |
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
|
| 936 |
-
|
| 937 |
# 3. Encode input prompt
|
| 938 |
lora_scale = (
|
| 939 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
@@ -941,11 +925,11 @@ class StableDiffusionXLTilingPipeline(
|
|
| 941 |
text_embeddings = [
|
| 942 |
[
|
| 943 |
self.encode_prompt(
|
| 944 |
-
prompt=col,
|
| 945 |
device=device,
|
| 946 |
num_images_per_prompt=num_images_per_prompt,
|
| 947 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 948 |
-
negative_prompt=negative_prompt,
|
| 949 |
prompt_embeds=None,
|
| 950 |
negative_prompt_embeds=None,
|
| 951 |
pooled_prompt_embeds=None,
|
|
@@ -956,10 +940,10 @@ class StableDiffusionXLTilingPipeline(
|
|
| 956 |
for col in row
|
| 957 |
]
|
| 958 |
for row in prompt
|
| 959 |
-
]
|
| 960 |
|
| 961 |
# 3. Prepare latents
|
| 962 |
-
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
| 963 |
dtype = text_embeddings[0][0][0].dtype
|
| 964 |
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
| 965 |
|
|
@@ -1008,7 +992,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1008 |
extra_set_kwargs["offset"] = 1
|
| 1009 |
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1010 |
self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
|
| 1011 |
-
)
|
| 1012 |
|
| 1013 |
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
| 1014 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
@@ -1023,7 +1007,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1023 |
for row in range(grid_rows):
|
| 1024 |
addition_embed_type_row = []
|
| 1025 |
for col in range(grid_cols):
|
| 1026 |
-
#extract generated values
|
| 1027 |
prompt_embeds = text_embeddings[row][col][0]
|
| 1028 |
negative_prompt_embeds = text_embeddings[row][col][1]
|
| 1029 |
pooled_prompt_embeds = text_embeddings[row][col][2]
|
|
@@ -1051,7 +1035,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1051 |
)
|
| 1052 |
else:
|
| 1053 |
negative_add_time_ids = add_time_ids
|
| 1054 |
-
|
| 1055 |
if self.do_classifier_free_guidance:
|
| 1056 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1057 |
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
|
@@ -1062,14 +1046,14 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1062 |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1063 |
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
| 1064 |
embeddings_and_added_time.append(addition_embed_type_row)
|
| 1065 |
-
|
| 1066 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1067 |
|
| 1068 |
# 7. Mask for tile weights strength
|
| 1069 |
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
|
| 1070 |
|
| 1071 |
# 8. Denoising loop
|
| 1072 |
-
self._num_timesteps = len(timesteps)
|
| 1073 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1074 |
for i, t in enumerate(timesteps):
|
| 1075 |
# Diffuse each tile
|
|
@@ -1084,16 +1068,21 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1084 |
)
|
| 1085 |
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
|
| 1086 |
# expand the latents if we are doing classifier free guidance
|
| 1087 |
-
latent_model_input =
|
|
|
|
|
|
|
| 1088 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1089 |
|
| 1090 |
# predict the noise residual
|
| 1091 |
-
added_cond_kwargs = {
|
| 1092 |
-
|
|
|
|
|
|
|
|
|
|
| 1093 |
noise_pred = self.unet(
|
| 1094 |
latent_model_input,
|
| 1095 |
t,
|
| 1096 |
-
encoder_hidden_states=embeddings_and_added_time[row][col][0],
|
| 1097 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1098 |
added_cond_kwargs=added_cond_kwargs,
|
| 1099 |
return_dict=False,
|
|
@@ -1110,7 +1099,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1110 |
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
| 1111 |
noise_preds_row.append(noise_pred_tile)
|
| 1112 |
noise_preds.append(noise_preds_row)
|
| 1113 |
-
|
| 1114 |
# Stitch noise predictions for all tiles
|
| 1115 |
noise_pred = torch.zeros(latents.shape, device=device)
|
| 1116 |
contributors = torch.zeros(latents.shape, device=device)
|
|
@@ -1140,7 +1129,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1140 |
|
| 1141 |
# update progress bar
|
| 1142 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1143 |
-
progress_bar.update()
|
| 1144 |
|
| 1145 |
if XLA_AVAILABLE:
|
| 1146 |
xm.mark_step()
|
|
@@ -1173,7 +1162,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1173 |
latents = latents / self.vae.config.scaling_factor
|
| 1174 |
|
| 1175 |
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1176 |
-
|
| 1177 |
# cast back to fp16 if needed
|
| 1178 |
if needs_upcasting:
|
| 1179 |
self.vae.to(dtype=torch.float16)
|
|
@@ -1184,7 +1173,7 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1184 |
# apply watermark if available
|
| 1185 |
if self.watermark is not None:
|
| 1186 |
image = self.watermark.apply_watermark(image)
|
| 1187 |
-
|
| 1188 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1189 |
|
| 1190 |
# Offload all models
|
|
@@ -1194,5 +1183,3 @@ class StableDiffusionXLTilingPipeline(
|
|
| 1194 |
return (image,)
|
| 1195 |
|
| 1196 |
return StableDiffusionXLPipelineOutput(images=image)
|
| 1197 |
-
|
| 1198 |
-
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
|
|
|
| 15 |
import inspect
|
| 16 |
+
from enum import Enum
|
| 17 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
from transformers import (
|
|
|
|
| 21 |
CLIPTextModel,
|
| 22 |
CLIPTextModelWithProjection,
|
| 23 |
CLIPTokenizer,
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
from diffusers.image_processor import VaeImageProcessor
|
| 27 |
from diffusers.loaders import (
|
| 28 |
FromSingleFileMixin,
|
|
|
|
| 29 |
StableDiffusionXLLoraLoaderMixin,
|
| 30 |
TextualInversionLoaderMixin,
|
| 31 |
)
|
|
|
|
| 36 |
XFormersAttnProcessor,
|
| 37 |
)
|
| 38 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 40 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 41 |
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
| 42 |
from diffusers.utils import (
|
| 43 |
USE_PEFT_BACKEND,
|
|
|
|
| 49 |
unscale_lora_layers,
|
| 50 |
)
|
| 51 |
from diffusers.utils.torch_utils import randn_tensor
|
| 52 |
+
|
|
|
|
| 53 |
|
| 54 |
try:
|
| 55 |
+
from ligo.segments import segment
|
| 56 |
except ImportError:
|
| 57 |
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
| 58 |
|
|
|
|
| 85 |
```
|
| 86 |
"""
|
| 87 |
|
| 88 |
+
|
| 89 |
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
|
| 90 |
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
|
| 91 |
|
|
|
|
| 150 |
# return row_init, row_end, col_init, col_end
|
| 151 |
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
|
| 152 |
|
| 153 |
+
|
| 154 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 155 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 156 |
r"""
|
|
|
|
| 210 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 211 |
second element is the number of inference steps.
|
| 212 |
"""
|
| 213 |
+
|
| 214 |
if timesteps is not None and sigmas is not None:
|
| 215 |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 216 |
if timesteps is not None:
|
|
|
|
| 245 |
FromSingleFileMixin,
|
| 246 |
StableDiffusionXLLoraLoaderMixin,
|
| 247 |
TextualInversionLoaderMixin,
|
|
|
|
| 248 |
):
|
| 249 |
r"""
|
| 250 |
Pipeline for text-to-image generation using Stable Diffusion XL.
|
|
|
|
| 257 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 258 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 259 |
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
|
|
|
| 260 |
|
| 261 |
Args:
|
| 262 |
vae ([`AutoencoderKL`]):
|
|
|
|
| 296 |
"tokenizer_2",
|
| 297 |
"text_encoder",
|
| 298 |
"text_encoder_2",
|
|
|
|
|
|
|
| 299 |
]
|
| 300 |
+
|
| 301 |
def __init__(
|
| 302 |
self,
|
| 303 |
vae: AutoencoderKL,
|
|
|
|
| 307 |
tokenizer_2: CLIPTokenizer,
|
| 308 |
unet: UNet2DConditionModel,
|
| 309 |
scheduler: KarrasDiffusionSchedulers,
|
|
|
|
|
|
|
| 310 |
force_zeros_for_empty_prompt: bool = True,
|
| 311 |
add_watermarker: Optional[bool] = None,
|
| 312 |
):
|
|
|
|
| 320 |
tokenizer_2=tokenizer_2,
|
| 321 |
unet=unet,
|
| 322 |
scheduler=scheduler,
|
|
|
|
|
|
|
| 323 |
)
|
| 324 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 325 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
|
|
| 343 |
|
| 344 |
FULL = "full"
|
| 345 |
EXCLUSIVE = "exclusive"
|
| 346 |
+
|
| 347 |
def encode_prompt(
|
| 348 |
self,
|
| 349 |
prompt: str,
|
|
|
|
| 581 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 582 |
|
| 583 |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 584 |
+
|
| 585 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 586 |
def prepare_extra_step_kwargs(self, generator, eta):
|
| 587 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 588 |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 589 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 590 |
# and should be between [0, 1]
|
| 591 |
+
|
| 592 |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 593 |
extra_step_kwargs = {}
|
| 594 |
if accepts_eta:
|
|
|
|
| 600 |
extra_step_kwargs["generator"] = generator
|
| 601 |
return extra_step_kwargs
|
| 602 |
|
| 603 |
+
def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
if height % 8 != 0 or width % 8 != 0:
|
| 605 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 606 |
|
|
|
|
| 609 |
|
| 610 |
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
|
| 611 |
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
|
| 612 |
+
|
| 613 |
if not all(len(row) == grid_cols for row in prompt):
|
| 614 |
raise ValueError("All prompt rows must have the same number of prompt columns")
|
| 615 |
+
|
| 616 |
if not isinstance(seed_tiles_mode, str) and (
|
| 617 |
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
|
| 618 |
):
|
| 619 |
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
|
| 620 |
+
|
| 621 |
if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
|
| 622 |
raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
|
| 623 |
+
|
| 624 |
def _get_add_time_ids(
|
| 625 |
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
| 626 |
):
|
|
|
|
| 662 |
weights_np = np.outer(y_probs, x_probs)
|
| 663 |
weights_torch = torch.tensor(weights_np, device=device)
|
| 664 |
weights_torch = weights_torch.to(dtype)
|
| 665 |
+
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
| 666 |
+
|
| 667 |
def upcast_vae(self):
|
| 668 |
dtype = self.vae.dtype
|
| 669 |
self.vae.to(dtype=torch.float32)
|
|
|
|
| 744 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 745 |
def __call__(
|
| 746 |
self,
|
| 747 |
+
prompt: Union[str, List[str]] = None,
|
| 748 |
height: Optional[int] = None,
|
| 749 |
width: Optional[int] = None,
|
| 750 |
+
num_inference_steps: int = 50,
|
| 751 |
guidance_scale: float = 5.0,
|
| 752 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 753 |
num_images_per_prompt: Optional[int] = 1,
|
| 754 |
eta: float = 0.0,
|
| 755 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 756 |
output_type: Optional[str] = "pil",
|
| 757 |
return_dict: bool = True,
|
| 758 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 759 |
original_size: Optional[Tuple[int, int]] = None,
|
| 760 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 761 |
target_size: Optional[Tuple[int, int]] = None,
|
| 762 |
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 763 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 764 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 765 |
+
clip_skip: Optional[int] = None,
|
| 766 |
tile_height: Optional[int] = 1024,
|
| 767 |
tile_width: Optional[int] = 1024,
|
| 768 |
tile_row_overlap: Optional[int] = 128,
|
|
|
|
| 770 |
guidance_scale_tiles: Optional[List[List[float]]] = None,
|
| 771 |
seed_tiles: Optional[List[List[int]]] = None,
|
| 772 |
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
|
| 773 |
+
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
|
| 774 |
**kwargs,
|
| 775 |
):
|
| 776 |
r"""
|
|
|
|
| 779 |
Args:
|
| 780 |
prompt (`str` or `List[str]`, *optional*):
|
| 781 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 782 |
+
instead.
|
| 783 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 784 |
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 785 |
Anything below 512 pixels won't work well for
|
|
|
|
| 792 |
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 793 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 794 |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 795 |
+
expense of slower inference.
|
| 796 |
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 797 |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 798 |
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
|
|
| 802 |
negative_prompt (`str` or `List[str]`, *optional*):
|
| 803 |
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 804 |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 805 |
+
less than `1`).
|
| 806 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 807 |
The number of images to generate per prompt.
|
| 808 |
eta (`float`, *optional*, defaults to 0.0):
|
|
|
|
| 810 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 811 |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 812 |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 813 |
+
to make generation deterministic.
|
| 814 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 815 |
The output format of the generate image. Choose between
|
| 816 |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
|
|
| 820 |
cross_attention_kwargs (`dict`, *optional*):
|
| 821 |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 822 |
`self.processor` in
|
| 823 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 824 |
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 825 |
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 826 |
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
|
|
|
| 865 |
seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
|
| 866 |
Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
|
| 867 |
seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
|
| 868 |
+
A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
|
| 869 |
**kwargs (`Dict[str, Any]`, *optional*):
|
| 870 |
Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
|
| 871 |
+
|
| 872 |
Examples:
|
| 873 |
|
| 874 |
Returns:
|
|
|
|
| 886 |
|
| 887 |
self._guidance_scale = guidance_scale
|
| 888 |
self._clip_skip = clip_skip
|
| 889 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 890 |
self._interrupt = False
|
| 891 |
|
| 892 |
grid_rows = len(prompt)
|
|
|
|
| 896 |
|
| 897 |
if isinstance(seed_tiles_mode, str):
|
| 898 |
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
|
| 899 |
+
|
| 900 |
# 1. Check inputs. Raise error if not correct
|
| 901 |
self.check_inputs(
|
| 902 |
+
prompt,
|
| 903 |
height,
|
| 904 |
+
width,
|
| 905 |
grid_cols,
|
| 906 |
seed_tiles_mode,
|
| 907 |
tiles_mode,
|
|
|
|
| 917 |
# update height and width tile size and tile overlap size
|
| 918 |
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
|
| 919 |
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
|
| 920 |
+
|
| 921 |
# 3. Encode input prompt
|
| 922 |
lora_scale = (
|
| 923 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
|
|
| 925 |
text_embeddings = [
|
| 926 |
[
|
| 927 |
self.encode_prompt(
|
| 928 |
+
prompt=col,
|
| 929 |
device=device,
|
| 930 |
num_images_per_prompt=num_images_per_prompt,
|
| 931 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 932 |
+
negative_prompt=negative_prompt,
|
| 933 |
prompt_embeds=None,
|
| 934 |
negative_prompt_embeds=None,
|
| 935 |
pooled_prompt_embeds=None,
|
|
|
|
| 940 |
for col in row
|
| 941 |
]
|
| 942 |
for row in prompt
|
| 943 |
+
]
|
| 944 |
|
| 945 |
# 3. Prepare latents
|
| 946 |
+
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
| 947 |
dtype = text_embeddings[0][0][0].dtype
|
| 948 |
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
| 949 |
|
|
|
|
| 992 |
extra_set_kwargs["offset"] = 1
|
| 993 |
timesteps, num_inference_steps = retrieve_timesteps(
|
| 994 |
self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
|
| 995 |
+
)
|
| 996 |
|
| 997 |
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
| 998 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
|
|
| 1007 |
for row in range(grid_rows):
|
| 1008 |
addition_embed_type_row = []
|
| 1009 |
for col in range(grid_cols):
|
| 1010 |
+
# extract generated values
|
| 1011 |
prompt_embeds = text_embeddings[row][col][0]
|
| 1012 |
negative_prompt_embeds = text_embeddings[row][col][1]
|
| 1013 |
pooled_prompt_embeds = text_embeddings[row][col][2]
|
|
|
|
| 1035 |
)
|
| 1036 |
else:
|
| 1037 |
negative_add_time_ids = add_time_ids
|
| 1038 |
+
|
| 1039 |
if self.do_classifier_free_guidance:
|
| 1040 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1041 |
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
|
|
|
| 1046 |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1047 |
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
| 1048 |
embeddings_and_added_time.append(addition_embed_type_row)
|
| 1049 |
+
|
| 1050 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1051 |
|
| 1052 |
# 7. Mask for tile weights strength
|
| 1053 |
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
|
| 1054 |
|
| 1055 |
# 8. Denoising loop
|
| 1056 |
+
self._num_timesteps = len(timesteps)
|
| 1057 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1058 |
for i, t in enumerate(timesteps):
|
| 1059 |
# Diffuse each tile
|
|
|
|
| 1068 |
)
|
| 1069 |
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
|
| 1070 |
# expand the latents if we are doing classifier free guidance
|
| 1071 |
+
latent_model_input = (
|
| 1072 |
+
torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
|
| 1073 |
+
)
|
| 1074 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1075 |
|
| 1076 |
# predict the noise residual
|
| 1077 |
+
added_cond_kwargs = {
|
| 1078 |
+
"text_embeds": embeddings_and_added_time[row][col][1],
|
| 1079 |
+
"time_ids": embeddings_and_added_time[row][col][2],
|
| 1080 |
+
}
|
| 1081 |
+
with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
|
| 1082 |
noise_pred = self.unet(
|
| 1083 |
latent_model_input,
|
| 1084 |
t,
|
| 1085 |
+
encoder_hidden_states=embeddings_and_added_time[row][col][0],
|
| 1086 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1087 |
added_cond_kwargs=added_cond_kwargs,
|
| 1088 |
return_dict=False,
|
|
|
|
| 1099 |
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
| 1100 |
noise_preds_row.append(noise_pred_tile)
|
| 1101 |
noise_preds.append(noise_preds_row)
|
| 1102 |
+
|
| 1103 |
# Stitch noise predictions for all tiles
|
| 1104 |
noise_pred = torch.zeros(latents.shape, device=device)
|
| 1105 |
contributors = torch.zeros(latents.shape, device=device)
|
|
|
|
| 1129 |
|
| 1130 |
# update progress bar
|
| 1131 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1132 |
+
progress_bar.update()
|
| 1133 |
|
| 1134 |
if XLA_AVAILABLE:
|
| 1135 |
xm.mark_step()
|
|
|
|
| 1162 |
latents = latents / self.vae.config.scaling_factor
|
| 1163 |
|
| 1164 |
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1165 |
+
|
| 1166 |
# cast back to fp16 if needed
|
| 1167 |
if needs_upcasting:
|
| 1168 |
self.vae.to(dtype=torch.float16)
|
|
|
|
| 1173 |
# apply watermark if available
|
| 1174 |
if self.watermark is not None:
|
| 1175 |
image = self.watermark.apply_watermark(image)
|
| 1176 |
+
|
| 1177 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1178 |
|
| 1179 |
# Offload all models
|
|
|
|
| 1183 |
return (image,)
|
| 1184 |
|
| 1185 |
return StableDiffusionXLPipelineOutput(images=image)
|
|
|
|
|
|