import argparse from typing import Any, Dict import torch from accelerate import init_empty_weights from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKLHunyuanVideo, FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, ) def remap_norm_scale_shift_(key, state_dict): weight = state_dict.pop(key) shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight def remap_txt_in_(key, state_dict): def rename_key(key): new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") new_key = new_key.replace("txt_in", "context_embedder") new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") new_key = new_key.replace("mlp", "ff") return new_key if "self_attn_qkv" in key: weight = state_dict.pop(key) to_q, to_k, to_v = weight.chunk(3, dim=0) state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v else: state_dict[rename_key(key)] = state_dict.pop(key) def remap_img_attn_qkv_(key, state_dict): weight = state_dict.pop(key) to_q, to_k, to_v = weight.chunk(3, dim=0) state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v def remap_txt_attn_qkv_(key, state_dict): weight = state_dict.pop(key) to_q, to_k, to_v = weight.chunk(3, dim=0) state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v def remap_single_transformer_blocks_(key, state_dict): hidden_size = 3072 if "linear1.weight" in key: linear1_weight = state_dict.pop(key) split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") state_dict[f"{new_key}.attn.to_q.weight"] = q state_dict[f"{new_key}.attn.to_k.weight"] = k state_dict[f"{new_key}.attn.to_v.weight"] = v state_dict[f"{new_key}.proj_mlp.weight"] = mlp elif "linear1.bias" in key: linear1_bias = state_dict.pop(key) split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") state_dict[f"{new_key}.attn.to_q.bias"] = q_bias state_dict[f"{new_key}.attn.to_k.bias"] = k_bias state_dict[f"{new_key}.attn.to_v.bias"] = v_bias state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias else: new_key = key.replace("single_blocks", "single_transformer_blocks") new_key = new_key.replace("linear2", "proj_out") new_key = new_key.replace("q_norm", "attn.norm_q") new_key = new_key.replace("k_norm", "attn.norm_k") state_dict[new_key] = state_dict.pop(key) TRANSFORMER_KEYS_RENAME_DICT = { "img_in": "x_embedder", "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", "img_attn_q_norm": "attn.norm_q", "img_attn_k_norm": "attn.norm_k", "img_attn_proj": "attn.to_out.0", "txt_attn_q_norm": "attn.norm_added_q", "txt_attn_k_norm": "attn.norm_added_k", "txt_attn_proj": "attn.to_add_out", "img_mod.linear": "norm1.linear", "img_norm1": "norm1.norm", "img_norm2": "norm2", "img_mlp": "ff", "txt_mod.linear": "norm1_context.linear", "txt_norm1": "norm1.norm", "txt_norm2": "norm2_context", "txt_mlp": "ff_context", "self_attn_proj": "attn.to_out.0", "modulation.linear": "norm.linear", "pre_norm": "norm.norm", "final_layer.norm_final": "norm_out.norm", "final_layer.linear": "proj_out", "fc1": "net.0.proj", "fc2": "net.2", "input_embedder": "proj_in", } TRANSFORMER_SPECIAL_KEYS_REMAP = { "txt_in": remap_txt_in_, "img_attn_qkv": remap_img_attn_qkv_, "txt_attn_qkv": remap_txt_attn_qkv_, "single_blocks": remap_single_transformer_blocks_, "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, } VAE_KEYS_RENAME_DICT = {} VAE_SPECIAL_KEYS_REMAP = {} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: state_dict[new_key] = state_dict.pop(old_key) def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict if "model" in saved_dict.keys(): state_dict = state_dict["model"] if "module" in saved_dict.keys(): state_dict = state_dict["module"] if "state_dict" in saved_dict.keys(): state_dict = state_dict["state_dict"] return state_dict def convert_transformer(ckpt_path: str): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) with init_empty_weights(): transformer = HunyuanVideoTransformer3DModel() for key in list(original_state_dict.keys()): new_key = key[:] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue handler_fn_inplace(key, original_state_dict) transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer def convert_vae(ckpt_path: str): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) with init_empty_weights(): vae = AutoencoderKLHunyuanVideo() for key in list(original_state_dict.keys()): new_key = key[:] for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue handler_fn_inplace(key, original_state_dict) vae.load_state_dict(original_state_dict, strict=True, assign=True) return vae def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") return parser.parse_args() DTYPE_MAPPING = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } if __name__ == "__main__": args = get_args() transformer = None dtype = DTYPE_MAPPING[args.dtype] if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None assert args.text_encoder_path is not None assert args.tokenizer_path is not None assert args.text_encoder_2_path is not None if args.transformer_ckpt_path is not None: transformer = convert_transformer(args.transformer_ckpt_path) transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path) if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.save_pipeline: text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) pipe = HunyuanVideoPipeline( transformer=transformer, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, scheduler=scheduler, ) pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")