Debito commited on
Commit
055a9c8
·
verified ·
1 Parent(s): 85d4a54

Upload 8 files

Browse files
core/config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/config.py
3
+ # =============================================================================
4
+ import torch
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ @dataclass
9
+ class MambaConfig:
10
+ # Model architecture
11
+ vocab_size: int = 50257
12
+ d_model: int = 1024
13
+ n_layers: int = 12
14
+ d_inner: int = 2048
15
+ d_state: int = 16
16
+ d_conv: int = 4
17
+ dt_rank: Optional[int] = None
18
+ bias: bool = False
19
+ conv_bias: bool = True
20
+
21
+ # Training
22
+ max_seq_len: int = 2048
23
+ batch_size: int = 8
24
+ learning_rate: float = 1e-4
25
+ weight_decay: float = 0.1
26
+ warmup_steps: int = 1000
27
+ max_steps: int = 100000
28
+
29
+ # Swarm specific
30
+ num_specialists: int = 100
31
+ specialist_domains: List[str] = None
32
+ shared_embedding: bool = True
33
+ hierarchical_sharing: bool = True
34
+
35
+ # Hardware
36
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
37
+ dtype: torch.dtype = torch.float16
38
+
39
+ def __post_init__(self):
40
+ if self.dt_rank is None:
41
+ self.dt_rank = max(16, self.d_model // 16)
42
+ if self.specialist_domains is None:
43
+ self.specialist_domains = [f"domain_{i}" for i in range(self.num_specialists)]
44
+
core/embedding.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/embedding.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+ from config import MambaConfig
8
+
9
+ class MambaEmbedding(nn.Module):
10
+ def __init__(self, config: MambaConfig):
11
+ super().__init__()
12
+ self.config = config
13
+
14
+ # Token embeddings (no positional encoding needed for Mamba)
15
+ self.token_embedding = nn.Embedding(
16
+ config.vocab_size,
17
+ config.d_model,
18
+ dtype=config.dtype
19
+ )
20
+
21
+ # Initialize embeddings
22
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
23
+
24
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Args:
27
+ input_ids: [batch_size, seq_len]
28
+ Returns:
29
+ embeddings: [batch_size, seq_len, d_model]
30
+ """
31
+ embeddings = self.token_embedding(input_ids)
32
+ return embeddings
core/mamba.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/mamba.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from core.stateSpace import StateSpaceModel
8
+ from utils.conv_layer import Mamba1DConv
9
+
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, d_model: int, eps: float = 1e-5):
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(d_model))
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ norm = x.norm(dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
18
+ return x / (norm + self.eps) * self.weight
19
+
20
+ class MambaBlock(nn.Module):
21
+ def __init__(self, config):
22
+ super().__init__()
23
+ self.config = config
24
+
25
+ # Projections
26
+ self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
27
+ self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
28
+
29
+ # Convolution for local context
30
+ self.conv1d = Mamba1DConv(config.d_inner, config.d_conv, config.conv_bias)
31
+
32
+ # State space model
33
+ self.ssm = StateSpaceModel(
34
+ d_inner=config.d_inner,
35
+ d_state=config.d_state,
36
+ dt_rank=config.dt_rank,
37
+ bias=config.bias
38
+ )
39
+
40
+ # Activation
41
+ self.act = F.silu
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Args:
46
+ x: [batch, seq_len, d_model]
47
+ Returns:
48
+ output: [batch, seq_len, d_model]
49
+ """
50
+ batch_size, seq_len, d_model = x.shape
51
+
52
+ # Input projection
53
+ xz = self.in_proj(x) # [batch, seq_len, 2*d_inner]
54
+ x, z = xz.chunk(2, dim=-1) # Each [batch, seq_len, d_inner]
55
+
56
+ # Apply convolution
57
+ x = self.act(self.conv1d(x))
58
+
59
+ # Apply state space model
60
+ y = self.ssm(x)
61
+
62
+ # Apply gating with z
63
+ y = y * self.act(z)
64
+
65
+ # Output projection
66
+ output = self.out_proj(y)
67
+
68
+ return output
69
+
70
+ class MambaLayer(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.norm = RMSNorm(config.d_model)
74
+ self.mamba_block = MambaBlock(config)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ # Pre-norm architecture
78
+ residual = x
79
+ x = self.norm(x)
80
+ x = self.mamba_block(x)
81
+ return x + residual
core/mamba_swarm_integration.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mamba Encoder Swarm - Integration with Existing Mamba Implementation
4
+ Uses your existing Mamba components as building blocks for the swarm architecture
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import List, Optional, Tuple
11
+
12
+ # Import your existing Mamba components
13
+ from core.config import MambaConfig
14
+ from core.model import MambaModel
15
+ from core.mamba import MambaLayer, RMSNorm
16
+ from core.embedding import MambaEmbedding
17
+
18
+ class SwarmRouter(nn.Module):
19
+ """
20
+ Routes input tokens to different encoder instances
21
+ This is the NEW component that enables the swarm architecture
22
+ """
23
+
24
+ def __init__(self, d_model: int, num_encoders: int, routing_strategy: str = "learned"):
25
+ super().__init__()
26
+ self.d_model = d_model
27
+ self.num_encoders = num_encoders
28
+ self.routing_strategy = routing_strategy
29
+
30
+ if routing_strategy == "learned":
31
+ # Neural router that learns optimal token distribution
32
+ self.router_network = nn.Sequential(
33
+ nn.Linear(d_model, d_model // 2),
34
+ nn.SiLU(),
35
+ nn.Linear(d_model // 2, num_encoders),
36
+ nn.Softmax(dim=-1)
37
+ )
38
+
39
+ # Load balancing coefficient
40
+ self.load_balance_coef = 0.01
41
+
42
+ def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
43
+ """
44
+ Route tokens to encoder instances
45
+
46
+ Args:
47
+ x: [batch, seq_len, d_model]
48
+
49
+ Returns:
50
+ encoder_inputs: List of inputs for each encoder
51
+ routing_weights: Weights for aggregation [batch, seq_len, num_encoders]
52
+ load_balance_loss: Loss term for training
53
+ """
54
+ batch_size, seq_len, d_model = x.shape
55
+
56
+ if self.routing_strategy == "learned":
57
+ # Learn routing patterns
58
+ routing_logits = self.router_network(x) # [batch, seq_len, num_encoders]
59
+ routing_weights = F.gumbel_softmax(routing_logits, tau=1.0, hard=False)
60
+
61
+ # Load balancing loss to encourage equal usage
62
+ avg_routing = routing_weights.mean(dim=[0, 1])
63
+ load_balance_loss = self.load_balance_coef * torch.var(avg_routing)
64
+
65
+ else: # Round-robin for simplicity
66
+ seq_indices = torch.arange(seq_len, device=x.device)
67
+ encoder_ids = seq_indices % self.num_encoders
68
+ routing_weights = F.one_hot(encoder_ids, self.num_encoders).float()
69
+ routing_weights = routing_weights.unsqueeze(0).expand(batch_size, -1, -1)
70
+ load_balance_loss = torch.tensor(0.0, device=x.device)
71
+
72
+ # Create weighted inputs for each encoder
73
+ encoder_inputs = []
74
+ for i in range(self.num_encoders):
75
+ weight = routing_weights[:, :, i:i+1] # [batch, seq_len, 1]
76
+ encoder_input = x * weight
77
+ encoder_inputs.append(encoder_input)
78
+
79
+ return encoder_inputs, routing_weights, load_balance_loss
80
+
81
+ class SwarmAggregator(nn.Module):
82
+ """
83
+ Aggregates outputs from all encoder instances
84
+ This is the NEW component that combines swarm outputs
85
+ """
86
+
87
+ def __init__(self, d_model: int, num_encoders: int):
88
+ super().__init__()
89
+ self.d_model = d_model
90
+ self.num_encoders = num_encoders
91
+
92
+ # Attention-based aggregation
93
+ self.attention = nn.MultiheadAttention(
94
+ embed_dim=d_model,
95
+ num_heads=8,
96
+ batch_first=True
97
+ )
98
+
99
+ # Output processing
100
+ self.norm = RMSNorm(d_model)
101
+ self.output_proj = nn.Linear(d_model, d_model)
102
+
103
+ def forward(self, encoder_outputs: List[torch.Tensor], routing_weights: torch.Tensor) -> torch.Tensor:
104
+ """
105
+ Aggregate encoder outputs using learned attention
106
+
107
+ Args:
108
+ encoder_outputs: List of [batch, seq_len, d_model] tensors
109
+ routing_weights: [batch, seq_len, num_encoders]
110
+
111
+ Returns:
112
+ aggregated: [batch, seq_len, d_model]
113
+ """
114
+ batch_size, seq_len, d_model = encoder_outputs[0].shape
115
+
116
+ # Stack and weight encoder outputs
117
+ stacked = torch.stack(encoder_outputs, dim=2) # [batch, seq_len, num_encoders, d_model]
118
+ routing_expanded = routing_weights.unsqueeze(-1) # [batch, seq_len, num_encoders, 1]
119
+ weighted = stacked * routing_expanded
120
+
121
+ # Initial aggregation
122
+ initial = weighted.sum(dim=2) # [batch, seq_len, d_model]
123
+
124
+ # Attention-based refinement
125
+ encoder_sequence = stacked.view(batch_size, seq_len * self.num_encoders, d_model)
126
+ refined, _ = self.attention(initial, encoder_sequence, encoder_sequence)
127
+
128
+ # Final processing
129
+ output = self.output_proj(refined)
130
+ output = self.norm(output + initial) # Residual connection
131
+
132
+ return output
133
+
134
+ class MambaEncoderSwarmModel(nn.Module):
135
+ """
136
+ Complete Swarm Model using your existing Mamba components
137
+
138
+ Architecture:
139
+ 1. Use your MambaEmbedding for input processing
140
+ 2. NEW: Router distributes tokens to encoder swarm
141
+ 3. Use your MambaLayer instances as shared encoders
142
+ 4. NEW: Aggregator combines encoder outputs
143
+ 5. Use your MambaLayer instances for decoder
144
+ 6. Use your existing LM head for output
145
+ """
146
+
147
+ def __init__(self, config: MambaConfig, num_encoders: int = 8, routing_strategy: str = "learned"):
148
+ super().__init__()
149
+ self.config = config
150
+ self.num_encoders = num_encoders
151
+
152
+ # Use your existing embedding
153
+ self.embedding = MambaEmbedding(config)
154
+
155
+ # NEW: Swarm components
156
+ self.router = SwarmRouter(config.d_model, num_encoders, routing_strategy)
157
+
158
+ # Shared encoder (using your MambaLayer)
159
+ # All encoder instances will use this same layer (weight sharing!)
160
+ self.shared_encoder_layer = MambaLayer(config)
161
+
162
+ # NEW: Aggregator
163
+ self.aggregator = SwarmAggregator(config.d_model, num_encoders)
164
+
165
+ # Decoder layers (using your MambaLayer)
166
+ self.decoder_layers = nn.ModuleList([
167
+ MambaLayer(config) for _ in range(config.n_layers)
168
+ ])
169
+
170
+ # Use your existing components
171
+ self.norm_f = RMSNorm(config.d_model)
172
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
173
+
174
+ # Initialize weights
175
+ self.apply(self._init_weights)
176
+
177
+ def _init_weights(self, module):
178
+ if isinstance(module, nn.Linear):
179
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
180
+ if module.bias is not None:
181
+ nn.init.zeros_(module.bias)
182
+ elif isinstance(module, nn.Embedding):
183
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
184
+
185
+ def forward(self, input_ids: torch.Tensor, targets: torch.Tensor = None):
186
+ """
187
+ Forward pass through swarm architecture
188
+
189
+ Args:
190
+ input_ids: [batch, seq_len]
191
+ targets: [batch, seq_len] (optional, for training)
192
+
193
+ Returns:
194
+ if targets is None: logits [batch, seq_len, vocab_size]
195
+ else: (logits, loss, load_balance_loss)
196
+ """
197
+ # 1. Embedding (using your existing component)
198
+ x = self.embedding(input_ids) # [batch, seq_len, d_model]
199
+
200
+ # 2. Route to encoder swarm
201
+ encoder_inputs, routing_weights, load_balance_loss = self.router(x)
202
+
203
+ # 3. Process through shared encoder instances
204
+ encoder_outputs = []
205
+ for encoder_input in encoder_inputs:
206
+ # Each instance uses the SAME shared_encoder_layer (weight sharing!)
207
+ encoder_output = self.shared_encoder_layer(encoder_input)
208
+ encoder_outputs.append(encoder_output)
209
+
210
+ # 4. Aggregate encoder outputs
211
+ x = self.aggregator(encoder_outputs, routing_weights)
212
+
213
+ # 5. Process through decoder (using your existing layers)
214
+ for decoder_layer in self.decoder_layers:
215
+ x = decoder_layer(x)
216
+
217
+ # 6. Final processing (using your existing components)
218
+ x = self.norm_f(x)
219
+ logits = self.lm_head(x) # [batch, seq_len, vocab_size]
220
+
221
+ if targets is not None:
222
+ # Compute loss
223
+ loss = F.cross_entropy(
224
+ logits.view(-1, logits.size(-1)),
225
+ targets.view(-1),
226
+ ignore_index=-100
227
+ )
228
+ return logits, loss, load_balance_loss
229
+
230
+ return logits
231
+
232
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
233
+ temperature: float = 1.0, top_k: int = None):
234
+ """Generate using swarm architecture"""
235
+ self.eval()
236
+
237
+ for _ in range(max_new_tokens):
238
+ with torch.no_grad():
239
+ logits = self.forward(input_ids)
240
+ logits = logits[:, -1, :] / temperature
241
+
242
+ if top_k is not None:
243
+ v, _ = torch.topk(logits, top_k)
244
+ logits[logits < v[:, [-1]]] = -float('Inf')
245
+
246
+ probs = F.softmax(logits, dim=-1)
247
+ next_token = torch.multinomial(probs, num_samples=1)
248
+ input_ids = torch.cat([input_ids, next_token], dim=1)
249
+
250
+ return input_ids
251
+
252
+ def get_num_params(self):
253
+ """Get number of parameters"""
254
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
255
+
256
+ def create_swarm_from_existing_config(config: MambaConfig, num_encoders: int = 8) -> MambaEncoderSwarmModel:
257
+ """
258
+ Create swarm model using your existing configuration
259
+ """
260
+ swarm_model = MambaEncoderSwarmModel(config, num_encoders, routing_strategy="learned")
261
+
262
+ num_params = swarm_model.get_num_params()
263
+ print(f"🚀 Swarm model created with {num_params:,} parameters ({num_params/1e6:.1f}M)")
264
+ print(f"📊 Using {num_encoders} encoder instances with shared weights")
265
+
266
+ return swarm_model
267
+
268
+ def compare_architectures(config: MambaConfig):
269
+ """
270
+ Compare standard Mamba vs Swarm architecture
271
+ """
272
+ print("🔍 Architecture Comparison")
273
+ print("=" * 50)
274
+
275
+ # Standard model (your existing)
276
+ standard_model = MambaModel(config)
277
+ standard_params = standard_model.get_num_params()
278
+
279
+ # Swarm model (new architecture)
280
+ swarm_model = create_swarm_from_existing_config(config, num_encoders=8)
281
+ swarm_params = swarm_model.get_num_params()
282
+
283
+ print(f"📈 Standard Mamba: {standard_params:,} parameters ({standard_params/1e6:.1f}M)")
284
+ print(f"🔥 Swarm Mamba: {swarm_params:,} parameters ({swarm_params/1e6:.1f}M)")
285
+ print(f"💡 Parameter overhead: {((swarm_params - standard_params) / standard_params * 100):.1f}%")
286
+
287
+ return standard_model, swarm_model
288
+
289
+ if __name__ == "__main__":
290
+ # Test with your existing config
291
+ from core.config import MambaConfig
292
+
293
+ # Create a test config
294
+ config = MambaConfig(
295
+ vocab_size=50257,
296
+ d_model=512,
297
+ n_layers=8,
298
+ d_state=16,
299
+ d_conv=4,
300
+ bias=False
301
+ )
302
+
303
+ print("🧪 Testing Swarm Integration")
304
+ print("=" * 40)
305
+
306
+ # Compare architectures
307
+ standard_model, swarm_model = compare_architectures(config)
308
+
309
+ # Test forward pass
310
+ batch_size, seq_len = 2, 32
311
+ input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
312
+
313
+ # Test standard model
314
+ with torch.no_grad():
315
+ standard_logits = standard_model(input_ids)
316
+ print(f"✅ Standard model output: {standard_logits.shape}")
317
+
318
+ # Test swarm model
319
+ with torch.no_grad():
320
+ swarm_logits = swarm_model(input_ids)
321
+ print(f"✅ Swarm model output: {swarm_logits.shape}")
322
+
323
+ print(f"\n🎉 Both architectures working! Ready to train the swarm.")
core/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/model.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from core.config import MambaConfig
8
+ from core.embedding import MambaEmbedding
9
+ from core.mamba import MambaLayer, RMSNorm
10
+
11
+ class MambaModel(nn.Module):
12
+ def __init__(self, config: MambaConfig):
13
+ super().__init__()
14
+ self.config = config
15
+
16
+ # Embeddings
17
+ self.embedding = MambaEmbedding(config)
18
+
19
+ # Mamba layers
20
+ self.layers = nn.ModuleList([
21
+ MambaLayer(config) for _ in range(config.n_layers)
22
+ ])
23
+
24
+ # Final normalization
25
+ self.norm_f = RMSNorm(config.d_model)
26
+
27
+ # Language modeling head
28
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
29
+
30
+ # Tie weights with embedding if specified
31
+ if hasattr(config, 'tie_word_embeddings') and config.tie_word_embeddings:
32
+ self.lm_head.weight = self.embedding.token_embedding.weight
33
+
34
+ # Initialize weights
35
+ self.apply(self._init_weights)
36
+
37
+ def _init_weights(self, module):
38
+ if isinstance(module, nn.Linear):
39
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
40
+ if module.bias is not None:
41
+ nn.init.zeros_(module.bias)
42
+ elif isinstance(module, nn.Embedding):
43
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
44
+
45
+ def forward(self, input_ids: torch.Tensor, targets: torch.Tensor = None):
46
+ """
47
+ Args:
48
+ input_ids: [batch, seq_len]
49
+ targets: [batch, seq_len] (optional, for training)
50
+ Returns:
51
+ if targets is None: logits [batch, seq_len, vocab_size]
52
+ else: (logits, loss)
53
+ """
54
+ # Get embeddings
55
+ x = self.embedding(input_ids) # [batch, seq_len, d_model]
56
+
57
+ # Apply Mamba layers
58
+ for layer in self.layers:
59
+ x = layer(x)
60
+
61
+ # Final normalization
62
+ x = self.norm_f(x)
63
+
64
+ # Language modeling head
65
+ logits = self.lm_head(x) # [batch, seq_len, vocab_size]
66
+
67
+ if targets is not None:
68
+ # Compute loss
69
+ loss = F.cross_entropy(
70
+ logits.view(-1, logits.size(-1)),
71
+ targets.view(-1),
72
+ ignore_index=-100
73
+ )
74
+ return logits, loss
75
+
76
+ return logits
77
+
78
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
79
+ temperature: float = 1.0, top_k: int = None):
80
+ """Generate text autoregressively"""
81
+ self.eval()
82
+
83
+ for _ in range(max_new_tokens):
84
+ with torch.no_grad():
85
+ # Get logits for last token
86
+ logits = self.forward(input_ids)
87
+ logits = logits[:, -1, :] / temperature
88
+
89
+ # Apply top-k filtering
90
+ if top_k is not None:
91
+ v, _ = torch.topk(logits, top_k)
92
+ logits[logits < v[:, [-1]]] = -float('Inf')
93
+
94
+ # Sample next token
95
+ probs = F.softmax(logits, dim=-1)
96
+ next_token = torch.multinomial(probs, num_samples=1)
97
+
98
+ # Append to sequence
99
+ input_ids = torch.cat([input_ids, next_token], dim=1)
100
+
101
+ return input_ids
102
+
103
+ def get_num_params(self):
104
+ """Get number of parameters"""
105
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
106
+
core/preprocess.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/preprocess.py
3
+ # =============================================================================
4
+ import re
5
+ import unicodedata
6
+ from config import MambaConfig
7
+ from typing import List, Dict, Any
8
+
9
+ class TextPreprocessor:
10
+ def __init__(self, config: MambaConfig):
11
+ self.config = config
12
+ self.max_length = config.max_seq_len
13
+
14
+ def clean_text(self, text: str) -> str:
15
+ """Basic text cleaning"""
16
+ # Normalize unicode
17
+ text = unicodedata.normalize('NFKC', text)
18
+
19
+ # Remove excessive whitespace
20
+ text = re.sub(r'\s+', ' ', text)
21
+
22
+ # Remove control characters except newlines and tabs
23
+ text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
24
+
25
+ return text.strip()
26
+
27
+ def chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
28
+ """Split text into chunks for distributed processing"""
29
+ if chunk_size is None:
30
+ chunk_size = self.max_length // 2
31
+
32
+ words = text.split()
33
+ chunks = []
34
+ current_chunk = []
35
+ current_length = 0
36
+
37
+ for word in words:
38
+ if current_length + len(word) + 1 > chunk_size and current_chunk:
39
+ chunks.append(' '.join(current_chunk))
40
+ current_chunk = [word]
41
+ current_length = len(word)
42
+ else:
43
+ current_chunk.append(word)
44
+ current_length += len(word) + 1
45
+
46
+ if current_chunk:
47
+ chunks.append(' '.join(current_chunk))
48
+
49
+ return chunks
50
+
51
+ def preprocess_batch(self, texts: List[str]) -> List[str]:
52
+ """Preprocess a batch of texts"""
53
+ return [self.clean_text(text) for text in texts]
54
+
core/stateSpace.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/stateSpace.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from utils.selective_scan import selective_scan_fn
8
+
9
+ class StateSpaceModel(nn.Module):
10
+ def __init__(self, d_inner: int, d_state: int = 16, dt_rank: int = None, bias: bool = False):
11
+ super().__init__()
12
+ self.d_inner = d_inner
13
+ self.d_state = d_state
14
+ self.dt_rank = dt_rank if dt_rank is not None else max(16, d_inner // 16)
15
+
16
+ # State space parameters
17
+ self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
18
+ self.D = nn.Parameter(torch.ones(d_inner))
19
+
20
+ # Projection layers
21
+ self.x_proj = nn.Linear(d_inner, self.dt_rank + d_state * 2, bias=False)
22
+ self.dt_proj = nn.Linear(self.dt_rank, d_inner, bias=True)
23
+
24
+ # Initialize parameters
25
+ self._init_parameters()
26
+
27
+ def _init_parameters(self):
28
+ # Initialize A with negative values for stability
29
+ nn.init.uniform_(self.A_log, -4.0, -1.0)
30
+
31
+ # Initialize dt_proj bias to encourage large dt values
32
+ dt_init_std = self.dt_rank**-0.5
33
+ with torch.no_grad():
34
+ self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Args:
39
+ x: [batch, seq_len, d_inner]
40
+ Returns:
41
+ y: [batch, seq_len, d_inner]
42
+ """
43
+ batch_size, seq_len, d_inner = x.shape
44
+
45
+ # Project x to get delta, B, C
46
+ x_dbl = self.x_proj(x) # [batch, seq_len, dt_rank + 2*d_state]
47
+
48
+ delta, B, C = torch.split(
49
+ x_dbl,
50
+ [self.dt_rank, self.d_state, self.d_state],
51
+ dim=-1
52
+ )
53
+
54
+ # Project delta to d_inner
55
+ delta = self.dt_proj(delta) # [batch, seq_len, d_inner]
56
+
57
+ # Get A matrix (ensure it's negative for stability)
58
+ A = -torch.exp(self.A_log) # [d_inner, d_state]
59
+
60
+ # Apply selective scan
61
+ y = selective_scan_fn(
62
+ u=x,
63
+ delta=delta,
64
+ A=A,
65
+ B=B,
66
+ C=C,
67
+ D=self.D,
68
+ delta_softplus=True
69
+ )
70
+
71
+ return y
core/tokenizer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # core/tokenizer.py
3
+ # =============================================================================
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+ from config import MambaConfig
7
+ from typing import List, Dict, Union
8
+
9
+ class MambaTokenizer:
10
+ def __init__(self, config: MambaConfig, tokenizer_name: str = "gpt2"):
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
13
+
14
+ # Add special tokens if needed
15
+ if self.tokenizer.pad_token is None:
16
+ self.tokenizer.pad_token = self.tokenizer.eos_token
17
+
18
+ self.vocab_size = len(self.tokenizer)
19
+
20
+ def encode(self, text: str, max_length: int = None) -> Dict[str, torch.Tensor]:
21
+ """Encode text to token ids"""
22
+ if max_length is None:
23
+ max_length = self.config.max_seq_len
24
+
25
+ encoded = self.tokenizer(
26
+ text,
27
+ max_length=max_length,
28
+ padding="max_length",
29
+ truncation=True,
30
+ return_tensors="pt"
31
+ )
32
+
33
+ return {
34
+ "input_ids": encoded["input_ids"],
35
+ "attention_mask": encoded["attention_mask"]
36
+ }
37
+
38
+ def encode_batch(self, texts: List[str], max_length: int = None) -> Dict[str, torch.Tensor]:
39
+ """Encode batch of texts"""
40
+ if max_length is None:
41
+ max_length = self.config.max_seq_len
42
+
43
+ encoded = self.tokenizer(
44
+ texts,
45
+ max_length=max_length,
46
+ padding="max_length",
47
+ truncation=True,
48
+ return_tensors="pt"
49
+ )
50
+
51
+ return {
52
+ "input_ids": encoded["input_ids"],
53
+ "attention_mask": encoded["attention_mask"]
54
+ }
55
+
56
+ def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
57
+ """Decode token ids to text"""
58
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
59
+
60
+ def decode_batch(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
61
+ """Decode batch of token ids"""
62
+ return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
63
+