| """ | |
| HF Transformers MambaConfig | |
| """ | |
| from transformers import PretrainedConfig | |
| class MambaConfig(PretrainedConfig): | |
| """ | |
| modeling configuration for state space model/mamba | |
| """ | |
| model_type = "mamba" | |
| def __init__( | |
| self, | |
| vocab_size=50280, | |
| d_model=2560, | |
| n_layer=64, | |
| rms_norm=True, | |
| residual_in_fp32=True, | |
| fused_add_norm=True, | |
| pad_vocab_size_multiple=8, | |
| pad_token_id=50277, | |
| bos_token_id=0, | |
| eos_token_id=0, | |
| tie_word_embeddings=False, | |
| **kwargs, | |
| ): | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.n_layer = n_layer | |
| self.rms_norm = rms_norm | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.fused_add_norm = fused_add_norm | |
| self.pad_vocab_size_multiple = pad_vocab_size_multiple | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| **kwargs, | |
| ) | |