PEFT documentation
Sparse High Rank Adapters
Sparse High Rank Adapters
Sparse High Rank Adapters or SHiRA is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model’s parameters to finetune the model on any adaptation task.
SHiRA currently has the following constraint:
- Only
nn.Linear
layers are supported.
The abstract from the paper is:
Low Rank Adaptation (LoRA) has gained massive attention in the recent generative AI research. One of the main advantages of LoRA is its ability to be fused with pretrained models, adding no overhead during inference. However, from a mobile deployment standpoint, we can either avoid inference overhead in the fused mode but lose the ability to switch adapters rapidly, or suffer significant (up to 30% higher) inference latency while enabling rapid switching in the unfused mode. LoRA also exhibits concept-loss when multiple adapters are used concurrently. In this paper, we propose Sparse High Rank Adapters (SHiRA), a new paradigm which incurs no inference overhead, enables rapid switching, and significantly reduces concept-loss. Specifically, SHiRA can be trained by directly tuning only 1-2% of the base model weights while leaving others unchanged. This results in a highly sparse adapter which can be switched directly in the fused mode. We further provide theoretical and empirical insights on how high sparsity in SHiRA can aid multi-adapter fusion by reducing concept loss. Our extensive experiments on LVMs and LLMs demonstrate that finetuning only a small fraction of the parameters in the base model significantly outperforms LoRA while enabling both rapid switching and multi-adapter fusion. Finally, we provide a latency- and memory-efficient SHiRA implementation based on Parameter-Efficient Finetuning (PEFT) Library which trains at nearly the same speed as LoRA while consuming up to 16% lower peak GPU memory, thus making SHiRA easy to adopt for practical use cases. To demonstrate rapid switching benefits during inference, we show that loading SHiRA on a base model can be 5x-16x faster than LoRA fusion on a CPU.
ShiraConfig
class peft.ShiraConfig
< source >( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False r: int = 32 mask_type: Literal['random'] = 'random' random_seed: Optional[int] = None target_modules: Optional[Union[list[str], str]] = None fan_in_fan_out: bool = False init_weights: bool = True modules_to_save: Optional[list[str]] = None )
Parameters
- r (
int
, optional, defaults to32
) — For a given target module, the number of SHiRA parameters is computed as r(m+n), where the original tensor dimensions are m x n. This means the number of SHiRA parameters is the same as that for a LoRA adapter. SHiRA is a high rank adapter. Setting this r parameter does not restrict the rank to this value. - mask_type (
str
, defaults torandom
) — Type of mask function. Defaults to a random sparse mask. An optional user-defined mask_fn to compute the mask value can also be supplied by instantiatingconfig = ShiraConfig(...)
and then settingconfig.mask_fn = <your custom mask function>
. For a pretrained weight with shape m x n, the custom mask function must return only one mask (shape: m x n) which must be binary 0 or 1 with num_shira_parameters = r(m + n) for linear layers. Device and dtype of mask must be same as base layer’s weight’s device and dtype. Please see mask_functions.py for more details and to see the default random sparse mask implementation. - random_seed (
int
, optional, defaults toNone
) — random seed for the torch generator for random_mask. - target_modules (
Union[List[str], str]
) — List of module names or regex expression of the module names to replace with SHiRA. For example, [‘q’, ‘v’] or ‘.decoder.(SelfAttention|EncDecAttention).*(q|v)$‘. Only linear layers are supported. - fan_in_fan_out (
bool
) — Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 usesConv1D
which stores weights like (fan_in, fan_out) and hence this should be set toTrue
. - init_weights (
bool
, defaults toTrue
) — Initialize SHiRA weight to have zero values. If set to False, SHiRA weights are initialized to randn values instead of zeros and this is used only for testing. - modules_to_save (
List[str]
) — List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint.
This is the configuration class to store the configuration of a ShiraModel.
ShiraModel
class peft.ShiraModel
< source >( model peft_config: Union[PeftConfig, dict[str, PeftConfig]] adapter_name: str low_cpu_mem_usage: bool = False state_dict: Optional[dict[str, torch.Tensor]] = None ) → torch.nn.Module
Parameters
- model (
PreTrainedModel
) — The model to be adapted. - config (ShiraConfig) — The configuration of the SHiRA model.
- adapter_name (
str
) — The name of the adapter, defaults to"default"
.
Returns
torch.nn.Module
The SHiRA model.
Creates a Sparse High Rank Adapter (SHiRA) Model from a pretrained model.
Example:
>>> from transformers import AutoModelForCausalLM
>>> from peft import ShiraConfig, get_peft_model
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> config = ShiraConfig(r=32)
>>> model = get_peft_model(base_model, config)
Attributes:
- model (
PreTrainedModel
) — The model to be adapted. - peft_config (ShiraConfig): The configuration of the SHiRA model.
delete_adapter
< source >( adapter_name: str )
Deletes an existing adapter.
merge_and_unload
< source >( progressbar: bool = False safe_merge: bool = False adapter_names: Optional[list[str]] = None )
Parameters
- progressbar (
bool
) — whether to show a progressbar indicating the unload and merge process - safe_merge (
bool
) — whether to activate the safe merging check to check if there is any potential Nan in the adapter weights - adapter_names (
list[str]
, optional) — The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults toNone
.
This method merges the Shira layers into the base model. This is needed if someone wants to use the base model as a standalone model.
Example:
>>> from transformers import AutoModelForCausalLM
>>> from peft import ShiraConfig, get_peft_model
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> config = ShiraConfig(r=32)
>>> model = get_peft_model(base_model, config)
>>> ## [Train the adapter] ##
>>> merged_model = model.merge_and_unload()
Gets back the base model by removing all the Shira modules without merging. This gives back the original base model.