Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,562 Bytes
c1bc1cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# extract approximating LoRA by svd from two FLUX models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
from library import flux_utils, sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
from networks import lora_flux
setup_logging()
import logging
logger = logging.getLogger(__name__)
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, state_dict, metadata, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
save_file(state_dict, file_name, metadata=metadata)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
mem_eff_safe_open=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
calc_dtype = torch.float
save_dtype = str_to_dtype(save_precision)
store_device = "cpu"
# open models
lora_weights = {}
if not mem_eff_safe_open:
# use original safetensors.safe_open
open_fn = lambda fn: safe_open(fn, framework="pt")
else:
logger.info("Using memory efficient safe_open")
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
with open_fn(model_org) as f_org:
# filter keys
keys = []
for key in f_org.keys():
if not ("single_block" in key or "double_block" in key):
continue
if ".bias" in key:
continue
if "norm" in key:
continue
keys.append(key)
with open_fn(model_tuned) as f_tuned:
for key in tqdm(keys):
# get tensors and calculate difference
value_o = f_org.get_tensor(key)
value_t = f_tuned.get_tensor(key)
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
del value_o, value_t
# extract LoRA weights
if device:
mat = mat.to(device)
out_dim, in_dim = mat.size()[0:2]
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U.to(store_device, dtype=save_dtype).contiguous()
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
# print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
lora_weights[key] = (U, Vh)
del mat, U, S, Vh
# make state dict for LoRA
lora_sd = {}
for key, (up_weight, down_weight) in lora_weights.items():
lora_name = key.replace(".weight", "").replace(".", "_")
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
# minimum metadata
net_kwargs = {}
metadata = {
"ss_v2": str(False),
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
"ss_network_module": "networks.lora_flux",
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, metadata, save_dtype)
logger.info(f"LoRA weights saved to {save_to}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
)
parser.add_argument(
"--model_org",
type=str,
default=None,
required=True,
help="Original model: safetensors file / 元モデル、safetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
)
parser.add_argument(
"--mem_eff_safe_open",
action="store_true",
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
)
parser.add_argument(
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)"
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
# parser.add_argument(
# "--min_diff",
# type=float,
# default=0.01,
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
# )
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(**vars(args))
|