|
from copy import deepcopy |
|
|
|
from .import_utils import is_accelerate_available, is_bitsandbytes_available |
|
|
|
|
|
if is_bitsandbytes_available(): |
|
import bitsandbytes as bnb |
|
import torch |
|
import torch.nn as nn |
|
|
|
if is_accelerate_available(): |
|
from accelerate import init_empty_weights |
|
from accelerate.utils import find_tied_parameters |
|
|
|
|
|
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): |
|
""" |
|
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing |
|
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The |
|
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the |
|
class `Int8Params` from `bitsandbytes`. |
|
|
|
Args: |
|
module (`torch.nn.Module`): |
|
The module in which the tensor we want to move lives. |
|
tensor_name (`str`): |
|
The full name of the parameter/buffer. |
|
device (`int`, `str` or `torch.device`): |
|
The device on which to set the tensor. |
|
value (`torch.Tensor`, *optional*): |
|
The value of the tensor (useful when going from the meta device to any other device). |
|
""" |
|
|
|
if "." in tensor_name: |
|
splits = tensor_name.split(".") |
|
for split in splits[:-1]: |
|
new_module = getattr(module, split) |
|
if new_module is None: |
|
raise ValueError(f"{module} has no attribute {split}.") |
|
module = new_module |
|
tensor_name = splits[-1] |
|
|
|
if tensor_name not in module._parameters and tensor_name not in module._buffers: |
|
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") |
|
is_buffer = tensor_name in module._buffers |
|
old_value = getattr(module, tensor_name) |
|
|
|
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: |
|
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") |
|
|
|
if is_buffer: |
|
has_fp16_weights = None |
|
else: |
|
has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None) |
|
|
|
if has_fp16_weights is not None: |
|
param = module._parameters[tensor_name] |
|
if param.device.type != "cuda": |
|
if value is None: |
|
new_value = old_value.to(device) |
|
elif isinstance(value, torch.Tensor): |
|
new_value = value.to("cpu") |
|
if value.dtype == torch.int8: |
|
raise ValueError( |
|
"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are", |
|
" using `load_in_8bit=True` on float32/float16/bfloat16 weights.", |
|
) |
|
else: |
|
new_value = torch.tensor(value, device="cpu") |
|
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) |
|
module._parameters[tensor_name] = new_value |
|
else: |
|
if value is None: |
|
new_value = old_value.to(device) |
|
elif isinstance(value, torch.Tensor): |
|
new_value = value.to(device) |
|
else: |
|
new_value = torch.tensor(value, device=device) |
|
|
|
if is_buffer: |
|
module._buffers[tensor_name] = new_value |
|
else: |
|
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) |
|
module._parameters[tensor_name] = new_value |
|
|
|
|
|
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None): |
|
""" |
|
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` |
|
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): |
|
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA |
|
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ |
|
bitsandbytes` |
|
|
|
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should |
|
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no |
|
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a |
|
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 |
|
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no |
|
predictive degradation is possible for very large models (>=176B parameters). |
|
|
|
Parameters: |
|
model (`torch.nn.Module`): |
|
Input model or `torch.nn.Module` as the function is run recursively. |
|
threshold (`float`, *optional*, defaults to 6.0): |
|
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to |
|
`6.0` as described by the paper. |
|
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): |
|
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision |
|
for numerical stability reasons. |
|
current_key_name (`List[`str`]`, *optional*): |
|
An array to track the current key of the recursion. This is used to check whether the current key (part of |
|
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or |
|
`disk`). |
|
""" |
|
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert |
|
for name, module in model.named_children(): |
|
if current_key_name is None: |
|
current_key_name = [] |
|
current_key_name.append(name) |
|
|
|
if len(list(module.children())) > 0: |
|
replace_8bit_linear(module, threshold, modules_to_not_convert, current_key_name) |
|
|
|
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: |
|
|
|
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): |
|
with init_empty_weights(): |
|
model._modules[name] = bnb.nn.Linear8bitLt( |
|
module.in_features, |
|
module.out_features, |
|
module.bias is not None, |
|
has_fp16_weights=False, |
|
threshold=threshold, |
|
) |
|
|
|
model._modules[name].requires_grad_(False) |
|
|
|
current_key_name.pop(-1) |
|
return model |
|
|
|
|
|
def get_keys_to_not_convert(model): |
|
r""" |
|
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules |
|
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want |
|
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in |
|
int8. |
|
|
|
Parameters: |
|
model (`torch.nn.Module`): |
|
Input model |
|
""" |
|
|
|
|
|
tied_model = deepcopy(model) |
|
tied_model.tie_weights() |
|
|
|
tied_params = find_tied_parameters(tied_model) |
|
|
|
if isinstance(tied_params, dict): |
|
tied_keys = list(tied_params.values()) |
|
else: |
|
tied_keys = sum([x[1:] for x in tied_params], []) |
|
has_tied_params = len(tied_keys) > 0 |
|
|
|
|
|
is_base_model = not hasattr(model, model.base_model_prefix) |
|
|
|
|
|
if (not has_tied_params) and is_base_model: |
|
return [] |
|
|
|
|
|
list_modules = list(model.named_parameters()) |
|
list_last_module = [list_modules[-1][0]] |
|
|
|
|
|
intersection = set(list_last_module) - set(tied_keys) |
|
list_untouched = tied_keys + list(intersection) |
|
|
|
|
|
names_to_remove = [".weight", ".bias"] |
|
filtered_module_names = [] |
|
for name in list_untouched: |
|
for name_to_remove in names_to_remove: |
|
if name_to_remove in name: |
|
name = name.replace(name_to_remove, "") |
|
filtered_module_names.append(name) |
|
|
|
return filtered_module_names |
|
|