Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional, Tuple, MutableMapping | |
| from typing import Union | |
| import math | |
| from contextlib import nullcontext | |
| import torch | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.nn.attention import SDPBackend | |
| from einops import rearrange | |
| from utils import si_module, default, exists, load_ckpt | |
| CACHE_FILL_VALUE = -1 | |
| def get_cache_len(cache: Optional[Tensor]) -> int: | |
| """ | |
| cache: (batch, seq_len, 2, kv_heads, head_dim) | |
| """ | |
| if cache is None: | |
| return 0 | |
| nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1) | |
| length = nonzeros.sum(dim=-1).int() | |
| assert T.all(length == length[0]) | |
| return length[0] | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(x, cos, sin, offset: int = 0): | |
| assert ( | |
| cos.shape[1] >= offset + x.shape[1] | |
| ), f"Offset and/or input sequence is too large,\ | |
| \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}" | |
| cos_out = cos[:, offset : offset + x.shape[1], :, :] | |
| sin_out = sin[:, offset : offset + x.shape[1], :, :] | |
| return (x * cos_out) + (rotate_half(x) * sin_out) | |
| # Adapted from https://github.com/foundation-model-stack/foundation-model-stack | |
| class ShapeRotator: | |
| def __init__( | |
| self, | |
| dim: int, | |
| end: int, | |
| theta: float = 10_000, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.ratio = theta | |
| self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {} | |
| self.max_seq_len_cached: MutableMapping[int, int] = {} | |
| self.ntk_scaling = False | |
| self.max_seq_len = end | |
| def compute_freqs_cis(self, device, max_seq_len=None): | |
| alpha = 1 | |
| dev_idx = device.index | |
| max_seq_len = default(max_seq_len, self.max_seq_len) | |
| if dev_idx not in self.cached_freqs: | |
| self.cached_freqs[dev_idx] = {} | |
| if dev_idx not in self.max_seq_len_cached: | |
| self.max_seq_len_cached[dev_idx] = 0 | |
| if self.max_seq_len_cached[dev_idx] > 0: | |
| return 1 | |
| max_seq_len = max(max_seq_len, self.max_seq_len) | |
| if ( | |
| 1 in self.cached_freqs[dev_idx] | |
| and max_seq_len <= self.max_seq_len_cached[dev_idx] | |
| ): | |
| return 1 | |
| ratio = self.ratio | |
| dim = self.dim | |
| freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim)) | |
| t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype) | |
| freqs = torch.einsum("i,j->ij", t, freqs) | |
| emb = torch.cat((freqs, freqs), dim=-1).to(device) | |
| cos_to_cache = emb.cos()[None, :, None, :] | |
| sin_to_cache = emb.sin()[None, :, None, :] | |
| self.max_seq_len_cached[dev_idx] = max_seq_len | |
| self.cached_freqs[dev_idx][alpha] = torch.stack( | |
| [ | |
| cos_to_cache, | |
| sin_to_cache, | |
| ], | |
| dim=-1, | |
| ) | |
| return alpha | |
| def rotate( | |
| self, | |
| q: Tensor, | |
| k: Tensor, | |
| offset: int = 0, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Args | |
| ---- | |
| q : torch.Tensor | |
| Embedded query tensor, expected size is B x S x H x Eh | |
| k : torch.Tensor | |
| Embedded query tensor, expected size is B x S x H x Eh | |
| """ | |
| assert len(q.size()) == 4 | |
| assert len(k.size()) == 4 | |
| seq_len = self.max_seq_len | |
| alpha = self.compute_freqs_cis(q.device, seq_len) | |
| freqs = self.cached_freqs[q.device.index][alpha] | |
| freqs = freqs.float() # 1 L D/2 2 2 | |
| q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q) | |
| k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k) | |
| return q_out.view_as(q), k_out.view_as(k) | |
| class Linear(nn.Linear): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs, bias=False) | |
| class Norm(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| eps: float = 1e-5,) -> None: | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(T.ones((dim,))) | |
| def forward(self, input: Tensor) -> Tensor: | |
| return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps) | |
| class FFNN(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| expand_dim: int = None,): | |
| super().__init__() | |
| expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256)) | |
| self.dim = dim | |
| self.expand_dim = expand_dim | |
| self.gateup_proj = Linear(dim, 2*expand_dim) | |
| self.down_proj = Linear(expand_dim, dim) | |
| def forward(self, x): | |
| gate, up = self.gateup_proj(x).chunk(2, dim=-1) | |
| return self.down_proj(up * F.silu(gate)) | |
| class GQA(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| n_head: int, | |
| shape_rotator: ShapeRotator, | |
| kv_heads: Optional[int] = None, | |
| eps: float = 1e-5, | |
| causal: bool = True,): | |
| super().__init__() | |
| self.n_heads = n_head | |
| self.kv_heads = default(kv_heads, n_head) | |
| self.head_dim = dim // n_head | |
| self.causal = causal | |
| self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads)) | |
| self.norm_q = Norm(self.head_dim*n_head, eps=eps) | |
| self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps) | |
| self.attn_out = Linear(dim, dim) | |
| self.shape_rotator = shape_rotator | |
| def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
| k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2) | |
| v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2) | |
| with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext(): | |
| x = F.scaled_dot_product_attention( | |
| q.transpose(1, 2), | |
| k.transpose(1, 2), | |
| v.transpose(1, 2), | |
| is_causal=False if (q.size(1) != k.size(1)) else self.causal, | |
| ) | |
| x = x.transpose(1, 2).contiguous() | |
| return x | |
| def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,): | |
| cache_len = get_cache_len(kv_cache) | |
| q, k = self.shape_rotator.rotate(q, k, offset=cache_len) | |
| if exists(kv_cache): | |
| k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1) | |
| v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1) | |
| kv_cache[:, :k.size(1), 0] = k | |
| kv_cache[:, :v.size(1), 1] = v | |
| x = self._sdpa(q, k, v) | |
| return self.attn_out(rearrange(x, 'b s h d -> b s (h d)')) | |
| def _project(self, x): | |
| full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1) | |
| normed_full_q = self.norm_q(full_q).to(full_q.dtype) | |
| normed_full_k = self.norm_k(full_k).to(full_k.dtype) | |
| q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads) | |
| k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads) | |
| v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads) | |
| return q, k, v | |
| def forward(self, | |
| x: Tensor, | |
| kv: Optional[Tensor] = None,): | |
| """ | |
| x: (B, S, D) | |
| kv: (B, S, H, D) | |
| """ | |
| q, k, v = self._project(x) | |
| return self._attend(q, k, v, kv_cache=kv) | |
| class PreNormAttn(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| n_head: int, | |
| shape_rotator: ShapeRotator, | |
| kv_heads: Optional[int] = None, | |
| eps: float = 1e-5, | |
| causal: bool = True,): | |
| super().__init__() | |
| self.attn_norm = Norm(dim, eps=eps) | |
| self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) | |
| def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: | |
| """ | |
| x: (B, S, D) | |
| kv: (B, S, H, D) | |
| """ | |
| return x + self.attn(self.attn_norm(x), kv) | |
| class PreNormFFNN(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| ff_dim: int, | |
| eps: float = 1e-5,): | |
| super().__init__() | |
| self.ffnn_norm = Norm(dim, eps=eps) | |
| self.ffnn = FFNN(dim, ff_dim) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.ffnn(self.ffnn_norm(x)) | |
| class Block(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| layer_id: int = 0, | |
| n_head: int = 16, | |
| kv_heads: Optional[int] = None, | |
| ff_dim: Optional[int] = None, | |
| eps: float = 1e-5, | |
| causal: bool = True, | |
| shape_rotator: ShapeRotator = None): | |
| super().__init__() | |
| self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) | |
| self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps) | |
| self.dim = dim | |
| self.layer_id = layer_id | |
| self.head_dim = dim // n_head | |
| self.expand_dim = self.ffnn.ffnn.expand_dim | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| std = 1.0 / math.sqrt(self.dim) | |
| nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std) | |
| nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std) | |
| nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std) | |
| xstd = 1.0 / math.sqrt(self.expand_dim) | |
| nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd) | |
| def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: | |
| """ | |
| x: (B, S, D) | |
| kv: (B, S, H, D) | |
| """ | |
| h = self.attn(x, kv) | |
| out = self.ffnn(h) | |
| return out | |
| class GPTOutput(nn.Module): | |
| def __init__(self, dim, vocab_size): | |
| super().__init__() | |
| self.dim = dim | |
| self.norm = Norm(dim) | |
| self.output = Linear(dim, vocab_size) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| std = 1.0 / math.sqrt(self.dim**2) | |
| nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) | |
| def forward(self, x): | |
| return self.output(self.norm(x)) | |
| class Stack(nn.Module): | |
| class Config: | |
| layers: int | |
| dim: int | |
| seq_len: int | |
| n_head: int = 32 | |
| ff_dim: int = None | |
| kv_heads: int = None | |
| eps: float = 1e-5 | |
| theta: Union[int, float] = 10_000 | |
| causal: bool = True | |
| from_pretrained: Optional[Tuple[str, int]] = None | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| from_pretrained = c.from_pretrained | |
| if exists(from_pretrained): | |
| checkpoint = load_ckpt(c.from_pretrained) | |
| self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta) | |
| self.layers = nn.ModuleList([ | |
| Block( | |
| dim=c.dim, | |
| layer_id=l, | |
| n_head=c.n_head, | |
| kv_heads=c.kv_heads, | |
| ff_dim=c.ff_dim, | |
| eps=c.eps, | |
| causal=c.causal, | |
| shape_rotator=self.shape_rotator, | |
| ) for l in range(c.layers) | |
| ]) | |
| kv_heads = c.kv_heads or c.n_head | |
| head_dim = c.dim // c.n_head | |
| cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim] | |
| self.cache_shape = cache_shape | |
| self.cache = [None] * c.layers | |
| if exists(from_pretrained): | |
| self.load_state_dict(checkpoint) | |
| def init_cache(self, bsize, device, dtype, length:int=None): | |
| if self.cache_shape is None: | |
| return | |
| cache_shape = self.cache_shape.copy() | |
| cache_shape[1] = length or cache_shape[1] | |
| self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) | |
| def deinit_cache(self): | |
| self.cache = [None] * len(self.cache) | |
| def forward(self, x: Tensor) -> Tensor: | |
| for l, layer in enumerate(self.layers): | |
| x = layer(x, kv=self.cache[l]) | |
| return x |