Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import torch | |
from modules.Utilities import util | |
from modules.AutoEncoders import VariationalAE | |
from modules.Device import Device | |
from modules.Model import ModelPatcher | |
from modules.NeuralNetwork import unet | |
from modules.clip import Clip | |
def load_checkpoint_guess_config( | |
ckpt_path: str, | |
output_vae: bool = True, | |
output_clip: bool = True, | |
output_clipvision: bool = False, | |
embedding_directory: str = None, | |
output_model: bool = True, | |
) -> tuple: | |
"""#### Load a checkpoint and guess the configuration. | |
#### Args: | |
- `ckpt_path` (str): The path to the checkpoint file. | |
- `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. | |
- `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. | |
- `output_clipvision` (bool, optional): Whether to output the CLIP vision. Defaults to False. | |
- `embedding_directory` (str, optional): The embedding directory. Defaults to None. | |
- `output_model` (bool, optional): Whether to output the model. Defaults to True. | |
#### Returns: | |
- `tuple`: The model patcher, CLIP, VAE, and CLIP vision. | |
""" | |
sd = util.load_torch_file(ckpt_path) | |
sd.keys() | |
clip = None | |
clipvision = None | |
vae = None | |
model = None | |
model_patcher = None | |
clip_target = None | |
parameters = util.calculate_parameters(sd, "model.diffusion_model.") | |
load_device = Device.get_torch_device() | |
model_config = unet.model_config_from_unet(sd, "model.diffusion_model.") | |
unet_dtype = unet.unet_dtype1( | |
model_params=parameters, | |
supported_dtypes=model_config.supported_inference_dtypes, | |
) | |
manual_cast_dtype = Device.unet_manual_cast( | |
unet_dtype, load_device, model_config.supported_inference_dtypes | |
) | |
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) | |
if output_model: | |
inital_load_device = Device.unet_inital_load_device(parameters, unet_dtype) | |
Device.unet_offload_device() | |
model = model_config.get_model( | |
sd, "model.diffusion_model.", device=inital_load_device | |
) | |
model.load_model_weights(sd, "model.diffusion_model.") | |
if output_vae: | |
vae_sd = util.state_dict_prefix_replace( | |
sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True | |
) | |
vae_sd = model_config.process_vae_state_dict(vae_sd) | |
vae = VariationalAE.VAE(sd=vae_sd) | |
if output_clip: | |
clip_target = model_config.clip_target() | |
if clip_target is not None: | |
clip_sd = model_config.process_clip_state_dict(sd) | |
if len(clip_sd) > 0: | |
clip = Clip.CLIP(clip_target, embedding_directory=embedding_directory) | |
m, u = clip.load_sd(clip_sd, full_model=True) | |
if len(m) > 0: | |
m_filter = list( | |
filter( | |
lambda a: ".logit_scale" not in a | |
and ".transformer.text_projection.weight" not in a, | |
m, | |
) | |
) | |
if len(m_filter) > 0: | |
logging.warning("clip missing: {}".format(m)) | |
else: | |
logging.debug("clip missing: {}".format(m)) | |
if len(u) > 0: | |
logging.debug("clip unexpected {}:".format(u)) | |
else: | |
logging.warning( | |
"no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded." | |
) | |
left_over = sd.keys() | |
if len(left_over) > 0: | |
logging.debug("left over keys: {}".format(left_over)) | |
if output_model: | |
model_patcher = ModelPatcher.ModelPatcher( | |
model, | |
load_device=load_device, | |
offload_device=Device.unet_offload_device(), | |
current_device=inital_load_device, | |
) | |
if inital_load_device != torch.device("cpu"): | |
logging.info("loaded straight to GPU") | |
Device.load_model_gpu(model_patcher) | |
return (model_patcher, clip, vae, clipvision) | |
class CheckpointLoaderSimple: | |
"""#### Class for loading checkpoints.""" | |
def load_checkpoint( | |
self, ckpt_name: str, output_vae: bool = True, output_clip: bool = True | |
) -> tuple: | |
"""#### Load a checkpoint. | |
#### Args: | |
- `ckpt_name` (str): The name of the checkpoint. | |
- `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. | |
- `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. | |
#### Returns: | |
- `tuple`: The model patcher, CLIP, and VAE. | |
""" | |
ckpt_path = f"{ckpt_name}" | |
out = load_checkpoint_guess_config( | |
ckpt_path, | |
output_vae=output_vae, | |
output_clip=output_clip, | |
embedding_directory="./_internal/embeddings/", | |
) | |
print("loading", ckpt_path) | |
return out[:3] | |