Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from modules.Utilities import util | |
| from modules.NeuralNetwork import unet | |
| LORA_CLIP_MAP = { | |
| "mlp.fc1": "mlp_fc1", | |
| "mlp.fc2": "mlp_fc2", | |
| "self_attn.k_proj": "self_attn_k_proj", | |
| "self_attn.q_proj": "self_attn_q_proj", | |
| "self_attn.v_proj": "self_attn_v_proj", | |
| "self_attn.out_proj": "self_attn_out_proj", | |
| } | |
| def load_lora(lora: dict, to_load: dict) -> dict: | |
| """#### Load a LoRA model. | |
| #### Args: | |
| - `lora` (dict): The LoRA model state dictionary. | |
| - `to_load` (dict): The keys to load from the LoRA model. | |
| #### Returns: | |
| - `dict`: The loaded LoRA model. | |
| """ | |
| patch_dict = {} | |
| loaded_keys = set() | |
| for x in to_load: | |
| alpha_name = "{}.alpha".format(x) | |
| alpha = None | |
| if alpha_name in lora.keys(): | |
| alpha = lora[alpha_name].item() | |
| loaded_keys.add(alpha_name) | |
| "{}.dora_scale".format(x) | |
| dora_scale = None | |
| regular_lora = "{}.lora_up.weight".format(x) | |
| "{}_lora.up.weight".format(x) | |
| "{}.lora_linear_layer.up.weight".format(x) | |
| A_name = None | |
| if regular_lora in lora.keys(): | |
| A_name = regular_lora | |
| B_name = "{}.lora_down.weight".format(x) | |
| "{}.lora_mid.weight".format(x) | |
| if A_name is not None: | |
| mid = None | |
| patch_dict[to_load[x]] = ( | |
| "lora", | |
| (lora[A_name], lora[B_name], alpha, mid, dora_scale), | |
| ) | |
| loaded_keys.add(A_name) | |
| loaded_keys.add(B_name) | |
| return patch_dict | |
| def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict: | |
| """#### Get the keys for a LoRA model's CLIP component. | |
| #### Args: | |
| - `model` (torch.nn.Module): The LoRA model. | |
| - `key_map` (dict, optional): The key map. Defaults to {}. | |
| #### Returns: | |
| - `dict`: The keys for the CLIP component. | |
| """ | |
| sdk = model.state_dict().keys() | |
| text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" | |
| for b in range(32): | |
| for c in LORA_CLIP_MAP: | |
| k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) | |
| if k in sdk: | |
| lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) | |
| key_map[lora_key] = k | |
| lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format( | |
| b, LORA_CLIP_MAP[c] | |
| ) # SDXL base | |
| key_map[lora_key] = k | |
| lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format( | |
| b, c | |
| ) # diffusers lora | |
| key_map[lora_key] = k | |
| return key_map | |
| def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict: | |
| """#### Get the keys for a LoRA model's UNet component. | |
| #### Args: | |
| - `model` (torch.nn.Module): The LoRA model. | |
| - `key_map` (dict, optional): The key map. Defaults to {}. | |
| #### Returns: | |
| - `dict`: The keys for the UNet component. | |
| """ | |
| sdk = model.state_dict().keys() | |
| for k in sdk: | |
| if k.startswith("diffusion_model.") and k.endswith(".weight"): | |
| key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_") | |
| key_map["lora_unet_{}".format(key_lora)] = k | |
| key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: | |
| diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config) | |
| for k in diffusers_keys: | |
| if k.endswith(".weight"): | |
| unet_key = "diffusion_model.{}".format(diffusers_keys[k]) | |
| key_lora = k[: -len(".weight")].replace(".", "_") | |
| key_map["lora_unet_{}".format(key_lora)] = unet_key | |
| diffusers_lora_prefix = ["", "unet."] | |
| for p in diffusers_lora_prefix: | |
| diffusers_lora_key = "{}{}".format( | |
| p, k[: -len(".weight")].replace(".to_", ".processor.to_") | |
| ) | |
| if diffusers_lora_key.endswith(".to_out.0"): | |
| diffusers_lora_key = diffusers_lora_key[:-2] | |
| key_map[diffusers_lora_key] = unet_key | |
| return key_map | |
| def load_lora_for_models( | |
| model: object, clip: object, lora: dict, strength_model: float, strength_clip: float | |
| ) -> tuple: | |
| """#### Load a LoRA model for the given models. | |
| #### Args: | |
| - `model` (object): The model. | |
| - `clip` (object): The CLIP model. | |
| - `lora` (dict): The LoRA model state dictionary. | |
| - `strength_model` (float): The strength of the model. | |
| - `strength_clip` (float): The strength of the CLIP model. | |
| #### Returns: | |
| - `tuple`: The new model patcher and CLIP model. | |
| """ | |
| key_map = {} | |
| if model is not None: | |
| key_map = model_lora_keys_unet(model.model, key_map) | |
| if clip is not None: | |
| key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) | |
| loaded = load_lora(lora, key_map) | |
| new_modelpatcher = model.clone() | |
| k = new_modelpatcher.add_patches(loaded, strength_model) | |
| new_clip = clip.clone() | |
| k1 = new_clip.add_patches(loaded, strength_clip) | |
| k = set(k) | |
| k1 = set(k1) | |
| return (new_modelpatcher, new_clip) | |
| class LoraLoader: | |
| """#### Class for loading LoRA models.""" | |
| def __init__(self): | |
| """#### Initialize the LoraLoader class.""" | |
| self.loaded_lora = None | |
| def load_lora( | |
| self, | |
| model: object, | |
| clip: object, | |
| lora_name: str, | |
| strength_model: float, | |
| strength_clip: float, | |
| ) -> tuple: | |
| """#### Load a LoRA model. | |
| #### Args: | |
| - `model` (object): The model. | |
| - `clip` (object): The CLIP model. | |
| - `lora_name` (str): The name of the LoRA model. | |
| - `strength_model` (float): The strength of the model. | |
| - `strength_clip` (float): The strength of the CLIP model. | |
| #### Returns: | |
| - `tuple`: The new model patcher and CLIP model. | |
| """ | |
| lora_path = util.get_full_path("loras", lora_name) | |
| lora = None | |
| if lora is None: | |
| lora = util.load_torch_file(lora_path, safe_load=True) | |
| self.loaded_lora = (lora_path, lora) | |
| model_lora, clip_lora = load_lora_for_models( | |
| model, clip, lora, strength_model, strength_clip | |
| ) | |
| return (model_lora, clip_lora) | |