| """ | |
| Modeling module for Mamba models | |
| """ | |
| import importlib | |
| def check_mamba_ssm_installed(): | |
| mamba_ssm_spec = importlib.util.find_spec("mamba_ssm") | |
| if mamba_ssm_spec is None: | |
| raise ImportError( | |
| "MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`" | |
| ) | |
| def fix_mamba_attn_for_loss(): | |
| check_mamba_ssm_installed() | |
| from mamba_ssm.models import mixer_seq_simple | |
| from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed | |
| mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed | |
| return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name | |