Spaces:
Sleeping
Sleeping
| import torch | |
| from pathlib import Path | |
| from utils import get_download_file | |
| from stkey import read_safetensors_key | |
| try: | |
| from diffusers import BitsAndBytesConfig | |
| is_nf4 = True | |
| except Exception: | |
| is_nf4 = False | |
| DTYPE_DEFAULT = "default" | |
| DTYPE_DICT = { | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "fp32": torch.float32, | |
| "fp8": torch.float8_e4m3fn, | |
| } | |
| #QTYPES = ["NF4"] if is_nf4 else [] | |
| QTYPES = [] | |
| def get_dtypes(): | |
| return list(DTYPE_DICT.keys()) + [DTYPE_DEFAULT] + QTYPES | |
| def get_dtype(dtype: str): | |
| if dtype in set(QTYPES): return torch.bfloat16 | |
| return DTYPE_DICT.get(dtype, torch.float16) | |
| from diffusers import ( | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSinglestepScheduler, | |
| KDPM2DiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| HeunDiscreteScheduler, | |
| LMSDiscreteScheduler, | |
| DDIMScheduler, | |
| DEISMultistepScheduler, | |
| UniPCMultistepScheduler, | |
| LCMScheduler, | |
| PNDMScheduler, | |
| KDPM2AncestralDiscreteScheduler, | |
| DPMSolverSDEScheduler, | |
| EDMDPMSolverMultistepScheduler, | |
| DDPMScheduler, | |
| EDMEulerScheduler, | |
| TCDScheduler, | |
| ) | |
| SCHEDULER_CONFIG_MAP = { | |
| "DPM++ 2M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}), | |
| "DPM++ 2M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}), | |
| "DPM++ 2M SDE": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}), | |
| "DPM++ 2M SDE Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}), | |
| "DPM++ 2S": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}), | |
| "DPM++ 2S Karras": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}), | |
| "DPM++ 1S": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1}), | |
| "DPM++ 1S Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1, "use_karras_sigmas": True}), | |
| "DPM++ 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3}), | |
| "DPM++ 3M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3, "use_karras_sigmas": True}), | |
| "DPM 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver", "final_sigmas_type": "sigma_min", "solver_order": 3}), | |
| "DPM++ SDE": (DPMSolverSDEScheduler, {"use_karras_sigmas": False}), | |
| "DPM++ SDE Karras": (DPMSolverSDEScheduler, {"use_karras_sigmas": True}), | |
| "DPM2": (KDPM2DiscreteScheduler, {}), | |
| "DPM2 Karras": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), | |
| "DPM2 a": (KDPM2AncestralDiscreteScheduler, {}), | |
| "DPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), | |
| "Euler": (EulerDiscreteScheduler, {}), | |
| "Euler a": (EulerAncestralDiscreteScheduler, {}), | |
| "Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}), | |
| "Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}), | |
| "Heun": (HeunDiscreteScheduler, {}), | |
| "Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}), | |
| "LMS": (LMSDiscreteScheduler, {}), | |
| "LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), | |
| "DDIM": (DDIMScheduler, {}), | |
| "DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}), | |
| "DEIS": (DEISMultistepScheduler, {}), | |
| "UniPC": (UniPCMultistepScheduler, {}), | |
| "UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}), | |
| "PNDM": (PNDMScheduler, {}), | |
| "Euler EDM": (EDMEulerScheduler, {}), | |
| "Euler EDM Karras": (EDMEulerScheduler, {"use_karras_sigmas": True}), | |
| "DPM++ 2M EDM": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}), | |
| "DPM++ 2M EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}), | |
| "DDPM": (DDPMScheduler, {}), | |
| "DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_lu_lambdas": True}), | |
| "DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "euler_at_final": True}), | |
| "DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}), | |
| "DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}), | |
| "LCM": (LCMScheduler, {}), | |
| "TCD": (TCDScheduler, {}), | |
| "LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}), | |
| "TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}), | |
| "LCM Auto-Loader": (LCMScheduler, {}), | |
| "TCD Auto-Loader": (TCDScheduler, {}), | |
| "EDM": (EDMDPMSolverMultistepScheduler, {}), | |
| "EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True}), | |
| "Euler (V-Prediction)": (EulerDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}), | |
| "Euler a (V-Prediction)": (EulerAncestralDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}), | |
| "Euler EDM (V-Prediction)": (EDMEulerScheduler, {"prediction_type": "v_prediction"}), | |
| "Euler EDM Karras (V-Prediction)": (EDMEulerScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}), | |
| "DPM++ 2M EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}), | |
| "DPM++ 2M EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}), | |
| "EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"prediction_type": "v_prediction"}), | |
| "EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}), | |
| } | |
| def get_scheduler_config(name: str): | |
| if not name in SCHEDULER_CONFIG_MAP.keys(): return SCHEDULER_CONFIG_MAP["Euler a"] | |
| return SCHEDULER_CONFIG_MAP[name] | |
| def fuse_loras(pipe, lora_dict: dict, temp_dir: str, civitai_key: str="", dkwargs: dict={}): | |
| if not lora_dict or not isinstance(lora_dict, dict): return pipe | |
| a_list = [] | |
| w_list = [] | |
| for k, v in lora_dict.items(): | |
| if not k: continue | |
| new_lora_file = get_download_file(temp_dir, k, civitai_key) | |
| if not new_lora_file or not Path(new_lora_file).exists(): | |
| print(f"LoRA file not found: {k}") | |
| continue | |
| w_name = Path(new_lora_file).name | |
| a_name = Path(new_lora_file).stem | |
| pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name, low_cpu_mem_usage=False, **dkwargs) | |
| a_list.append(a_name) | |
| w_list.append(v) | |
| if Path(new_lora_file).exists(): Path(new_lora_file).unlink() | |
| if len(a_list) == 0: return pipe | |
| pipe.set_adapters(a_list, adapter_weights=w_list) | |
| pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) | |
| pipe.unload_lora_weights() | |
| return pipe | |
| MODEL_TYPE_KEY = { | |
| "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL", | |
| "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5", | |
| "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX", | |
| "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX", | |
| "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5", | |
| } | |
| def get_model_type_from_key(path: str): | |
| default = "SDXL" | |
| try: | |
| keys = read_safetensors_key(path) | |
| for k, v in MODEL_TYPE_KEY.items(): | |
| if k in set(keys): | |
| print(f"Model type is {v}.") | |
| return v | |
| print("Model type could not be identified.") | |
| except Exception: | |
| return default | |
| return default | |
| def get_process_dtype(dtype: str, model_type: str): | |
| if dtype in set(["fp8"] + QTYPES): return torch.bfloat16 if model_type in ["FLUX", "SD 3.5"] else torch.float16 | |
| return DTYPE_DICT.get(dtype, torch.float16) | |