|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class BitsAndBytesConfig: |
|
""" |
|
This is a wrapper class about all possible attributes and features that you can play with a model that has been |
|
loaded using `bitsandbytes`. |
|
|
|
This replaces `load_in_8bit` therefore both options are mutually exclusive. |
|
|
|
For now, only arguments that are relative to `LLM.int8()` are supported, therefore the arguments are all termed as |
|
`llm_int8_*`. If more methods are added to `bitsandbytes`, then more arguments will be added to this class. |
|
|
|
Args: |
|
load_in_8bit (`bool`, *optional*, defaults to `False`): |
|
This flag is used to enable 8-bit quantization with LLM.int8(). |
|
llm_int8_threshold (`float`, *optional*, defaults to 6): |
|
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix |
|
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value |
|
that is above this threshold will be considered an outlier and the operation on those values will be done |
|
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but |
|
there are some exceptional systematic outliers that are very differently distributed for large models. |
|
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of |
|
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, |
|
but a lower threshold might be needed for more unstable models (small models, fine-tuning). |
|
llm_int8_skip_modules (`List[str]`, *optional*): |
|
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as |
|
Jukebox that has several heads in different places and not necessarily at the last position. For example |
|
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. |
|
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): |
|
This flag is used for advanced use cases and users that are aware of this feature. If you want to split |
|
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use |
|
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 |
|
operations will not be run on CPU. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
load_in_8bit=False, |
|
llm_int8_threshold=6.0, |
|
llm_int8_skip_modules=None, |
|
llm_int8_enable_fp32_cpu_offload=False, |
|
): |
|
self.load_in_8bit = load_in_8bit |
|
self.llm_int8_threshold = llm_int8_threshold |
|
self.llm_int8_skip_modules = llm_int8_skip_modules |
|
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload |
|
|
|
self.post_init() |
|
|
|
def post_init(self): |
|
r""" |
|
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. |
|
""" |
|
if not isinstance(self.llm_int8_threshold, float): |
|
raise ValueError("llm_int8_threshold must be a float") |
|
|
|
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): |
|
raise ValueError("llm_int8_skip_modules must be a list of strings") |
|
|
|
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): |
|
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean") |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs): |
|
""" |
|
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. |
|
|
|
Args: |
|
config_dict (`Dict[str, Any]`): |
|
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be |
|
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. |
|
kwargs (`Dict[str, Any]`): |
|
Additional parameters from which to initialize the configuration object. |
|
|
|
Returns: |
|
[`PretrainedConfig`]: The configuration object instantiated from those parameters. |
|
""" |
|
config = cls(**config_dict) |
|
|
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(config, key): |
|
setattr(config, key, value) |
|
to_remove.append(key) |
|
for key in to_remove: |
|
kwargs.pop(key, None) |
|
|
|
if return_unused_kwargs: |
|
return config, kwargs |
|
else: |
|
return config |
|
|