Spaces:
Running
Running
File size: 4,253 Bytes
b999262 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# one head of self-attention using scaled-dot product attention
class Head(nn.Module):
def __init__(self, n_embed, head_size, context_size, dropout=0.1):
super().__init__()
self.key = nn.Linear(n_embed, head_size, bias=False)
self.query = nn.Linear(n_embed, head_size, bias=False)
self.value = nn.Linear(n_embed, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
tril = torch.tril(torch.ones(T, T, device=device))
wei = q @ k.transpose(-2, -1) * (C**-0.5)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embed, num_heads, context_size, head_size, dropout):
super().__init__()
self.heads = nn.ModuleList([
Head(n_embed, head_size, context_size)
for _ in range(num_heads)
])
self.projection = nn.Linear(n_embed, n_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.projection(out)
return self.dropout(out)
# simple feed forward layer
class FeedForward(nn.Module):
def __init__(self, n_embeds, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embeds, 4 * n_embeds),
nn.ReLU(),
# projection layer
nn.Linear(4 * n_embeds, n_embeds),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Transformer block
class Block(nn.Module):
def __init__(self, n_embeds, n_head, context_size, dropout):
super().__init__()
head_size = n_embeds // n_head
self.sa = MultiHeadAttention(n_embeds, n_head, context_size, head_size, dropout)
self.ffwd = FeedForward(n_embeds, dropout)
self.ln1 = nn.LayerNorm(n_embeds)
self.ln2 = nn.LayerNorm(n_embeds)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
# simple bigram model
class DecoderTransformer(nn.Module):
def __init__(self, vocab_size, n_embed, context_size, n_layer, n_head, dropout):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
self.position_embedding_table = nn.Embedding(context_size, n_embed)
self.blocks = nn.Sequential(
*[Block(
n_embeds=n_embed,
n_head=n_head,
context_size=context_size,
dropout=dropout
) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embed)
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets of size (B,T)
token_embeds = self.token_embedding_table(idx) # yields (B, T, C)
pos_embeds = self.position_embedding_table(torch.arange(T, device=device))
x = token_embeds + pos_embeds
x = self.ln_f(self.blocks(x))
logits = self.lm_head(x)
if targets is None:
return logits, None
# reshape elements
B, T, C = logits.shape
logits = logits.view(B*T,C)
targets = targets.view(B*T)
# compute loss (CE)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens, context_size):
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx
|