Spaces:
Runtime error
Runtime error
import argparse | |
import pathlib | |
from typing import Any, Dict, Tuple | |
import torch | |
from accelerate import init_empty_weights | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from safetensors.torch import load_file | |
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel | |
from diffusers import ( | |
AutoencoderKLWan, | |
UniPCMultistepScheduler, | |
WanImageToVideoPipeline, | |
WanPipeline, | |
WanTransformer3DModel, | |
WanVACEPipeline, | |
WanVACETransformer3DModel, | |
) | |
TRANSFORMER_KEYS_RENAME_DICT = { | |
"time_embedding.0": "condition_embedder.time_embedder.linear_1", | |
"time_embedding.2": "condition_embedder.time_embedder.linear_2", | |
"text_embedding.0": "condition_embedder.text_embedder.linear_1", | |
"text_embedding.2": "condition_embedder.text_embedder.linear_2", | |
"time_projection.1": "condition_embedder.time_proj", | |
"head.modulation": "scale_shift_table", | |
"head.head": "proj_out", | |
"modulation": "scale_shift_table", | |
"ffn.0": "ffn.net.0.proj", | |
"ffn.2": "ffn.net.2", | |
# Hack to swap the layer names | |
# The original model calls the norms in following order: norm1, norm3, norm2 | |
# We convert it to: norm1, norm2, norm3 | |
"norm2": "norm__placeholder", | |
"norm3": "norm2", | |
"norm__placeholder": "norm3", | |
# For the I2V model | |
"img_emb.proj.0": "condition_embedder.image_embedder.norm1", | |
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", | |
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", | |
"img_emb.proj.4": "condition_embedder.image_embedder.norm2", | |
# for the FLF2V model | |
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", | |
# Add attention component mappings | |
"self_attn.q": "attn1.to_q", | |
"self_attn.k": "attn1.to_k", | |
"self_attn.v": "attn1.to_v", | |
"self_attn.o": "attn1.to_out.0", | |
"self_attn.norm_q": "attn1.norm_q", | |
"self_attn.norm_k": "attn1.norm_k", | |
"cross_attn.q": "attn2.to_q", | |
"cross_attn.k": "attn2.to_k", | |
"cross_attn.v": "attn2.to_v", | |
"cross_attn.o": "attn2.to_out.0", | |
"cross_attn.norm_q": "attn2.norm_q", | |
"cross_attn.norm_k": "attn2.norm_k", | |
"attn2.to_k_img": "attn2.add_k_proj", | |
"attn2.to_v_img": "attn2.add_v_proj", | |
"attn2.norm_k_img": "attn2.norm_added_k", | |
} | |
VACE_TRANSFORMER_KEYS_RENAME_DICT = { | |
"time_embedding.0": "condition_embedder.time_embedder.linear_1", | |
"time_embedding.2": "condition_embedder.time_embedder.linear_2", | |
"text_embedding.0": "condition_embedder.text_embedder.linear_1", | |
"text_embedding.2": "condition_embedder.text_embedder.linear_2", | |
"time_projection.1": "condition_embedder.time_proj", | |
"head.modulation": "scale_shift_table", | |
"head.head": "proj_out", | |
"modulation": "scale_shift_table", | |
"ffn.0": "ffn.net.0.proj", | |
"ffn.2": "ffn.net.2", | |
# Hack to swap the layer names | |
# The original model calls the norms in following order: norm1, norm3, norm2 | |
# We convert it to: norm1, norm2, norm3 | |
"norm2": "norm__placeholder", | |
"norm3": "norm2", | |
"norm__placeholder": "norm3", | |
# # For the I2V model | |
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1", | |
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", | |
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", | |
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2", | |
# # for the FLF2V model | |
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", | |
# Add attention component mappings | |
"self_attn.q": "attn1.to_q", | |
"self_attn.k": "attn1.to_k", | |
"self_attn.v": "attn1.to_v", | |
"self_attn.o": "attn1.to_out.0", | |
"self_attn.norm_q": "attn1.norm_q", | |
"self_attn.norm_k": "attn1.norm_k", | |
"cross_attn.q": "attn2.to_q", | |
"cross_attn.k": "attn2.to_k", | |
"cross_attn.v": "attn2.to_v", | |
"cross_attn.o": "attn2.to_out.0", | |
"cross_attn.norm_q": "attn2.norm_q", | |
"cross_attn.norm_k": "attn2.norm_k", | |
"attn2.to_k_img": "attn2.add_k_proj", | |
"attn2.to_v_img": "attn2.add_v_proj", | |
"attn2.norm_k_img": "attn2.norm_added_k", | |
"before_proj": "proj_in", | |
"after_proj": "proj_out", | |
} | |
TRANSFORMER_SPECIAL_KEYS_REMAP = {} | |
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} | |
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | |
state_dict[new_key] = state_dict.pop(old_key) | |
def load_sharded_safetensors(dir: pathlib.Path): | |
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) | |
state_dict = {} | |
for path in file_paths: | |
state_dict.update(load_file(path)) | |
return state_dict | |
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: | |
if model_type == "Wan-T2V-1.3B": | |
config = { | |
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", | |
"diffusers_config": { | |
"added_kv_proj_dim": None, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 8960, | |
"freq_dim": 256, | |
"in_channels": 16, | |
"num_attention_heads": 12, | |
"num_layers": 30, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
}, | |
} | |
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-T2V-14B": | |
config = { | |
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", | |
"diffusers_config": { | |
"added_kv_proj_dim": None, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 13824, | |
"freq_dim": 256, | |
"in_channels": 16, | |
"num_attention_heads": 40, | |
"num_layers": 40, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
}, | |
} | |
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-I2V-14B-480p": | |
config = { | |
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", | |
"diffusers_config": { | |
"image_dim": 1280, | |
"added_kv_proj_dim": 5120, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 13824, | |
"freq_dim": 256, | |
"in_channels": 36, | |
"num_attention_heads": 40, | |
"num_layers": 40, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
}, | |
} | |
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-I2V-14B-720p": | |
config = { | |
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", | |
"diffusers_config": { | |
"image_dim": 1280, | |
"added_kv_proj_dim": 5120, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 13824, | |
"freq_dim": 256, | |
"in_channels": 36, | |
"num_attention_heads": 40, | |
"num_layers": 40, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
}, | |
} | |
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-FLF2V-14B-720P": | |
config = { | |
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder | |
"diffusers_config": { | |
"image_dim": 1280, | |
"added_kv_proj_dim": 5120, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 13824, | |
"freq_dim": 256, | |
"in_channels": 36, | |
"num_attention_heads": 40, | |
"num_layers": 40, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
"rope_max_seq_len": 1024, | |
"pos_embed_seq_len": 257 * 2, | |
}, | |
} | |
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-VACE-1.3B": | |
config = { | |
"model_id": "Wan-AI/Wan2.1-VACE-1.3B", | |
"diffusers_config": { | |
"added_kv_proj_dim": None, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 8960, | |
"freq_dim": 256, | |
"in_channels": 16, | |
"num_attention_heads": 12, | |
"num_layers": 30, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], | |
"vace_in_channels": 96, | |
}, | |
} | |
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP | |
elif model_type == "Wan-VACE-14B": | |
config = { | |
"model_id": "Wan-AI/Wan2.1-VACE-14B", | |
"diffusers_config": { | |
"added_kv_proj_dim": None, | |
"attention_head_dim": 128, | |
"cross_attn_norm": True, | |
"eps": 1e-06, | |
"ffn_dim": 13824, | |
"freq_dim": 256, | |
"in_channels": 16, | |
"num_attention_heads": 40, | |
"num_layers": 40, | |
"out_channels": 16, | |
"patch_size": [1, 2, 2], | |
"qk_norm": "rms_norm_across_heads", | |
"text_dim": 4096, | |
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], | |
"vace_in_channels": 96, | |
}, | |
} | |
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT | |
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP | |
return config, RENAME_DICT, SPECIAL_KEYS_REMAP | |
def convert_transformer(model_type: str): | |
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) | |
diffusers_config = config["diffusers_config"] | |
model_id = config["model_id"] | |
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) | |
original_state_dict = load_sharded_safetensors(model_dir) | |
with init_empty_weights(): | |
if "VACE" not in model_type: | |
transformer = WanTransformer3DModel.from_config(diffusers_config) | |
else: | |
transformer = WanVACETransformer3DModel.from_config(diffusers_config) | |
for key in list(original_state_dict.keys()): | |
new_key = key[:] | |
for replace_key, rename_key in RENAME_DICT.items(): | |
new_key = new_key.replace(replace_key, rename_key) | |
update_state_dict_(original_state_dict, key, new_key) | |
for key in list(original_state_dict.keys()): | |
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): | |
if special_key not in key: | |
continue | |
handler_fn_inplace(key, original_state_dict) | |
transformer.load_state_dict(original_state_dict, strict=True, assign=True) | |
return transformer | |
def convert_vae(): | |
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth") | |
old_state_dict = torch.load(vae_ckpt_path, weights_only=True) | |
new_state_dict = {} | |
# Create mappings for specific components | |
middle_key_mapping = { | |
# Encoder middle block | |
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", | |
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", | |
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", | |
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", | |
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", | |
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", | |
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", | |
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", | |
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", | |
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", | |
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", | |
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", | |
# Decoder middle block | |
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", | |
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", | |
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", | |
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", | |
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", | |
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", | |
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", | |
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", | |
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", | |
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", | |
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", | |
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", | |
} | |
# Create a mapping for attention blocks | |
attention_mapping = { | |
# Encoder middle attention | |
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", | |
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", | |
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", | |
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", | |
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", | |
# Decoder middle attention | |
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", | |
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", | |
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", | |
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", | |
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", | |
} | |
# Create a mapping for the head components | |
head_mapping = { | |
# Encoder head | |
"encoder.head.0.gamma": "encoder.norm_out.gamma", | |
"encoder.head.2.bias": "encoder.conv_out.bias", | |
"encoder.head.2.weight": "encoder.conv_out.weight", | |
# Decoder head | |
"decoder.head.0.gamma": "decoder.norm_out.gamma", | |
"decoder.head.2.bias": "decoder.conv_out.bias", | |
"decoder.head.2.weight": "decoder.conv_out.weight", | |
} | |
# Create a mapping for the quant components | |
quant_mapping = { | |
"conv1.weight": "quant_conv.weight", | |
"conv1.bias": "quant_conv.bias", | |
"conv2.weight": "post_quant_conv.weight", | |
"conv2.bias": "post_quant_conv.bias", | |
} | |
# Process each key in the state dict | |
for key, value in old_state_dict.items(): | |
# Handle middle block keys using the mapping | |
if key in middle_key_mapping: | |
new_key = middle_key_mapping[key] | |
new_state_dict[new_key] = value | |
# Handle attention blocks using the mapping | |
elif key in attention_mapping: | |
new_key = attention_mapping[key] | |
new_state_dict[new_key] = value | |
# Handle head keys using the mapping | |
elif key in head_mapping: | |
new_key = head_mapping[key] | |
new_state_dict[new_key] = value | |
# Handle quant keys using the mapping | |
elif key in quant_mapping: | |
new_key = quant_mapping[key] | |
new_state_dict[new_key] = value | |
# Handle encoder conv1 | |
elif key == "encoder.conv1.weight": | |
new_state_dict["encoder.conv_in.weight"] = value | |
elif key == "encoder.conv1.bias": | |
new_state_dict["encoder.conv_in.bias"] = value | |
# Handle decoder conv1 | |
elif key == "decoder.conv1.weight": | |
new_state_dict["decoder.conv_in.weight"] = value | |
elif key == "decoder.conv1.bias": | |
new_state_dict["decoder.conv_in.bias"] = value | |
# Handle encoder downsamples | |
elif key.startswith("encoder.downsamples."): | |
# Convert to down_blocks | |
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") | |
# Convert residual block naming but keep the original structure | |
if ".residual.0.gamma" in new_key: | |
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") | |
elif ".residual.2.bias" in new_key: | |
new_key = new_key.replace(".residual.2.bias", ".conv1.bias") | |
elif ".residual.2.weight" in new_key: | |
new_key = new_key.replace(".residual.2.weight", ".conv1.weight") | |
elif ".residual.3.gamma" in new_key: | |
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") | |
elif ".residual.6.bias" in new_key: | |
new_key = new_key.replace(".residual.6.bias", ".conv2.bias") | |
elif ".residual.6.weight" in new_key: | |
new_key = new_key.replace(".residual.6.weight", ".conv2.weight") | |
elif ".shortcut.bias" in new_key: | |
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") | |
elif ".shortcut.weight" in new_key: | |
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") | |
new_state_dict[new_key] = value | |
# Handle decoder upsamples | |
elif key.startswith("decoder.upsamples."): | |
# Convert to up_blocks | |
parts = key.split(".") | |
block_idx = int(parts[2]) | |
# Group residual blocks | |
if "residual" in key: | |
if block_idx in [0, 1, 2]: | |
new_block_idx = 0 | |
resnet_idx = block_idx | |
elif block_idx in [4, 5, 6]: | |
new_block_idx = 1 | |
resnet_idx = block_idx - 4 | |
elif block_idx in [8, 9, 10]: | |
new_block_idx = 2 | |
resnet_idx = block_idx - 8 | |
elif block_idx in [12, 13, 14]: | |
new_block_idx = 3 | |
resnet_idx = block_idx - 12 | |
else: | |
# Keep as is for other blocks | |
new_state_dict[key] = value | |
continue | |
# Convert residual block naming | |
if ".residual.0.gamma" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" | |
elif ".residual.2.bias" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" | |
elif ".residual.2.weight" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" | |
elif ".residual.3.gamma" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" | |
elif ".residual.6.bias" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" | |
elif ".residual.6.weight" in key: | |
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" | |
else: | |
new_key = key | |
new_state_dict[new_key] = value | |
# Handle shortcut connections | |
elif ".shortcut." in key: | |
if block_idx == 4: | |
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") | |
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") | |
else: | |
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | |
new_key = new_key.replace(".shortcut.", ".conv_shortcut.") | |
new_state_dict[new_key] = value | |
# Handle upsamplers | |
elif ".resample." in key or ".time_conv." in key: | |
if block_idx == 3: | |
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") | |
elif block_idx == 7: | |
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") | |
elif block_idx == 11: | |
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") | |
else: | |
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | |
new_state_dict[new_key] = value | |
else: | |
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | |
new_state_dict[new_key] = value | |
else: | |
# Keep other keys unchanged | |
new_state_dict[key] = value | |
with init_empty_weights(): | |
vae = AutoencoderKLWan() | |
vae.load_state_dict(new_state_dict, strict=True, assign=True) | |
return vae | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_type", type=str, default=None) | |
parser.add_argument("--output_path", type=str, required=True) | |
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"]) | |
return parser.parse_args() | |
DTYPE_MAPPING = { | |
"fp32": torch.float32, | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
} | |
if __name__ == "__main__": | |
args = get_args() | |
transformer = convert_transformer(args.model_type) | |
vae = convert_vae() | |
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") | |
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 | |
scheduler = UniPCMultistepScheduler( | |
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift | |
) | |
# If user has specified "none", we keep the original dtypes of the state dict without any conversion | |
if args.dtype != "none": | |
dtype = DTYPE_MAPPING[args.dtype] | |
transformer.to(dtype) | |
if "I2V" in args.model_type or "FLF2V" in args.model_type: | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 | |
) | |
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
pipe = WanImageToVideoPipeline( | |
transformer=transformer, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
vae=vae, | |
scheduler=scheduler, | |
image_encoder=image_encoder, | |
image_processor=image_processor, | |
) | |
elif "VACE" in args.model_type: | |
pipe = WanVACEPipeline( | |
transformer=transformer, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
vae=vae, | |
scheduler=scheduler, | |
) | |
else: | |
pipe = WanPipeline( | |
transformer=transformer, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
vae=vae, | |
scheduler=scheduler, | |
) | |
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | |