|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
import torch |
|
|
|
|
|
|
|
def bloom_model_postprocess_past_key_value(past_key_values): |
|
past_key_values = torch.cat(past_key_values) |
|
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape |
|
keys = past_key_values[: total_layers // 2] |
|
keys = keys.transpose(2, 3).reshape( |
|
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens |
|
) |
|
values = past_key_values[total_layers // 2 :] |
|
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) |
|
|
|
return tuple(zip(keys, values)) |
|
|
|
|
|
def prepare_model_for_int8_training( |
|
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] |
|
): |
|
r""" |
|
This method wraps the entire protocol for preparing a model before running a training. This includes: |
|
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm |
|
head to fp32 |
|
|
|
Args: |
|
model, (`transformers.PreTrainedModel`): |
|
The loaded model from `transformers` |
|
""" |
|
loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False) |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
if loaded_in_8bit: |
|
|
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): |
|
param.data = param.data.to(torch.float32) |
|
|
|
if loaded_in_8bit and use_gradient_checkpointing: |
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
|
if hasattr(model, output_embedding_layer_name): |
|
output_embedding_layer = getattr(model, output_embedding_layer_name) |
|
input_dtype = output_embedding_layer.weight.dtype |
|
|
|
class CastOutputToFloat(torch.nn.Sequential): |
|
r""" |
|
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted |
|
in fp32 |
|
|
|
""" |
|
|
|
def forward(self, x): |
|
return super().forward(x.to(input_dtype)).to(torch.float32) |
|
|
|
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) |
|
|
|
return model |
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
|
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids |
|
pad_token_id (`int`): The id of the `padding` token. |
|
decoder_start_token_id (`int`): The id of the `start` token. |
|
""" |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
if pad_token_id is None: |
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
class ModulesToSaveWrapper(torch.nn.Module): |
|
def __init__(self, module_to_save, adapter_name): |
|
super().__init__() |
|
self.original_module = module_to_save |
|
self.modules_to_save = torch.nn.ModuleDict({}) |
|
self.update(adapter_name) |
|
self.active_adapter = adapter_name |
|
|
|
def update(self, adapter_name): |
|
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) |
|
|
|
def forward(self, *args, **kwargs): |
|
if self.active_adapter not in self.modules_to_save: |
|
return self.original_module(*args, **kwargs) |
|
return self.modules_to_save[self.active_adapter](*args, **kwargs) |
|
|
|
|
|
def _get_submodules(model, key): |
|
parent = model.get_submodule(".".join(key.split(".")[:-1])) |
|
target_name = key.split(".")[-1] |
|
target = model.get_submodule(key) |
|
return parent, target, target_name |
|
|
|
|
|
def _freeze_adapter(model, adapter_name): |
|
for n, p in model.named_parameters(): |
|
if adapter_name in n: |
|
p.requires_grad = False |
|
|
|
|
|
def _set_trainable(model, adapter_name): |
|
key_list = [key for key, _ in model.named_modules()] |
|
for key in key_list: |
|
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) |
|
if target_module_found: |
|
parent, target, target_name = _get_submodules(model, key) |
|
if isinstance(target, ModulesToSaveWrapper): |
|
target.update(adapter_name) |
|
else: |
|
for param in target.parameters(): |
|
param.requires_grad = True |
|
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) |
|
|
|
|
|
def _set_adapter(model, adapter_name): |
|
for module in model.modules(): |
|
if isinstance(module, ModulesToSaveWrapper): |
|
module.active_adapter = adapter_name |
|
|
|
|
|
def fsdp_auto_wrap_policy(model): |
|
import functools |
|
import os |
|
|
|
from accelerate import FullyShardedDataParallelPlugin |
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
|
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder |
|
|
|
def lambda_policy_fn(module): |
|
if ( |
|
len(list(module.named_children())) == 0 |
|
and getattr(module, "weight", None) is not None |
|
and module.weight.requires_grad |
|
): |
|
return True |
|
return False |
|
|
|
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
|
transformer_wrap_policy = functools.partial( |
|
transformer_auto_wrap_policy, |
|
transformer_layer_cls=( |
|
PrefixEncoder, |
|
PromptEncoder, |
|
PromptEmbedding, |
|
FullyShardedDataParallelPlugin.get_module_class_from_name( |
|
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") |
|
), |
|
), |
|
) |
|
|
|
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) |
|
return auto_wrap_policy |
|
|
|
|
|
def transpose(weight, fan_in_fan_out): |
|
return weight.T if fan_in_fan_out else weight |
|
|
|
|
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { |
|
"t5": ["q", "v"], |
|
"mt5": ["q", "v"], |
|
"bart": ["q_proj", "v_proj"], |
|
"gpt2": ["c_attn"], |
|
"bloom": ["query_key_value"], |
|
"blip-2": ["q", "v", "q_proj", "v_proj"], |
|
"opt": ["q_proj", "v_proj"], |
|
"gptj": ["q_proj", "v_proj"], |
|
"gpt_neox": ["query_key_value"], |
|
"gpt_neo": ["q_proj", "v_proj"], |
|
"bert": ["query", "value"], |
|
"roberta": ["query", "value"], |
|
"xlm-roberta": ["query", "value"], |
|
"electra": ["query", "value"], |
|
"deberta-v2": ["query_proj", "value_proj"], |
|
"deberta": ["in_proj"], |
|
"layoutlm": ["query", "value"], |
|
"llama": ["q_proj", "v_proj"], |
|
"chatglm": ["query_key_value"], |
|
} |
|
|
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { |
|
"t5": ["q", "k", "v", "o", "wi", "wo"], |
|
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], |
|
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], |
|
|
|
|
|
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], |
|
|
|
|
|
|
|
|
|
"roberta": ["query", "key", "value", "dense"], |
|
|
|
|
|
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], |
|
|
|
|
|
} |
|
|
|
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { |
|
"bloom": bloom_model_postprocess_past_key_value, |
|
} |
|
|
|
WEIGHTS_NAME = "adapter_model.bin" |
|
CONFIG_NAME = "adapter_config.json" |
|
|