Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- utils/conv_layer.py +33 -0
- utils/domain_configs.py +116 -0
- utils/selective_scan.py +55 -0
utils/conv_layer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# =============================================================================
|
2 |
+
# utils/conv_layer.py
|
3 |
+
# =============================================================================
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
class Mamba1DConv(nn.Module):
|
8 |
+
def __init__(self, d_inner: int, d_conv: int = 4, bias: bool = True):
|
9 |
+
super().__init__()
|
10 |
+
self.d_conv = d_conv
|
11 |
+
|
12 |
+
self.conv1d = nn.Conv1d(
|
13 |
+
in_channels=d_inner,
|
14 |
+
out_channels=d_inner,
|
15 |
+
kernel_size=d_conv,
|
16 |
+
bias=bias,
|
17 |
+
groups=d_inner, # Depthwise convolution
|
18 |
+
padding=d_conv - 1
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
x: [batch, seq_len, d_inner]
|
25 |
+
Returns:
|
26 |
+
x: [batch, seq_len, d_inner]
|
27 |
+
"""
|
28 |
+
# Conv1d expects [batch, channels, seq_len]
|
29 |
+
x = x.transpose(1, 2) # [batch, d_inner, seq_len]
|
30 |
+
x = self.conv1d(x)
|
31 |
+
x = x[:, :, :-(self.d_conv-1)] # Remove padding
|
32 |
+
x = x.transpose(1, 2) # [batch, seq_len, d_inner]
|
33 |
+
return x
|
utils/domain_configs.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# =============================================================================
|
2 |
+
# utils/domain_configs.py
|
3 |
+
# =============================================================================
|
4 |
+
from typing import Dict, List
|
5 |
+
from core.config import MambaConfig
|
6 |
+
|
7 |
+
class DomainConfigs:
|
8 |
+
"""Configurations for different specialist domains"""
|
9 |
+
|
10 |
+
DOMAINS = {
|
11 |
+
# STEM domains
|
12 |
+
"mathematics": {
|
13 |
+
"keywords": ["equation", "theorem", "proof", "calculate", "derivative", "integral", "matrix", "algebra", "geometry", "statistics"],
|
14 |
+
"description": "Mathematical reasoning and computation"
|
15 |
+
},
|
16 |
+
"physics": {
|
17 |
+
"keywords": ["force", "energy", "momentum", "quantum", "relativity", "particle", "wave", "thermodynamics", "mechanics"],
|
18 |
+
"description": "Physics concepts and problems"
|
19 |
+
},
|
20 |
+
"chemistry": {
|
21 |
+
"keywords": ["molecule", "atom", "reaction", "compound", "bond", "element", "organic", "inorganic", "catalyst"],
|
22 |
+
"description": "Chemistry and molecular science"
|
23 |
+
},
|
24 |
+
"biology": {
|
25 |
+
"keywords": ["cell", "DNA", "protein", "organism", "evolution", "genetics", "ecology", "anatomy", "physiology"],
|
26 |
+
"description": "Biological sciences"
|
27 |
+
},
|
28 |
+
|
29 |
+
# Programming domains
|
30 |
+
"python": {
|
31 |
+
"keywords": ["def", "class", "import", "python", "pandas", "numpy", "matplotlib", "sklearn", "tensorflow"],
|
32 |
+
"description": "Python programming and data science"
|
33 |
+
},
|
34 |
+
"javascript": {
|
35 |
+
"keywords": ["function", "var", "let", "const", "javascript", "react", "node", "async", "promise"],
|
36 |
+
"description": "JavaScript and web development"
|
37 |
+
},
|
38 |
+
"systems": {
|
39 |
+
"keywords": ["linux", "server", "network", "database", "docker", "kubernetes", "cloud", "devops"],
|
40 |
+
"description": "Systems programming and infrastructure"
|
41 |
+
},
|
42 |
+
|
43 |
+
# Language domains
|
44 |
+
"writing": {
|
45 |
+
"keywords": ["essay", "article", "story", "paragraph", "thesis", "narrative", "prose", "literature"],
|
46 |
+
"description": "Creative and technical writing"
|
47 |
+
},
|
48 |
+
"translation": {
|
49 |
+
"keywords": ["translate", "language", "spanish", "french", "german", "chinese", "japanese", "korean"],
|
50 |
+
"description": "Language translation and linguistics"
|
51 |
+
},
|
52 |
+
|
53 |
+
# Business domains
|
54 |
+
"business": {
|
55 |
+
"keywords": ["market", "strategy", "finance", "management", "revenue", "profit", "customer", "sales"],
|
56 |
+
"description": "Business and economics"
|
57 |
+
},
|
58 |
+
"legal": {
|
59 |
+
"keywords": ["law", "contract", "court", "legal", "attorney", "judge", "case", "statute", "regulation"],
|
60 |
+
"description": "Legal reasoning and analysis"
|
61 |
+
},
|
62 |
+
|
63 |
+
# Other domains
|
64 |
+
"history": {
|
65 |
+
"keywords": ["war", "empire", "civilization", "century", "ancient", "medieval", "revolution", "dynasty"],
|
66 |
+
"description": "Historical knowledge and analysis"
|
67 |
+
},
|
68 |
+
"philosophy": {
|
69 |
+
"keywords": ["ethics", "moral", "logic", "metaphysics", "epistemology", "consciousness", "existence"],
|
70 |
+
"description": "Philosophical reasoning"
|
71 |
+
},
|
72 |
+
"medical": {
|
73 |
+
"keywords": ["patient", "diagnosis", "treatment", "disease", "medicine", "surgery", "therapy", "symptom"],
|
74 |
+
"description": "Medical knowledge and healthcare"
|
75 |
+
},
|
76 |
+
"arts": {
|
77 |
+
"keywords": ["painting", "music", "sculpture", "artist", "gallery", "museum", "aesthetic", "culture"],
|
78 |
+
"description": "Arts and cultural topics"
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def get_domain_configs(cls, num_specialists: int = 100) -> List[Dict]:
|
84 |
+
"""Generate configurations for specialist domains"""
|
85 |
+
configs = []
|
86 |
+
base_domains = list(cls.DOMAINS.keys())
|
87 |
+
|
88 |
+
# Create configurations
|
89 |
+
for i in range(num_specialists):
|
90 |
+
if i < len(base_domains):
|
91 |
+
# Use predefined domains
|
92 |
+
domain_name = base_domains[i]
|
93 |
+
domain_info = cls.DOMAINS[domain_name]
|
94 |
+
else:
|
95 |
+
# Create sub-specializations or general domains
|
96 |
+
base_idx = i % len(base_domains)
|
97 |
+
domain_name = f"{base_domains[base_idx]}_sub_{i}"
|
98 |
+
domain_info = cls.DOMAINS[base_domains[base_idx]]
|
99 |
+
|
100 |
+
config = {
|
101 |
+
"id": i,
|
102 |
+
"name": domain_name,
|
103 |
+
"keywords": domain_info["keywords"],
|
104 |
+
"description": domain_info["description"],
|
105 |
+
"weight": 1.0 # Can be adjusted based on importance
|
106 |
+
}
|
107 |
+
configs.append(config)
|
108 |
+
|
109 |
+
return configs
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def create_specialist_config(cls, base_config: MambaConfig, domain_id: int) -> MambaConfig:
|
113 |
+
"""Create a specialist configuration for a specific domain"""
|
114 |
+
specialist_config = MambaConfig(**base_config.__dict__)
|
115 |
+
specialist_config.specialist_id = domain_id
|
116 |
+
return specialist_config
|
utils/selective_scan.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# =============================================================================
|
2 |
+
# utils/selective_scan.py
|
3 |
+
# =============================================================================
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False):
|
9 |
+
"""
|
10 |
+
Selective scan function - core of Mamba's state space model
|
11 |
+
|
12 |
+
Args:
|
13 |
+
u: input sequence [batch, seq_len, d_inner]
|
14 |
+
delta: time step [batch, seq_len, d_inner]
|
15 |
+
A: state matrix [d_inner, d_state]
|
16 |
+
B: input matrix [batch, seq_len, d_state]
|
17 |
+
C: output matrix [batch, seq_len, d_state]
|
18 |
+
D: skip connection [d_inner]
|
19 |
+
z: gating [batch, seq_len, d_inner] (optional)
|
20 |
+
delta_bias: bias for delta (optional)
|
21 |
+
delta_softplus: whether to apply softplus to delta
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
y: output [batch, seq_len, d_inner]
|
25 |
+
"""
|
26 |
+
batch_size, seq_len, d_inner = u.shape
|
27 |
+
d_state = A.shape[1]
|
28 |
+
|
29 |
+
if delta_bias is not None:
|
30 |
+
delta = delta + delta_bias[None, None, :]
|
31 |
+
|
32 |
+
if delta_softplus:
|
33 |
+
delta = F.softplus(delta)
|
34 |
+
|
35 |
+
# Discretization
|
36 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # [batch, seq_len, d_inner, d_state]
|
37 |
+
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # [batch, seq_len, d_inner, d_state]
|
38 |
+
|
39 |
+
# Initialize hidden state
|
40 |
+
h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype)
|
41 |
+
|
42 |
+
outputs = []
|
43 |
+
for i in range(seq_len):
|
44 |
+
h = deltaA[:, i] * h + deltaB_u[:, i] # State update
|
45 |
+
y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # Output projection
|
46 |
+
if D is not None:
|
47 |
+
y = y + D * u[:, i]
|
48 |
+
outputs.append(y)
|
49 |
+
|
50 |
+
y = torch.stack(outputs, dim=1) # [batch, seq_len, d_inner]
|
51 |
+
|
52 |
+
if z is not None:
|
53 |
+
y = y * F.silu(z)
|
54 |
+
|
55 |
+
return y
|