|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from datetime import date |
|
import json |
|
import os |
|
from pathlib import Path |
|
import safetensors |
|
import safetensors.torch |
|
import torch |
|
import tqdm |
|
from collections import OrderedDict |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("diffusers_path", type=str, |
|
help="Path to the original Flux diffusers folder.") |
|
parser.add_argument("flux_path", type=str, |
|
help="Output path for the Flux safetensors file.") |
|
parser.add_argument("--do_8_bit", action="store_true", |
|
help="Use 8-bit weights with stochastic rounding instead of bf16.") |
|
parser.add_argument("--do_8bit_scaled", action="store_true", |
|
help="Use scaled 8-bit weights instead of bf16.") |
|
args = parser.parse_args() |
|
|
|
flux_path = Path(args.flux_path) |
|
diffusers_path = Path(args.diffusers_path) |
|
|
|
if os.path.exists(os.path.join(diffusers_path, "transformer")): |
|
diffusers_path = Path(os.path.join(diffusers_path, "transformer")) |
|
|
|
do_8_bit = args.do_8_bit |
|
do_8bit_scaled = args.do_8bit_scaled |
|
|
|
|
|
if do_8_bit and do_8bit_scaled: |
|
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.") |
|
exit() |
|
|
|
if not os.path.exists(flux_path.parent): |
|
os.makedirs(flux_path.parent) |
|
|
|
if not diffusers_path.exists(): |
|
print(f"Error: Missing transformer folder: {diffusers_path}") |
|
exit() |
|
|
|
original_json_path = Path.joinpath( |
|
diffusers_path, "diffusion_pytorch_model.safetensors.index.json") |
|
|
|
if not original_json_path.exists(): |
|
print(f"Error: Missing transformer index json: {original_json_path}") |
|
exit() |
|
|
|
with open(original_json_path, "r", encoding="utf-8") as f: |
|
original_json = json.load(f) |
|
|
|
diffusers_map = { |
|
"time_in.in_layer.weight": [ |
|
"time_text_embed.timestep_embedder.linear_1.weight", |
|
], |
|
"time_in.in_layer.bias": [ |
|
"time_text_embed.timestep_embedder.linear_1.bias", |
|
], |
|
"time_in.out_layer.weight": [ |
|
"time_text_embed.timestep_embedder.linear_2.weight", |
|
], |
|
"time_in.out_layer.bias": [ |
|
"time_text_embed.timestep_embedder.linear_2.bias", |
|
], |
|
"vector_in.in_layer.weight": [ |
|
"time_text_embed.text_embedder.linear_1.weight", |
|
], |
|
"vector_in.in_layer.bias": [ |
|
"time_text_embed.text_embedder.linear_1.bias", |
|
], |
|
"vector_in.out_layer.weight": [ |
|
"time_text_embed.text_embedder.linear_2.weight", |
|
], |
|
"vector_in.out_layer.bias": [ |
|
"time_text_embed.text_embedder.linear_2.bias", |
|
], |
|
"guidance_in.in_layer.weight": [ |
|
"time_text_embed.guidance_embedder.linear_1.weight", |
|
], |
|
"guidance_in.in_layer.bias": [ |
|
"time_text_embed.guidance_embedder.linear_1.bias", |
|
], |
|
"guidance_in.out_layer.weight": [ |
|
"time_text_embed.guidance_embedder.linear_2.weight", |
|
], |
|
"guidance_in.out_layer.bias": [ |
|
"time_text_embed.guidance_embedder.linear_2.bias", |
|
], |
|
"txt_in.weight": [ |
|
"context_embedder.weight", |
|
], |
|
"txt_in.bias": [ |
|
"context_embedder.bias", |
|
], |
|
"img_in.weight": [ |
|
"x_embedder.weight", |
|
], |
|
"img_in.bias": [ |
|
"x_embedder.bias", |
|
], |
|
"double_blocks.().img_mod.lin.weight": [ |
|
"norm1.linear.weight", |
|
], |
|
"double_blocks.().img_mod.lin.bias": [ |
|
"norm1.linear.bias", |
|
], |
|
"double_blocks.().txt_mod.lin.weight": [ |
|
"norm1_context.linear.weight", |
|
], |
|
"double_blocks.().txt_mod.lin.bias": [ |
|
"norm1_context.linear.bias", |
|
], |
|
"double_blocks.().img_attn.qkv.weight": [ |
|
"attn.to_q.weight", |
|
"attn.to_k.weight", |
|
"attn.to_v.weight", |
|
], |
|
"double_blocks.().img_attn.qkv.bias": [ |
|
"attn.to_q.bias", |
|
"attn.to_k.bias", |
|
"attn.to_v.bias", |
|
], |
|
"double_blocks.().txt_attn.qkv.weight": [ |
|
"attn.add_q_proj.weight", |
|
"attn.add_k_proj.weight", |
|
"attn.add_v_proj.weight", |
|
], |
|
"double_blocks.().txt_attn.qkv.bias": [ |
|
"attn.add_q_proj.bias", |
|
"attn.add_k_proj.bias", |
|
"attn.add_v_proj.bias", |
|
], |
|
"double_blocks.().img_attn.norm.query_norm.scale": [ |
|
"attn.norm_q.weight", |
|
], |
|
"double_blocks.().img_attn.norm.key_norm.scale": [ |
|
"attn.norm_k.weight", |
|
], |
|
"double_blocks.().txt_attn.norm.query_norm.scale": [ |
|
"attn.norm_added_q.weight", |
|
], |
|
"double_blocks.().txt_attn.norm.key_norm.scale": [ |
|
"attn.norm_added_k.weight", |
|
], |
|
"double_blocks.().img_mlp.0.weight": [ |
|
"ff.net.0.proj.weight", |
|
], |
|
"double_blocks.().img_mlp.0.bias": [ |
|
"ff.net.0.proj.bias", |
|
], |
|
"double_blocks.().img_mlp.2.weight": [ |
|
"ff.net.2.weight", |
|
], |
|
"double_blocks.().img_mlp.2.bias": [ |
|
"ff.net.2.bias", |
|
], |
|
"double_blocks.().txt_mlp.0.weight": [ |
|
"ff_context.net.0.proj.weight", |
|
], |
|
"double_blocks.().txt_mlp.0.bias": [ |
|
"ff_context.net.0.proj.bias", |
|
], |
|
"double_blocks.().txt_mlp.2.weight": [ |
|
"ff_context.net.2.weight", |
|
], |
|
"double_blocks.().txt_mlp.2.bias": [ |
|
"ff_context.net.2.bias", |
|
], |
|
"double_blocks.().img_attn.proj.weight": [ |
|
"attn.to_out.0.weight", |
|
], |
|
"double_blocks.().img_attn.proj.bias": [ |
|
"attn.to_out.0.bias", |
|
], |
|
"double_blocks.().txt_attn.proj.weight": [ |
|
"attn.to_add_out.weight", |
|
], |
|
"double_blocks.().txt_attn.proj.bias": [ |
|
"attn.to_add_out.bias", |
|
], |
|
"single_blocks.().modulation.lin.weight": [ |
|
"norm.linear.weight", |
|
], |
|
"single_blocks.().modulation.lin.bias": [ |
|
"norm.linear.bias", |
|
], |
|
"single_blocks.().linear1.weight": [ |
|
"attn.to_q.weight", |
|
"attn.to_k.weight", |
|
"attn.to_v.weight", |
|
"proj_mlp.weight", |
|
], |
|
"single_blocks.().linear1.bias": [ |
|
"attn.to_q.bias", |
|
"attn.to_k.bias", |
|
"attn.to_v.bias", |
|
"proj_mlp.bias", |
|
], |
|
"single_blocks.().linear2.weight": [ |
|
"proj_out.weight", |
|
], |
|
"single_blocks.().norm.query_norm.scale": [ |
|
"attn.norm_q.weight", |
|
], |
|
"single_blocks.().norm.key_norm.scale": [ |
|
"attn.norm_k.weight", |
|
], |
|
"single_blocks.().linear2.weight": [ |
|
"proj_out.weight", |
|
], |
|
"single_blocks.().linear2.bias": [ |
|
"proj_out.bias", |
|
], |
|
"final_layer.linear.weight": [ |
|
"proj_out.weight", |
|
], |
|
"final_layer.linear.bias": [ |
|
"proj_out.bias", |
|
], |
|
"final_layer.adaLN_modulation.1.weight": [ |
|
"norm_out.linear.weight", |
|
], |
|
"final_layer.adaLN_modulation.1.bias": [ |
|
"norm_out.linear.bias", |
|
], |
|
} |
|
|
|
|
|
def is_in_diffusers_map(k): |
|
for values in diffusers_map.values(): |
|
for value in values: |
|
if k.endswith(value): |
|
return True |
|
return False |
|
|
|
|
|
diffusers = {k: Path.joinpath(diffusers_path, v) |
|
for k, v in original_json["weight_map"].items() if is_in_diffusers_map(k)} |
|
|
|
original_safetensors = set(diffusers.values()) |
|
|
|
|
|
transformer_blocks = 0 |
|
single_transformer_blocks = 0 |
|
for key in diffusers.keys(): |
|
print(key) |
|
if key.startswith("transformer_blocks."): |
|
print(key) |
|
block = int(key.split(".")[1]) |
|
if block >= transformer_blocks: |
|
transformer_blocks = block + 1 |
|
elif key.startswith("single_transformer_blocks."): |
|
block = int(key.split(".")[1]) |
|
if block >= single_transformer_blocks: |
|
single_transformer_blocks = block + 1 |
|
|
|
print(f"Transformer blocks: {transformer_blocks}") |
|
print(f"Single transformer blocks: {single_transformer_blocks}") |
|
|
|
for file in original_safetensors: |
|
if not file.exists(): |
|
print(f"Error: Missing transformer safetensors file: {file}") |
|
exit() |
|
|
|
original_safetensors = {f: safetensors.safe_open( |
|
f, framework="pt", device="cpu") for f in original_safetensors} |
|
|
|
|
|
def swap_scale_shift(weight): |
|
shift, scale = weight.chunk(2, dim=0) |
|
new_weight = torch.cat([scale, shift], dim=0) |
|
return new_weight |
|
|
|
|
|
flux_values = {} |
|
|
|
for b in range(transformer_blocks): |
|
for key, weights in diffusers_map.items(): |
|
if key.startswith("double_blocks."): |
|
block_prefix = f"transformer_blocks.{b}." |
|
found = True |
|
for weight in weights: |
|
if not (f"{block_prefix}{weight}" in diffusers): |
|
found = False |
|
if found: |
|
flux_values[key.replace("()", f"{b}")] = [ |
|
f"{block_prefix}{weight}" for weight in weights] |
|
for b in range(single_transformer_blocks): |
|
for key, weights in diffusers_map.items(): |
|
if key.startswith("single_blocks."): |
|
block_prefix = f"single_transformer_blocks.{b}." |
|
found = True |
|
for weight in weights: |
|
if not (f"{block_prefix}{weight}" in diffusers): |
|
found = False |
|
if found: |
|
flux_values[key.replace("()", f"{b}")] = [ |
|
f"{block_prefix}{weight}" for weight in weights] |
|
|
|
for key, weights in diffusers_map.items(): |
|
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): |
|
found = True |
|
for weight in weights: |
|
if not (f"{weight}" in diffusers): |
|
found = False |
|
if found: |
|
flux_values[key] = [f"{weight}" for weight in weights] |
|
|
|
flux = {} |
|
|
|
for key, values in tqdm.tqdm(flux_values.items()): |
|
if len(values) == 1: |
|
flux[key] = original_safetensors[diffusers[values[0]] |
|
].get_tensor(values[0]).to("cpu") |
|
else: |
|
flux[key] = torch.cat( |
|
[ |
|
original_safetensors[diffusers[value] |
|
].get_tensor(value).to("cpu") |
|
for value in values |
|
] |
|
) |
|
|
|
if "norm_out.linear.weight" in diffusers: |
|
flux["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( |
|
original_safetensors[diffusers["norm_out.linear.weight"]].get_tensor( |
|
"norm_out.linear.weight").to("cpu") |
|
) |
|
if "norm_out.linear.bias" in diffusers: |
|
flux["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( |
|
original_safetensors[diffusers["norm_out.linear.bias"]].get_tensor( |
|
"norm_out.linear.bias").to("cpu") |
|
) |
|
|
|
|
|
def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn): |
|
|
|
min_val = torch.finfo(dtype).min |
|
max_val = torch.finfo(dtype).max |
|
|
|
|
|
tensor = torch.clamp(tensor, min_val, max_val) |
|
|
|
|
|
tensor = tensor.float() |
|
|
|
|
|
lower = torch.floor(tensor * 256) / 256 |
|
upper = torch.ceil(tensor * 256) / 256 |
|
|
|
|
|
prob = (tensor - lower) / (upper - lower) |
|
|
|
|
|
rand = torch.rand_like(tensor) |
|
|
|
|
|
rounded = torch.where(rand < prob, upper, lower) |
|
|
|
|
|
return rounded.to(dtype) |
|
|
|
|
|
|
|
blacklist = [] |
|
for key in flux.keys(): |
|
if not key.endswith(".weight") or "embed" in key: |
|
blacklist.append(key) |
|
|
|
|
|
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn): |
|
|
|
min_val = torch.finfo(dtype).min |
|
max_val = torch.finfo(dtype).max |
|
|
|
|
|
if tensor.dim() == 2: |
|
|
|
abs_max = torch.max(torch.abs(tensor)) |
|
scale = abs_max / max_value |
|
|
|
|
|
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype) |
|
|
|
return scaled_tensor, scale |
|
else: |
|
|
|
return tensor.clip(min=min_val, max=max_val).to(dtype), None |
|
|
|
|
|
|
|
if do_8_bit: |
|
print("Converting to 8-bit with stochastic rounding...") |
|
for key in flux.keys(): |
|
flux[key] = stochastic_round_to( |
|
flux[key], torch.float8_e4m3fn).to('cpu') |
|
elif do_8bit_scaled: |
|
print("Converting to scaled 8-bit...") |
|
scales = {} |
|
for key in tqdm.tqdm(flux.keys()): |
|
if key.endswith(".weight") and key not in blacklist: |
|
flux[key], scale = scale_weights_to_8bit(flux[key]) |
|
if scale is not None: |
|
scale_key = key[:-len(".weight")] + ".scale_weight" |
|
scales[scale_key] = scale |
|
else: |
|
|
|
min_val = torch.finfo(torch.float8_e4m3fn).min |
|
max_val = torch.finfo(torch.float8_e4m3fn).max |
|
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu') |
|
|
|
|
|
flux.update(scales) |
|
|
|
|
|
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn) |
|
else: |
|
print("Converting to bfloat16...") |
|
for key in flux.keys(): |
|
flux[key] = flux[key].clone().to('cpu', torch.bfloat16) |
|
|
|
meta = OrderedDict() |
|
meta['format'] = 'pt' |
|
|
|
meta['modelspec.date'] = date.today().strftime("%Y-%m-%d") |
|
|
|
os.makedirs(os.path.dirname(flux_path), exist_ok=True) |
|
|
|
print(f"Saving to {flux_path}") |
|
|
|
safetensors.torch.save_file(flux, flux_path, metadata=meta) |
|
|
|
print("Done.") |