QuillGPT / core /models /llama.py
NotShrirang's picture
feat: add application file
f4e648b
raw
history blame
2.02 kB
import torch
import torch.nn as nn
import tqdm
from torch.nn import functional as F
from core.layers import LlamaBlock, RMSNorm
class LlamaLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "llama"):
super().__init__()
self.name = name
self.block_size = block_size
self.device = device
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[LlamaBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
self.ln_f = RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.apply(self._init_weights)
self.history = {}
self.vocab_size = vocab_size
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx):
B, T = idx.shape
kv_cache = None
token_embeddings = self.token_embedding_table(idx)
for block in self.blocks:
token_embeddings = block(token_embeddings, kv_cache)
token_embeddings = self.ln_f(token_embeddings)
logits = self.lm_head(token_embeddings)
return logits, token_embeddings
def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
for _ in range(max_new_tokens):
if idx.size(1) > max_seq_length:
idx = idx[:, -max_seq_length:]
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
yield idx