File size: 1,810 Bytes
e6d86b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from transformers import PretrainedConfig

class MambaSwarmConfig(PretrainedConfig):
    model_type = "mamba_swarm"
    
    def __init__(

        self,

        num_mamba_encoders=5,

        max_mamba_encoders=1000,

        d_model=768,

        d_state=16,

        d_conv=4,

        expand_factor=2,

        vocab_size=50257,

        max_sequence_length=2048,

        pad_token_id=50256,

        bos_token_id=50256,

        eos_token_id=50256,

        tie_word_embeddings=False,

        use_cache=True,

        gating_config=None,

        routing_config=None,

        **kwargs

    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs
        )
        
        self.num_mamba_encoders = num_mamba_encoders
        self.max_mamba_encoders = max_mamba_encoders
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand_factor = expand_factor
        self.vocab_size = vocab_size
        self.max_sequence_length = max_sequence_length
        self.use_cache = use_cache
        
        # Default gating configuration
        if gating_config is None:
            gating_config = {
                "gating_type": "learned",
                "top_k": 2,
                "load_balancing_loss_coef": 0.01
            }
        self.gating_config = gating_config
        
        # Default routing configuration
        if routing_config is None:
            routing_config = {
                "routing_strategy": "dynamic",
                "aggregation_method": "weighted_average"
            }
        self.routing_config = routing_config