Spaces:
Runtime error
Runtime error
| import argparse | |
| from typing import Any, Dict | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from diffusers import AutoencoderDC | |
| def remap_qkv_(key: str, state_dict: Dict[str, Any]): | |
| qkv = state_dict.pop(key) | |
| q, k, v = torch.chunk(qkv, 3, dim=0) | |
| parent_module, _, _ = key.rpartition(".qkv.conv.weight") | |
| state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() | |
| state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() | |
| state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() | |
| def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): | |
| parent_module, _, _ = key.rpartition(".proj.conv.weight") | |
| state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() | |
| AE_KEYS_RENAME_DICT = { | |
| # common | |
| "main.": "", | |
| "op_list.": "", | |
| "context_module": "attn", | |
| "local_module": "conv_out", | |
| # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 | |
| # If there were more scales, there would be more layers, so a loop would be better to handle this | |
| "aggreg.0.0": "to_qkv_multiscale.0.proj_in", | |
| "aggreg.0.1": "to_qkv_multiscale.0.proj_out", | |
| "depth_conv.conv": "conv_depth", | |
| "inverted_conv.conv": "conv_inverted", | |
| "point_conv.conv": "conv_point", | |
| "point_conv.norm": "norm", | |
| "conv.conv.": "conv.", | |
| "conv1.conv": "conv1", | |
| "conv2.conv": "conv2", | |
| "conv2.norm": "norm", | |
| "proj.norm": "norm_out", | |
| # encoder | |
| "encoder.project_in.conv": "encoder.conv_in", | |
| "encoder.project_out.0.conv": "encoder.conv_out", | |
| "encoder.stages": "encoder.down_blocks", | |
| # decoder | |
| "decoder.project_in.conv": "decoder.conv_in", | |
| "decoder.project_out.0": "decoder.norm_out", | |
| "decoder.project_out.2.conv": "decoder.conv_out", | |
| "decoder.stages": "decoder.up_blocks", | |
| } | |
| AE_F32C32_KEYS = { | |
| # encoder | |
| "encoder.project_in.conv": "encoder.conv_in.conv", | |
| # decoder | |
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | |
| } | |
| AE_F64C128_KEYS = { | |
| # encoder | |
| "encoder.project_in.conv": "encoder.conv_in.conv", | |
| # decoder | |
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | |
| } | |
| AE_F128C512_KEYS = { | |
| # encoder | |
| "encoder.project_in.conv": "encoder.conv_in.conv", | |
| # decoder | |
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | |
| } | |
| AE_SPECIAL_KEYS_REMAP = { | |
| "qkv.conv.weight": remap_qkv_, | |
| "proj.conv.weight": remap_proj_conv_, | |
| } | |
| def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| state_dict = saved_dict | |
| if "model" in saved_dict.keys(): | |
| state_dict = state_dict["model"] | |
| if "module" in saved_dict.keys(): | |
| state_dict = state_dict["module"] | |
| if "state_dict" in saved_dict.keys(): | |
| state_dict = state_dict["state_dict"] | |
| return state_dict | |
| def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| def convert_ae(config_name: str, dtype: torch.dtype): | |
| config = get_ae_config(config_name) | |
| hub_id = f"mit-han-lab/{config_name}" | |
| ckpt_path = hf_hub_download(hub_id, "model.safetensors") | |
| original_state_dict = get_state_dict(load_file(ckpt_path)) | |
| ae = AutoencoderDC(**config).to(dtype=dtype) | |
| for key in list(original_state_dict.keys()): | |
| new_key = key[:] | |
| for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): | |
| new_key = new_key.replace(replace_key, rename_key) | |
| update_state_dict_(original_state_dict, key, new_key) | |
| for key in list(original_state_dict.keys()): | |
| for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): | |
| if special_key not in key: | |
| continue | |
| handler_fn_inplace(key, original_state_dict) | |
| ae.load_state_dict(original_state_dict, strict=True) | |
| return ae | |
| def get_ae_config(name: str): | |
| if name in ["dc-ae-f32c32-sana-1.0"]: | |
| config = { | |
| "latent_channels": 32, | |
| "encoder_block_types": ( | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ), | |
| "decoder_block_types": ( | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ), | |
| "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | |
| "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | |
| "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | |
| "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | |
| "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), | |
| "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], | |
| "downsample_block_type": "conv", | |
| "upsample_block_type": "interpolate", | |
| "decoder_norm_types": "rms_norm", | |
| "decoder_act_fns": "silu", | |
| "scaling_factor": 0.41407, | |
| } | |
| elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: | |
| AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) | |
| config = { | |
| "latent_channels": 32, | |
| "encoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "decoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | |
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | |
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], | |
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], | |
| "encoder_qkv_multiscales": ((), (), (), (), (), ()), | |
| "decoder_qkv_multiscales": ((), (), (), (), (), ()), | |
| "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], | |
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], | |
| } | |
| if name == "dc-ae-f32c32-in-1.0": | |
| config["scaling_factor"] = 0.3189 | |
| elif name == "dc-ae-f32c32-mix-1.0": | |
| config["scaling_factor"] = 0.4552 | |
| elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: | |
| AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) | |
| config = { | |
| "latent_channels": 128, | |
| "encoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "decoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | |
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | |
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], | |
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], | |
| "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), | |
| "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), | |
| "decoder_norm_types": [ | |
| "batch_norm", | |
| "batch_norm", | |
| "batch_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| ], | |
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], | |
| } | |
| if name == "dc-ae-f64c128-in-1.0": | |
| config["scaling_factor"] = 0.2889 | |
| elif name == "dc-ae-f64c128-mix-1.0": | |
| config["scaling_factor"] = 0.4538 | |
| elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: | |
| AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) | |
| config = { | |
| "latent_channels": 512, | |
| "encoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "decoder_block_types": [ | |
| "ResBlock", | |
| "ResBlock", | |
| "ResBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| "EfficientViTBlock", | |
| ], | |
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | |
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | |
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], | |
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], | |
| "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | |
| "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | |
| "decoder_norm_types": [ | |
| "batch_norm", | |
| "batch_norm", | |
| "batch_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| "rms_norm", | |
| ], | |
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], | |
| } | |
| if name == "dc-ae-f128c512-in-1.0": | |
| config["scaling_factor"] = 0.4883 | |
| elif name == "dc-ae-f128c512-mix-1.0": | |
| config["scaling_factor"] = 0.3620 | |
| else: | |
| raise ValueError("Invalid config name provided.") | |
| return config | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config_name", | |
| type=str, | |
| default="dc-ae-f32c32-sana-1.0", | |
| choices=[ | |
| "dc-ae-f32c32-sana-1.0", | |
| "dc-ae-f32c32-in-1.0", | |
| "dc-ae-f32c32-mix-1.0", | |
| "dc-ae-f64c128-in-1.0", | |
| "dc-ae-f64c128-mix-1.0", | |
| "dc-ae-f128c512-in-1.0", | |
| "dc-ae-f128c512-mix-1.0", | |
| ], | |
| help="The DCAE checkpoint to convert", | |
| ) | |
| parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | |
| parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") | |
| return parser.parse_args() | |
| DTYPE_MAPPING = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| } | |
| VARIANT_MAPPING = { | |
| "fp32": None, | |
| "fp16": "fp16", | |
| "bf16": "bf16", | |
| } | |
| if __name__ == "__main__": | |
| args = get_args() | |
| dtype = DTYPE_MAPPING[args.dtype] | |
| variant = VARIANT_MAPPING[args.dtype] | |
| ae = convert_ae(args.config_name, dtype) | |
| ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) | |