Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import math | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor | |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
| from diffusers.loaders import QwenImageLoraLoaderMixin | |
| from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput | |
| if is_torch_xla_available(): | |
| import torch_xla.core.xla_model as xm | |
| XLA_AVAILABLE = True | |
| else: | |
| XLA_AVAILABLE = False | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```py | |
| >>> import torch | |
| >>> from PIL import Image | |
| >>> from diffusers import QwenImageEditPipeline | |
| >>> from diffusers.utils import load_image | |
| >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) | |
| >>> pipe.to("cuda") | |
| >>> image = load_image( | |
| ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" | |
| ... ).convert("RGB") | |
| >>> prompt = ( | |
| ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" | |
| ... ) | |
| >>> # Depending on the variant being used, the pipeline call will slightly vary. | |
| >>> # Refer to the pipeline documentation for more details. | |
| >>> image = pipe(image, prompt, num_inference_steps=50).images[0] | |
| >>> image.save("qwenimage_edit.png") | |
| ``` | |
| """ | |
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift | |
| def calculate_shift( | |
| image_seq_len, | |
| base_seq_len: int = 256, | |
| max_seq_len: int = 4096, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.15, | |
| ): | |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
| b = base_shift - m * base_seq_len | |
| mu = image_seq_len * m + b | |
| return mu | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents | |
| def retrieve_latents( | |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | |
| ): | |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, "latents"): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError("Could not access latents of provided encoder_output") | |
| def calculate_dimensions(target_area, ratio): | |
| width = math.sqrt(target_area * ratio) | |
| height = width / ratio | |
| width = round(width / 32) * 32 | |
| height = round(height / 32) * 32 | |
| return width, height, None | |
| class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): | |
| r""" | |
| The Qwen-Image-Edit pipeline for image editing. | |
| Args: | |
| transformer ([`QwenImageTransformer2DModel`]): | |
| Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. | |
| scheduler ([`FlowMatchEulerDiscreteScheduler`]): | |
| A scheduler to be used in combination with `transformer` to denoise the encoded image latents. | |
| vae ([`AutoencoderKL`]): | |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
| text_encoder ([`Qwen2.5-VL-7B-Instruct`]): | |
| [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the | |
| [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. | |
| tokenizer (`QwenTokenizer`): | |
| Tokenizer of class | |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). | |
| """ | |
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| _callback_tensor_inputs = ["latents", "prompt_embeds"] | |
| def __init__( | |
| self, | |
| scheduler: FlowMatchEulerDiscreteScheduler, | |
| vae: AutoencoderKLQwenImage, | |
| text_encoder: Qwen2_5_VLForConditionalGeneration, | |
| tokenizer: Qwen2Tokenizer, | |
| processor: Qwen2VLProcessor, | |
| transformer: QwenImageTransformer2DModel, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| ) | |
| self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 | |
| self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 | |
| # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible | |
| # by the patch size. So the vae scale factor is multiplied by the patch size to account for this | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) | |
| self.vl_processor = processor | |
| self.tokenizer_max_length = 1024 | |
| self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" | |
| self.prompt_template_encode_start_idx = 64 | |
| self.default_sample_size = 128 | |
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden | |
| def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | |
| return split_result | |
| def _get_qwen_prompt_embeds( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| image: Optional[torch.Tensor] = None, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| ): | |
| device = device or self._execution_device | |
| dtype = dtype or self.text_encoder.dtype | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| template = self.prompt_template_encode | |
| drop_idx = self.prompt_template_encode_start_idx | |
| txt = [template.format(e) for e in prompt] | |
| model_inputs = self.processor( | |
| text=txt, | |
| images=image, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| outputs = self.text_encoder( | |
| input_ids=model_inputs.input_ids, | |
| attention_mask=model_inputs.attention_mask, | |
| pixel_values=model_inputs.pixel_values, | |
| image_grid_thw=model_inputs.image_grid_thw, | |
| output_hidden_states=True, | |
| ) | |
| hidden_states = outputs.hidden_states[-1] | |
| split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] | |
| max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| prompt_embeds = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] | |
| ) | |
| encoder_attention_mask = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] | |
| ) | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| return prompt_embeds, encoder_attention_mask | |
| def encode_prompt( | |
| self, | |
| prompt: Union[str, List[str]], | |
| image: Optional[torch.Tensor] = None, | |
| device: Optional[torch.device] = None, | |
| num_images_per_prompt: int = 1, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| prompt_embeds_mask: Optional[torch.Tensor] = None, | |
| max_sequence_length: int = 1024, | |
| ): | |
| r""" | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| prompt to be encoded | |
| image (`torch.Tensor`, *optional*): | |
| image to be encoded | |
| device: (`torch.device`): | |
| torch device | |
| num_images_per_prompt (`int`): | |
| number of images that should be generated per prompt | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| """ | |
| device = device or self._execution_device | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] | |
| if prompt_embeds is None: | |
| prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) | |
| return prompt_embeds, prompt_embeds_mask | |
| def check_inputs( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| negative_prompt=None, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| prompt_embeds_mask=None, | |
| negative_prompt_embeds_mask=None, | |
| callback_on_step_end_tensor_inputs=None, | |
| max_sequence_length=None, | |
| ): | |
| if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: | |
| logger.warning( | |
| f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" | |
| ) | |
| if callback_on_step_end_tensor_inputs is not None and not all( | |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
| ): | |
| raise ValueError( | |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
| ) | |
| if prompt is not None and prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
| " only forward one of the two." | |
| ) | |
| elif prompt is None and prompt_embeds is None: | |
| raise ValueError( | |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
| ) | |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
| if negative_prompt is not None and negative_prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
| ) | |
| if prompt_embeds is not None and prompt_embeds_mask is None: | |
| raise ValueError( | |
| "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." | |
| ) | |
| if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: | |
| raise ValueError( | |
| "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." | |
| ) | |
| if max_sequence_length is not None and max_sequence_length > 1024: | |
| raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") | |
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents | |
| def _pack_latents(latents, batch_size, num_channels_latents, height, width): | |
| latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | |
| latents = latents.permute(0, 2, 4, 1, 3, 5) | |
| latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | |
| return latents | |
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents | |
| def _unpack_latents(latents, height, width, vae_scale_factor): | |
| batch_size, num_patches, channels = latents.shape | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (vae_scale_factor * 2)) | |
| latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | |
| latents = latents.permute(0, 3, 1, 4, 2, 5) | |
| latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) | |
| return latents | |
| def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): | |
| if isinstance(generator, list): | |
| image_latents = [ | |
| retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") | |
| for i in range(image.shape[0]) | |
| ] | |
| image_latents = torch.cat(image_latents, dim=0) | |
| else: | |
| image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") | |
| latents_mean = ( | |
| torch.tensor(self.vae.config.latents_mean) | |
| .view(1, self.latent_channels, 1, 1, 1) | |
| .to(image_latents.device, image_latents.dtype) | |
| ) | |
| latents_std = ( | |
| torch.tensor(self.vae.config.latents_std) | |
| .view(1, self.latent_channels, 1, 1, 1) | |
| .to(image_latents.device, image_latents.dtype) | |
| ) | |
| image_latents = (image_latents - latents_mean) / latents_std | |
| return image_latents | |
| def enable_vae_slicing(self): | |
| r""" | |
| Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | |
| compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | |
| """ | |
| self.vae.enable_slicing() | |
| def disable_vae_slicing(self): | |
| r""" | |
| Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_slicing() | |
| def enable_vae_tiling(self): | |
| r""" | |
| Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | |
| compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | |
| processing larger images. | |
| """ | |
| self.vae.enable_tiling() | |
| def disable_vae_tiling(self): | |
| r""" | |
| Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_tiling() | |
| def prepare_latents( | |
| self, | |
| image, | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| dtype, | |
| device, | |
| generator, | |
| latents=None, | |
| ): | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (self.vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (self.vae_scale_factor * 2)) | |
| shape = (batch_size, 1, num_channels_latents, height, width) | |
| image_latents = None | |
| if image is not None: | |
| image = image.to(device=device, dtype=dtype) | |
| if image.shape[1] != self.latent_channels: | |
| image_latents = self._encode_vae_image(image=image, generator=generator) | |
| else: | |
| image_latents = image | |
| if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | |
| # expand init_latents for batch_size | |
| additional_image_per_prompt = batch_size // image_latents.shape[0] | |
| image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | |
| elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | |
| raise ValueError( | |
| f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | |
| ) | |
| else: | |
| image_latents = torch.cat([image_latents], dim=0) | |
| image_latent_height, image_latent_width = image_latents.shape[3:] | |
| image_latents = self._pack_latents( | |
| image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width | |
| ) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | |
| else: | |
| latents = latents.to(device=device, dtype=dtype) | |
| return latents, image_latents | |
| def guidance_scale(self): | |
| return self._guidance_scale | |
| def attention_kwargs(self): | |
| return self._attention_kwargs | |
| def num_timesteps(self): | |
| return self._num_timesteps | |
| def current_timestep(self): | |
| return self._current_timestep | |
| def interrupt(self): | |
| return self._interrupt | |
| def __call__( | |
| self, | |
| image: Optional[PipelineImageInput] = None, | |
| prompt: Union[str, List[str]] = None, | |
| negative_prompt: Union[str, List[str]] = None, | |
| true_cfg_scale: float = 4.0, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| sigmas: Optional[List[float]] = None, | |
| guidance_scale: float = 1.0, | |
| num_images_per_prompt: int = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| prompt_embeds_mask: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds_mask: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 512, | |
| ): | |
| r""" | |
| Function invoked when calling the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
| instead. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is | |
| not greater than `1`). | |
| true_cfg_scale (`float`, *optional*, defaults to 1.0): | |
| When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. | |
| height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
| The height in pixels of the generated image. This is set to 1024 by default for the best results. | |
| width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
| The width in pixels of the generated image. This is set to 1024 by default for the best results. | |
| num_inference_steps (`int`, *optional*, defaults to 50): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in | |
| their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed | |
| will be used. | |
| guidance_scale (`float`, *optional*, defaults to 3.5): | |
| Guidance scale as defined in [Classifier-Free Diffusion | |
| Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. | |
| of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting | |
| `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to | |
| the text `prompt`, usually at the expense of lower image quality. | |
| This parameter in the pipeline is there to support future guidance-distilled models when they come up. | |
| Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, | |
| please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should | |
| enable classifier-free guidance computations. | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
| to make generation deterministic. | |
| latents (`torch.Tensor`, *optional*): | |
| Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor will be generated by sampling using the supplied random `generator`. | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
| argument. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generate image. Choose between | |
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. | |
| attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| callback_on_step_end (`Callable`, *optional*): | |
| A function that calls at the end of each denoising steps during the inference. The function is called | |
| with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, | |
| callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by | |
| `callback_on_step_end_tensor_inputs`. | |
| callback_on_step_end_tensor_inputs (`List`, *optional*): | |
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
| `._callback_tensor_inputs` attribute of your pipeline class. | |
| max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. | |
| Examples: | |
| Returns: | |
| [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: | |
| [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When | |
| returning a tuple, the first element is a list with the generated images. | |
| """ | |
| image_size = image[0].size if isinstance(image, list) else image.size | |
| calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) | |
| height = height or calculated_height | |
| width = width or calculated_width | |
| multiple_of = self.vae_scale_factor * 2 | |
| width = width // multiple_of * multiple_of | |
| height = height // multiple_of * multiple_of | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| height, | |
| width, | |
| negative_prompt=negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| prompt_embeds_mask=prompt_embeds_mask, | |
| negative_prompt_embeds_mask=negative_prompt_embeds_mask, | |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._attention_kwargs = attention_kwargs | |
| self._current_timestep = None | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # 3. Preprocess image | |
| if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): | |
| image = self.image_processor.resize(image, calculated_height, calculated_width) | |
| prompt_image = image | |
| image = self.image_processor.preprocess(image, calculated_height, calculated_width) | |
| image = image.unsqueeze(2) | |
| has_neg_prompt = negative_prompt is not None or ( | |
| negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None | |
| ) | |
| do_true_cfg = true_cfg_scale > 1 and has_neg_prompt | |
| prompt_embeds, prompt_embeds_mask = self.encode_prompt( | |
| image=prompt_image, | |
| prompt=prompt, | |
| prompt_embeds=prompt_embeds, | |
| prompt_embeds_mask=prompt_embeds_mask, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| if do_true_cfg: | |
| negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( | |
| image=prompt_image, | |
| prompt=negative_prompt, | |
| prompt_embeds=negative_prompt_embeds, | |
| prompt_embeds_mask=negative_prompt_embeds_mask, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| # 4. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels // 4 | |
| latents, image_latents = self.prepare_latents( | |
| image, | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| img_shapes = [ | |
| [ | |
| (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), | |
| (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), | |
| ] | |
| ] * batch_size | |
| # 5. Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | |
| image_seq_len = latents.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.get("base_image_seq_len", 256), | |
| self.scheduler.config.get("max_image_seq_len", 4096), | |
| self.scheduler.config.get("base_shift", 0.5), | |
| self.scheduler.config.get("max_shift", 1.15), | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| num_inference_steps, | |
| device, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
| self._num_timesteps = len(timesteps) | |
| # handle guidance | |
| if self.transformer.config.guidance_embeds: | |
| guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) | |
| guidance = guidance.expand(latents.shape[0]) | |
| else: | |
| guidance = None | |
| if self.attention_kwargs is None: | |
| self._attention_kwargs = {} | |
| txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None | |
| negative_txt_seq_lens = ( | |
| negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None | |
| ) | |
| image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device) | |
| # 6. Denoising loop | |
| self.scheduler.set_begin_index(0) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| self._current_timestep = t | |
| latent_model_input = latents | |
| if image_latents is not None: | |
| latent_model_input = torch.cat([latents, image_latents], dim=1) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
| with self.transformer.cache_context("cond"): | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| encoder_hidden_states_mask=prompt_embeds_mask, | |
| encoder_hidden_states=prompt_embeds, | |
| image_rotary_emb=image_rotary_emb, | |
| attention_kwargs=self.attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_pred[:, : latents.size(1)] | |
| if do_true_cfg: | |
| with self.transformer.cache_context("uncond"): | |
| neg_noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| encoder_hidden_states_mask=negative_prompt_embeds_mask, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| image_rotary_emb=image_rotary_emb, | |
| attention_kwargs=self.attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| neg_noise_pred = neg_noise_pred[:, : latents.size(1)] | |
| comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) | |
| cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) | |
| noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) | |
| noise_pred = comb_pred * (cond_norm / noise_norm) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if XLA_AVAILABLE: | |
| xm.mark_step() | |
| self._current_timestep = None | |
| if output_type == "latent": | |
| image = latents | |
| else: | |
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
| latents = latents.to(self.vae.dtype) | |
| latents_mean = ( | |
| torch.tensor(self.vae.config.latents_mean) | |
| .view(1, self.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( | |
| latents.device, latents.dtype | |
| ) | |
| latents = latents / latents_std + latents_mean | |
| image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return QwenImagePipelineOutput(images=image) | |