Sartc commited on
Commit
75ae50f
·
verified ·
1 Parent(s): 2213b69

Delete model

Browse files
Files changed (1) hide show
  1. model/py +0 -134
model/py DELETED
@@ -1,134 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import lightning
5
- from safetensors.torch import save_file
6
-
7
- class Config:
8
- vocab_size = 50304
9
- n_epochs = 50
10
- batch_size = 36
11
- lr = 3e-4
12
- wd = 1e-6
13
- n_embed = 256
14
- num_blocks = 12
15
- num_heads = 12
16
- head_size = n_embed//num_heads
17
- context_len = 224
18
- attn_dropout_val = 0.2
19
- mha_dropout_val = 0.2
20
- ffn_dropout_val = 0.2
21
-
22
- class CausalAttentionHead(nn.Module):
23
- def __init__(self, config):
24
- super(CausalAttentionHead, self).__init__()
25
- self.config = config
26
-
27
- self.query = nn.Linear(config.n_embed, config.head_size, bias=False)
28
- self.key = nn.Linear(config.n_embed, config.head_size, bias=False)
29
- self.value = nn.Linear(config.n_embed, config.head_size, bias=False)
30
- self.attn_drop = nn.Dropout(config.attn_dropout_val)
31
- # mask for causal attention during training
32
- self.register_buffer("mask", torch.tril(torch.ones(config.context_len, config.context_len)))
33
-
34
- def forward(self, x):
35
- bs, context_len, embed_dim = x.shape
36
- q, k, v = self.query(x), self.key(x), self.value(x)
37
- attn_filter = torch.divide(torch.bmm(q, k.transpose(1, 2)), self.config.head_size)
38
- attn_filter = attn_filter.masked_fill(self.mask[:context_len, :context_len]==0, float("-inf"))
39
- attn_weights = F.softmax(attn_filter, dim=-1)
40
- attn_weights = self.attn_drop(attn_weights)
41
- output = torch.bmm(attn_weights, v)
42
- return output
43
-
44
- class MultiHeadedAttention(nn.Module):
45
- def __init__(self, config):
46
- super(MultiHeadedAttention, self).__init__()
47
- self.config = config
48
- self.heads = nn.ModuleList(
49
- [CausalAttentionHead(config) for _ in range(config.num_heads)]
50
- )
51
- self.proj = nn.Linear(config.num_heads*config.head_size, config.n_embed)
52
- self.mha_drop = nn.Dropout(config.mha_dropout_val)
53
-
54
- def forward(self, x):
55
- mha_output = torch.cat([head(x) for head in self.heads], dim=-1)
56
- return self.mha_drop(self.proj(mha_output))
57
-
58
- class FeedForwardNetwork(nn.Module):
59
- def __init__(self, config):
60
- super(FeedForwardNetwork, self).__init__()
61
-
62
- self.ffn = nn.Sequential(
63
- nn.Linear(config.n_embed, config.n_embed*4),
64
- nn.GELU(),
65
- nn.Linear(config.n_embed*4, config.n_embed),
66
- nn.Dropout()
67
- )
68
- def forward(self, x):
69
- return self.ffn(x)
70
-
71
- class Block(nn.Module):
72
- def __init__(self, config):
73
- super(Block, self).__init__()
74
- self.mha = MultiHeadedAttention(config)
75
- self.ln1 = nn.LayerNorm(config.n_embed)
76
- self.ffn = FeedForwardNetwork(config)
77
- self.ln2 = nn.LayerNorm(config.n_embed)
78
-
79
- def forward(self, x):
80
- x = self.ln1(x+self.mha(x))
81
- x = self.ln2(x+self.ffn(x))
82
- return x
83
-
84
- class GPT(lightning.LightningModule):
85
- def __init__(self, config):
86
- super(GPT, self).__init__()
87
- self.config = config
88
- self.save_hyperparameters()
89
- self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)
90
- self.positional_embedding = nn.Embedding(config.context_len, config.n_embed)
91
- self.backbone = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)])
92
- self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
93
-
94
- def forward(self, x):
95
- tok_emb = self.token_embedding(x)
96
- pos_emb = self.positional_embedding(torch.arange(x.shape[1], device=self.device))
97
- x = tok_emb+pos_emb
98
- x = self.backbone(x)
99
- logits = self.lm_head(x)
100
- return logits
101
-
102
- def get_loss(self, predictions, target):
103
- B, C, V = predictions.shape
104
- predictions = predictions.view(B*C, V)
105
- target = target.view(B*C)
106
- loss = F.cross_entropy(predictions, target)
107
- return loss
108
-
109
- def training_step(self, batch, batch_idx):
110
- text, target = batch
111
- text = text.long()
112
- target = target.long()
113
- logits = self(text)
114
- loss = self.get_loss(logits, target)
115
-
116
- self.log('loss', loss.item(), prog_bar=True)
117
- logs = {'loss': loss}
118
-
119
- return {"log": logs, "loss": loss}
120
-
121
- def training_end(self, outputs):
122
- avg_loss = torch.stack([x['log']['loss'] for x in outputs]).mean()
123
- logs = {"log": avg_loss}
124
- print(f"val_loss: {avg_loss}")
125
- return {"log": logs}
126
-
127
- def configure_optimizers(self):
128
- opt = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd)
129
- return [opt], []
130
-
131
- if __name__ == "__main__":
132
- config = Config()
133
- gpt = GPT(config)
134
- save_file(gpt, "storyGPT.safetensors")