| import json
|
| import torch
|
| from safetensors.torch import load_file, save_file
|
| from pathlib import Path
|
| import gc
|
| import gguf
|
| from dequant import dequantize_tensor
|
|
|
| import os
|
| import argparse
|
| import gradio as gr
|
|
|
| import spaces
|
|
|
| flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
|
| flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
|
| system_temp_dir = "temp"
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| torch.set_grad_enabled(False)
|
|
|
| GGUF_QTYPE = [gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q5_1,
|
| gguf.GGMLQuantizationType.Q5_0, gguf.GGMLQuantizationType.Q4_1,
|
| gguf.GGMLQuantizationType.Q4_0, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
|
|
| TORCH_DTYPE = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half,
|
| torch.bfloat16, torch.complex32, torch.chalf, torch.complex64, torch.cfloat,
|
| torch.complex128, torch.cdouble, torch.uint8, torch.uint16, torch.uint32, torch.uint64,
|
| torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,
|
| torch.bool, torch.float8_e4m3fn, torch.float8_e5m2]
|
|
|
| TORCH_QUANTIZED_DTYPE = [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]
|
|
|
| def list_sub(a, b):
|
| return [e for e in a if e not in b]
|
|
|
| def is_repo_name(s):
|
| import re
|
| return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
|
|
|
| def print_resource_usage():
|
| import psutil
|
| cpu_usage = psutil.cpu_percent()
|
| ram_usage = psutil.virtual_memory().used / psutil.virtual_memory().total * 100
|
| print(f"CPU usage: {cpu_usage}% / RAM usage: {ram_usage}%")
|
|
|
| def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc="Start downloading...")
|
| url = url.strip()
|
| if "drive.google.com" in url:
|
| original_dir = os.getcwd()
|
| os.chdir(directory)
|
| os.system(f"gdown --fuzzy {url}")
|
| os.chdir(original_dir)
|
| elif "huggingface.co" in url:
|
| url = url.replace("?download=true", "")
|
| if "/blob/" in url:
|
| url = url.replace("/blob/", "/resolve/")
|
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
| else:
|
| os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
| elif "civitai.com" in url:
|
| if "?" in url:
|
| url = url.split("?")[0]
|
| if civitai_api_key:
|
| url = url + f"?token={civitai_api_key}"
|
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
| else:
|
| print("You need an API key to download Civitai models.")
|
| else:
|
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
|
| def get_local_model_list(dir_path):
|
| model_list = []
|
| valid_extensions = ('.safetensors')
|
| for file in Path(dir_path).glob("*"):
|
| if file.suffix in valid_extensions:
|
| file_path = str(Path(f"{dir_path}/{file.name}"))
|
| model_list.append(file_path)
|
| return model_list
|
|
|
| def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
| if not "http" in url and is_repo_name(url) and not Path(url).exists():
|
| print(f"Use HF Repo: {url}")
|
| new_file = url
|
| elif not "http" in url and Path(url).exists():
|
| print(f"Use local file: {url}")
|
| new_file = url
|
| elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
|
| print(f"File to download alreday exists: {url}")
|
| new_file = f"{temp_dir}/{url.split('/')[-1]}"
|
| else:
|
| print(f"Start downloading: {url}")
|
| before = get_local_model_list(temp_dir)
|
| try:
|
| download_thing(temp_dir, url.strip(), civitai_key)
|
| except Exception:
|
| print(f"Download failed: {url}")
|
| return ""
|
| after = get_local_model_list(temp_dir)
|
| new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
|
| if not new_file:
|
| print(f"Download failed: {url}")
|
| return ""
|
| print(f"Download completed: {url}")
|
| return new_file
|
|
|
| def save_readme_md(dir, url):
|
| orig_url = ""
|
| if "http" in url:
|
| orig_url = url
|
| if orig_url:
|
| md = f"""---
|
| license: other
|
| license_name: flux-1-dev-non-commercial-license
|
| license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
|
| language:
|
| - en
|
| library_name: diffusers
|
| pipeline_tag: text-to-image
|
| tags:
|
| - text-to-image
|
| - Flux
|
| ---
|
| Converted from [{orig_url}]({orig_url}).
|
| """
|
| else:
|
| md = f"""---
|
| license: other
|
| license_name: flux-1-dev-non-commercial-license
|
| license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
|
| language:
|
| - en
|
| library_name: diffusers
|
| pipeline_tag: text-to-image
|
| tags:
|
| - text-to-image
|
| - Flux
|
| ---
|
| """
|
| path = str(Path(dir, "README.md"))
|
| with open(path, mode='w', encoding="utf-8") as f:
|
| f.write(md)
|
|
|
| def is_repo_exists(repo_id):
|
| from huggingface_hub import HfApi
|
| api = HfApi()
|
| try:
|
| if api.repo_exists(repo_id=repo_id): return True
|
| else: return False
|
| except Exception as e:
|
| print(f"Error: Failed to connect {repo_id}. ")
|
| return True
|
|
|
| def create_diffusers_repo(new_repo_id, diffusers_folder, progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import HfApi
|
| import os
|
| hf_token = os.environ.get("HF_TOKEN")
|
| api = HfApi()
|
| try:
|
| progress(0, desc="Start uploading...")
|
| api.create_repo(repo_id=new_repo_id, token=hf_token, private=True, exist_ok=True)
|
| for path in Path(diffusers_folder).glob("*"):
|
| if path.is_dir():
|
| api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
|
| elif path.is_file():
|
| api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
|
| progress(1, desc="Uploaded.")
|
| url = f"https://huggingface.co/{new_repo_id}"
|
| except Exception as e:
|
| print(f"Error: Failed to upload to {new_repo_id}. ")
|
| print(e)
|
| return ""
|
| return url
|
|
|
|
|
|
|
|
|
| with torch.no_grad(), torch.autocast(device):
|
| @torch.jit.script
|
| def swap_scale_shift(weight):
|
| shift, scale = weight.chunk(2, dim=0)
|
| new_weight = torch.cat([scale, shift], dim=0)
|
| return new_weight
|
|
|
| with torch.no_grad(), torch.autocast(device):
|
| def convert_flux_transformer_checkpoint_to_diffusers(
|
| original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0,
|
| progress=gr.Progress(track_tqdm=True)):
|
| def conv(cdict: dict, odict: dict, ckey: str, okey: str):
|
| if okey in odict.keys():
|
| progress(0, desc=f"Converting {okey} => {ckey}")
|
| print(f"Converting {okey} => {ckey}")
|
| cdict[ckey] = odict.pop(okey)
|
| gc.collect()
|
|
|
| def convswap(cdict: dict, odict: dict, ckey: str, okey: str):
|
| if okey in odict.keys():
|
| progress(0, desc=f"Converting (swap) {okey} => {ckey}")
|
| print(f"Converting {okey} => {ckey} (swap)")
|
| cdict[ckey] = swap_scale_shift(odict.pop(okey))
|
| gc.collect()
|
|
|
| def convqkv(cdict: dict, odict: dict, i: int):
|
| keys = odict.keys()
|
| if (f"double_blocks.{i}.img_attn.qkv.weight" in keys or f"double_blocks.{i}.txt_attn.qkv.weight" in keys\
|
| or f"double_blocks.{i}.img_attn.qkv.bias" in keys or f"double_blocks.{i}.txt_attn.qkv.bias" in keys)\
|
| and (f"double_blocks.{i}.img_attn.qkv.weight" not in keys or f"double_blocks.{i}.txt_attn.qkv.weight" not in keys\
|
| or f"double_blocks.{i}.img_attn.qkv.bias" not in keys or f"double_blocks.{i}.txt_attn.qkv.bias" not in keys):
|
| progress(0, desc=f"Key error in converting Q, K, V (double_blocks.{i}).")
|
| print(f"Key error in converting Q, K, V (double_blocks.{i}).")
|
| return
|
| progress(0, desc=f"Converting Q, K, V (double_blocks.{i}).")
|
| print(f"Converting Q, K, V (double_blocks.{i}).")
|
| sample_q, sample_k, sample_v = torch.chunk(
|
| odict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
|
| )
|
| context_q, context_k, context_v = torch.chunk(
|
| odict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
| )
|
| sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
| odict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
| )
|
| context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
| odict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
| )
|
| cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
| cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
| cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
| cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
| cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
| cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
| cdict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
| cdict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
| cdict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
| cdict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
| cdict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
| cdict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
| gc.collect()
|
|
|
| def convqkvmlp(cdict: dict, odict: dict, i: int, inner_dim: int, mlp_ratio: float):
|
| keys = odict.keys()
|
| if (f"single_blocks.{i}.linear1.weight" in keys or f"single_blocks.{i}.linear1.bias" in keys)\
|
| and (f"single_blocks.{i}.linear1.weight" not in keys or f"single_blocks.{i}.linear1.bias" not in keys):
|
| progress(0, desc=f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
|
| print(f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
|
| return
|
| progress(0, desc=f"Converting Q, K, V, mlp (single_blocks.{i}).")
|
| print(f"Converting Q, K, V, mlp (single_blocks.{i}).")
|
| mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
| split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
| q, k, v, mlp = torch.split(odict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
| q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
| odict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
| )
|
| cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
| cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
| cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
| cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
| cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
| cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
| cdict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
| cdict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
| gc.collect()
|
|
|
| converted_state_dict = {}
|
| progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
|
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.bias", "time_in.in_layer.bias")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.weight", "time_in.out_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.bias", "time_in.out_layer.bias")
|
|
|
|
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.bias", "vector_in.in_layer.bias")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.bias", "vector_in.out_layer.bias")
|
|
|
|
|
| has_guidance = any("guidance" in k for k in original_state_dict)
|
| if has_guidance:
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.bias", "guidance_in.in_layer.bias")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight")
|
| conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.bias", "guidance_in.out_layer.bias")
|
|
|
|
|
| conv(converted_state_dict, original_state_dict, "context_embedder.weight", "txt_in.weight")
|
| conv(converted_state_dict, original_state_dict, "context_embedder.bias", "txt_in.bias")
|
|
|
|
|
| conv(converted_state_dict, original_state_dict, "x_embedder.weight", "img_in.weight")
|
| conv(converted_state_dict, original_state_dict, "x_embedder.bias", "img_in.bias")
|
|
|
| progress(0.25, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
| for i in range(num_layers):
|
| block_prefix = f"transformer_blocks.{i}."
|
|
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.weight", f"double_blocks.{i}.img_mod.lin.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.bias", f"double_blocks.{i}.img_mod.lin.bias")
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.weight", f"double_blocks.{i}.txt_mod.lin.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.bias", f"double_blocks.{i}.txt_mod.lin.bias")
|
|
|
| convqkv(converted_state_dict, original_state_dict, i)
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"double_blocks.{i}.img_attn.norm.query_norm.scale")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"double_blocks.{i}.img_attn.norm.key_norm.scale")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_q.weight", f"double_blocks.{i}.txt_attn.norm.query_norm.scale")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_k.weight", f"double_blocks.{i}.txt_attn.norm.key_norm.scale")
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.weight", f"double_blocks.{i}.img_mlp.0.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.bias", f"double_blocks.{i}.img_mlp.0.bias")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.weight", f"double_blocks.{i}.img_mlp.2.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.bias", f"double_blocks.{i}.img_mlp.2.bias")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{i}.txt_mlp.0.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.bias", f"double_blocks.{i}.txt_mlp.0.bias")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.weight", f"double_blocks.{i}.txt_mlp.2.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.bias", f"double_blocks.{i}.txt_mlp.2.bias")
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.weight", f"double_blocks.{i}.img_attn.proj.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.bias", f"double_blocks.{i}.img_attn.proj.bias")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.weight", f"double_blocks.{i}.txt_attn.proj.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.bias", f"double_blocks.{i}.txt_attn.proj.bias")
|
|
|
| progress(0.5, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
| for i in range(num_single_layers):
|
| block_prefix = f"single_transformer_blocks.{i}."
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.weight", f"single_blocks.{i}.modulation.lin.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.bias", f"single_blocks.{i}.modulation.lin.bias")
|
|
|
| convqkvmlp(converted_state_dict, original_state_dict, i, inner_dim, mlp_ratio)
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"single_blocks.{i}.norm.query_norm.scale")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"single_blocks.{i}.norm.key_norm.scale")
|
|
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.weight", f"single_blocks.{i}.linear2.weight")
|
| conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.bias", f"single_blocks.{i}.linear2.bias")
|
|
|
| progress(0.75, desc="Converting FLUX.1 state dict to Diffusers format.")
|
| conv(converted_state_dict, original_state_dict, "proj_out.weight", "final_layer.linear.weight")
|
| conv(converted_state_dict, original_state_dict, "proj_out.bias", "final_layer.linear.bias")
|
| convswap(converted_state_dict, original_state_dict, "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight")
|
| convswap(converted_state_dict, original_state_dict, "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias")
|
|
|
| progress(1, desc="Converting FLUX.1 state dict to Diffusers format.")
|
| return converted_state_dict
|
|
|
|
|
| def read_safetensors_metadata(path):
|
| with open(path, 'rb') as f:
|
| header_size = int.from_bytes(f.read(8), 'little')
|
| header_json = f.read(header_size).decode('utf-8')
|
| header = json.loads(header_json)
|
| metadata = header.get('__metadata__', {})
|
| return metadata
|
|
|
| def normalize_key(k: str):
|
| return k.replace("vae.", "").replace("model.diffusion_model.", "")\
|
| .replace("text_encoders.clip_l.transformer.text_model.", "")\
|
| .replace("text_encoders.t5xxl.transformer.", "")
|
|
|
| def load_json_list(path: str):
|
| try:
|
| with open(path, encoding='utf-8') as f:
|
| return list(json.load(f))
|
| except Exception as e:
|
| print(e)
|
| return []
|
|
|
|
|
|
|
|
|
| with torch.no_grad():
|
| def to_safetensors(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import save_torch_state_dict
|
| print(f"Saving a temporary file to disk: {path}")
|
| os.makedirs(path, exist_ok=True)
|
| try:
|
| for k, v in sd.items():
|
| sd[k] = v.to(device="cpu")
|
| save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| except Exception as e:
|
| print(e)
|
|
|
|
|
|
|
|
|
|
|
| with torch.no_grad():
|
| def to_safetensors_flux_module(sd: dict, path: str, pattern: str, size: str,
|
| quantization: bool=False, name: str = "",
|
| metadata: dict | None = None, progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import save_torch_state_dict
|
| try:
|
| progress(0, desc=f"Preparing to save FLUX.1 {name} to Diffusers format.")
|
| print(f"Preparing to save FLUX.1 {name} to Diffusers format.")
|
| for k, v in sd.items():
|
| sd[k] = v.to(device="cpu")
|
| progress(0, desc=f"Loading FLUX.1 {name}.")
|
| print(f"Loading FLUX.1 {name}.")
|
| os.makedirs(path, exist_ok=True)
|
| if quantization:
|
| progress(0.5, desc=f"Saving quantized FLUX.1 {name} to {path}")
|
| print(f"Saving quantized FLUX.1 {name} to {path}")
|
| else:
|
| progress(0.5, desc=f"Saving FLUX.1 {name} to: {path}")
|
| print(f"Saving FLUX.1 {name} to: {path}")
|
| if metadata is not None:
|
| save_torch_state_dict(state_dict=sd, save_directory=path,
|
| filename_pattern=pattern, max_shard_size=size, metadata=metadata)
|
| else:
|
| save_torch_state_dict(state_dict=sd, save_directory=path,
|
| filename_pattern=pattern, max_shard_size=size)
|
| progress(1, desc=f"Saved FLUX.1 {name} to: {path}")
|
| print(f"Saved FLUX.1 {name} to: {path}")
|
| except Exception as e:
|
| print(e)
|
| finally:
|
| gc.collect()
|
|
|
| flux_transformer_json = "flux_transformer_keys.json"
|
| flux_t5xxl_json = "flux_t5xxl_keys.json"
|
| flux_clip_json = "flux_clip_keys.json"
|
| flux_vae_json = "flux_vae_keys.json"
|
| keys_flux_t5xxl = set(load_json_list(flux_t5xxl_json))
|
| keys_flux_transformer = set(load_json_list(flux_transformer_json))
|
| keys_flux_clip = set(load_json_list(flux_clip_json))
|
| keys_flux_vae = set(load_json_list(flux_vae_json))
|
|
|
| with torch.no_grad():
|
| def dequant_tensor(v: torch.Tensor, dtype: torch.dtype, dequant: bool):
|
| try:
|
|
|
| if dequant:
|
| qtype = v.tensor_type
|
| if v.dtype in TORCH_DTYPE: return v.to(dtype) if v.dtype != dtype else v
|
| elif qtype in GGUF_QTYPE: return dequantize_tensor(v, dtype)
|
| elif torch.dtype in TORCH_QUANTIZED_DTYPE: return torch.dequantize(v).to(dtype)
|
| else: return torch.dequantize(v).to(dtype)
|
| else: return v.to(dtype) if v.dtype != dtype else v
|
| except Exception as e:
|
| print(e)
|
|
|
| with torch.no_grad():
|
| def normalize_flux_state_dict(path: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
| dequant: bool = False, progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc=f"Loading and normalizing FLUX.1 safetensors: {path}")
|
| print(f"Loading and normalizing FLUX.1 safetensors: {path}")
|
| new_sd = dict()
|
| state_dict = load_file(path, device="cpu")
|
| try:
|
| for k in list(state_dict.keys()):
|
| v = state_dict.pop(k)
|
| nk = normalize_key(k)
|
| print(f"{k} => {nk}")
|
| new_sd[nk] = dequant_tensor(v, dtype, dequant)
|
| except Exception as e:
|
| print(e)
|
| return
|
| finally:
|
| del state_dict
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
|
| metadata = read_safetensors_metadata(path)
|
| progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
|
| print(f"Saving FLUX.1 safetensors: {new_path}")
|
| os.makedirs(savepath, exist_ok=True)
|
| save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
|
| progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
|
| print(f"Saved FLUX.1 safetensors: {new_path}")
|
| del new_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
|
|
| with torch.no_grad():
|
| def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
|
| dequant: bool = False, name: str = "", keys: set = {},
|
| progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc=f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
|
| print(f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
|
| new_sd = dict()
|
| state_dict = load_file(path, device="cpu")
|
| try:
|
| for k in list(state_dict.keys()):
|
| if k not in keys: state_dict.pop(k)
|
| gc.collect()
|
| for k in list(state_dict.keys()):
|
| v = state_dict.pop(k)
|
| if k in keys:
|
| nk = normalize_key(k)
|
| progress(0.5, desc=f"{k} => {nk}")
|
| print(f"{k} => {nk}")
|
| new_sd[nk] = dequant_tensor(v, dtype, dequant)
|
|
|
| except Exception as e:
|
| print(e)
|
| return None
|
| finally:
|
| progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
|
| print(f"Normalized FLUX.1 {name} safetensors: {path}")
|
| del state_dict
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| return new_sd
|
|
|
| with torch.no_grad():
|
| def convert_flux_transformer_sd_to_diffusers(sd: dict, progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
|
| print("Converting FLUX.1 state dict to Diffusers format.")
|
| num_layers = 19
|
| num_single_layers = 38
|
| inner_dim = 3072
|
| mlp_ratio = 4.0
|
| try:
|
| sd = convert_flux_transformer_checkpoint_to_diffusers(
|
| sd, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
|
| )
|
| except Exception as e:
|
| print(e)
|
| finally:
|
| progress(1, desc="Converted FLUX.1 state dict to Diffusers format.")
|
| print("Converted FLUX.1 state dict to Diffusers format.")
|
| gc.collect()
|
| return sd
|
|
|
| with torch.no_grad():
|
| def load_sharded_safetensors(path: str):
|
| import glob
|
| sd = {}
|
| try:
|
| for filepath in glob.glob(f"{path}/*.safetensors"):
|
| sharded_sd = load_file(str(filepath), device="cpu")
|
| for k, v in sharded_sd.items():
|
| sharded_sd[k] = v.to(device="cpu")
|
| sd = sd | sharded_sd.copy()
|
| del sharded_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| except Exception as e:
|
| print(e)
|
| return sd
|
|
|
|
|
| with torch.no_grad():
|
| def convert_flux_transformer_sd_to_diffusers_sharded(sd: dict, path: str, pattern: str,
|
| size: str, progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import save_torch_state_dict
|
| import glob
|
| try:
|
| progress(0, desc=f"Saving temporary files to disk: {path}")
|
| print(f"Saving temporary files to disk: {path}")
|
| os.makedirs(path, exist_ok=True)
|
| for k, v in sd.items():
|
| if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
|
| save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| del sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| print(f"Saved temporary files to disk: {path}")
|
| for filepath in glob.glob(f"{path}/*.safetensors"):
|
| progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
|
| print(f"Processing temporary files: {str(filepath)}")
|
| sharded_sd = load_file(str(filepath), device="cpu")
|
| sharded_sd = convert_flux_transformer_sd_to_diffusers(sharded_sd)
|
| for k, v in sharded_sd.items():
|
| sharded_sd[k] = v.to(device="cpu")
|
| save_file(sharded_sd, str(filepath))
|
| del sharded_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| print(f"Loading temporary files from disk: {path}")
|
| sd = load_sharded_safetensors(path)
|
| print(f"Loaded temporary files from disk: {path}")
|
| except Exception as e:
|
| print(e)
|
| return sd
|
|
|
| with torch.no_grad():
|
| def extract_normalized_flux_state_dict_sharded(loadpath: str, dtype: torch.dtype,
|
| dequant: bool, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import save_torch_state_dict
|
| import glob
|
| try:
|
| progress(0, desc=f"Loading model file: {loadpath}")
|
| print(f"Loading model file: {loadpath}")
|
| sd = load_file(loadpath, device="cpu")
|
| progress(0, desc=f"Saving temporary files to disk: {path}")
|
| print(f"Saving temporary files to disk: {path}")
|
| os.makedirs(path, exist_ok=True)
|
| for k, v in sd.items():
|
| sd[k] = v.to(device="cpu")
|
| save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| del sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| print(f"Saved temporary files to disk: {path}")
|
| for filepath in glob.glob(f"{path}/*.safetensors"):
|
| progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
|
| print(f"Processing temporary files: {str(filepath)}")
|
| sharded_sd = extract_norm_flux_module_sd(str(filepath), dtype, dequant,
|
| "Transformer", keys_flux_transformer)
|
| for k, v in sharded_sd.items():
|
| sharded_sd[k] = v.to(device="cpu")
|
| save_file(sharded_sd, str(filepath))
|
| del sharded_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| print(f"Processed temporary files: {str(filepath)}")
|
| print(f"Loading temporary files from disk: {path}")
|
| sd = load_sharded_safetensors(path)
|
| print(f"Loaded temporary files from disk: {path}")
|
| except Exception as e:
|
| print(e)
|
| return sd
|
|
|
| def download_repo(repo_name, path, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
|
| from huggingface_hub import snapshot_download
|
| print(f"Downloading {repo_name}.")
|
| try:
|
| if "text_encoder_2" in use_original:
|
| snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"])
|
| else:
|
| snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "text_encoder_2/model*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"])
|
| except Exception as e:
|
| print(e)
|
|
|
| def copy_nontensor_files(from_path, to_path, use_original=["vae", "text_encoder"]):
|
| import shutil
|
| if "text_encoder_2" in use_original:
|
| te_from = str(Path(from_path, "text_encoder_2"))
|
| te_to = str(Path(to_path, "text_encoder_2"))
|
| print(f"Copying Text Encoder 2 files {te_from} to {te_to}")
|
| shutil.copytree(te_from, te_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
| if "text_encoder" in use_original:
|
| te1_from = str(Path(from_path, "text_encoder"))
|
| te1_to = str(Path(to_path, "text_encoder"))
|
| print(f"Copying Text Encoder 1 files {te1_from} to {te1_to}")
|
| shutil.copytree(te1_from, te1_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
| if "vae" in use_original:
|
| vae_from = str(Path(from_path, "vae"))
|
| vae_to = str(Path(to_path, "vae"))
|
| print(f"Copying VAE files {vae_from} to {vae_to}")
|
| shutil.copytree(vae_from, vae_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
| tn2_from = str(Path(from_path, "tokenizer_2"))
|
| tn2_to = str(Path(to_path, "tokenizer_2"))
|
| print(f"Copying Tokenizer 2 files {tn2_from} to {tn2_to}")
|
| shutil.copytree(tn2_from, tn2_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
| print(f"Copying non-tensor files {from_path} to {to_path}")
|
| shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp", "*.index.json"), dirs_exist_ok=True)
|
|
|
| def save_flux_other_diffusers(path: str, model_type: str = "dev", use_original: list = ["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
|
| import shutil
|
| progress(0, desc="Loading FLUX.1 Components.")
|
| print("Loading FLUX.1 Components.")
|
| temppath = system_temp_dir
|
| if model_type == "schnell": repo = flux_schnell_repo
|
| else: repo = flux_dev_repo
|
| os.makedirs(temppath, exist_ok=True)
|
| os.makedirs(path, exist_ok=True)
|
| download_repo(repo, temppath, use_original)
|
| progress(0.5, desc="Saving FLUX.1 Components.")
|
| print("Saving FLUX.1 Components.")
|
| copy_nontensor_files(temppath, path, use_original)
|
| shutil.rmtree(temppath)
|
|
|
| with torch.no_grad():
|
| def fix_flux_safetensors(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
| quantization: bool = False, model_type: str = "dev", dequant: bool = False):
|
| save_flux_other_diffusers(savepath, model_type)
|
| normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
|
|
| with torch.no_grad():
|
| def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
| quantization: bool = False, model_type: str = "dev",
|
| dequant: bool = False, use_original: list = ["vae", "text_encoder"],
|
| new_repo_id: str = "", local: bool = False, progress=gr.Progress(track_tqdm=True)):
|
| unet_sd_path = savepath.removesuffix("/") + "/transformer"
|
| unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| unet_sd_size = "10GB"
|
| te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
|
| te_sd_pattern = "model{suffix}.safetensors"
|
| te_sd_size = "5GB"
|
| clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
|
| clip_sd_pattern = "model{suffix}.safetensors"
|
| clip_sd_size = "10GB"
|
| vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| vae_sd_size = "10GB"
|
| metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
| if "vae" not in use_original:
|
| vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| keys_flux_vae)
|
| to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| quantization, "VAE", None)
|
| del vae_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| if "text_encoder" not in use_original:
|
| clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| keys_flux_clip)
|
| to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| quantization, "Text Encoder", None)
|
| del clip_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| if "text_encoder_2" not in use_original:
|
| te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| keys_flux_t5xxl)
|
| to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| quantization, "Text Encoder 2", None)
|
| del te_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
|
| keys_flux_transformer)
|
| if not local: os.remove(loadpath)
|
| to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| quantization, "Transformer", metadata)
|
| del unet_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| save_flux_other_diffusers(savepath, model_type, use_original)
|
|
|
| with torch.no_grad():
|
| def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
| quantization: bool = False, model_type: str = "dev",
|
| dequant: bool = False, use_original: list = ["vae", "text_encoder"],
|
| new_repo_id: str = "", progress=gr.Progress(track_tqdm=True)):
|
| unet_sd_path = savepath.removesuffix("/") + "/transformer"
|
| unet_temp_path = system_temp_dir.removesuffix("/") + "/sharded"
|
| unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| unet_sd_size = "10GB"
|
| unet_temp_size = "5GB"
|
| te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
|
| te_sd_pattern = "model{suffix}.safetensors"
|
| te_sd_size = "5GB"
|
| clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
|
| clip_sd_pattern = "model{suffix}.safetensors"
|
| clip_sd_size = "10GB"
|
| vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| vae_sd_size = "10GB"
|
| metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
| if "vae" not in use_original:
|
| vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| keys_flux_vae)
|
| to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| quantization, "VAE", None)
|
| del vae_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| if "text_encoder" not in use_original:
|
| clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| keys_flux_clip)
|
| to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| quantization, "Text Encoder", None)
|
| del clip_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| if "text_encoder_2" not in use_original:
|
| te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| keys_flux_t5xxl)
|
| to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| quantization, "Text Encoder 2", None)
|
| del te_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
|
| unet_temp_path, unet_sd_pattern, unet_temp_size)
|
| unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
|
| unet_sd_pattern, unet_temp_size)
|
| to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| quantization, "Transformer", metadata)
|
| del unet_sd
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| save_flux_other_diffusers(savepath, model_type, use_original)
|
|
|
| def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
| model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
|
| hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc="Start converting...")
|
| temp_dir = "."
|
| new_file = get_download_file(temp_dir, url, civitai_key)
|
| if not new_file:
|
| print(f"Not found: {url}")
|
| return ""
|
| new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_")
|
|
|
| dtype = torch.bfloat16
|
| quantization = False
|
| if data_type == "fp8": dtype = torch.float8_e4m3fn
|
| elif data_type == "fp16": dtype = torch.float16
|
| elif data_type == "qfloat8":
|
| dtype = torch.bfloat16
|
| quantization = True
|
| else: dtype = torch.bfloat16
|
|
|
| new_repo_id = f"{hf_user}/{Path(new_repo_name).stem}"
|
| if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
|
| flux_to_diffusers_lowmem(new_file, new_repo_name, dtype, quantization, model_type, dequant, use_original, new_repo_id)
|
|
|
| """if is_upload_sf:
|
| import shutil
|
| shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve()))
|
| else: os.remove(new_file)"""
|
|
|
| progress(1, desc="Converted.")
|
| q.put(new_repo_name)
|
| return new_repo_name
|
|
|
| def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
| model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc="Start converting...")
|
| temp_dir = "."
|
| new_file = get_download_file(temp_dir, url, civitai_key)
|
| if not new_file:
|
| print(f"Not found: {url}")
|
| return ""
|
| new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_")
|
|
|
| dtype = torch.bfloat16
|
| quantization = False
|
| if data_type == "fp8": dtype = torch.float8_e4m3fn
|
| elif data_type == "fp16": dtype = torch.float16
|
| elif data_type == "qfloat8":
|
| dtype = torch.bfloat16
|
| quantization = True
|
| else: dtype = torch.bfloat16
|
|
|
| fix_flux_safetensors(new_file, new_repo_name, dtype, model_type, dequant)
|
|
|
| os.remove(new_file)
|
|
|
| progress(1, desc="Converted.")
|
| q.put(new_repo_name)
|
| return new_repo_name
|
|
|
| def convert_url_to_diffusers_repo_flux(dl_url, hf_user, hf_repo, hf_token, civitai_key="",
|
| is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False,
|
| repo_urls=[], fix_only=False, use_original=["vae", "text_encoder"],
|
| progress=gr.Progress(track_tqdm=True)):
|
| import multiprocessing as mp
|
| import shutil
|
| if not hf_user:
|
| print(f"Invalid user name: {hf_user}")
|
| progress(1, desc=f"Invalid user name: {hf_user}")
|
| return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
| if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token
|
| if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY")
|
| q = mp.Queue()
|
| if fix_only:
|
| p = mp.Process(target=convert_url_to_fixed_flux_safetensors, args=(dl_url, civitai_key,
|
| is_upload_sf, data_type, model_type, dequant, q))
|
|
|
| else:
|
| p = mp.Process(target=convert_url_to_diffusers_flux, args=(dl_url, civitai_key,
|
| is_upload_sf, data_type, model_type, dequant, use_original, hf_user, hf_repo, q))
|
|
|
| p.start()
|
| new_path = q.get()
|
| p.join()
|
| if not new_path: return ""
|
| new_repo_id = f"{hf_user}/{Path(new_path).stem}"
|
| if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
|
| if not is_repo_name(new_repo_id):
|
| print(f"Invalid repo name: {new_repo_id}")
|
| progress(1, desc=f"Invalid repo name: {new_repo_id}")
|
| return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
| if is_repo_exists(new_repo_id):
|
| print(f"Repo already exists: {new_repo_id}")
|
| progress(1, desc=f"Repo already exists: {new_repo_id}")
|
| return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
| save_readme_md(new_path, dl_url)
|
| repo_url = create_diffusers_repo(new_repo_id, new_path)
|
| shutil.rmtree(new_path)
|
| if not repo_urls: repo_urls = []
|
| repo_urls.append(repo_url)
|
| md = "Your new repo:<br>"
|
| for u in repo_urls:
|
| md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
|
| return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md)
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--url", default=None, type=str, required=False, help="URL of the model to convert.")
|
| parser.add_argument("--file", default=None, type=str, required=False, help="Filename of the model to convert.")
|
| parser.add_argument("--fix", action="store_true", help="Only fix the keys of the local model.")
|
| parser.add_argument("--civitai_key", default=None, type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
|
| parser.add_argument("--dtype", type=str, default="fp8")
|
| parser.add_argument("--model", type=str, default="dev")
|
| parser.add_argument("--dequant", action="store_true", help="Dequantize model.")
|
| args = parser.parse_args()
|
| assert (args.url, args.file) != (None, None), "Must provide --url or --file!"
|
|
|
| dtype = torch.bfloat16
|
| quantization = False
|
| if args.dtype == "fp8": dtype = torch.float8_e4m3fn
|
| elif args.dtype == "fp16": dtype = torch.float16
|
| elif args.dtype == "qfloat8":
|
| dtype = torch.bfloat16
|
| quantization = True
|
| else: dtype = torch.bfloat16
|
|
|
| use_original = ["vae", "text_encoder"]
|
| new_repo_id = ""
|
| use_local = True
|
|
|
| if args.file is not None and Path(args.file).exists():
|
| if args.fix: normalize_flux_state_dict(args.file, ".", dtype, args.dequant)
|
| else: flux_to_diffusers_lowmem(args.file, Path(args.file).stem, dtype, quantization,
|
| args.model, args.dequant, use_original, new_repo_id, use_local)
|
| elif args.url is not None:
|
| convert_url_to_diffusers_flux(args.url, args.civitai_key, False, args.dtype, args.model,
|
| args.dequant)
|
|
|