Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 1,974 Bytes
			
			| fcc02a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from typing import TYPE_CHECKING
from toolkit.config_modules import NetworkConfig
from toolkit.lora_special import LoRASpecialNetwork
from safetensors.torch import load_file
if TYPE_CHECKING:
    from toolkit.stable_diffusion_model import StableDiffusion
def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
    if not sd.is_flux:
        raise ValueError("Only Flux models can load assistant adapters currently.")
    pipe = sd.pipeline
    print(f"Loading assistant adapter from {adapter_path}")
    adapter_name = adapter_path.split("/")[-1].split(".")[0]
    lora_state_dict = load_file(adapter_path)
    linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
    # linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
    linear_alpha = linear_dim
    transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
    # get dim and scale
    network_config = NetworkConfig(
        linear=linear_dim,
        linear_alpha=linear_alpha,
        transformer_only=transformer_only,
    )
    network = LoRASpecialNetwork(
        text_encoder=pipe.text_encoder,
        unet=pipe.transformer,
        lora_dim=network_config.linear,
        multiplier=1.0,
        alpha=network_config.linear_alpha,
        train_unet=True,
        train_text_encoder=False,
        is_flux=True,
        network_config=network_config,
        network_type=network_config.type,
        transformer_only=network_config.transformer_only,
        is_assistant_adapter=True
    )
    network.apply_to(
        pipe.text_encoder,
        pipe.transformer,
        apply_text_encoder=False,
        apply_unet=True
    )
    network.force_to(sd.device_torch, dtype=sd.torch_dtype)
    network.eval()
    network._update_torch_multiplier()
    network.load_weights(lora_state_dict)
    network.is_active = True
    return network
 | 
