Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Literal | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
def gelu_approx(x): | |
return F.gelu(x, approximate="tanh") | |
class LinearWeights: | |
weight: torch.Tensor | |
bias: torch.Tensor | |
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: | |
return F.linear(x, w.weight, w.bias) | |
class LayerNormWeights: | |
weight: torch.Tensor | |
bias: torch.Tensor | |
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: | |
return F.layer_norm(x, w.bias.shape, w.weight, w.bias) | |
class MLPWeights: | |
fc1: LinearWeights | |
fc2: LinearWeights | |
act: Literal["gelu_approx"] = "gelu_approx" | |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: | |
x = linear(x, w.fc1) | |
if w.act == "gelu_approx": | |
x = gelu_approx(x) | |
else: | |
raise NotImplementedError(f"Activation function {w.act} not implemented.") | |
x = linear(x, w.fc2) | |
return x | |
class AttentionWeights: | |
qkv: LinearWeights | |
proj: LinearWeights | |
n_heads: int | |
def attn(x: torch.Tensor, w: AttentionWeights) -> torch.Tensor: | |
bsz, q_len, d_model = x.shape | |
n_heads, head_dim = w.n_heads, d_model // w.n_heads | |
q, k, v = [ | |
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) | |
for t in linear(x, w.qkv).chunk(3, dim=-1) | |
] | |
out = F.scaled_dot_product_attention(q, k, v) | |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model) | |
out = linear(out, w.proj) | |
return out | |