| """ | |
| Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention | |
| """ | |
| from axolotl.monkeypatch.utils import ( | |
| patched_prepare_4d_causal_attention_mask, | |
| patched_prepare_4d_causal_attention_mask_for_sdpa, | |
| ) | |
| def hijack_llama_prepare_4d_mask(): | |
| import transformers.modeling_attn_mask_utils | |
| import transformers.models.llama.modeling_llama | |
| transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access | |
| patched_prepare_4d_causal_attention_mask_for_sdpa | |
| ) | |
| transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access | |
| patched_prepare_4d_causal_attention_mask_for_sdpa | |
| ) | |
| transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access | |
| patched_prepare_4d_causal_attention_mask | |
| ) | |
| transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access | |
| patched_prepare_4d_causal_attention_mask | |
| ) | |