File size: 5,339 Bytes
be3a39d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import math
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.file_utils import add_start_docstrings_to_model_forward
from .configuration_spt import SPTConfig
def repeat_kv(hidden_states, repeat_times):
if repeat_times == 1:
return hidden_states
batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:,:,None,:,:].expand(batch, n_kv_heads, repeat_times, seq_len, head_dim)
return hidden_states.reshape(batch, n_kv_heads*repeat_times, seq_len, head_dim)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.head_dim = config.hidden_size // config.n_attn_heads
kv_size = config.n_kv_heads * self.head_dim
self.hidden_size = config.hidden_size
self.n_attn_heads = config.n_attn_heads
self.n_kv_heads = config.n_kv_heads
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k = nn.Linear(config.hidden_size, kv_size, bias=False)
self.v = nn.Linear(config.hidden_size, kv_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(config.max_len, config.max_len)).view(1, 1, config.max_len, config.max_len))
def forward(self, x):
batch_size, seq_len, hidden_dim = x.shape
q = self.q(x)
k = self.k(x)
v = self.v(x)
q = q.view(batch_size, seq_len, self.n_attn_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
k = repeat_kv(k, self.n_attn_heads//self.n_kv_heads)
v = repeat_kv(v, self.n_attn_heads//self.n_kv_heads)
attention = (q @ k.transpose(-2,-1)) * (1.0/math.sqrt(self.hidden_size))
attention = attention.masked_fill(self.tril[:,:,:seq_len,:seq_len]==0, float('-inf'))
probs = nn.functional.softmax(attention,dim=-1)
y = probs@v
y = y.transpose(1,2).contiguous().reshape(batch_size, seq_len, -1)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = nn.GELU()
def forward(self,x):
up = self.up(x)
gate = self.gate(x)
return self.down(self.act_fn(up * gate))
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = Attention(config)
self.mlp = MLP(config)
self.residual = config.residual
self.norm = RMSNorm(config.hidden_size) if config.normalise else nn.Identity()
def forward(self, x):
if self.residual:
x = x + self.attn(self.norm(x))
x = x + self.mlp(self.norm(x))
else:
x = self.attn(self.norm(x))
x = self.mlp(self.norm(x))
return x
class SPTModel(PreTrainedModel):
config_class = SPTConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.hidden_size) if config.normalise else nn.Identity()
def forward(self, input_ids):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return x
class SPTForCausalLM(PreTrainedModel):
config_class = SPTConfig
def __init__(self, config):
super().__init__(config)
self.model = SPTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids, labels=None):
x = self.model(input_ids)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
hidden_states=x,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
@staticmethod
def _reorder_cache(past, beam_idx):
return past
# Register the custom model
AutoModelForCausalLM.register(SPTConfig, SPTForCausalLM) |