File size: 1,331 Bytes
6db4d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# =============================================================================
# core/config.py
# =============================================================================
import torch
from dataclasses import dataclass
from typing import Dict, List, Optional

@dataclass
class MambaConfig:
    # Model architecture
    vocab_size: int = 50257
    d_model: int = 1024
    n_layers: int = 12
    d_inner: int = 2048
    d_state: int = 16
    d_conv: int = 4
    dt_rank: Optional[int] = None
    bias: bool = False
    conv_bias: bool = True
    
    # Training
    max_seq_len: int = 2048
    batch_size: int = 8
    learning_rate: float = 1e-4
    weight_decay: float = 0.1
    warmup_steps: int = 1000
    max_steps: int = 100000
    
    # Swarm specific
    num_specialists: int = 100
    specialist_domains: List[str] = None
    shared_embedding: bool = True
    hierarchical_sharing: bool = True
    
    # Hardware
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: torch.dtype = torch.float16
    
    def __post_init__(self):
        if self.dt_rank is None:
            self.dt_rank = max(16, self.d_model // 16)
        if self.specialist_domains is None:
            self.specialist_domains = [f"domain_{i}" for i in range(self.num_specialists)]