Spaces:
Runtime error
Runtime error
| import unittest | |
| from diffusers.pipelines.pipeline_utils import is_safetensors_compatible | |
| class IsSafetensorsCompatibleTests(unittest.TestCase): | |
| def test_all_is_compatible(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.bin", | |
| "safety_checker/model.safetensors", | |
| "vae/diffusion_pytorch_model.bin", | |
| "vae/diffusion_pytorch_model.safetensors", | |
| "text_encoder/pytorch_model.bin", | |
| "text_encoder/model.safetensors", | |
| "unet/diffusion_pytorch_model.bin", | |
| "unet/diffusion_pytorch_model.safetensors", | |
| ] | |
| self.assertTrue(is_safetensors_compatible(filenames)) | |
| def test_diffusers_model_is_compatible(self): | |
| filenames = [ | |
| "unet/diffusion_pytorch_model.bin", | |
| "unet/diffusion_pytorch_model.safetensors", | |
| ] | |
| self.assertTrue(is_safetensors_compatible(filenames)) | |
| def test_diffusers_model_is_not_compatible(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.bin", | |
| "safety_checker/model.safetensors", | |
| "vae/diffusion_pytorch_model.bin", | |
| "vae/diffusion_pytorch_model.safetensors", | |
| "text_encoder/pytorch_model.bin", | |
| "text_encoder/model.safetensors", | |
| "unet/diffusion_pytorch_model.bin", | |
| # Removed: 'unet/diffusion_pytorch_model.safetensors', | |
| ] | |
| self.assertFalse(is_safetensors_compatible(filenames)) | |
| def test_transformer_model_is_compatible(self): | |
| filenames = [ | |
| "text_encoder/pytorch_model.bin", | |
| "text_encoder/model.safetensors", | |
| ] | |
| self.assertTrue(is_safetensors_compatible(filenames)) | |
| def test_transformer_model_is_not_compatible(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.bin", | |
| "safety_checker/model.safetensors", | |
| "vae/diffusion_pytorch_model.bin", | |
| "vae/diffusion_pytorch_model.safetensors", | |
| "text_encoder/pytorch_model.bin", | |
| # Removed: 'text_encoder/model.safetensors', | |
| "unet/diffusion_pytorch_model.bin", | |
| "unet/diffusion_pytorch_model.safetensors", | |
| ] | |
| self.assertFalse(is_safetensors_compatible(filenames)) | |
| def test_all_is_compatible_variant(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.fp16.bin", | |
| "safety_checker/model.fp16.safetensors", | |
| "vae/diffusion_pytorch_model.fp16.bin", | |
| "vae/diffusion_pytorch_model.fp16.safetensors", | |
| "text_encoder/pytorch_model.fp16.bin", | |
| "text_encoder/model.fp16.safetensors", | |
| "unet/diffusion_pytorch_model.fp16.bin", | |
| "unet/diffusion_pytorch_model.fp16.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_diffusers_model_is_compatible_variant(self): | |
| filenames = [ | |
| "unet/diffusion_pytorch_model.fp16.bin", | |
| "unet/diffusion_pytorch_model.fp16.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_diffusers_model_is_compatible_variant_partial(self): | |
| # pass variant but use the non-variant filenames | |
| filenames = [ | |
| "unet/diffusion_pytorch_model.bin", | |
| "unet/diffusion_pytorch_model.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_diffusers_model_is_not_compatible_variant(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.fp16.bin", | |
| "safety_checker/model.fp16.safetensors", | |
| "vae/diffusion_pytorch_model.fp16.bin", | |
| "vae/diffusion_pytorch_model.fp16.safetensors", | |
| "text_encoder/pytorch_model.fp16.bin", | |
| "text_encoder/model.fp16.safetensors", | |
| "unet/diffusion_pytorch_model.fp16.bin", | |
| # Removed: 'unet/diffusion_pytorch_model.fp16.safetensors', | |
| ] | |
| variant = "fp16" | |
| self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_transformer_model_is_compatible_variant(self): | |
| filenames = [ | |
| "text_encoder/pytorch_model.fp16.bin", | |
| "text_encoder/model.fp16.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_transformer_model_is_compatible_variant_partial(self): | |
| # pass variant but use the non-variant filenames | |
| filenames = [ | |
| "text_encoder/pytorch_model.bin", | |
| "text_encoder/model.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) | |
| def test_transformer_model_is_not_compatible_variant(self): | |
| filenames = [ | |
| "safety_checker/pytorch_model.fp16.bin", | |
| "safety_checker/model.fp16.safetensors", | |
| "vae/diffusion_pytorch_model.fp16.bin", | |
| "vae/diffusion_pytorch_model.fp16.safetensors", | |
| "text_encoder/pytorch_model.fp16.bin", | |
| # 'text_encoder/model.fp16.safetensors', | |
| "unet/diffusion_pytorch_model.fp16.bin", | |
| "unet/diffusion_pytorch_model.fp16.safetensors", | |
| ] | |
| variant = "fp16" | |
| self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) | |