# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. # Copyright (c) 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from diffusers import FluxControlNetPipeline from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from diffusers.image_processor import PipelineImageInput from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.utils import replace_example_docstring, is_torch_xla_available, logging if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class FluxInfuseNetPipeline(FluxControlNetPipeline): @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 3.5, controlnet_guidance_scale: float = 1.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, control_image: PipelineImageInput = None, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, # ID-specific parameters controlnet_prompt_embeds: Optional[torch.FloatTensor] = None, # True CFG parameters true_guidance_scale: float = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. controlnet_guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `controlnet_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. control_mode (`int` or `List[int]`,, *optional*, defaults to None): The control mode when applying ControlNet-Union. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. controlnet_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated embeddings for the InfuseNet. Can be used to easily tweak inputs, *e.g.* image embeddings. If not provided, embeddings will be generated from `prompt` or `prompt_embeds` input arguments. true_guidance_scale (`float`, *optional*, defaults to 1.0): True CFG scale as defined in [Classifier-Free Diffusion Guidance]((https://arxiv.org/abs/2207.12598). negative_prompt (`str` or `List[str]`, *optional*): The negative prompt or negative prompts to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. negative_prompt_2 (`str` or `List[str]`, *optional*): The negative prompt or negative prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is will be used instead. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative text embeddings will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative pooled text embeddings will be generated from `negative_prompt` input argument. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._controlnet_guidance_scale = controlnet_guidance_scale self._true_guidance_scale = true_guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device dtype = self.transformer.dtype # CPU offload controlnet, move back T5 to GPU self.controlnet.cpu() torch.cuda.empty_cache() self.text_encoder_2.to(device) lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if negative_prompt is not None or (negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None): ( negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, ) = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if controlnet_prompt_embeds is None: controlnet_prompt_embeds = prompt_embeds ( controlnet_prompt_embeds, pooled_prompt_embeds, controlnet_text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=controlnet_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # CPU offload T5, move back controlnet to GPU self.text_encoder_2.cpu() torch.cuda.empty_cache() self.controlnet.to(device) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] # xlab controlnet has a input_hint_block and instantx controlnet does not controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode control_image = self.vae.encode(control_image).latent_dist.sample() control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack height_control_image, width_control_image = control_image.shape[2:] control_image = self._pack_latents( control_image, batch_size * num_images_per_prompt, num_channels_latents, height_control_image, width_control_image, ) # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: if not isinstance(control_mode, int): raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] # xlab controlnet has a input_hint_block and instantx controlnet does not controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] if self.controlnet.nets[0].input_hint_block is None: # vae encode control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack height_control_image, width_control_image = control_image_.shape[2:] control_image_ = self._pack_latents( control_image_, batch_size * num_images_per_prompt, num_channels_latents, height_control_image, width_control_image, ) control_images.append(control_image_) control_image = control_images # Here we ensure that `control_mode` has the same length as the control_image. if isinstance(control_mode, list) and len(control_mode) != len(control_image): raise ValueError( "For Multi-ControlNet, `control_mode` must be a list of the same " + " length as the number of controlnets (control images) specified" ) if not isinstance(control_mode, list): control_mode = [control_mode] * len(control_image) # set control mode control_modes = [] for cmode in control_mode: if cmode is None: cmode = -1 control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) control_modes.append(control_mode) control_mode = control_modes # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) if isinstance(self.controlnet, FluxMultiControlNetModel): use_guidance = self.controlnet.nets[0].config.guidance_embeds else: use_guidance = self.controlnet.config.guidance_embeds guidance = torch.tensor([controlnet_guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] # controlnet controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds, txt_ids=controlnet_text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) guidance = ( torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None ) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # perform true CFG if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None and negative_text_ids is not None: noise_pred_uncond = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, controlnet_block_samples=None, controlnet_single_block_samples=None, txt_ids=negative_text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] noise_pred = noise_pred_uncond + true_guidance_scale * (noise_pred - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image)