Spaces:
Runtime error
Runtime error
Upload hy3dgen/texgen/hunyuanpaint/pipeline.py with huggingface_hub
Browse files
hy3dgen/texgen/hunyuanpaint/pipeline.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Open Source Model Licensed under the Apache License Version 2.0
|
2 |
+
# and Other Licenses of the Third-Party Components therein:
|
3 |
+
# The below Model in this distribution may have been modified by THL A29 Limited
|
4 |
+
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
5 |
+
|
6 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
7 |
+
# The below software and/or models in this distribution may have been
|
8 |
+
# modified by THL A29 Limited ("Tencent Modifications").
|
9 |
+
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
10 |
+
|
11 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
12 |
+
# except for the third-party components listed below.
|
13 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
14 |
+
# in the repsective licenses of these third-party components.
|
15 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
16 |
+
# components and must ensure that the usage of the third party components adheres to
|
17 |
+
# all relevant laws and regulations.
|
18 |
+
|
19 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
20 |
+
# their software and algorithms, including trained model weights, parameters (including
|
21 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
22 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
+
|
25 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
26 |
+
|
27 |
+
import numpy
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.distributed
|
31 |
+
import torch.utils.checkpoint
|
32 |
+
from PIL import Image
|
33 |
+
from diffusers import (
|
34 |
+
AutoencoderKL,
|
35 |
+
DiffusionPipeline,
|
36 |
+
ImagePipelineOutput
|
37 |
+
)
|
38 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
39 |
+
from diffusers.image_processor import PipelineImageInput
|
40 |
+
from diffusers.image_processor import VaeImageProcessor
|
41 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
42 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, retrieve_timesteps, \
|
43 |
+
rescale_noise_cfg
|
44 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
45 |
+
from diffusers.utils import deprecate
|
46 |
+
from einops import rearrange
|
47 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
48 |
+
|
49 |
+
from .unet.modules import UNet2p5DConditionModel
|
50 |
+
|
51 |
+
|
52 |
+
def to_rgb_image(maybe_rgba: Image.Image):
|
53 |
+
if maybe_rgba.mode == 'RGB':
|
54 |
+
return maybe_rgba
|
55 |
+
elif maybe_rgba.mode == 'RGBA':
|
56 |
+
rgba = maybe_rgba
|
57 |
+
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
58 |
+
img = Image.fromarray(img, 'RGB')
|
59 |
+
img.paste(rgba, mask=rgba.getchannel('A'))
|
60 |
+
return img
|
61 |
+
else:
|
62 |
+
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
63 |
+
|
64 |
+
|
65 |
+
class HunyuanPaintPipeline(StableDiffusionPipeline):
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
vae: AutoencoderKL,
|
70 |
+
text_encoder: CLIPTextModel,
|
71 |
+
tokenizer: CLIPTokenizer,
|
72 |
+
unet: UNet2p5DConditionModel,
|
73 |
+
scheduler: KarrasDiffusionSchedulers,
|
74 |
+
feature_extractor: CLIPImageProcessor,
|
75 |
+
safety_checker=None,
|
76 |
+
use_torch_compile=False,
|
77 |
+
):
|
78 |
+
DiffusionPipeline.__init__(self)
|
79 |
+
|
80 |
+
safety_checker = None
|
81 |
+
self.register_modules(
|
82 |
+
vae=torch.compile(vae) if use_torch_compile else vae,
|
83 |
+
text_encoder=text_encoder,
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
unet=unet,
|
86 |
+
scheduler=scheduler,
|
87 |
+
safety_checker=safety_checker,
|
88 |
+
feature_extractor=torch.compile(feature_extractor) if use_torch_compile else feature_extractor,
|
89 |
+
)
|
90 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
91 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def encode_images(self, images):
|
95 |
+
B = images.shape[0]
|
96 |
+
images = rearrange(images, 'b n c h w -> (b n) c h w')
|
97 |
+
|
98 |
+
dtype = next(self.vae.parameters()).dtype
|
99 |
+
images = (images - 0.5) * 2.0
|
100 |
+
posterior = self.vae.encode(images.to(dtype)).latent_dist
|
101 |
+
latents = posterior.sample() * self.vae.config.scaling_factor
|
102 |
+
|
103 |
+
latents = rearrange(latents, '(b n) c h w -> b n c h w', b=B)
|
104 |
+
return latents
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def __call__(
|
108 |
+
self,
|
109 |
+
image: Image.Image = None,
|
110 |
+
prompt=None,
|
111 |
+
negative_prompt='watermark, ugly, deformed, noisy, blurry, low contrast',
|
112 |
+
*args,
|
113 |
+
num_images_per_prompt: Optional[int] = 1,
|
114 |
+
guidance_scale=2.0,
|
115 |
+
output_type: Optional[str] = "pil",
|
116 |
+
width=512,
|
117 |
+
height=512,
|
118 |
+
num_inference_steps=28,
|
119 |
+
return_dict=True,
|
120 |
+
**cached_condition,
|
121 |
+
):
|
122 |
+
if image is None:
|
123 |
+
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
|
124 |
+
assert not isinstance(image, torch.Tensor)
|
125 |
+
|
126 |
+
image = to_rgb_image(image)
|
127 |
+
|
128 |
+
image_vae = torch.tensor(np.array(image) / 255.0)
|
129 |
+
image_vae = image_vae.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(0)
|
130 |
+
image_vae = image_vae.to(device=self.vae.device, dtype=self.vae.dtype)
|
131 |
+
|
132 |
+
batch_size = image_vae.shape[0]
|
133 |
+
assert batch_size == 1
|
134 |
+
assert num_images_per_prompt == 1
|
135 |
+
|
136 |
+
ref_latents = self.encode_images(image_vae)
|
137 |
+
|
138 |
+
def convert_pil_list_to_tensor(images):
|
139 |
+
bg_c = [1., 1., 1.]
|
140 |
+
images_tensor = []
|
141 |
+
for batch_imgs in images:
|
142 |
+
view_imgs = []
|
143 |
+
for pil_img in batch_imgs:
|
144 |
+
img = numpy.asarray(pil_img, dtype=numpy.float32) / 255.
|
145 |
+
if img.shape[2] > 3:
|
146 |
+
alpha = img[:, :, 3:]
|
147 |
+
img = img[:, :, :3] * alpha + bg_c * (1 - alpha)
|
148 |
+
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).contiguous().half().to("cuda")
|
149 |
+
view_imgs.append(img)
|
150 |
+
view_imgs = torch.cat(view_imgs, dim=0)
|
151 |
+
images_tensor.append(view_imgs.unsqueeze(0))
|
152 |
+
|
153 |
+
images_tensor = torch.cat(images_tensor, dim=0)
|
154 |
+
return images_tensor
|
155 |
+
|
156 |
+
if "normal_imgs" in cached_condition:
|
157 |
+
|
158 |
+
if isinstance(cached_condition["normal_imgs"], List):
|
159 |
+
cached_condition["normal_imgs"] = convert_pil_list_to_tensor(cached_condition["normal_imgs"])
|
160 |
+
|
161 |
+
cached_condition['normal_imgs'] = self.encode_images(cached_condition["normal_imgs"])
|
162 |
+
|
163 |
+
if "position_imgs" in cached_condition:
|
164 |
+
|
165 |
+
if isinstance(cached_condition["position_imgs"], List):
|
166 |
+
cached_condition["position_imgs"] = convert_pil_list_to_tensor(cached_condition["position_imgs"])
|
167 |
+
|
168 |
+
cached_condition["position_imgs"] = self.encode_images(cached_condition["position_imgs"])
|
169 |
+
|
170 |
+
if 'camera_info_gen' in cached_condition:
|
171 |
+
camera_info = cached_condition['camera_info_gen'] # B,N
|
172 |
+
if isinstance(camera_info, List):
|
173 |
+
camera_info = torch.tensor(camera_info)
|
174 |
+
camera_info = camera_info.to(image_vae.device).to(torch.int64)
|
175 |
+
cached_condition['camera_info_gen'] = camera_info
|
176 |
+
if 'camera_info_ref' in cached_condition:
|
177 |
+
camera_info = cached_condition['camera_info_ref'] # B,N
|
178 |
+
if isinstance(camera_info, List):
|
179 |
+
camera_info = torch.tensor(camera_info)
|
180 |
+
camera_info = camera_info.to(image_vae.device).to(torch.int64)
|
181 |
+
cached_condition['camera_info_ref'] = camera_info
|
182 |
+
|
183 |
+
cached_condition['ref_latents'] = ref_latents
|
184 |
+
|
185 |
+
if guidance_scale > 1:
|
186 |
+
negative_ref_latents = torch.zeros_like(cached_condition['ref_latents'])
|
187 |
+
cached_condition['ref_latents'] = torch.cat([negative_ref_latents, cached_condition['ref_latents']])
|
188 |
+
cached_condition['ref_scale'] = torch.as_tensor([0.0, 1.0]).to(cached_condition['ref_latents'])
|
189 |
+
if "normal_imgs" in cached_condition:
|
190 |
+
cached_condition['normal_imgs'] = torch.cat(
|
191 |
+
(cached_condition['normal_imgs'], cached_condition['normal_imgs']))
|
192 |
+
|
193 |
+
if "position_imgs" in cached_condition:
|
194 |
+
cached_condition['position_imgs'] = torch.cat(
|
195 |
+
(cached_condition['position_imgs'], cached_condition['position_imgs']))
|
196 |
+
|
197 |
+
if 'position_maps' in cached_condition:
|
198 |
+
cached_condition['position_maps'] = torch.cat(
|
199 |
+
(cached_condition['position_maps'], cached_condition['position_maps']))
|
200 |
+
|
201 |
+
if 'camera_info_gen' in cached_condition:
|
202 |
+
cached_condition['camera_info_gen'] = torch.cat(
|
203 |
+
(cached_condition['camera_info_gen'], cached_condition['camera_info_gen']))
|
204 |
+
if 'camera_info_ref' in cached_condition:
|
205 |
+
cached_condition['camera_info_ref'] = torch.cat(
|
206 |
+
(cached_condition['camera_info_ref'], cached_condition['camera_info_ref']))
|
207 |
+
|
208 |
+
prompt_embeds = self.unet.learned_text_clip_gen.repeat(num_images_per_prompt, 1, 1)
|
209 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
210 |
+
|
211 |
+
latents: torch.Tensor = self.denoise(
|
212 |
+
None,
|
213 |
+
*args,
|
214 |
+
cross_attention_kwargs=None,
|
215 |
+
guidance_scale=guidance_scale,
|
216 |
+
num_images_per_prompt=num_images_per_prompt,
|
217 |
+
prompt_embeds=prompt_embeds,
|
218 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
219 |
+
num_inference_steps=num_inference_steps,
|
220 |
+
output_type='latent',
|
221 |
+
width=width,
|
222 |
+
height=height,
|
223 |
+
**cached_condition
|
224 |
+
).images
|
225 |
+
|
226 |
+
if not output_type == "latent":
|
227 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
228 |
+
else:
|
229 |
+
image = latents
|
230 |
+
|
231 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
232 |
+
if not return_dict:
|
233 |
+
return (image,)
|
234 |
+
|
235 |
+
return ImagePipelineOutput(images=image)
|
236 |
+
|
237 |
+
def denoise(
|
238 |
+
self,
|
239 |
+
prompt: Union[str, List[str]] = None,
|
240 |
+
height: Optional[int] = None,
|
241 |
+
width: Optional[int] = None,
|
242 |
+
num_inference_steps: int = 50,
|
243 |
+
timesteps: List[int] = None,
|
244 |
+
sigmas: List[float] = None,
|
245 |
+
guidance_scale: float = 7.5,
|
246 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
247 |
+
num_images_per_prompt: Optional[int] = 1,
|
248 |
+
eta: float = 0.0,
|
249 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
250 |
+
latents: Optional[torch.Tensor] = None,
|
251 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
252 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
253 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
254 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
255 |
+
output_type: Optional[str] = "pil",
|
256 |
+
return_dict: bool = True,
|
257 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
258 |
+
guidance_rescale: float = 0.0,
|
259 |
+
clip_skip: Optional[int] = None,
|
260 |
+
callback_on_step_end: Optional[
|
261 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
262 |
+
] = None,
|
263 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
264 |
+
**kwargs,
|
265 |
+
):
|
266 |
+
r"""
|
267 |
+
The call function to the pipeline for generation.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
prompt (`str` or `List[str]`, *optional*):
|
271 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
272 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
273 |
+
The height in pixels of the generated image.
|
274 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
275 |
+
The width in pixels of the generated image.
|
276 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
277 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
278 |
+
expense of slower inference.
|
279 |
+
timesteps (`List[int]`, *optional*):
|
280 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
281 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
282 |
+
passed will be used. Must be in descending order.
|
283 |
+
sigmas (`List[float]`, *optional*):
|
284 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
285 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
286 |
+
will be used.
|
287 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
288 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
289 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
290 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
291 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
292 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
293 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
294 |
+
The number of images to generate per prompt.
|
295 |
+
eta (`float`, *optional*, defaults to 0.0):
|
296 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
297 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
298 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
299 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
300 |
+
generation deterministic.
|
301 |
+
latents (`torch.Tensor`, *optional*):
|
302 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
303 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
304 |
+
tensor is generated by sampling using the supplied random `generator`.
|
305 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
306 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
307 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
308 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
309 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
310 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
311 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
312 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
313 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
314 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
315 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
316 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
317 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
318 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
319 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
320 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
321 |
+
plain tuple.
|
322 |
+
cross_attention_kwargs (`dict`, *optional*):
|
323 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
324 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
325 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
326 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
327 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
328 |
+
using zero terminal SNR.
|
329 |
+
clip_skip (`int`, *optional*):
|
330 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
331 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
332 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
333 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
334 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
335 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
336 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
337 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
338 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
339 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
340 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
341 |
+
|
342 |
+
Examples:
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
346 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
347 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
348 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
349 |
+
"not-safe-for-work" (nsfw) content.
|
350 |
+
"""
|
351 |
+
|
352 |
+
callback = kwargs.pop("callback", None)
|
353 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
354 |
+
|
355 |
+
if callback is not None:
|
356 |
+
deprecate(
|
357 |
+
"callback",
|
358 |
+
"1.0.0",
|
359 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
360 |
+
)
|
361 |
+
if callback_steps is not None:
|
362 |
+
deprecate(
|
363 |
+
"callback_steps",
|
364 |
+
"1.0.0",
|
365 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
366 |
+
)
|
367 |
+
|
368 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
369 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
370 |
+
|
371 |
+
# 0. Default height and width to unet
|
372 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
373 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
374 |
+
# to deal with lora scaling and other possible forward hooks
|
375 |
+
|
376 |
+
# 1. Check inputs. Raise error if not correct
|
377 |
+
self.check_inputs(
|
378 |
+
prompt,
|
379 |
+
height,
|
380 |
+
width,
|
381 |
+
callback_steps,
|
382 |
+
negative_prompt,
|
383 |
+
prompt_embeds,
|
384 |
+
negative_prompt_embeds,
|
385 |
+
ip_adapter_image,
|
386 |
+
ip_adapter_image_embeds,
|
387 |
+
callback_on_step_end_tensor_inputs,
|
388 |
+
)
|
389 |
+
|
390 |
+
self._guidance_scale = guidance_scale
|
391 |
+
self._guidance_rescale = guidance_rescale
|
392 |
+
self._clip_skip = clip_skip
|
393 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
394 |
+
self._interrupt = False
|
395 |
+
|
396 |
+
# 2. Define call parameters
|
397 |
+
if prompt is not None and isinstance(prompt, str):
|
398 |
+
batch_size = 1
|
399 |
+
elif prompt is not None and isinstance(prompt, list):
|
400 |
+
batch_size = len(prompt)
|
401 |
+
else:
|
402 |
+
batch_size = prompt_embeds.shape[0]
|
403 |
+
|
404 |
+
device = self._execution_device
|
405 |
+
|
406 |
+
# 3. Encode input prompt
|
407 |
+
lora_scale = (
|
408 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
409 |
+
)
|
410 |
+
|
411 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
412 |
+
prompt,
|
413 |
+
device,
|
414 |
+
num_images_per_prompt,
|
415 |
+
self.do_classifier_free_guidance,
|
416 |
+
negative_prompt,
|
417 |
+
prompt_embeds=prompt_embeds,
|
418 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
419 |
+
lora_scale=lora_scale,
|
420 |
+
clip_skip=self.clip_skip,
|
421 |
+
)
|
422 |
+
|
423 |
+
# For classifier free guidance, we need to do two forward passes.
|
424 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
425 |
+
# to avoid doing two forward passes
|
426 |
+
if self.do_classifier_free_guidance:
|
427 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
428 |
+
|
429 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
430 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
431 |
+
ip_adapter_image,
|
432 |
+
ip_adapter_image_embeds,
|
433 |
+
device,
|
434 |
+
batch_size * num_images_per_prompt,
|
435 |
+
self.do_classifier_free_guidance,
|
436 |
+
)
|
437 |
+
|
438 |
+
# 4. Prepare timesteps
|
439 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
440 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
441 |
+
)
|
442 |
+
assert num_images_per_prompt == 1
|
443 |
+
# 5. Prepare latent variables
|
444 |
+
num_channels_latents = self.unet.config.in_channels
|
445 |
+
latents = self.prepare_latents(
|
446 |
+
batch_size * kwargs['num_in_batch'], # num_images_per_prompt,
|
447 |
+
num_channels_latents,
|
448 |
+
height,
|
449 |
+
width,
|
450 |
+
prompt_embeds.dtype,
|
451 |
+
device,
|
452 |
+
generator,
|
453 |
+
latents,
|
454 |
+
)
|
455 |
+
|
456 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
457 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
458 |
+
|
459 |
+
# 6.1 Add image embeds for IP-Adapter
|
460 |
+
added_cond_kwargs = (
|
461 |
+
{"image_embeds": image_embeds}
|
462 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
463 |
+
else None
|
464 |
+
)
|
465 |
+
|
466 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
467 |
+
timestep_cond = None
|
468 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
469 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
470 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
471 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
472 |
+
).to(device=device, dtype=latents.dtype)
|
473 |
+
|
474 |
+
# 7. Denoising loop
|
475 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
476 |
+
self._num_timesteps = len(timesteps)
|
477 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
478 |
+
for i, t in enumerate(timesteps):
|
479 |
+
if self.interrupt:
|
480 |
+
continue
|
481 |
+
|
482 |
+
# expand the latents if we are doing classifier free guidance
|
483 |
+
latents = rearrange(latents, '(b n) c h w -> b n c h w', n=kwargs['num_in_batch'])
|
484 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
485 |
+
latent_model_input = rearrange(latent_model_input, 'b n c h w -> (b n) c h w')
|
486 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
487 |
+
latent_model_input = rearrange(latent_model_input, '(b n) c h w ->b n c h w', n=kwargs['num_in_batch'])
|
488 |
+
|
489 |
+
# predict the noise residual
|
490 |
+
|
491 |
+
noise_pred = self.unet(
|
492 |
+
latent_model_input,
|
493 |
+
t,
|
494 |
+
encoder_hidden_states=prompt_embeds,
|
495 |
+
timestep_cond=timestep_cond,
|
496 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
497 |
+
added_cond_kwargs=added_cond_kwargs,
|
498 |
+
return_dict=False, **kwargs
|
499 |
+
)[0]
|
500 |
+
latents = rearrange(latents, 'b n c h w -> (b n) c h w')
|
501 |
+
# perform guidance
|
502 |
+
if self.do_classifier_free_guidance:
|
503 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
504 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
505 |
+
|
506 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
507 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
508 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
509 |
+
|
510 |
+
# compute the previous noisy sample x_t -> x_t-1
|
511 |
+
latents = \
|
512 |
+
self.scheduler.step(noise_pred, t, latents[:, :num_channels_latents, :, :], **extra_step_kwargs,
|
513 |
+
return_dict=False)[0]
|
514 |
+
|
515 |
+
if callback_on_step_end is not None:
|
516 |
+
callback_kwargs = {}
|
517 |
+
for k in callback_on_step_end_tensor_inputs:
|
518 |
+
callback_kwargs[k] = locals()[k]
|
519 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
520 |
+
|
521 |
+
latents = callback_outputs.pop("latents", latents)
|
522 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
523 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
524 |
+
|
525 |
+
# call the callback, if provided
|
526 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
527 |
+
progress_bar.update()
|
528 |
+
if callback is not None and i % callback_steps == 0:
|
529 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
530 |
+
callback(step_idx, t, latents)
|
531 |
+
|
532 |
+
if not output_type == "latent":
|
533 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
534 |
+
0
|
535 |
+
]
|
536 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
537 |
+
else:
|
538 |
+
image = latents
|
539 |
+
has_nsfw_concept = None
|
540 |
+
|
541 |
+
if has_nsfw_concept is None:
|
542 |
+
do_denormalize = [True] * image.shape[0]
|
543 |
+
else:
|
544 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
545 |
+
|
546 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
547 |
+
|
548 |
+
# Offload all models
|
549 |
+
self.maybe_free_model_hooks()
|
550 |
+
|
551 |
+
if not return_dict:
|
552 |
+
return (image, has_nsfw_concept)
|
553 |
+
|
554 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|