|
import os |
|
import json |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, List, Tuple |
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
|
|
|
def norm(x): |
|
return F.rms_norm(x, (x.size(-1),)) |
|
|
|
class CastedLinear(nn.Linear): |
|
def __init__(self, in_features, out_features): |
|
super().__init__(in_features, out_features, bias=False) |
|
@torch.inference_mode() |
|
def forward(self, x): |
|
return F.linear(x, self.weight.type_as(x)) |
|
|
|
class Rotary(nn.Module): |
|
def __init__(self, dim, max_seq_len=65536): |
|
super().__init__() |
|
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) |
|
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) |
|
t = torch.arange(max_seq_len, dtype=torch.float32) |
|
theta = torch.einsum('i,j -> ij', t, angular_freq) |
|
self.register_buffer('cos', theta.cos(), persistent=False) |
|
self.register_buffer('sin', theta.sin(), persistent=False) |
|
@torch.inference_mode() |
|
def forward(self, x): |
|
cos, sin = self.cos[None, :x.size(-3), None, :], self.sin[None, :x.size(-3), None, :] |
|
x1, x2 = x.float().chunk(2, dim=-1) |
|
y1 = x1 * cos + x2 * sin |
|
y2 = x1 * (-sin) + x2 * cos |
|
return torch.cat((y1, y2), 3).type_as(x) |
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, dim, num_heads): |
|
super().__init__() |
|
assert dim % num_heads == 0 |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.c_q = CastedLinear(dim, dim) |
|
self.c_k = CastedLinear(dim, dim) |
|
self.c_v = CastedLinear(dim, dim) |
|
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) |
|
self.rotary = Rotary(self.head_dim) |
|
self.c_proj = CastedLinear(dim, dim) |
|
self.register_buffer('kv_cache', None, persistent=False) |
|
@torch.inference_mode() |
|
def forward(self, x, ve): |
|
B, T = x.size(0), x.size(1) |
|
q = self.c_q(x).view(B, T, self.num_heads, self.head_dim) |
|
k = self.c_k(x).view(B, T, self.num_heads, self.head_dim) |
|
v = self.c_v(x).view(B, T, self.num_heads, self.head_dim) |
|
if ve is not None: |
|
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) |
|
else: |
|
v = self.lambdas[0] * v |
|
q, k = norm(q), norm(k) |
|
q, k = self.rotary(q), self.rotary(k) |
|
if self.kv_cache is not None: |
|
k = torch.cat([self.kv_cache[0], k], dim=1) |
|
v = torch.cat([self.kv_cache[1], v], dim=1) |
|
self.kv_cache = torch.stack([k, v]) |
|
if hasattr(F, 'scaled_dot_product_attention'): |
|
y = F.scaled_dot_product_attention( |
|
q.transpose(1, 2), |
|
k.transpose(1, 2), |
|
v.transpose(1, 2), |
|
is_causal=True |
|
) |
|
else: |
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
|
att = att.masked_fill( |
|
torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool(), |
|
float('-inf') |
|
) |
|
att = F.softmax(att, dim=-1) |
|
y = att @ v |
|
y = y.transpose(1, 2).contiguous().view(B, T, -1) |
|
y = self.c_proj(y) |
|
return y |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.c_fc = CastedLinear(dim, 4 * dim) |
|
self.c_proj = CastedLinear(4 * dim, dim) |
|
self.c_proj.weight.data.zero_() |
|
@torch.inference_mode() |
|
def forward(self, x): |
|
x = self.c_fc(x) |
|
x = F.relu(x).square() |
|
x = self.c_proj(x) |
|
return x |
|
|
|
class Block(nn.Module): |
|
def __init__(self, model_dim, num_heads, use_attn=True): |
|
super().__init__() |
|
self.attn = CausalSelfAttention(model_dim, num_heads) if use_attn else None |
|
self.mlp = MLP(model_dim) |
|
self.lambdas = nn.Parameter(torch.tensor([1., 0.])) |
|
@torch.inference_mode() |
|
def forward(self, x, ve, x0): |
|
x = self.lambdas[0] * x + self.lambdas[1] * x0 |
|
if self.attn is not None: |
|
x = x + self.attn(norm(x), ve) |
|
x = x + self.mlp(norm(x)) |
|
return x |
|
|
|
class ValueEmbedding(nn.Module): |
|
def __init__(self, vocab_size, model_dim): |
|
super().__init__() |
|
self.embed = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) |
|
@torch.inference_mode() |
|
def forward(self, inputs): |
|
ve = [emb(inputs).bfloat16() for emb in self.embed] |
|
ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]] |
|
return ve |
|
|
|
class ChronoGPT(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs): |
|
super().__init__() |
|
|
|
self.num_heads = num_heads |
|
self.vocab_size = vocab_size |
|
self.embed = nn.Embedding(vocab_size, model_dim) |
|
self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=(i != 7)) |
|
for i in range(num_layers)]) |
|
self.value_embeds = ValueEmbedding(vocab_size, model_dim) |
|
self.lm_head = CastedLinear(model_dim, vocab_size) |
|
self.lm_head.weight.data.zero_() |
|
self.num_encoder_layers = num_layers // 2 |
|
self.num_decoder_layers = num_layers - self.num_encoder_layers |
|
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) |
|
@torch.inference_mode() |
|
def forward(self, inputs, past_key_values=None): |
|
B = inputs.size(0) |
|
if inputs.dim() == 1: |
|
inputs = inputs.unsqueeze(0) |
|
layer_outputs = [] |
|
x0 = norm(self.embed(inputs).bfloat16()) |
|
x = x0 |
|
layer_outputs.append(norm(x)) |
|
ve = [self.value_embeds(inputs[i].view(-1)) for i in range(B)] |
|
ve = [torch.stack([ve[b][i] for b in range(B)]) if ve[0][i] is not None else None |
|
for i in range(len(ve[0]))] |
|
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] |
|
if past_key_values is not None: |
|
for i, block in enumerate(self.blocks): |
|
if block.attn is not None: |
|
block.attn.kv_cache = past_key_values[i] |
|
present = [] |
|
skip_connections = [] |
|
for i in range(self.num_encoder_layers): |
|
block = self.blocks[i] |
|
x = block(x, ve_enc[i], x0) |
|
if block.attn is not None: |
|
present.append(block.attn.kv_cache) |
|
block.attn.kv_cache = None |
|
skip_connections.append(x) |
|
layer_outputs.append(norm(x)) |
|
for i in range(self.num_decoder_layers): |
|
x = x + self.skip_weights[i] * skip_connections.pop() |
|
block = self.blocks[self.num_encoder_layers + i] |
|
x = block(x, ve_dec[i], x0) |
|
layer_outputs.append(norm(x)) |
|
if block.attn is not None: |
|
present.append(block.attn.kv_cache) |
|
block.attn.kv_cache = None |
|
x = norm(x) |
|
logits = self.lm_head(x) |
|
logits = 15 * torch.tanh(logits / 15) |
|
return logits.float(), layer_outputs |
|
def save_pretrained(self, save_directory, **kwargs): |
|
os.makedirs(save_directory, exist_ok=True) |
|
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
config = { |
|
"model_type": "ChronoGPT", |
|
"vocab_size": self.embed.num_embeddings, |
|
"num_layers": len(self.blocks), |
|
"num_heads": self.num_heads, |
|
"model_dim": self.embed.embedding_dim |
|
} |
|
torch.save(config, os.path.join(save_directory, "config.pt")) |
|
with open(os.path.join(save_directory, "config.json"), "w") as f: |
|
json.dump(config, f) |
|
@classmethod |
|
def from_pretrained(cls, repo_id, cache_dir=None, **kwargs): |
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.pt", cache_dir=cache_dir) |
|
bin_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir) |
|
config = torch.load(config_path) |
|
model = cls(**config) |
|
model.load_state_dict(torch.load(bin_path)) |
|
return model |