Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import os | |
| from tqdm import tqdm | |
| import argparse | |
| from collections import OrderedDict | |
| parser = argparse.ArgumentParser(description="Extract LoRA from Flex") | |
| parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") | |
| parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") | |
| parser.add_argument("--output", type=str, required=True, help="Output path for lora") | |
| parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") | |
| parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") | |
| parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") | |
| args = parser.parse_args() | |
| if True: | |
| # set cuda environment variable | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) | |
| import torch | |
| from safetensors.torch import load_file, save_file | |
| from lycoris.utils import extract_linear, extract_conv, make_sparse | |
| from diffusers import FluxTransformer2DModel | |
| base = args.base | |
| tuned = args.tuned | |
| output_path = args.output | |
| dim = args.rank | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| state_dict_base = {} | |
| state_dict_tuned = {} | |
| output_dict = {} | |
| def extract_diff( | |
| base_unet, | |
| db_unet, | |
| mode="fixed", | |
| linear_mode_param=0, | |
| conv_mode_param=0, | |
| extract_device="cpu", | |
| use_bias=False, | |
| sparsity=0.98, | |
| # small_conv=True, | |
| small_conv=False, | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Linear", | |
| "Conv2d", | |
| "LayerNorm", | |
| "GroupNorm", | |
| "GroupNorm32", | |
| "LoRACompatibleLinear", | |
| "LoRACompatibleConv" | |
| ] | |
| LORA_PREFIX_UNET = "transformer" | |
| def make_state_dict( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| target_module: torch.nn.Module, | |
| target_replace_modules, | |
| ): | |
| loras = {} | |
| temp = {} | |
| for name, module in root_module.named_modules(): | |
| if module.__class__.__name__ in target_replace_modules: | |
| temp[name] = module | |
| for name, module in tqdm( | |
| list((n, m) for n, m in target_module.named_modules() if n in temp) | |
| ): | |
| weights = temp[name] | |
| lora_name = prefix + "." + name | |
| # lora_name = lora_name.replace(".", "_") | |
| layer = module.__class__.__name__ | |
| if 'transformer_blocks' not in lora_name and not args.full: | |
| continue | |
| if layer in { | |
| "Linear", | |
| "Conv2d", | |
| "LayerNorm", | |
| "GroupNorm", | |
| "GroupNorm32", | |
| "Embedding", | |
| "LoRACompatibleLinear", | |
| "LoRACompatibleConv" | |
| }: | |
| root_weight = module.weight | |
| try: | |
| if torch.allclose(root_weight, weights.weight): | |
| continue | |
| except: | |
| continue | |
| else: | |
| continue | |
| module = module.to(extract_device, torch.float32) | |
| weights = weights.to(extract_device, torch.float32) | |
| if mode == "full": | |
| decompose_mode = "full" | |
| elif layer == "Linear": | |
| weight, decompose_mode = extract_linear( | |
| (root_weight - weights.weight), | |
| mode, | |
| linear_mode_param, | |
| device=extract_device, | |
| ) | |
| if decompose_mode == "low rank": | |
| extract_a, extract_b, diff = weight | |
| elif layer == "Conv2d": | |
| is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 | |
| weight, decompose_mode = extract_conv( | |
| (root_weight - weights.weight), | |
| mode, | |
| linear_mode_param if is_linear else conv_mode_param, | |
| device=extract_device, | |
| ) | |
| if decompose_mode == "low rank": | |
| extract_a, extract_b, diff = weight | |
| if small_conv and not is_linear and decompose_mode == "low rank": | |
| dim = extract_a.size(0) | |
| (extract_c, extract_a, _), _ = extract_conv( | |
| extract_a.transpose(0, 1), | |
| "fixed", | |
| dim, | |
| extract_device, | |
| True, | |
| ) | |
| extract_a = extract_a.transpose(0, 1) | |
| extract_c = extract_c.transpose(0, 1) | |
| loras[f"{lora_name}.lora_mid.weight"] = ( | |
| extract_c.detach().cpu().contiguous().half() | |
| ) | |
| diff = ( | |
| ( | |
| root_weight | |
| - torch.einsum( | |
| "i j k l, j r, p i -> p r k l", | |
| extract_c, | |
| extract_a.flatten(1, -1), | |
| extract_b.flatten(1, -1), | |
| ) | |
| ) | |
| .detach() | |
| .cpu() | |
| .contiguous() | |
| ) | |
| del extract_c | |
| else: | |
| module = module.to("cpu") | |
| weights = weights.to("cpu") | |
| continue | |
| if decompose_mode == "low rank": | |
| loras[f"{lora_name}.lora_A.weight"] = ( | |
| extract_a.detach().cpu().contiguous().half() | |
| ) | |
| loras[f"{lora_name}.lora_B.weight"] = ( | |
| extract_b.detach().cpu().contiguous().half() | |
| ) | |
| # loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() | |
| if use_bias: | |
| diff = diff.detach().cpu().reshape(extract_b.size(0), -1) | |
| sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() | |
| indices = sparse_diff.indices().to(torch.int16) | |
| values = sparse_diff.values().half() | |
| loras[f"{lora_name}.bias_indices"] = indices | |
| loras[f"{lora_name}.bias_values"] = values | |
| loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( | |
| torch.int16 | |
| ) | |
| del extract_a, extract_b, diff | |
| elif decompose_mode == "full": | |
| if "Norm" in layer: | |
| w_key = "w_norm" | |
| b_key = "b_norm" | |
| else: | |
| w_key = "diff" | |
| b_key = "diff_b" | |
| weight_diff = module.weight - weights.weight | |
| loras[f"{lora_name}.{w_key}"] = ( | |
| weight_diff.detach().cpu().contiguous().half() | |
| ) | |
| if getattr(weights, "bias", None) is not None: | |
| bias_diff = module.bias - weights.bias | |
| loras[f"{lora_name}.{b_key}"] = ( | |
| bias_diff.detach().cpu().contiguous().half() | |
| ) | |
| else: | |
| raise NotImplementedError | |
| module = module.to("cpu", torch.bfloat16) | |
| weights = weights.to("cpu", torch.bfloat16) | |
| return loras | |
| all_loras = {} | |
| all_loras |= make_state_dict( | |
| LORA_PREFIX_UNET, | |
| base_unet, | |
| db_unet, | |
| UNET_TARGET_REPLACE_MODULE, | |
| ) | |
| del base_unet, db_unet | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| all_lora_name = set() | |
| for k in all_loras: | |
| lora_name, weight = k.rsplit(".", 1) | |
| all_lora_name.add(lora_name) | |
| print(len(all_lora_name)) | |
| return all_loras | |
| # find all the .safetensors files and load them | |
| print("Loading Base") | |
| base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) | |
| print("Loading Tuned") | |
| tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) | |
| output_dict = extract_diff( | |
| base_model, | |
| tuned_model, | |
| mode="fixed", | |
| linear_mode_param=dim, | |
| conv_mode_param=dim, | |
| extract_device="cuda", | |
| use_bias=False, | |
| sparsity=0.98, | |
| small_conv=False, | |
| ) | |
| meta = OrderedDict() | |
| meta['format'] = 'pt' | |
| save_file(output_dict, output_path, metadata=meta) | |
| print("Done") | |
