Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 4,297 Bytes
			
			| fcc02a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from collections import OrderedDict
import torch
from safetensors.torch import load_file
import argparse
import os
import json
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
# load keymap
with open(keymap_path, 'r') as f:
    keymap = json.load(f)
lora_keymap = OrderedDict()
# convert keymap to lora key naming
for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
    if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
        # skip it
        continue
    # sdxl has same te for locon with kohya and ours
    if ldm_key.startswith('conditioner'):
        #skip it
        continue
    # ignore vae
    if ldm_key.startswith('first_stage_model'):
        continue
    ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
    ldm_key = ldm_key.replace('.weight', '')
    ldm_key = ldm_key.replace('.', '_')
    diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
    diffusers_key = diffusers_key.replace('.weight', '')
    diffusers_key = diffusers_key.replace('.', '_')
    lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
    lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
    lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
parser = argparse.ArgumentParser()
parser.add_argument("input", help="input file")
parser.add_argument("input2", help="input2 file")
args = parser.parse_args()
# name = args.name
# if args.sdxl:
#     name += '_sdxl'
# elif args.sd2:
#     name += '_sd2'
# else:
#     name += '_sd1'
name = 'stable_diffusion_locon_sdxl'
locon_save = load_file(args.input)
our_save = load_file(args.input2)
our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
save_dtype = torch.float16
print(f"our extra keys: {our_extra_keys}")
print(f"locon extra keys: {locon_extra_keys}")
def export_state_dict(our_save):
    converted_state_dict = OrderedDict()
    for key, value in our_save.items():
        # test encoders share keys for some reason
        if key.startswith('lora_te'):
            converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
        else:
            converted_key = key
            for ldm_key, diffusers_key in lora_keymap.items():
                if converted_key == diffusers_key:
                    converted_key = ldm_key
            converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
    return converted_state_dict
def import_state_dict(loaded_state_dict):
    converted_state_dict = OrderedDict()
    for key, value in loaded_state_dict.items():
        if key.startswith('lora_te'):
            converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
        else:
            converted_key = key
            for ldm_key, diffusers_key in lora_keymap.items():
                if converted_key == ldm_key:
                    converted_key = diffusers_key
            converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
    return converted_state_dict
# check it again
converted_state_dict = export_state_dict(our_save)
converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
print(f"we have {len(converted_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
print(f"our extra keys: {converted_extra_keys}")
# convert back
cycle_state_dict = import_state_dict(converted_state_dict)
cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"cycle has {len(cycle_extra_keys)} extra keys")
# save keymap
to_save = OrderedDict()
to_save['ldm_diffusers_keymap'] = lora_keymap
with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
    json.dump(to_save, f, indent=4)
 | 
