Debito commited on
Commit
6db4d44
·
verified ·
1 Parent(s): 1535ec7

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +32 -0
  2. config.py +44 -0
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["MambaSwarmForCausalLM"],
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_mamba_swarm.MambaSwarmConfig",
5
+ "AutoModelForCausalLM": "modeling_mamba_swarm.MambaSwarmForCausalLM"
6
+ },
7
+ "model_type": "mamba_swarm",
8
+ "num_mamba_encoders": 5,
9
+ "max_mamba_encoders": 1000,
10
+ "d_model": 768,
11
+ "d_state": 16,
12
+ "d_conv": 4,
13
+ "expand_factor": 2,
14
+ "vocab_size": 50257,
15
+ "max_sequence_length": 2048,
16
+ "pad_token_id": 50256,
17
+ "bos_token_id": 50256,
18
+ "eos_token_id": 50256,
19
+ "tie_word_embeddings": false,
20
+ "torch_dtype": "float16",
21
+ "transformers_version": "4.36.0",
22
+ "use_cache": true,
23
+ "gating_config": {
24
+ "gating_type": "learned",
25
+ "top_k": 2,
26
+ "load_balancing_loss_coef": 0.01
27
+ },
28
+ "routing_config": {
29
+ "routing_strategy": "dynamic",
30
+ "aggregation_method": "weighted_average"
31
+ }
32
+ }
config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/config.py
3
+ # =============================================================================
4
+ import torch
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ @dataclass
9
+ class MambaConfig:
10
+ # Model architecture
11
+ vocab_size: int = 50257
12
+ d_model: int = 1024
13
+ n_layers: int = 12
14
+ d_inner: int = 2048
15
+ d_state: int = 16
16
+ d_conv: int = 4
17
+ dt_rank: Optional[int] = None
18
+ bias: bool = False
19
+ conv_bias: bool = True
20
+
21
+ # Training
22
+ max_seq_len: int = 2048
23
+ batch_size: int = 8
24
+ learning_rate: float = 1e-4
25
+ weight_decay: float = 0.1
26
+ warmup_steps: int = 1000
27
+ max_steps: int = 100000
28
+
29
+ # Swarm specific
30
+ num_specialists: int = 100
31
+ specialist_domains: List[str] = None
32
+ shared_embedding: bool = True
33
+ hierarchical_sharing: bool = True
34
+
35
+ # Hardware
36
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
37
+ dtype: torch.dtype = torch.float16
38
+
39
+ def __post_init__(self):
40
+ if self.dt_rank is None:
41
+ self.dt_rank = max(16, self.d_model // 16)
42
+ if self.specialist_domains is None:
43
+ self.specialist_domains = [f"domain_{i}" for i in range(self.num_specialists)]
44
+