Spaces:
Sleeping
Sleeping
File size: 1,810 Bytes
e6d86b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from transformers import PretrainedConfig
class MambaSwarmConfig(PretrainedConfig):
model_type = "mamba_swarm"
def __init__(
self,
num_mamba_encoders=5,
max_mamba_encoders=1000,
d_model=768,
d_state=16,
d_conv=4,
expand_factor=2,
vocab_size=50257,
max_sequence_length=2048,
pad_token_id=50256,
bos_token_id=50256,
eos_token_id=50256,
tie_word_embeddings=False,
use_cache=True,
gating_config=None,
routing_config=None,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
self.num_mamba_encoders = num_mamba_encoders
self.max_mamba_encoders = max_mamba_encoders
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand_factor = expand_factor
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.use_cache = use_cache
# Default gating configuration
if gating_config is None:
gating_config = {
"gating_type": "learned",
"top_k": 2,
"load_balancing_loss_coef": 0.01
}
self.gating_config = gating_config
# Default routing configuration
if routing_config is None:
routing_config = {
"routing_strategy": "dynamic",
"aggregation_method": "weighted_average"
}
self.routing_config = routing_config |