Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import comfy.model_management | |
| import comfy.utils | |
| import folder_paths | |
| import os | |
| import logging | |
| from tqdm import tqdm | |
| device = comfy.model_management.get_torch_device() | |
| CLAMP_QUANTILE = 0.99 | |
| def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0): | |
| """ | |
| Extracts LoRA weights from a weight difference tensor using SVD. | |
| """ | |
| conv2d = (len(diff.shape) == 4) | |
| kernel_size = None if not conv2d else diff.size()[2:4] | |
| conv2d_3x3 = conv2d and kernel_size != (1, 1) | |
| out_dim, in_dim = diff.size()[0:2] | |
| if conv2d: | |
| if conv2d_3x3: | |
| diff = diff.flatten(start_dim=1) | |
| else: | |
| diff = diff.squeeze() | |
| diff_float = diff.float() | |
| if algorithm == "svd_lowrank": | |
| U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters) | |
| U = U @ torch.diag(S) | |
| Vh = V.t() | |
| else: | |
| #torch.linalg.svdvals() | |
| U, S, Vh = torch.linalg.svd(diff_float) | |
| # Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py | |
| if "adaptive" in lora_type: | |
| if lora_type == "adaptive_ratio": | |
| min_s = torch.max(S) * adaptive_param | |
| lora_rank = torch.sum(S > min_s).item() | |
| elif lora_type == "adaptive_energy": | |
| energy = torch.cumsum(S**2, dim=0) | |
| total_energy = torch.sum(S**2) | |
| threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95% | |
| lora_rank = torch.sum(energy < threshold).item() + 1 | |
| elif lora_type == "adaptive_quantile": | |
| s_cum = torch.cumsum(S, dim=0) | |
| min_cum_sum = adaptive_param * torch.sum(S) | |
| lora_rank = torch.sum(s_cum < min_cum_sum).item() | |
| print(f"{key} Extracted LoRA rank: {lora_rank}") | |
| else: | |
| lora_rank = rank | |
| lora_rank = max(1, lora_rank) | |
| lora_rank = min(out_dim, in_dim, lora_rank) | |
| U = U[:, :lora_rank] | |
| S = S[:lora_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:lora_rank, :] | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| if dist.numel() > 100_000: | |
| # Sample 100,000 elements for quantile estimation | |
| idx = torch.randperm(dist.numel(), device=dist.device)[:100_000] | |
| dist_sample = dist[idx] | |
| hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE) | |
| else: | |
| hi_val = torch.quantile(dist, CLAMP_QUANTILE) | |
| low_val = -hi_val | |
| U = U.clamp(low_val, hi_val) | |
| Vh = Vh.clamp(low_val, hi_val) | |
| if conv2d: | |
| U = U.reshape(out_dim, lora_rank, 1, 1) | |
| Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1]) | |
| return (U, Vh) | |
| def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0): | |
| comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) | |
| model_diff.model.diffusion_model.cpu() | |
| sd = model_diff.model_state_dict(filter_prefix=prefix_model) | |
| del model_diff | |
| comfy.model_management.soft_empty_cache() | |
| for k, v in sd.items(): | |
| if isinstance(v, torch.Tensor): | |
| sd[k] = v.cpu() | |
| # Get total number of keys to process for progress bar | |
| total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))]) | |
| # Create progress bar | |
| progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})") | |
| comfy_pbar = comfy.utils.ProgressBar(total_keys) | |
| for k in sd: | |
| if k.endswith(".weight"): | |
| weight_diff = sd[k] | |
| if weight_diff.ndim == 5: | |
| logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed | |
| progress_bar.update(1) | |
| comfy_pbar.update(1) | |
| continue | |
| if lora_type != "full": | |
| if weight_diff.ndim < 2: | |
| if bias_diff: | |
| output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() | |
| progress_bar.update(1) | |
| comfy_pbar.update(1) | |
| continue | |
| try: | |
| out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param) | |
| output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu() | |
| output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu() | |
| except Exception as e: | |
| logging.warning(f"Could not generate lora weights for key {k}, error {e}") | |
| else: | |
| output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() | |
| progress_bar.update(1) | |
| comfy_pbar.update(1) | |
| elif bias_diff and k.endswith(".bias"): | |
| output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu() | |
| progress_bar.update(1) | |
| comfy_pbar.update(1) | |
| progress_bar.close() | |
| return output_sd | |
| class LoraExtractKJ: | |
| def __init__(self): | |
| self.output_dir = folder_paths.get_output_directory() | |
| def INPUT_TYPES(s): | |
| return {"required": | |
| { | |
| "finetuned_model": ("MODEL",), | |
| "original_model": ("MODEL",), | |
| "filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), | |
| "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), | |
| "lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy"],), | |
| "algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}), | |
| "lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}), | |
| "output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}), | |
| "bias_diff": ("BOOLEAN", {"default": True}), | |
| "adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}), | |
| }, | |
| } | |
| RETURN_TYPES = () | |
| FUNCTION = "save" | |
| OUTPUT_NODE = True | |
| CATEGORY = "KJNodes/lora" | |
| def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param): | |
| if algorithm == "svd_lowrank" and lora_type != "standard": | |
| raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.") | |
| dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype] | |
| m = finetuned_model.clone() | |
| kp = original_model.get_key_patches("diffusion_model.") | |
| for k in kp: | |
| m.add_patches({k: kp[k]}, - 1.0, 1.0) | |
| model_diff = m | |
| full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) | |
| output_sd = {} | |
| if model_diff is not None: | |
| output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param) | |
| if "adaptive" in lora_type: | |
| rank_str = f"{lora_type}_{adaptive_param:.2f}" | |
| else: | |
| rank_str = rank | |
| output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors" | |
| output_checkpoint = os.path.join(full_output_folder, output_checkpoint) | |
| comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) | |
| return {} | |
| NODE_CLASS_MAPPINGS = { | |
| "LoraExtractKJ": LoraExtractKJ | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "LoraExtractKJ": "LoraExtractKJ" | |
| } | |