from copy import deepcopy from .import_utils import is_accelerate_available, is_bitsandbytes_available if is_bitsandbytes_available(): import bitsandbytes as bnb import torch import torch.nn as nn if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import find_tied_parameters def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): """ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the class `Int8Params` from `bitsandbytes`. Args: module (`torch.nn.Module`): The module in which the tensor we want to move lives. tensor_name (`str`): The full name of the parameter/buffer. device (`int`, `str` or `torch.device`): The device on which to set the tensor. value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any other device). """ # Recurse if needed if "." in tensor_name: splits = tensor_name.split(".") for split in splits[:-1]: new_module = getattr(module, split) if new_module is None: raise ValueError(f"{module} has no attribute {split}.") module = new_module tensor_name = splits[-1] if tensor_name not in module._parameters and tensor_name not in module._buffers: raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") is_buffer = tensor_name in module._buffers old_value = getattr(module, tensor_name) if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") if is_buffer: has_fp16_weights = None else: has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None) if has_fp16_weights is not None: param = module._parameters[tensor_name] if param.device.type != "cuda": if value is None: new_value = old_value.to(device) elif isinstance(value, torch.Tensor): new_value = value.to("cpu") if value.dtype == torch.int8: raise ValueError( "You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are", " using `load_in_8bit=True` on float32/float16/bfloat16 weights.", ) else: new_value = torch.tensor(value, device="cpu") new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) module._parameters[tensor_name] = new_value else: if value is None: new_value = old_value.to(device) elif isinstance(value, torch.Tensor): new_value = value.to(device) else: new_value = torch.tensor(value, device=device) if is_buffer: module._buffers[tensor_name] = new_value else: new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) module._parameters[tensor_name] = new_value def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ bitsandbytes` The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters). Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. threshold (`float`, *optional*, defaults to 6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision for numerical stability reasons. current_key_name (`List[`str`]`, *optional*): An array to track the current key of the recursion. This is used to check whether the current key (part of it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or `disk`). """ modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if len(list(module.children())) > 0: replace_8bit_linear(module, threshold, modules_to_not_convert, current_key_name) if isinstance(module, nn.Linear) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): with init_empty_weights(): model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, module.out_features, module.bias is not None, has_fp16_weights=False, threshold=threshold, ) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) # Remove the last key for recursion current_key_name.pop(-1) return model def get_keys_to_not_convert(model): r""" An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in int8. Parameters: model (`torch.nn.Module`): Input model """ # Create a copy of the model and tie the weights, then # check if it contains tied weights tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model.tie_weights() tied_params = find_tied_parameters(tied_model) # For compatibility with Accelerate < 0.18 if isinstance(tied_params, dict): tied_keys = list(tied_params.values()) else: tied_keys = sum([x[1:] for x in tied_params], []) has_tied_params = len(tied_keys) > 0 # Check if it is a base model is_base_model = not hasattr(model, model.base_model_prefix) # Ignore this for base models (BertModel, GPT2Model, etc.) if (not has_tied_params) and is_base_model: return [] # otherwise they have an attached head list_modules = list(model.named_parameters()) list_last_module = [list_modules[-1][0]] # add last module together with tied weights intersection = set(list_last_module) - set(tied_keys) list_untouched = tied_keys + list(intersection) # remove ".weight" from the keys names_to_remove = [".weight", ".bias"] filtered_module_names = [] for name in list_untouched: for name_to_remove in names_to_remove: if name_to_remove in name: name = name.replace(name_to_remove, "") filtered_module_names.append(name) return filtered_module_names