storyGPT / model.py
Sartc's picture
Create model.py
64d28a7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning
from safetensors.torch import save_file
class Config:
vocab_size = 50304
n_epochs = 50
batch_size = 36
lr = 3e-4
wd = 1e-6
n_embed = 256
num_blocks = 12
num_heads = 12
head_size = n_embed//num_heads
context_len = 224
attn_dropout_val = 0.2
mha_dropout_val = 0.2
ffn_dropout_val = 0.2
class CausalAttentionHead(nn.Module):
def __init__(self, config):
super(CausalAttentionHead, self).__init__()
self.config = config
self.query = nn.Linear(config.n_embed, config.head_size, bias=False)
self.key = nn.Linear(config.n_embed, config.head_size, bias=False)
self.value = nn.Linear(config.n_embed, config.head_size, bias=False)
self.attn_drop = nn.Dropout(config.attn_dropout_val)
# mask for causal attention during training
self.register_buffer("mask", torch.tril(torch.ones(config.context_len, config.context_len)))
def forward(self, x):
bs, context_len, embed_dim = x.shape
q, k, v = self.query(x), self.key(x), self.value(x)
attn_filter = torch.divide(torch.bmm(q, k.transpose(1, 2)), self.config.head_size)
attn_filter = attn_filter.masked_fill(self.mask[:context_len, :context_len]==0, float("-inf"))
attn_weights = F.softmax(attn_filter, dim=-1)
attn_weights = self.attn_drop(attn_weights)
output = torch.bmm(attn_weights, v)
return output
class MultiHeadedAttention(nn.Module):
def __init__(self, config):
super(MultiHeadedAttention, self).__init__()
self.config = config
self.heads = nn.ModuleList(
[CausalAttentionHead(config) for _ in range(config.num_heads)]
)
self.proj = nn.Linear(config.num_heads*config.head_size, config.n_embed)
self.mha_drop = nn.Dropout(config.mha_dropout_val)
def forward(self, x):
mha_output = torch.cat([head(x) for head in self.heads], dim=-1)
return self.mha_drop(self.proj(mha_output))
class FeedForwardNetwork(nn.Module):
def __init__(self, config):
super(FeedForwardNetwork, self).__init__()
self.ffn = nn.Sequential(
nn.Linear(config.n_embed, config.n_embed*4),
nn.GELU(),
nn.Linear(config.n_embed*4, config.n_embed),
nn.Dropout()
)
def forward(self, x):
return self.ffn(x)
class Block(nn.Module):
def __init__(self, config):
super(Block, self).__init__()
self.mha = MultiHeadedAttention(config)
self.ln1 = nn.LayerNorm(config.n_embed)
self.ffn = FeedForwardNetwork(config)
self.ln2 = nn.LayerNorm(config.n_embed)
def forward(self, x):
x = self.ln1(x+self.mha(x))
x = self.ln2(x+self.ffn(x))
return x
class GPT(lightning.LightningModule):
def __init__(self, config):
super(GPT, self).__init__()
self.config = config
self.save_hyperparameters()
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)
self.positional_embedding = nn.Embedding(config.context_len, config.n_embed)
self.backbone = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)])
self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
def forward(self, x):
tok_emb = self.token_embedding(x)
pos_emb = self.positional_embedding(torch.arange(x.shape[1], device=self.device))
x = tok_emb+pos_emb
x = self.backbone(x)
logits = self.lm_head(x)
return logits
def get_loss(self, predictions, target):
B, C, V = predictions.shape
predictions = predictions.view(B*C, V)
target = target.view(B*C)
loss = F.cross_entropy(predictions, target)
return loss
def training_step(self, batch, batch_idx):
text, target = batch
text = text.long()
target = target.long()
logits = self(text)
loss = self.get_loss(logits, target)
self.log('loss', loss.item(), prog_bar=True)
logs = {'loss': loss}
return {"log": logs, "loss": loss}
def training_end(self, outputs):
avg_loss = torch.stack([x['log']['loss'] for x in outputs]).mean()
logs = {"log": avg_loss}
print(f"val_loss: {avg_loss}")
return {"log": logs}
def configure_optimizers(self):
opt = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd)
return [opt], []
if __name__ == "__main__":
config = Config()
gpt = GPT(config)
save_file(gpt, "storyGPT.safetensors")