import logging import torch from modules.Utilities import util from modules.AutoEncoders import VariationalAE from modules.Device import Device from modules.Model import ModelPatcher from modules.NeuralNetwork import unet from modules.clip import Clip def load_checkpoint_guess_config( ckpt_path: str, output_vae: bool = True, output_clip: bool = True, output_clipvision: bool = False, embedding_directory: str = None, output_model: bool = True, ) -> tuple: """#### Load a checkpoint and guess the configuration. #### Args: - `ckpt_path` (str): The path to the checkpoint file. - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. - `output_clipvision` (bool, optional): Whether to output the CLIP vision. Defaults to False. - `embedding_directory` (str, optional): The embedding directory. Defaults to None. - `output_model` (bool, optional): Whether to output the model. Defaults to True. #### Returns: - `tuple`: The model patcher, CLIP, VAE, and CLIP vision. """ sd = util.load_torch_file(ckpt_path) sd.keys() clip = None clipvision = None vae = None model = None model_patcher = None clip_target = None parameters = util.calculate_parameters(sd, "model.diffusion_model.") load_device = Device.get_torch_device() model_config = unet.model_config_from_unet(sd, "model.diffusion_model.") unet_dtype = unet.unet_dtype1( model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes, ) manual_cast_dtype = Device.unet_manual_cast( unet_dtype, load_device, model_config.supported_inference_dtypes ) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if output_model: inital_load_device = Device.unet_inital_load_device(parameters, unet_dtype) Device.unet_offload_device() model = model_config.get_model( sd, "model.diffusion_model.", device=inital_load_device ) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: vae_sd = util.state_dict_prefix_replace( sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True ) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VariationalAE.VAE(sd=vae_sd) if output_clip: clip_target = model_config.clip_target() if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: clip = Clip.CLIP(clip_target, embedding_directory=embedding_directory) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: m_filter = list( filter( lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m, ) ) if len(m_filter) > 0: logging.warning("clip missing: {}".format(m)) else: logging.debug("clip missing: {}".format(m)) if len(u) > 0: logging.debug("clip unexpected {}:".format(u)) else: logging.warning( "no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded." ) left_over = sd.keys() if len(left_over) > 0: logging.debug("left over keys: {}".format(left_over)) if output_model: model_patcher = ModelPatcher.ModelPatcher( model, load_device=load_device, offload_device=Device.unet_offload_device(), current_device=inital_load_device, ) if inital_load_device != torch.device("cpu"): logging.info("loaded straight to GPU") Device.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision) class CheckpointLoaderSimple: """#### Class for loading checkpoints.""" def load_checkpoint( self, ckpt_name: str, output_vae: bool = True, output_clip: bool = True ) -> tuple: """#### Load a checkpoint. #### Args: - `ckpt_name` (str): The name of the checkpoint. - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. #### Returns: - `tuple`: The model patcher, CLIP, and VAE. """ ckpt_path = f"{ckpt_name}" out = load_checkpoint_guess_config( ckpt_path, output_vae=output_vae, output_clip=output_clip, embedding_directory="./_internal/embeddings/", ) print("loading", ckpt_path) return out[:3]