# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from collections import OrderedDict from pathlib import Path from typing import List, Optional, Union import safetensors import torch from huggingface_hub.utils import EntryNotFoundError from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, WEIGHTS_INDEX_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_torch_version, logging, ) logger = logging.get_logger(__name__) _CLASS_REMAPPING_DICT = { "Transformer2DModel": { "ada_norm_zero": "DiTTransformer2DModel", "ada_norm_single": "PixArtTransformer2DModel", } } if is_accelerate_available(): from accelerate import infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device # Adapted from `transformers` (see modeling_utils.py) def _determine_device_map( model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None ): if isinstance(device_map, str): special_dtypes = {} if hf_quantizer is not None: special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) special_dtypes.update( { name: torch.float32 for name, _ in model.named_parameters() if any(m in name for m in keep_in_fp32_modules) } ) target_dtype = torch_dtype if hf_quantizer is not None: target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: device_map_kwargs["special_dtypes"] = special_dtypes elif len(special_dtypes) > 0: logger.warning( "This model has some weights that should be kept in higher precision, you need to upgrade " "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." ) if device_map != "sequential": max_memory = get_balanced_memory( model, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), max_memory=max_memory, **device_map_kwargs, ) else: max_memory = get_max_memory(max_memory) if hf_quantizer is not None: max_memory = hf_quantizer.adjust_max_memory(max_memory) device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) return device_map def _fetch_remapped_cls_from_config(config, old_class): previous_class_name = old_class.__name__ remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) # Details: # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818 if remapped_class_name: # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) remapped_class = getattr(diffusers_library, remapped_class_name) logger.info( f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" " DOESN'T affect the final results." ) return remapped_class else: return old_class def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change # when refactoring the _merge_sharded_checkpoints() method later. if isinstance(checkpoint_file, dict): return checkpoint_file try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( checkpoint_file, map_location="cpu", **weights_only_kwarg, ) except Exception as e: try: with open(checkpoint_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned." ) else: raise ValueError( f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): raise OSError( f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " ) def load_model_dict_into_meta( model, state_dict: OrderedDict, device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: if hf_quantizer is None: device = device or torch.device("cpu") dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) empty_state_dict = model.state_dict() unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] for param_name, param in state_dict.items(): if param_name not in empty_state_dict: continue set_module_kwargs = {} # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # in int/uint/bool and not cast them. # TODO: revisit cases when param.dtype == torch.float8_e4m3fn if torch.is_floating_point(param): if ( keep_in_fp32_modules is not None and any( module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules ) and dtype == torch.float16 ): param = param.to(torch.float32) if accepts_dtype: set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) if accepts_dtype: set_module_kwargs["dtype"] = dtype # bnb params are flattened. if empty_state_dict[param_name].shape != param.shape: if ( is_quant_method_bnb and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) elif not is_quant_method_bnb: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) else: if accepts_dtype: set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) else: set_module_tensor_to_device(model, param_name, device, value=param) return unexpected_keys def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix: str = ""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model_to_load) return error_msgs def _fetch_index_file( is_local, pretrained_model_name_or_path, subfolder, use_safetensors, cache_dir, variant, force_download, proxies, local_files_only, token, revision, user_agent, commit_hash, ): if is_local: index_file = Path( pretrained_model_name_or_path, subfolder or "", _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), ) else: index_file_in_repo = Path( subfolder or "", _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), ).as_posix() try: index_file = _get_model_file( pretrained_model_name_or_path, weights_name=index_file_in_repo, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) index_file = Path(index_file) except (EntryNotFoundError, EnvironmentError): index_file = None return index_file # Adapted from # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): weight_map = sharded_metadata.get("weight_map", None) if weight_map is None: raise KeyError("'weight_map' key not found in the shard index file.") # Collect all unique safetensors files from weight_map files_to_load = set(weight_map.values()) is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) merged_state_dict = {} # Load tensors from each unique file for file_name in files_to_load: part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) if not os.path.exists(part_file_path): raise FileNotFoundError(f"Part file {file_name} not found.") if is_safetensors: with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: for tensor_key in f.keys(): if tensor_key in weight_map: merged_state_dict[tensor_key] = f.get_tensor(tensor_key) else: merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) return merged_state_dict def _fetch_index_file_legacy( is_local, pretrained_model_name_or_path, subfolder, use_safetensors, cache_dir, variant, force_download, proxies, local_files_only, token, revision, user_agent, commit_hash, ): if is_local: index_file = Path( pretrained_model_name_or_path, subfolder or "", SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, ).as_posix() splits = index_file.split(".") split_index = -3 if ".cache" in index_file else -2 splits = splits[:-split_index] + [variant] + splits[-split_index:] index_file = ".".join(splits) if os.path.exists(index_file): deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) index_file = Path(index_file) else: index_file = None else: if variant is not None: index_file_in_repo = Path( subfolder or "", SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, ).as_posix() splits = index_file_in_repo.split(".") split_index = -2 splits = splits[:-split_index] + [variant] + splits[-split_index:] index_file_in_repo = ".".join(splits) try: index_file = _get_model_file( pretrained_model_name_or_path, weights_name=index_file_in_repo, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) index_file = Path(index_file) deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) except (EntryNotFoundError, EnvironmentError): index_file = None return index_file