| """ | |
| Patches to support multipack for mixtral | |
| """ | |
| import transformers | |
| def replace_mixtral_attn_with_multipack_flash_attn(): | |
| from .modeling_mixtral import ( | |
| MixtralMultipackFlashAttention2, | |
| mixtral_decoder_layer_forward, | |
| mixtral_model_forward, | |
| ) | |
| transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = ( | |
| mixtral_decoder_layer_forward | |
| ) | |
| transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( | |
| mixtral_model_forward | |
| ) | |
| transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[ | |
| "flash_attention_2" | |
| ] = MixtralMultipackFlashAttention2 | |