Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from modules.Utilities import util | |
from modules.NeuralNetwork import unet | |
LORA_CLIP_MAP = { | |
"mlp.fc1": "mlp_fc1", | |
"mlp.fc2": "mlp_fc2", | |
"self_attn.k_proj": "self_attn_k_proj", | |
"self_attn.q_proj": "self_attn_q_proj", | |
"self_attn.v_proj": "self_attn_v_proj", | |
"self_attn.out_proj": "self_attn_out_proj", | |
} | |
def load_lora(lora: dict, to_load: dict) -> dict: | |
"""#### Load a LoRA model. | |
#### Args: | |
- `lora` (dict): The LoRA model state dictionary. | |
- `to_load` (dict): The keys to load from the LoRA model. | |
#### Returns: | |
- `dict`: The loaded LoRA model. | |
""" | |
patch_dict = {} | |
loaded_keys = set() | |
for x in to_load: | |
alpha_name = "{}.alpha".format(x) | |
alpha = None | |
if alpha_name in lora.keys(): | |
alpha = lora[alpha_name].item() | |
loaded_keys.add(alpha_name) | |
"{}.dora_scale".format(x) | |
dora_scale = None | |
regular_lora = "{}.lora_up.weight".format(x) | |
"{}_lora.up.weight".format(x) | |
"{}.lora_linear_layer.up.weight".format(x) | |
A_name = None | |
if regular_lora in lora.keys(): | |
A_name = regular_lora | |
B_name = "{}.lora_down.weight".format(x) | |
"{}.lora_mid.weight".format(x) | |
if A_name is not None: | |
mid = None | |
patch_dict[to_load[x]] = ( | |
"lora", | |
(lora[A_name], lora[B_name], alpha, mid, dora_scale), | |
) | |
loaded_keys.add(A_name) | |
loaded_keys.add(B_name) | |
return patch_dict | |
def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict: | |
"""#### Get the keys for a LoRA model's CLIP component. | |
#### Args: | |
- `model` (torch.nn.Module): The LoRA model. | |
- `key_map` (dict, optional): The key map. Defaults to {}. | |
#### Returns: | |
- `dict`: The keys for the CLIP component. | |
""" | |
sdk = model.state_dict().keys() | |
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" | |
for b in range(32): | |
for c in LORA_CLIP_MAP: | |
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) | |
if k in sdk: | |
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) | |
key_map[lora_key] = k | |
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format( | |
b, LORA_CLIP_MAP[c] | |
) # SDXL base | |
key_map[lora_key] = k | |
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format( | |
b, c | |
) # diffusers lora | |
key_map[lora_key] = k | |
return key_map | |
def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict: | |
"""#### Get the keys for a LoRA model's UNet component. | |
#### Args: | |
- `model` (torch.nn.Module): The LoRA model. | |
- `key_map` (dict, optional): The key map. Defaults to {}. | |
#### Returns: | |
- `dict`: The keys for the UNet component. | |
""" | |
sdk = model.state_dict().keys() | |
for k in sdk: | |
if k.startswith("diffusion_model.") and k.endswith(".weight"): | |
key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_") | |
key_map["lora_unet_{}".format(key_lora)] = k | |
key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: | |
diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config) | |
for k in diffusers_keys: | |
if k.endswith(".weight"): | |
unet_key = "diffusion_model.{}".format(diffusers_keys[k]) | |
key_lora = k[: -len(".weight")].replace(".", "_") | |
key_map["lora_unet_{}".format(key_lora)] = unet_key | |
diffusers_lora_prefix = ["", "unet."] | |
for p in diffusers_lora_prefix: | |
diffusers_lora_key = "{}{}".format( | |
p, k[: -len(".weight")].replace(".to_", ".processor.to_") | |
) | |
if diffusers_lora_key.endswith(".to_out.0"): | |
diffusers_lora_key = diffusers_lora_key[:-2] | |
key_map[diffusers_lora_key] = unet_key | |
return key_map | |
def load_lora_for_models( | |
model: object, clip: object, lora: dict, strength_model: float, strength_clip: float | |
) -> tuple: | |
"""#### Load a LoRA model for the given models. | |
#### Args: | |
- `model` (object): The model. | |
- `clip` (object): The CLIP model. | |
- `lora` (dict): The LoRA model state dictionary. | |
- `strength_model` (float): The strength of the model. | |
- `strength_clip` (float): The strength of the CLIP model. | |
#### Returns: | |
- `tuple`: The new model patcher and CLIP model. | |
""" | |
key_map = {} | |
if model is not None: | |
key_map = model_lora_keys_unet(model.model, key_map) | |
if clip is not None: | |
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) | |
loaded = load_lora(lora, key_map) | |
new_modelpatcher = model.clone() | |
k = new_modelpatcher.add_patches(loaded, strength_model) | |
new_clip = clip.clone() | |
k1 = new_clip.add_patches(loaded, strength_clip) | |
k = set(k) | |
k1 = set(k1) | |
return (new_modelpatcher, new_clip) | |
class LoraLoader: | |
"""#### Class for loading LoRA models.""" | |
def __init__(self): | |
"""#### Initialize the LoraLoader class.""" | |
self.loaded_lora = None | |
def load_lora( | |
self, | |
model: object, | |
clip: object, | |
lora_name: str, | |
strength_model: float, | |
strength_clip: float, | |
) -> tuple: | |
"""#### Load a LoRA model. | |
#### Args: | |
- `model` (object): The model. | |
- `clip` (object): The CLIP model. | |
- `lora_name` (str): The name of the LoRA model. | |
- `strength_model` (float): The strength of the model. | |
- `strength_clip` (float): The strength of the CLIP model. | |
#### Returns: | |
- `tuple`: The new model patcher and CLIP model. | |
""" | |
lora_path = util.get_full_path("loras", lora_name) | |
lora = None | |
if lora is None: | |
lora = util.load_torch_file(lora_path, safe_load=True) | |
self.loaded_lora = (lora_path, lora) | |
model_lora, clip_lora = load_lora_for_models( | |
model, clip, lora, strength_model, strength_clip | |
) | |
return (model_lora, clip_lora) | |