Debito commited on
Commit
3fb2fb4
·
verified ·
1 Parent(s): bb70da7

Upload 3 files

Browse files
utils/conv_layer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # utils/conv_layer.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class Mamba1DConv(nn.Module):
8
+ def __init__(self, d_inner: int, d_conv: int = 4, bias: bool = True):
9
+ super().__init__()
10
+ self.d_conv = d_conv
11
+
12
+ self.conv1d = nn.Conv1d(
13
+ in_channels=d_inner,
14
+ out_channels=d_inner,
15
+ kernel_size=d_conv,
16
+ bias=bias,
17
+ groups=d_inner, # Depthwise convolution
18
+ padding=d_conv - 1
19
+ )
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Args:
24
+ x: [batch, seq_len, d_inner]
25
+ Returns:
26
+ x: [batch, seq_len, d_inner]
27
+ """
28
+ # Conv1d expects [batch, channels, seq_len]
29
+ x = x.transpose(1, 2) # [batch, d_inner, seq_len]
30
+ x = self.conv1d(x)
31
+ x = x[:, :, :-(self.d_conv-1)] # Remove padding
32
+ x = x.transpose(1, 2) # [batch, seq_len, d_inner]
33
+ return x
utils/domain_configs.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # utils/domain_configs.py
3
+ # =============================================================================
4
+ from typing import Dict, List
5
+ from core.config import MambaConfig
6
+
7
+ class DomainConfigs:
8
+ """Configurations for different specialist domains"""
9
+
10
+ DOMAINS = {
11
+ # STEM domains
12
+ "mathematics": {
13
+ "keywords": ["equation", "theorem", "proof", "calculate", "derivative", "integral", "matrix", "algebra", "geometry", "statistics"],
14
+ "description": "Mathematical reasoning and computation"
15
+ },
16
+ "physics": {
17
+ "keywords": ["force", "energy", "momentum", "quantum", "relativity", "particle", "wave", "thermodynamics", "mechanics"],
18
+ "description": "Physics concepts and problems"
19
+ },
20
+ "chemistry": {
21
+ "keywords": ["molecule", "atom", "reaction", "compound", "bond", "element", "organic", "inorganic", "catalyst"],
22
+ "description": "Chemistry and molecular science"
23
+ },
24
+ "biology": {
25
+ "keywords": ["cell", "DNA", "protein", "organism", "evolution", "genetics", "ecology", "anatomy", "physiology"],
26
+ "description": "Biological sciences"
27
+ },
28
+
29
+ # Programming domains
30
+ "python": {
31
+ "keywords": ["def", "class", "import", "python", "pandas", "numpy", "matplotlib", "sklearn", "tensorflow"],
32
+ "description": "Python programming and data science"
33
+ },
34
+ "javascript": {
35
+ "keywords": ["function", "var", "let", "const", "javascript", "react", "node", "async", "promise"],
36
+ "description": "JavaScript and web development"
37
+ },
38
+ "systems": {
39
+ "keywords": ["linux", "server", "network", "database", "docker", "kubernetes", "cloud", "devops"],
40
+ "description": "Systems programming and infrastructure"
41
+ },
42
+
43
+ # Language domains
44
+ "writing": {
45
+ "keywords": ["essay", "article", "story", "paragraph", "thesis", "narrative", "prose", "literature"],
46
+ "description": "Creative and technical writing"
47
+ },
48
+ "translation": {
49
+ "keywords": ["translate", "language", "spanish", "french", "german", "chinese", "japanese", "korean"],
50
+ "description": "Language translation and linguistics"
51
+ },
52
+
53
+ # Business domains
54
+ "business": {
55
+ "keywords": ["market", "strategy", "finance", "management", "revenue", "profit", "customer", "sales"],
56
+ "description": "Business and economics"
57
+ },
58
+ "legal": {
59
+ "keywords": ["law", "contract", "court", "legal", "attorney", "judge", "case", "statute", "regulation"],
60
+ "description": "Legal reasoning and analysis"
61
+ },
62
+
63
+ # Other domains
64
+ "history": {
65
+ "keywords": ["war", "empire", "civilization", "century", "ancient", "medieval", "revolution", "dynasty"],
66
+ "description": "Historical knowledge and analysis"
67
+ },
68
+ "philosophy": {
69
+ "keywords": ["ethics", "moral", "logic", "metaphysics", "epistemology", "consciousness", "existence"],
70
+ "description": "Philosophical reasoning"
71
+ },
72
+ "medical": {
73
+ "keywords": ["patient", "diagnosis", "treatment", "disease", "medicine", "surgery", "therapy", "symptom"],
74
+ "description": "Medical knowledge and healthcare"
75
+ },
76
+ "arts": {
77
+ "keywords": ["painting", "music", "sculpture", "artist", "gallery", "museum", "aesthetic", "culture"],
78
+ "description": "Arts and cultural topics"
79
+ }
80
+ }
81
+
82
+ @classmethod
83
+ def get_domain_configs(cls, num_specialists: int = 100) -> List[Dict]:
84
+ """Generate configurations for specialist domains"""
85
+ configs = []
86
+ base_domains = list(cls.DOMAINS.keys())
87
+
88
+ # Create configurations
89
+ for i in range(num_specialists):
90
+ if i < len(base_domains):
91
+ # Use predefined domains
92
+ domain_name = base_domains[i]
93
+ domain_info = cls.DOMAINS[domain_name]
94
+ else:
95
+ # Create sub-specializations or general domains
96
+ base_idx = i % len(base_domains)
97
+ domain_name = f"{base_domains[base_idx]}_sub_{i}"
98
+ domain_info = cls.DOMAINS[base_domains[base_idx]]
99
+
100
+ config = {
101
+ "id": i,
102
+ "name": domain_name,
103
+ "keywords": domain_info["keywords"],
104
+ "description": domain_info["description"],
105
+ "weight": 1.0 # Can be adjusted based on importance
106
+ }
107
+ configs.append(config)
108
+
109
+ return configs
110
+
111
+ @classmethod
112
+ def create_specialist_config(cls, base_config: MambaConfig, domain_id: int) -> MambaConfig:
113
+ """Create a specialist configuration for a specific domain"""
114
+ specialist_config = MambaConfig(**base_config.__dict__)
115
+ specialist_config.specialist_id = domain_id
116
+ return specialist_config
utils/selective_scan.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # utils/selective_scan.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from typing import Tuple
7
+
8
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False):
9
+ """
10
+ Selective scan function - core of Mamba's state space model
11
+
12
+ Args:
13
+ u: input sequence [batch, seq_len, d_inner]
14
+ delta: time step [batch, seq_len, d_inner]
15
+ A: state matrix [d_inner, d_state]
16
+ B: input matrix [batch, seq_len, d_state]
17
+ C: output matrix [batch, seq_len, d_state]
18
+ D: skip connection [d_inner]
19
+ z: gating [batch, seq_len, d_inner] (optional)
20
+ delta_bias: bias for delta (optional)
21
+ delta_softplus: whether to apply softplus to delta
22
+
23
+ Returns:
24
+ y: output [batch, seq_len, d_inner]
25
+ """
26
+ batch_size, seq_len, d_inner = u.shape
27
+ d_state = A.shape[1]
28
+
29
+ if delta_bias is not None:
30
+ delta = delta + delta_bias[None, None, :]
31
+
32
+ if delta_softplus:
33
+ delta = F.softplus(delta)
34
+
35
+ # Discretization
36
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # [batch, seq_len, d_inner, d_state]
37
+ deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # [batch, seq_len, d_inner, d_state]
38
+
39
+ # Initialize hidden state
40
+ h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype)
41
+
42
+ outputs = []
43
+ for i in range(seq_len):
44
+ h = deltaA[:, i] * h + deltaB_u[:, i] # State update
45
+ y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # Output projection
46
+ if D is not None:
47
+ y = y + D * u[:, i]
48
+ outputs.append(y)
49
+
50
+ y = torch.stack(outputs, dim=1) # [batch, seq_len, d_inner]
51
+
52
+ if z is not None:
53
+ y = y * F.silu(z)
54
+
55
+ return y