from transformers import PretrainedConfig class MambaSwarmConfig(PretrainedConfig): model_type = "mamba_swarm" def __init__( self, num_mamba_encoders=5, max_mamba_encoders=1000, d_model=768, d_state=16, d_conv=4, expand_factor=2, vocab_size=50257, max_sequence_length=2048, pad_token_id=50256, bos_token_id=50256, eos_token_id=50256, tie_word_embeddings=False, use_cache=True, gating_config=None, routing_config=None, **kwargs ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) self.num_mamba_encoders = num_mamba_encoders self.max_mamba_encoders = max_mamba_encoders self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand_factor = expand_factor self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length self.use_cache = use_cache # Default gating configuration if gating_config is None: gating_config = { "gating_type": "learned", "top_k": 2, "load_balancing_loss_coef": 0.01 } self.gating_config = gating_config # Default routing configuration if routing_config is None: routing_config = { "routing_strategy": "dynamic", "aggregation_method": "weighted_average" } self.routing_config = routing_config