Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import torch
from modules.Utilities import util
from modules.NeuralNetwork import unet
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
"mlp.fc2": "mlp_fc2",
"self_attn.k_proj": "self_attn_k_proj",
"self_attn.q_proj": "self_attn_q_proj",
"self_attn.v_proj": "self_attn_v_proj",
"self_attn.out_proj": "self_attn_out_proj",
}
def load_lora(lora: dict, to_load: dict) -> dict:
"""#### Load a LoRA model.
#### Args:
- `lora` (dict): The LoRA model state dictionary.
- `to_load` (dict): The keys to load from the LoRA model.
#### Returns:
- `dict`: The loaded LoRA model.
"""
patch_dict = {}
loaded_keys = set()
for x in to_load:
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
"{}.dora_scale".format(x)
dora_scale = None
regular_lora = "{}.lora_up.weight".format(x)
"{}_lora.up.weight".format(x)
"{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
"{}.lora_mid.weight".format(x)
if A_name is not None:
mid = None
patch_dict[to_load[x]] = (
"lora",
(lora[A_name], lora[B_name], alpha, mid, dora_scale),
)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
return patch_dict
def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict:
"""#### Get the keys for a LoRA model's CLIP component.
#### Args:
- `model` (torch.nn.Module): The LoRA model.
- `key_map` (dict, optional): The key map. Defaults to {}.
#### Returns:
- `dict`: The keys for the CLIP component.
"""
sdk = model.state_dict().keys()
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
for b in range(32):
for c in LORA_CLIP_MAP:
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(
b, LORA_CLIP_MAP[c]
) # SDXL base
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(
b, c
) # diffusers lora
key_map[lora_key] = k
return key_map
def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict:
"""#### Get the keys for a LoRA model's UNet component.
#### Args:
- `model` (torch.nn.Module): The LoRA model.
- `key_map` (dict, optional): The key map. Defaults to {}.
#### Returns:
- `dict`: The keys for the UNet component.
"""
sdk = model.state_dict().keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora:
diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[: -len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(
p, k[: -len(".weight")].replace(".to_", ".processor.to_")
)
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
return key_map
def load_lora_for_models(
model: object, clip: object, lora: dict, strength_model: float, strength_clip: float
) -> tuple:
"""#### Load a LoRA model for the given models.
#### Args:
- `model` (object): The model.
- `clip` (object): The CLIP model.
- `lora` (dict): The LoRA model state dictionary.
- `strength_model` (float): The strength of the model.
- `strength_clip` (float): The strength of the CLIP model.
#### Returns:
- `tuple`: The new model patcher and CLIP model.
"""
key_map = {}
if model is not None:
key_map = model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
k = set(k)
k1 = set(k1)
return (new_modelpatcher, new_clip)
class LoraLoader:
"""#### Class for loading LoRA models."""
def __init__(self):
"""#### Initialize the LoraLoader class."""
self.loaded_lora = None
def load_lora(
self,
model: object,
clip: object,
lora_name: str,
strength_model: float,
strength_clip: float,
) -> tuple:
"""#### Load a LoRA model.
#### Args:
- `model` (object): The model.
- `clip` (object): The CLIP model.
- `lora_name` (str): The name of the LoRA model.
- `strength_model` (float): The strength of the model.
- `strength_clip` (float): The strength of the CLIP model.
#### Returns:
- `tuple`: The new model patcher and CLIP model.
"""
lora_path = util.get_full_path("loras", lora_name)
lora = None
if lora is None:
lora = util.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = load_lora_for_models(
model, clip, lora, strength_model, strength_clip
)
return (model_lora, clip_lora)