multimodalart's picture
Upload 2025 files
22a452a verified
raw
history blame
41.1 kB
import contextlib
import io
import re
import unittest
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AnimateDiffPipeline,
AnimateDiffVideoToVideoPipeline,
AutoencoderKL,
DDIMScheduler,
MotionAdapter,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.utils.testing_utils import require_torch_accelerator, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
def test_all_is_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_compatible(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformer_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_all_is_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model-00001-of-00002.safetensors",
"text_encoder/model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformers_is_compatible_variant_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_variant_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_no_components(self):
filenames = [
"diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_no_components_only_variants(self):
filenames = [
"diffusion_pytorch_model.fp16.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_is_compatible_mixed_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_is_compatible_variant_and_non_safetensors(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
f"text_encoder/model.{variant}.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_only_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
f"text_encoder/model.{variant}.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "text_encoder/model.safetensors"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "model.safetensors"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_sharded_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
assert model_filenames == variant_filenames
def test_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
"vae/diffusion_pytorch_model.safetensors.index.json",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_sharded_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_downloading_when_no_variant_exists(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_downloading_use_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"text_encoder/model.bin",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(".safetensors" not in f for f in model_filenames)
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "diffusion_pytorch_model.safetensors"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.bin",
"vae/diffusion_pytorch_model.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert {
f"unet/diffusion_pytorch_model.{variant}.bin",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
} == model_filenames
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
allowed_non_variant = "unet"
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_sharded_legacy_variants(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_download_onnx_models(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"vae/model.onnx",
"unet/model.onnx",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
def test_download_flax_models(self):
ignore_patterns = ["*.safetensors", "*.bin"]
filenames = [
"vae/diffusion_flax_model.msgpack",
"unet/diffusion_flax_model.msgpack",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
cross_attention_dim = 8
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=16,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_components_video_generation(self):
cross_attention_dim = 8
block_out_channels = (8, 8)
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=block_out_channels,
layers_per_block=2,
sample_size=8,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=block_out_channels,
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
motion_adapter = MotionAdapter(
block_out_channels=block_out_channels,
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
}
return components
def test_text_to_image(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(torch_device)
inputs = {"prompt": "a cute cat", "num_inference_steps": 2}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_image_to_image(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionImg2ImgPipeline(**components)
pipe.to(torch_device)
image = Image.new("RGB", (32, 32))
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "strength": 0.5, "image": image}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_inpainting(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionInpaintPipeline(**components)
pipe.to(torch_device)
image = Image.new("RGB", (32, 32))
mask = Image.new("RGB", (32, 32))
inputs = {
"prompt": "a cute cat",
"num_inference_steps": 2,
"strength": 0.5,
"image": image,
"mask_image": mask,
}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_text_to_video(self):
components = self.get_dummy_components_video_generation()
pipe = AnimateDiffPipeline(**components)
pipe.to(torch_device)
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "num_frames": 2}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_video_to_video(self):
components = self.get_dummy_components_video_generation()
pipe = AnimateDiffVideoToVideoPipeline(**components)
pipe.to(torch_device)
num_frames = 2
video = [Image.new("RGB", (32, 32))] * num_frames
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "video": video}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@require_torch_accelerator
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
expected_pipe_device = torch.device(f"{torch_device}:0")
expected_pipe_dtype = torch.float64
def get_dummy_components_image_generation(self):
cross_attention_dim = 8
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=16,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def test_deterministic_device(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(device="cpu")
pipe.vae.to(device=torch_device)
pipe.text_encoder.to(device=f"{torch_device}:0")
pipe_device = pipe.device
self.assertEqual(
self.expected_pipe_device,
pipe_device,
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
)
def test_deterministic_dtype(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(dtype=torch.float16)
pipe.vae.to(dtype=torch.float32)
pipe.text_encoder.to(dtype=torch.float64)
pipe_dtype = pipe.dtype
self.assertEqual(
self.expected_pipe_dtype,
pipe_dtype,
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
)