Spaces:
Running
Running
| # ============================================================================= | |
| # core/model.py | |
| # ============================================================================= | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from core.config import MambaConfig | |
| from core.embedding import MambaEmbedding | |
| from core.mamba import MambaLayer, RMSNorm | |
| class MambaModel(nn.Module): | |
| def __init__(self, config: MambaConfig): | |
| super().__init__() | |
| self.config = config | |
| # Embeddings | |
| self.embedding = MambaEmbedding(config) | |
| # Mamba layers | |
| self.layers = nn.ModuleList([ | |
| MambaLayer(config) for _ in range(config.n_layers) | |
| ]) | |
| # Final normalization | |
| self.norm_f = RMSNorm(config.d_model) | |
| # Language modeling head | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Tie weights with embedding if specified | |
| if hasattr(config, 'tie_word_embeddings') and config.tie_word_embeddings: | |
| self.lm_head.weight = self.embedding.token_embedding.weight | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, input_ids: torch.Tensor, targets: torch.Tensor = None): | |
| """ | |
| Args: | |
| input_ids: [batch, seq_len] | |
| targets: [batch, seq_len] (optional, for training) | |
| Returns: | |
| if targets is None: logits [batch, seq_len, vocab_size] | |
| else: (logits, loss) | |
| """ | |
| # Get embeddings | |
| x = self.embedding(input_ids) # [batch, seq_len, d_model] | |
| # Apply Mamba layers | |
| for layer in self.layers: | |
| x = layer(x) | |
| # Final normalization | |
| x = self.norm_f(x) | |
| # Language modeling head | |
| logits = self.lm_head(x) # [batch, seq_len, vocab_size] | |
| if targets is not None: | |
| # Compute loss | |
| loss = F.cross_entropy( | |
| logits.view(-1, logits.size(-1)), | |
| targets.view(-1), | |
| ignore_index=-100 | |
| ) | |
| return logits, loss | |
| return logits | |
| def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, | |
| temperature: float = 1.0, top_k: int = None): | |
| """Generate text autoregressively""" | |
| self.eval() | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| # Get logits for last token | |
| logits = self.forward(input_ids) | |
| logits = logits[:, -1, :] / temperature | |
| # Apply top-k filtering | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, top_k) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| # Sample next token | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Append to sequence | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| return input_ids | |
| def get_num_params(self): | |
| """Get number of parameters""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |