File size: 12,768 Bytes
dde5d93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 |
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import inspect
import random
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import (
AutoencoderKL,
AutoPipelineForInpainting,
EulerDiscreteScheduler,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPAGInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
require_torch_gpu,
slow,
torch_device,
)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
enable_full_determinism()
class StableDiffusionXLPAGInpaintPipelineFastTests(
PipelineTesterMixin,
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPAGInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([])
image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"}
)
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components
def get_dummy_components(
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=72 if requires_aesthetics_score else 80, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder if not skip_first_text_encoder else None,
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
"requires_aesthetics_score": requires_aesthetics_score,
}
return components
def get_dummy_inputs(self, device, seed=0):
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"strength": 1.0,
"pag_scale": 0.9,
"output_type": "np",
}
return inputs
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(requires_aesthetics_score=True)
# base pipeline
pipe_sd = StableDiffusionXLInpaintPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
# pag enabled
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(requires_aesthetics_score=True)
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe_pag(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (
1,
64,
64,
3,
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
expected_slice = np.array([0.8366, 0.5513, 0.6105, 0.6213, 0.6957, 0.7400, 0.6614, 0.6102, 0.5239])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
@slow
@require_torch_gpu
class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase):
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "A majestic tiger sitting on a bench",
"generator": generator,
"image": init_image,
"mask_image": mask_image,
"strength": 0.8,
"num_inference_steps": 3,
"guidance_scale": guidance_scale,
"pag_scale": 3.0,
"output_type": "np",
}
return inputs
def test_pag_cfg(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipeline(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 1024, 1024, 3)
expected_slice = np.array(
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
image = pipeline(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 1024, 1024, 3)
expected_slice = np.array(
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
|