File size: 2,020 Bytes
f4e648b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import tqdm
from torch.nn import functional as F
from  core.layers import LlamaBlock, RMSNorm

class LlamaLanguageModel(nn.Module):

    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "llama"):
        super().__init__()
        self.name = name
        self.block_size = block_size
        self.device = device
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.Sequential(*[LlamaBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = RMSNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)
        self.history = {}
        self.vocab_size = vocab_size

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        B, T = idx.shape
        kv_cache = None
        token_embeddings = self.token_embedding_table(idx)
        for block in self.blocks:
            token_embeddings = block(token_embeddings, kv_cache)
        token_embeddings = self.ln_f(token_embeddings)
        logits = self.lm_head(token_embeddings)
        return logits, token_embeddings


    def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
        for _ in range(max_new_tokens):
            if idx.size(1) > max_seq_length:
                idx = idx[:, -max_seq_length:]
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            yield idx