|
import torch
|
|
import safetensors
|
|
from accelerate import init_empty_weights
|
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
|
from safetensors.torch import load_file, save_file
|
|
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
|
from typing import List
|
|
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
|
from library import model_util
|
|
from library import sdxl_original_unet
|
|
from .utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VAE_SCALE_FACTOR = 0.13025
|
|
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
|
|
|
|
|
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
|
DIFFUSERS_SDXL_UNET_CONFIG = {
|
|
"act_fn": "silu",
|
|
"addition_embed_type": "text_time",
|
|
"addition_embed_type_num_heads": 64,
|
|
"addition_time_embed_dim": 256,
|
|
"attention_head_dim": [5, 10, 20],
|
|
"block_out_channels": [320, 640, 1280],
|
|
"center_input_sample": False,
|
|
"class_embed_type": None,
|
|
"class_embeddings_concat": False,
|
|
"conv_in_kernel": 3,
|
|
"conv_out_kernel": 3,
|
|
"cross_attention_dim": 2048,
|
|
"cross_attention_norm": None,
|
|
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
|
"downsample_padding": 1,
|
|
"dual_cross_attention": False,
|
|
"encoder_hid_dim": None,
|
|
"encoder_hid_dim_type": None,
|
|
"flip_sin_to_cos": True,
|
|
"freq_shift": 0,
|
|
"in_channels": 4,
|
|
"layers_per_block": 2,
|
|
"mid_block_only_cross_attention": None,
|
|
"mid_block_scale_factor": 1,
|
|
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
|
"norm_eps": 1e-05,
|
|
"norm_num_groups": 32,
|
|
"num_attention_heads": None,
|
|
"num_class_embeds": None,
|
|
"only_cross_attention": False,
|
|
"out_channels": 4,
|
|
"projection_class_embeddings_input_dim": 2816,
|
|
"resnet_out_scale_factor": 1.0,
|
|
"resnet_skip_time_act": False,
|
|
"resnet_time_scale_shift": "default",
|
|
"sample_size": 128,
|
|
"time_cond_proj_dim": None,
|
|
"time_embedding_act_fn": None,
|
|
"time_embedding_dim": None,
|
|
"time_embedding_type": "positional",
|
|
"timestep_post_act": None,
|
|
"transformer_layers_per_block": [1, 2, 10],
|
|
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
|
"upcast_attention": False,
|
|
"use_linear_projection": True,
|
|
}
|
|
|
|
|
|
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
|
|
|
|
|
|
|
def convert_key(key):
|
|
|
|
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
|
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
|
|
|
if "resblocks" in key:
|
|
|
|
key = key.replace(".resblocks.", ".layers.")
|
|
if ".ln_" in key:
|
|
key = key.replace(".ln_", ".layer_norm")
|
|
elif ".mlp." in key:
|
|
key = key.replace(".c_fc.", ".fc1.")
|
|
key = key.replace(".c_proj.", ".fc2.")
|
|
elif ".attn.out_proj" in key:
|
|
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
|
elif ".attn.in_proj" in key:
|
|
key = None
|
|
else:
|
|
raise ValueError(f"unexpected key in SD: {key}")
|
|
elif ".positional_embedding" in key:
|
|
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
|
elif ".text_projection" in key:
|
|
key = key.replace("text_model.text_projection", "text_projection.weight")
|
|
elif ".logit_scale" in key:
|
|
key = None
|
|
elif ".token_embedding" in key:
|
|
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
|
elif ".ln_final" in key:
|
|
key = key.replace(".ln_final", ".final_layer_norm")
|
|
|
|
elif ".embeddings.position_ids" in key:
|
|
key = None
|
|
return key
|
|
|
|
keys = list(checkpoint.keys())
|
|
new_sd = {}
|
|
for key in keys:
|
|
new_key = convert_key(key)
|
|
if new_key is None:
|
|
continue
|
|
new_sd[new_key] = checkpoint[key]
|
|
|
|
|
|
for key in keys:
|
|
if ".resblocks" in key and ".attn.in_proj_" in key:
|
|
|
|
values = torch.chunk(checkpoint[key], 3)
|
|
|
|
key_suffix = ".weight" if "weight" in key else ".bias"
|
|
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
|
key_pfx = key_pfx.replace("_weight", "")
|
|
key_pfx = key_pfx.replace("_bias", "")
|
|
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
|
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
|
|
|
|
|
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
|
|
|
|
|
if "text_projection.weight.weight" in new_sd:
|
|
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
|
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
|
del new_sd["text_projection.weight.weight"]
|
|
|
|
return new_sd, logit_scale
|
|
|
|
|
|
|
|
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
|
|
|
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
|
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
|
|
|
|
|
if not missing_keys and not unexpected_keys:
|
|
for k in list(state_dict.keys()):
|
|
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
|
return "<All keys matched successfully>"
|
|
|
|
|
|
error_msgs: List[str] = []
|
|
if missing_keys:
|
|
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
|
|
if unexpected_keys:
|
|
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
|
|
|
|
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
|
|
|
|
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
|
|
|
|
|
|
|
|
|
if model_util.is_safetensors(ckpt_path):
|
|
checkpoint = None
|
|
if disable_mmap:
|
|
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
|
|
else:
|
|
try:
|
|
state_dict = load_file(ckpt_path, device=map_location)
|
|
except:
|
|
state_dict = load_file(ckpt_path)
|
|
epoch = None
|
|
global_step = None
|
|
else:
|
|
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
|
if "state_dict" in checkpoint:
|
|
state_dict = checkpoint["state_dict"]
|
|
epoch = checkpoint.get("epoch", 0)
|
|
global_step = checkpoint.get("global_step", 0)
|
|
else:
|
|
state_dict = checkpoint
|
|
epoch = 0
|
|
global_step = 0
|
|
checkpoint = None
|
|
|
|
|
|
logger.info("building U-Net")
|
|
with init_empty_weights():
|
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
|
|
|
logger.info("loading U-Net from checkpoint")
|
|
unet_sd = {}
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith("model.diffusion_model."):
|
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
|
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
|
logger.info(f"U-Net: {info}")
|
|
|
|
|
|
logger.info("building text encoders")
|
|
|
|
|
|
text_model1_cfg = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=768,
|
|
intermediate_size=3072,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
max_position_embeddings=77,
|
|
hidden_act="quick_gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=768,
|
|
|
|
|
|
)
|
|
with init_empty_weights():
|
|
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
|
|
|
|
|
|
|
text_model2_cfg = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=1280,
|
|
intermediate_size=5120,
|
|
num_hidden_layers=32,
|
|
num_attention_heads=20,
|
|
max_position_embeddings=77,
|
|
hidden_act="gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=1280,
|
|
|
|
|
|
)
|
|
with init_empty_weights():
|
|
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
|
|
|
logger.info("loading text encoders from checkpoint")
|
|
te1_sd = {}
|
|
te2_sd = {}
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith("conditioner.embedders.0.transformer."):
|
|
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
|
elif k.startswith("conditioner.embedders.1.model."):
|
|
te2_sd[k] = state_dict.pop(k)
|
|
|
|
|
|
if "text_model.embeddings.position_ids" in te1_sd:
|
|
te1_sd.pop("text_model.embeddings.position_ids")
|
|
|
|
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location)
|
|
logger.info(f"text encoder 1: {info1}")
|
|
|
|
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
|
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location)
|
|
logger.info(f"text encoder 2: {info2}")
|
|
|
|
|
|
logger.info("building VAE")
|
|
vae_config = model_util.create_vae_diffusers_config()
|
|
with init_empty_weights():
|
|
vae = AutoencoderKL(**vae_config)
|
|
|
|
logger.info("loading VAE from checkpoint")
|
|
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
|
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
|
logger.info(f"VAE: {info}")
|
|
|
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
|
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
|
|
|
|
|
def make_unet_conversion_map():
|
|
unet_conversion_map_layer = []
|
|
|
|
for i in range(3):
|
|
|
|
for j in range(2):
|
|
|
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
|
|
if i < 3:
|
|
|
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
|
|
for j in range(3):
|
|
|
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
|
|
|
|
|
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
|
|
if i < 3:
|
|
|
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
|
|
|
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}."
|
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
|
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
sd_mid_atn_prefix = "middle_block.1."
|
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
|
|
for j in range(2):
|
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
|
|
unet_conversion_map_resnet = [
|
|
|
|
("in_layers.0.", "norm1."),
|
|
("in_layers.2.", "conv1."),
|
|
("out_layers.0.", "norm2."),
|
|
("out_layers.3.", "conv2."),
|
|
("emb_layers.1.", "time_emb_proj."),
|
|
("skip_connection.", "conv_shortcut."),
|
|
]
|
|
|
|
unet_conversion_map = []
|
|
for sd, hf in unet_conversion_map_layer:
|
|
if "resnets" in hf:
|
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
|
else:
|
|
unet_conversion_map.append((sd, hf))
|
|
|
|
for j in range(2):
|
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
|
|
|
for j in range(2):
|
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
|
|
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
|
|
|
return unet_conversion_map
|
|
|
|
|
|
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
|
unet_conversion_map = make_unet_conversion_map()
|
|
|
|
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
|
return convert_unet_state_dict(du_sd, conversion_map)
|
|
|
|
|
|
def convert_unet_state_dict(src_sd, conversion_map):
|
|
converted_sd = {}
|
|
for src_key, value in src_sd.items():
|
|
|
|
src_key_fragments = src_key.split(".")[:-1]
|
|
while len(src_key_fragments) > 0:
|
|
src_key_prefix = ".".join(src_key_fragments) + "."
|
|
if src_key_prefix in conversion_map:
|
|
converted_prefix = conversion_map[src_key_prefix]
|
|
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
|
converted_sd[converted_key] = value
|
|
break
|
|
src_key_fragments.pop(-1)
|
|
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
|
|
|
return converted_sd
|
|
|
|
|
|
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
|
unet_conversion_map = make_unet_conversion_map()
|
|
|
|
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
|
return convert_unet_state_dict(sd, conversion_dict)
|
|
|
|
|
|
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
|
def convert_key(key):
|
|
|
|
if ".position_ids" in key:
|
|
return None
|
|
|
|
|
|
key = key.replace("text_model.encoder.", "transformer.")
|
|
key = key.replace("text_model.", "")
|
|
if "layers" in key:
|
|
|
|
key = key.replace(".layers.", ".resblocks.")
|
|
if ".layer_norm" in key:
|
|
key = key.replace(".layer_norm", ".ln_")
|
|
elif ".mlp." in key:
|
|
key = key.replace(".fc1.", ".c_fc.")
|
|
key = key.replace(".fc2.", ".c_proj.")
|
|
elif ".self_attn.out_proj" in key:
|
|
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
|
elif ".self_attn." in key:
|
|
key = None
|
|
else:
|
|
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
|
elif ".position_embedding" in key:
|
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
|
elif ".token_embedding" in key:
|
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
|
elif "text_projection" in key:
|
|
key = key.replace("text_projection.weight", "text_projection")
|
|
elif "final_layer_norm" in key:
|
|
key = key.replace("final_layer_norm", "ln_final")
|
|
return key
|
|
|
|
keys = list(checkpoint.keys())
|
|
new_sd = {}
|
|
for key in keys:
|
|
new_key = convert_key(key)
|
|
if new_key is None:
|
|
continue
|
|
new_sd[new_key] = checkpoint[key]
|
|
|
|
|
|
for key in keys:
|
|
if "layers" in key and "q_proj" in key:
|
|
|
|
key_q = key
|
|
key_k = key.replace("q_proj", "k_proj")
|
|
key_v = key.replace("q_proj", "v_proj")
|
|
|
|
value_q = checkpoint[key_q]
|
|
value_k = checkpoint[key_k]
|
|
value_v = checkpoint[key_v]
|
|
value = torch.cat([value_q, value_k, value_v])
|
|
|
|
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
|
new_sd[new_key] = value
|
|
|
|
if logit_scale is not None:
|
|
new_sd["logit_scale"] = logit_scale
|
|
|
|
return new_sd
|
|
|
|
|
|
def save_stable_diffusion_checkpoint(
|
|
output_file,
|
|
text_encoder1,
|
|
text_encoder2,
|
|
unet,
|
|
epochs,
|
|
steps,
|
|
ckpt_info,
|
|
vae,
|
|
logit_scale,
|
|
metadata,
|
|
save_dtype=None,
|
|
):
|
|
state_dict = {}
|
|
|
|
def update_sd(prefix, sd):
|
|
for k, v in sd.items():
|
|
key = prefix + k
|
|
if save_dtype is not None:
|
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
|
state_dict[key] = v
|
|
|
|
|
|
update_sd("model.diffusion_model.", unet.state_dict())
|
|
|
|
|
|
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
|
|
|
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
|
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
|
|
|
|
|
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
|
update_sd("first_stage_model.", vae_dict)
|
|
|
|
|
|
key_count = len(state_dict.keys())
|
|
new_ckpt = {"state_dict": state_dict}
|
|
|
|
|
|
if ckpt_info is not None:
|
|
epochs += ckpt_info[0]
|
|
steps += ckpt_info[1]
|
|
|
|
new_ckpt["epoch"] = epochs
|
|
new_ckpt["global_step"] = steps
|
|
|
|
if model_util.is_safetensors(output_file):
|
|
save_file(state_dict, output_file, metadata)
|
|
else:
|
|
torch.save(new_ckpt, output_file)
|
|
|
|
return key_count
|
|
|
|
|
|
def save_diffusers_checkpoint(
|
|
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
|
):
|
|
from diffusers import StableDiffusionXLPipeline
|
|
|
|
|
|
unet_sd = unet.state_dict()
|
|
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
|
|
|
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
|
if save_dtype is not None:
|
|
diffusers_unet.to(save_dtype)
|
|
diffusers_unet.load_state_dict(du_unet_sd)
|
|
|
|
|
|
if pretrained_model_name_or_path is None:
|
|
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
|
|
|
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
|
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
|
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
|
if vae is None:
|
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
|
|
|
|
|
def remove_name_or_path(model):
|
|
if hasattr(model, "config"):
|
|
model.config._name_or_path = None
|
|
model.config._name_or_path = None
|
|
|
|
remove_name_or_path(diffusers_unet)
|
|
remove_name_or_path(text_encoder1)
|
|
remove_name_or_path(text_encoder2)
|
|
remove_name_or_path(scheduler)
|
|
remove_name_or_path(tokenizer1)
|
|
remove_name_or_path(tokenizer2)
|
|
remove_name_or_path(vae)
|
|
|
|
pipeline = StableDiffusionXLPipeline(
|
|
unet=diffusers_unet,
|
|
text_encoder=text_encoder1,
|
|
text_encoder_2=text_encoder2,
|
|
vae=vae,
|
|
scheduler=scheduler,
|
|
tokenizer=tokenizer1,
|
|
tokenizer_2=tokenizer2,
|
|
)
|
|
if save_dtype is not None:
|
|
pipeline.to(None, save_dtype)
|
|
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
|
|