Spaces:
Running
on
T4
Running
on
T4
| # Copyright (c) NXAI GmbH. | |
| # This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
| import os | |
| from dataclasses import dataclass, field | |
| import torch | |
| from torch import nn | |
| from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig | |
| from xlstm.xlstm_large import xLSTMLargeConfig | |
| from xlstm.xlstm_large.components import RMSNorm | |
| from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType | |
| def skip_cuda(): | |
| return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t") | |
| def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks): | |
| return sLSTMLayer( | |
| sLSTMLayerConfig( | |
| embedding_dim=config.embedding_dim, | |
| num_heads=config.num_heads, | |
| conv1d_kernel_size=0, # 0 means no convolution included | |
| group_norm_weight=True, | |
| dropout=0, | |
| # CellConfig | |
| backend="vanilla" if skip_cuda() else "cuda", | |
| bias_init="powerlaw_blockdependent", | |
| recurrent_weight_init="zeros", | |
| num_gates=4, | |
| gradient_recurrent_cut=False, | |
| gradient_recurrent_clipval=None, | |
| forward_clipval=None, | |
| batch_size=8, # needed? | |
| _block_idx=block_idx, | |
| _num_blocks=num_blocks, | |
| ) | |
| ) | |
| sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor] | |
| sLSTMStateType = dict[int, sLSTMLayerStateType] | |
| class sLSTMBlock(nn.Module): | |
| def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int): | |
| super().__init__() | |
| self.config = config | |
| self.norm_slstm = RMSNorm( | |
| num_features=config.embedding_dim, | |
| eps=config.norm_eps, | |
| use_weight=True, | |
| use_bias=config.use_bias, | |
| force_float32_reductions=config.norm_reduction_force_float32, | |
| ) | |
| self.slstm_layer = init_cell(config, block_idx, num_blocks) | |
| self.norm_ffn = RMSNorm( | |
| num_features=config.embedding_dim, | |
| eps=config.norm_eps, | |
| use_weight=True, | |
| use_bias=config.use_bias, | |
| force_float32_reductions=config.norm_reduction_force_float32, | |
| ) | |
| self.ffn = FeedForward(config) | |
| def forward( | |
| self, x: torch.Tensor, state: sLSTMLayerStateType | None = None | |
| ) -> tuple[torch.Tensor, sLSTMLayerStateType]: | |
| x_slstm = self.norm_slstm(x) | |
| if state is None: | |
| conv_state, slstm_state = None, None | |
| else: | |
| conv_state, slstm_state = state | |
| x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True) | |
| x = x + x_slstm | |
| x_ffn = self.norm_ffn(x) | |
| x_ffn = self.ffn(x_ffn) | |
| x = x + x_ffn | |
| return x, (state["conv_state"], state["slstm_state"]) | |
| class xLSTMMixedLargeConfig(xLSTMLargeConfig): | |
| slstm_at: list[int] = field(default_factory=list) | |
| all_slstm: bool = True | |
| def block_types(self): | |
| return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)] | |
| class xLSTMMixedLargeBlockStack(nn.Module): | |
| config_class = xLSTMMixedLargeConfig | |
| def __init__(self, config: xLSTMMixedLargeConfig): | |
| super().__init__() | |
| self.config = config | |
| self.blocks = nn.ModuleList( | |
| [ | |
| sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config) | |
| for i, t in enumerate(config.block_types) | |
| ] | |
| ) | |
| if self.config.add_out_norm: | |
| self.out_norm = RMSNorm( | |
| num_features=config.embedding_dim, | |
| eps=config.norm_eps, | |
| use_weight=True, | |
| use_bias=config.use_bias, | |
| force_float32_reductions=config.norm_reduction_force_float32, | |
| ) | |
| else: | |
| self.out_norm = nn.Identity() | |
| def forward( | |
| self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None | |
| ) -> tuple[torch.Tensor, mLSTMStateType]: | |
| if state is None: | |
| state = {i: None for i in range(len(self.blocks))} | |
| for i, block in enumerate(self.blocks): | |
| block_state = state[i] | |
| x, block_state_new = block(x, block_state) | |
| if block_state is None: | |
| state[i] = block_state_new | |
| else: | |
| pass | |
| ## layer state is a tuple of three tensors: c, n, m | |
| ## we update the state in place in order to avoid creating new tensors | |
| # for state_idx in range(len(block_state)): | |
| # state[i][state_idx].copy_(block_state_new[state_idx]) | |
| x = self.out_norm(x) | |
| return x, state | |