|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import lightning |
|
from safetensors.torch import save_file |
|
|
|
class Config: |
|
vocab_size = 50304 |
|
n_epochs = 50 |
|
batch_size = 36 |
|
lr = 3e-4 |
|
wd = 1e-6 |
|
n_embed = 256 |
|
num_blocks = 12 |
|
num_heads = 12 |
|
head_size = n_embed//num_heads |
|
context_len = 224 |
|
attn_dropout_val = 0.2 |
|
mha_dropout_val = 0.2 |
|
ffn_dropout_val = 0.2 |
|
|
|
class CausalAttentionHead(nn.Module): |
|
def __init__(self, config): |
|
super(CausalAttentionHead, self).__init__() |
|
self.config = config |
|
|
|
self.query = nn.Linear(config.n_embed, config.head_size, bias=False) |
|
self.key = nn.Linear(config.n_embed, config.head_size, bias=False) |
|
self.value = nn.Linear(config.n_embed, config.head_size, bias=False) |
|
self.attn_drop = nn.Dropout(config.attn_dropout_val) |
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.context_len, config.context_len))) |
|
|
|
def forward(self, x): |
|
bs, context_len, embed_dim = x.shape |
|
q, k, v = self.query(x), self.key(x), self.value(x) |
|
attn_filter = torch.divide(torch.bmm(q, k.transpose(1, 2)), self.config.head_size) |
|
attn_filter = attn_filter.masked_fill(self.mask[:context_len, :context_len]==0, float("-inf")) |
|
attn_weights = F.softmax(attn_filter, dim=-1) |
|
attn_weights = self.attn_drop(attn_weights) |
|
output = torch.bmm(attn_weights, v) |
|
return output |
|
|
|
class MultiHeadedAttention(nn.Module): |
|
def __init__(self, config): |
|
super(MultiHeadedAttention, self).__init__() |
|
self.config = config |
|
self.heads = nn.ModuleList( |
|
[CausalAttentionHead(config) for _ in range(config.num_heads)] |
|
) |
|
self.proj = nn.Linear(config.num_heads*config.head_size, config.n_embed) |
|
self.mha_drop = nn.Dropout(config.mha_dropout_val) |
|
|
|
def forward(self, x): |
|
mha_output = torch.cat([head(x) for head in self.heads], dim=-1) |
|
return self.mha_drop(self.proj(mha_output)) |
|
|
|
class FeedForwardNetwork(nn.Module): |
|
def __init__(self, config): |
|
super(FeedForwardNetwork, self).__init__() |
|
|
|
self.ffn = nn.Sequential( |
|
nn.Linear(config.n_embed, config.n_embed*4), |
|
nn.GELU(), |
|
nn.Linear(config.n_embed*4, config.n_embed), |
|
nn.Dropout() |
|
) |
|
def forward(self, x): |
|
return self.ffn(x) |
|
|
|
class Block(nn.Module): |
|
def __init__(self, config): |
|
super(Block, self).__init__() |
|
self.mha = MultiHeadedAttention(config) |
|
self.ln1 = nn.LayerNorm(config.n_embed) |
|
self.ffn = FeedForwardNetwork(config) |
|
self.ln2 = nn.LayerNorm(config.n_embed) |
|
|
|
def forward(self, x): |
|
x = self.ln1(x+self.mha(x)) |
|
x = self.ln2(x+self.ffn(x)) |
|
return x |
|
|
|
class GPT(lightning.LightningModule): |
|
def __init__(self, config): |
|
super(GPT, self).__init__() |
|
self.config = config |
|
self.save_hyperparameters() |
|
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed) |
|
self.positional_embedding = nn.Embedding(config.context_len, config.n_embed) |
|
self.backbone = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)]) |
|
self.lm_head = nn.Linear(config.n_embed, config.vocab_size) |
|
|
|
def forward(self, x): |
|
tok_emb = self.token_embedding(x) |
|
pos_emb = self.positional_embedding(torch.arange(x.shape[1], device=self.device)) |
|
x = tok_emb+pos_emb |
|
x = self.backbone(x) |
|
logits = self.lm_head(x) |
|
return logits |
|
|
|
def get_loss(self, predictions, target): |
|
B, C, V = predictions.shape |
|
predictions = predictions.view(B*C, V) |
|
target = target.view(B*C) |
|
loss = F.cross_entropy(predictions, target) |
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
text, target = batch |
|
text = text.long() |
|
target = target.long() |
|
logits = self(text) |
|
loss = self.get_loss(logits, target) |
|
|
|
self.log('loss', loss.item(), prog_bar=True) |
|
logs = {'loss': loss} |
|
|
|
return {"log": logs, "loss": loss} |
|
|
|
def training_end(self, outputs): |
|
avg_loss = torch.stack([x['log']['loss'] for x in outputs]).mean() |
|
logs = {"log": avg_loss} |
|
print(f"val_loss: {avg_loss}") |
|
return {"log": logs} |
|
|
|
def configure_optimizers(self): |
|
opt = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd) |
|
return [opt], [] |
|
|
|
if __name__ == "__main__": |
|
config = Config() |
|
gpt = GPT(config) |
|
save_file(gpt, "storyGPT.safetensors") |