Spaces:
Sleeping
Sleeping
# app.py (second app by claude) | |
import gradio as gr | |
import torch | |
# ==================================================================================- | |
# inference code part-1 | |
# ------------------------------------------------- | |
# 1. Download the model weights (from huggingface) | |
# ------------------------------------------------- | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="Aananda-giri/LLAMA3-Nepali", filename="parameters_300m/model_pg_398000_steps.pth", local_dir="./") | |
# ---------------------- | |
# 2. Load The tokenizer | |
# ---------------------- | |
from transformers import PreTrainedTokenizerFast | |
# Load the tokenizer | |
tokenizer = PreTrainedTokenizerFast.from_pretrained("Aananda-giri/LLAMA3-Nepali") | |
tokenizer.save_pretrained("NepaliBPE") | |
# Llama 3.2 ~300M Scaled Version | |
LLAMA32_CONFIG = { | |
"vocab_size": 50006, # <len(tokenizer.tokenizer)=50006> 128_256 reduced vocabulary size | |
"context_length": 512, # 131_072 reduced Context length (unrelated to model size but higheer context length consumes more RAM) | |
"emb_dim": 1320, # 2048 reduced Embedding dimension | |
"n_heads": 20, # 32 reduced Number of attention heads | |
"n_layers": 10, # 16 reduced Number of layers | |
"hidden_dim": 5280, # 8192 Size of the intermediate dimension in FeedForward | |
"n_kv_groups": 5, # 8 Key-Value groups for grouped-query attention | |
"rope_base": 500_000.0, # 500_000 The base in RoPE's "theta" | |
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage | |
"rope_freq": { # RoPE frequency scaling | |
"factor": 32.0, | |
"low_freq_factor": 1.0, | |
"high_freq_factor": 4.0, | |
"original_context_length": 8192, | |
} | |
} | |
# ==================================================================================- | |
# ==================================================================================- | |
# ==================================================================================- | |
# ==================================================================================- | |
# could not make importing from previous_chapters.py work, so i copied the all code | |
# from previous_chapters.py here | |
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | |
# Source for "Build a Large Language Model From Scratch" | |
# - https://www.manning.com/books/build-a-large-language-model-from-scratch | |
# Code: https://github.com/rasbt/LLMs-from-scratch | |
# This file collects all the relevant code that we covered thus far | |
# throughout Chapters 2-4. | |
# This file can be run as a standalone script. | |
import json | |
# modified from `import tiktoken` | |
from transformers import PreTrainedTokenizerFast | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import MaxNLocator | |
# modified. added for create_dataloader_v2 | |
from datasets import load_dataset | |
##################################### | |
# 1. Dataloader | |
##################################### | |
def create_dataloader_v3(batch_size, shuffle=True, drop_last=True, num_workers=0): | |
''' | |
modified. | |
* parameter: text removed | |
* parameters: max_length and stride removed : they were set during preparing tokenized_datasets | |
* parameter: context_length removed (as dataset is pre-tokenized) | |
''' | |
print('downloading dataset...') | |
# Download the whole dataset | |
base_url = "https://huggingface.co/datasets/Aananda-giri/nepali_llm_datasets/resolve/main/pre_tokenized/" | |
# data_files = {"train": base_url + "nepberta_" + str(context_length) + ".parquet"} | |
# previous version: stride = .75*512, context_len.512 | |
# data_files = { | |
# "train": base_url + "iriisnepal_u_nepberta_train_512.parquet", | |
# "test": base_url + "iriisnepal_u_nepberta_test_512.parquet" | |
# } | |
# context_len.512, stride=512 | |
data_files={"train": base_url + "iriis_u_nepbert_512_512_train.parquet", "validation": base_url + "iriis_u_nepbert_512_512_test.parquet"} | |
dataset = load_dataset("parquet", data_files=data_files, cache_dir='hf_cache', streaming=True) | |
print(dataset) | |
# and split it later | |
# dataset = dataset.train_test_split(train_size=train_ratio, seed=42) | |
# Convert Hugging Face Dataset to PyTorch tensors (we can directly use the dataset as it is already in the correct format) | |
# dataset.set_format(type="torch", columns=['input_ids,target_ids']) # Directly set columns to torch tensors | |
# Define the custom collate_fn function | |
def collate_fn(batch): | |
# Extract the 'input_ids' and 'target_ids' from the batch and return them as a list of tensors | |
input_ids = [] | |
target_ids = [] | |
for data_item in batch: | |
splitted_data_item = data_item['input_ids,target_ids'].split("\",") | |
input_ids.append(torch.tensor(json.loads(splitted_data_item[0].replace('\"','')))) | |
# print(f'input_ids: {type(input_ids)} {input_ids}') | |
target_ids.append(torch.tensor(json.loads(splitted_data_item[1].replace('\"','')))) | |
# print(f'target_ids: {type(target_ids)} {target_ids}') | |
# Convert to tensors (if not already) | |
input_ids_tensor = torch.stack(input_ids) | |
target_ids_tensor = torch.stack(target_ids) | |
return [input_ids_tensor, target_ids_tensor] | |
# Creating the DataLoader for the 'train' split of the dataset with the custom collate_fn | |
train_loader = DataLoader( | |
dataset['train'], | |
batch_size=batch_size, | |
shuffle=shuffle, | |
drop_last=drop_last, | |
num_workers=num_workers, | |
collate_fn=collate_fn | |
) | |
val_loader = DataLoader( | |
dataset['validation'], | |
batch_size=batch_size, | |
shuffle=shuffle, | |
drop_last=drop_last, | |
num_workers=num_workers, | |
collate_fn=collate_fn | |
) | |
return train_loader, val_loader | |
##################################### | |
# 2. Architecture Code | |
##################################### | |
import torch | |
import torch.nn as nn | |
class FeedForward(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) | |
def forward(self, x): | |
x_fc1 = self.fc1(x) | |
x_fc2 = self.fc2(x) | |
x = nn.functional.silu(x_fc1) * x_fc2 | |
return self.fc3(x) | |
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None): | |
assert head_dim % 2 == 0, "Embedding dimension must be even" | |
# Compute the inverse frequencies | |
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)) | |
# Frequency adjustments | |
if freq_config is not None: | |
low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] | |
high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] | |
wavelen = 2 * torch.pi / inv_freq | |
inv_freq_llama = torch.where( | |
wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq | |
) | |
smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( | |
freq_config["high_freq_factor"] - freq_config["low_freq_factor"] | |
) | |
smoothed_inv_freq = ( | |
(1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq | |
) | |
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) | |
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) | |
inv_freq = inv_freq_llama | |
# Generate position indices | |
positions = torch.arange(context_length) | |
# Compute the angles | |
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) | |
# Expand angles to match the head_dim | |
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) | |
# Precompute sine and cosine | |
cos = torch.cos(angles) | |
sin = torch.sin(angles) | |
return cos, sin | |
def compute_rope(x, cos, sin): | |
# x: (batch_size, num_heads, seq_len, head_dim) | |
batch_size, num_heads, seq_len, head_dim = x.shape | |
assert head_dim % 2 == 0, "Head dimension must be even" | |
# Split x into first half and second half | |
x1 = x[..., : head_dim // 2] # First half | |
x2 = x[..., head_dim // 2 :] # Second half | |
# Adjust sin and cos shapes | |
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) | |
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
# Apply the rotary transformation | |
rotated = torch.cat((-x2, x1), dim=-1) | |
x_rotated = (x * cos) + (rotated * sin) | |
return x_rotated.to(dtype=x.dtype) | |
class SharedBuffers: | |
_buffers = {} | |
def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32): | |
key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype) | |
if key not in SharedBuffers._buffers: | |
# Create or fetch the buffers | |
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) | |
cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config) | |
if dtype is not None: | |
cos = cos.to(dtype) | |
sin = sin.to(dtype) | |
SharedBuffers._buffers[key] = (mask, cos, sin) | |
return SharedBuffers._buffers[key] | |
class GroupedQueryAttention(nn.Module): | |
def __init__( | |
self, d_in, d_out, context_length, num_heads, | |
num_kv_groups, | |
rope_base=10_000, | |
rope_config=None, | |
dtype=None | |
): | |
super().__init__() | |
assert d_out % num_heads == 0, f"d_out:{d_out} must be divisible by num_heads:{num_heads}" | |
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" | |
self.d_out = d_out | |
self.num_heads = num_heads | |
self.head_dim = d_out // num_heads | |
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) | |
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) | |
self.num_kv_groups = num_kv_groups | |
self.group_size = num_heads // num_kv_groups | |
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) | |
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) | |
# Fetch buffers using SharedBuffers | |
mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype) | |
self.register_buffer("mask", mask) | |
self.register_buffer("cos", cos) | |
self.register_buffer("sin", sin) | |
def forward(self, x): | |
b, num_tokens, d_in = x.shape | |
queries = self.W_query(x) # Shape: (b, num_tokens, d_out) | |
keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) | |
values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) | |
# Reshape queries, keys, and values | |
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) | |
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) | |
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) | |
# Transpose keys, values, and queries | |
keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) | |
values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) | |
queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim) | |
# Apply RoPE | |
keys = compute_rope(keys, self.cos, self.sin) | |
queries = compute_rope(queries, self.cos, self.sin) | |
# Expand keys and values to match the number of heads | |
# Shape: (b, num_heads, num_tokens, head_dim) | |
keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) | |
values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) | |
# For example, before repeat_interleave along dim=1 (query groups): | |
# [K1, K2] | |
# After repeat_interleave (each query group is repeated group_size times): | |
# [K1, K1, K2, K2] | |
# If we used regular repeat instead of repeat_interleave, we'd get: | |
# [K1, K2, K1, K2] | |
# Compute scaled dot-product attention (aka self-attention) with a causal mask | |
# Shape: (b, num_heads, num_tokens, num_tokens) | |
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head | |
# Original mask truncated to the number of tokens and converted to boolean | |
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] | |
# Use the mask to fill attention scores | |
attn_scores.masked_fill_(mask_bool, -torch.inf) | |
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) | |
assert keys.shape[-1] == self.head_dim | |
# Shape: (b, num_tokens, num_heads, head_dim) | |
context_vec = (attn_weights @ values).transpose(1, 2) | |
# Combine heads, where self.d_out = self.num_heads * self.head_dim | |
context_vec = context_vec.reshape(b, num_tokens, self.d_out) | |
context_vec = self.out_proj(context_vec) # optional projection | |
return context_vec | |
class TransformerBlock(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.att = GroupedQueryAttention( | |
d_in=cfg["emb_dim"], | |
d_out=cfg["emb_dim"], | |
context_length=cfg["context_length"], | |
num_heads=cfg["n_heads"], | |
num_kv_groups=cfg["n_kv_groups"], | |
rope_base=cfg["rope_base"], | |
rope_config=cfg["rope_freq"], | |
dtype=cfg["dtype"] | |
) | |
self.ff = FeedForward(cfg) | |
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5) | |
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5) | |
def forward(self, x): | |
# Shortcut connection for attention block | |
shortcut = x | |
x = self.norm1(x) | |
x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size] | |
x = x + shortcut # Add the original input back | |
# Shortcut connection for feed-forward block | |
shortcut = x | |
x = self.norm2(x) | |
x = self.ff(x.to(torch.bfloat16)) | |
x = x + shortcut # Add the original input back | |
return x | |
class Llama3Model(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) | |
self.trf_blocks = nn.Sequential( | |
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) | |
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5) | |
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) | |
def forward(self, in_idx): | |
tok_embeds = self.tok_emb(in_idx) | |
x = tok_embeds | |
x = self.trf_blocks(x) | |
x = self.final_norm(x) | |
logits = self.out_head(x.to(torch.bfloat16)) | |
return logits | |
##################################### | |
# 3. Load Tokenizer | |
##################################### | |
import os | |
from transformers import PreTrainedTokenizerFast | |
class Tokenizer: | |
def __init__(self, tokenizer_model_path): | |
assert os.path.isfile(tokenizer_model_path), f"Tokenizer Model file {tokenizer_model_path} not found" | |
# load the tokenizehere | |
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_model_path) | |
# previously added | |
# self.special_tokens = { | |
# "<|begin_of_text|>": 128000, | |
# "<|end_of_text|>": 128001, | |
# "<|start_header_id|>": 128006, | |
# "<|end_header_id|>": 128007, | |
# "<|eot_id|>": 128009, | |
# } | |
def encode(self, text, bos=False, eos=False): | |
''' | |
parameter: allowed_special removed | |
parameter: disallowed_special removed | |
''' | |
if bos: | |
# tokens = [self.special_tokens["<|begin_of_text|>"]] | |
tokens = self.tokenizer.encode('<|begin_of_text|>') # [50000] | |
else: | |
tokens = [] | |
tokens += self.tokenizer.encode(text) | |
if eos: | |
# tokens.append(self.special_tokens["<|end_of_text|>"]) | |
tokens.append(self.tokenizer.encode('<|end_of_text|>')[0]) # [50001] | |
return tokens | |
def decode(self, tokens): | |
# return self.model.decode(tokens) | |
return self.tokenizer.decode(tokens) | |
class ChatFormat: | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
def encode_header(self, message): | |
tokens = [] | |
tokens.append(self.tokenizer.tokenizer.encode('<|start_header_id|>')[0]) # 50002 | |
tokens.extend(self.tokenizer.encode(message["भूमिका"], bos=False, eos=False)) | |
tokens.append(self.tokenizer.tokenizer.encode('<|end_header_id|>')[0]) | |
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) | |
# tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) | |
# tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) | |
# tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) | |
# tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) | |
return tokens | |
def encode(self, text): | |
message = { | |
"भूमिका": "प्रयोगकर्ता", | |
"सन्दर्भ": text | |
} | |
tokens = self.encode_header(message) | |
tokens.extend( | |
self.tokenizer.encode(message["सन्दर्भ"].strip(), bos=False, eos=False) | |
) | |
# tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) | |
tokens.append(self.tokenizer.tokenizer.encode('<|eot_id|>')[0]) | |
return tokens | |
def decode(self, token_ids): | |
return self.tokenizer.decode(token_ids) | |
_tokenizer = Tokenizer("NepaliBPE/tokenizer.json") | |
chat_tokenizer = ChatFormat(_tokenizer) | |
# text = "नेपाल विद्युत प्राधिकरणका कार्यकारी निर्देशक कुलमान घिसिङले माथिल्लो अरुण जलविद्युत आयोजना विश्व बैंक र एडीबीबाट वित्तीय व्यवस्थापन नभए नेपाली जनताको लगानीमा बनाउने तयारी रहेको बताएका छन् ।" | |
# # normal tokenizer | |
# print([tokenizer.tokenizer.decode([token]) for token in tokenizer.encode(text)]) | |
# # formatted tokenizer | |
# print([tokenizer.tokenizer.decode([token]) for token in chat_tokenizer.encode(text)]) | |
##################################### | |
# 4. Generate Text | |
#################################### | |
def text_to_token_ids(text, tokenizer): | |
encoded = tokenizer.encode(text) | |
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension | |
return encoded_tensor | |
# ''' | |
# we have modified above return statement by sebastian because there are no tokens like 'start_header_id', 'end_header_id' and tokenizer is returning None which inturn is giving error | |
# TODO: add special tokens: 'start_header_id', 'end_header_id' and uncomment above return statement | |
# ''' | |
# print(encoded_tensor) | |
# return torch.tensor([token for token in encoded_tensor]) # TODO: use additional vocab like encoded_tensor | |
def token_ids_to_text(token_ids, tokenizer): | |
flat = token_ids.squeeze(0) # remove batch dimension | |
return tokenizer.decode(flat.tolist()) | |
def generate(model, idx, max_new_tokens, context_length, temperature=0.0, top_k=None, eos_id=None): | |
# For-loop is the same as before: Get logits, and only focus on last time step | |
for _ in range(max_new_tokens): | |
idx_cond = idx[:, -context_length:] | |
with torch.no_grad(): | |
logits = model(idx_cond) | |
logits = logits[:, -1, :] | |
# New: Filter logits with top_k sampling | |
if top_k is not None: | |
# Keep only top_k values | |
top_logits, _ = torch.topk(logits, top_k) | |
min_val = top_logits[:, -1] | |
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) | |
# New: Apply temperature scaling | |
if temperature > 0.0: | |
logits = logits / temperature | |
# Apply softmax to get probabilities | |
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) | |
# Sample from the distribution | |
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) | |
# Otherwise same as before: get idx of the vocab entry with the highest logits value | |
else: | |
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) | |
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified | |
break | |
# Same as before: append sampled index to the running sequence | |
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) | |
return idx | |
def generate_and_print_sample(PROMPT, tokenizer, chat_tokenizer, model, device, context_length): | |
# PROMPT = "What do llamas eat?" | |
# PROMPT="रामले भात" | |
torch.manual_seed(123) | |
# token_ids = generate( | |
# model=model, | |
# idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), | |
# max_new_tokens=150, | |
# context_length=context_length, | |
# temperature=0.5, | |
# top_k=1, | |
# eos_id=tokenizer.eos_token_id | |
# ) | |
# output_text = token_ids_to_text(token_ids, tokenizer) | |
# We have re-defined generate function below. | |
output_text = generate( | |
model=model, | |
prompt=PROMPT, | |
tokenizer=tokenizer, | |
max_new_tokens=150, | |
) | |
print("Output text:\n", clean_text(output_text)) | |
# ------------------------------------------------------------- | |
# Generte sample text | |
# PROMPT = "लामा हरु ले के खान्छन् ?" | |
# torch.manual_seed(123) | |
# token_ids = generate( | |
# model=model, | |
# idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), | |
# max_new_tokens=150, | |
# context_size=LLAMA32_CONFIG["context_length"], | |
# top_k=1, | |
# temperature=0. | |
# ) | |
# output_text = token_ids_to_text(token_ids, tokenizer) | |
# ------------------------------------------------------------- | |
def clean_text(text, header_end="प्रयोगकर्ता <|end_header_id|>\n\n"): | |
# Find the index of the first occurrence of "<|end_header_id|>" | |
index = text.find(header_end) | |
if index != -1: | |
# Return the substring starting after "<|end_header_id|>" | |
return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace | |
else: | |
# If the token is not found, return the original text | |
return text | |
# print("Output text:\n", clean_text(output_text)) | |
########################################################################## | |
# Chapter 5 (keep everything as it is except `generate_and_print_sample` function) | |
######################################################################## | |
def calc_loss_batch(input_batch, target_batch, model, device): | |
input_batch, target_batch = input_batch.to(device), target_batch.to(device) | |
logits = model(input_batch) | |
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) | |
return loss | |
def calc_loss_loader(data_loader, model, device, num_batches=None, len_data_loader=0): | |
''' | |
- parameter: len_data_loader=None <added to set len_data_loader since we cant calulate len(data_loader) of Iterable> | |
''' | |
total_loss = 0. | |
if len_data_loader == 0: # len(data_loader) | |
return float("nan") | |
elif num_batches is None: | |
num_batches = len_data_loader | |
else: | |
num_batches = min(num_batches, len_data_loader) | |
for i, (input_batch, target_batch) in enumerate(data_loader): | |
if i < num_batches: | |
loss = calc_loss_batch(input_batch, target_batch, model, device) | |
total_loss += loss.item() | |
else: | |
break | |
return total_loss / num_batches | |
def evaluate_model(model, train_loader, val_loader, device, eval_iter, len_train_loader=0, len_val_loader=0): | |
model.eval() | |
with torch.no_grad(): | |
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, len_data_loader = len_train_loader) | |
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, len_data_loader = len_val_loader) | |
model.train() | |
return train_loss, val_loss | |
def generate( | |
model, | |
prompt, | |
tokenizer, | |
max_new_tokens, | |
temperature=0.7, | |
top_k=50, | |
top_p=None, # New parameter for nucleus sampling | |
eos_id=None, | |
repetition_penalty=1.2, | |
penalize_len_below=50, | |
context_size = 512 | |
): | |
# context_size = GPT_CONFIG_124M['context_length'] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
idx = text_to_token_ids(prompt, tokenizer).to(device) | |
if not eos_id: | |
encoded_endoftext = tokenizer.encode("<|endoftext|>") | |
eos_id = encoded_endoftext[0] if encoded_endoftext else None | |
token_freq = {} | |
for step in range(max_new_tokens): | |
idx_cond = idx[:, -context_size:] | |
with torch.no_grad(): | |
logits = model(idx_cond) | |
logits = logits[:, -1, :] | |
# Apply repetition penalty | |
for token_id in idx[0].tolist(): | |
if token_id in token_freq: | |
logits[0, token_id] /= repetition_penalty | |
else: | |
token_freq[token_id] = 1 | |
# Penalize EOT token for shorter sequences | |
if eos_id is not None and step < penalize_len_below: | |
logits[0, eos_id] /= (penalize_len_below - step) / penalize_len_below | |
# Apply temperature scaling | |
if temperature > 0.0: | |
logits = logits / temperature | |
# Convert logits to probabilities | |
probs = torch.softmax(logits, dim=-1) | |
# Apply top-p (nucleus) sampling if specified | |
if top_p: | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
# Create a mask for indices to remove | |
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) | |
probs = probs.masked_fill(indices_to_remove, 0.0) | |
# Renormalize probabilities | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
# If top_p is None, apply top-k sampling | |
elif top_k: | |
top_probs, top_indices = torch.topk(probs, top_k) | |
probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs) | |
# Renormalize probabilities | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
# Sample from the filtered distribution | |
if temperature > 0.0: | |
idx_next = torch.multinomial(probs, num_samples=1) | |
else: | |
idx_next = torch.argmax(probs, dim=-1, keepdim=True) | |
if idx_next == eos_id: | |
break | |
idx = torch.cat((idx, idx_next), dim=1) | |
text = token_ids_to_text(idx, tokenizer) | |
return text | |
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir): | |
fig, ax1 = plt.subplots() | |
# Plot training and validation loss against epochs | |
ax1.plot(epochs_seen, train_losses, label="Training loss") | |
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") | |
ax1.set_xlabel("Epochs") | |
ax1.set_ylabel("Loss") | |
ax1.legend(loc="upper right") | |
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
# Create a second x-axis for tokens seen | |
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis | |
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks | |
ax2.set_xlabel("Tokens seen") | |
fig.tight_layout() # Adjust layout to make room | |
plt.savefig(output_dir / "losses.pdf") | |
# -------------------------------------------------------------------------------- | |
# -------------------------- New Chat function --------------------- | |
# -------------------------------------------------------------------------------- | |
def generate_chat_optimized( | |
model, | |
prompt, | |
tokenizer, | |
chat_tokenizer, | |
max_new_tokens, | |
context_size, | |
temperature=0.7, | |
top_k=50, | |
top_p=None, | |
eos_id=None, | |
repetition_penalty=1.2, | |
penalize_len_below=50, | |
device=None, | |
batch_size=1, # Added parameter | |
clean_the_text=True | |
): | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
idx = text_to_token_ids(prompt, chat_tokenizer).to(device) | |
# Find EOS token once instead of checking every time | |
if not eos_id: | |
if "<|endoftext|>" in tokenizer.get_vocab(): | |
encoded_endoftext = tokenizer.encode("<|endoftext|>") | |
eos_id = encoded_endoftext[0] if encoded_endoftext else None | |
elif "<|eot_id|>" in tokenizer.get_vocab(): | |
encoded_endoftext = tokenizer.encode("<|eot_id|>") | |
eos_id = encoded_endoftext[0] if encoded_endoftext else None | |
# Pre-compute token frequencies for the initial context | |
token_freq = {} | |
for token_id in idx[0].tolist(): | |
if token_id in token_freq: | |
token_freq[token_id] += 1 | |
else: | |
token_freq[token_id] = 1 | |
# Process tokens in batches for efficiency | |
with torch.no_grad(): # Move this outside the loop | |
for step in range(0, max_new_tokens, batch_size): | |
batch_end = min(step + batch_size, max_new_tokens) | |
current_batch_size = batch_end - step | |
idx_cond = idx[:, -context_size:] | |
logits = model(idx_cond) | |
logits = logits[:, -1, :] | |
# Apply repetition penalty once for the batch | |
for token_id in idx[0].tolist()[-current_batch_size:]: | |
if token_id in token_freq: | |
token_freq[token_id] += 1 | |
logits[0, token_id] /= repetition_penalty | |
else: | |
token_freq[token_id] = 1 | |
# Process each token in the batch | |
for i in range(current_batch_size): | |
current_step = step + i | |
# Penalize EOT token for shorter sequences | |
current_logits = logits.clone() # Work with a copy | |
if eos_id is not None and current_step < penalize_len_below: | |
penalty_factor = 1.0 + (penalize_len_below - current_step) / penalize_len_below | |
current_logits[0, eos_id] /= penalty_factor | |
# Apply temperature scaling | |
if temperature > 0.0: | |
current_logits = current_logits / temperature | |
# Convert logits to probabilities | |
probs = torch.softmax(current_logits, dim=-1) | |
# Apply sampling strategies | |
if top_p and top_p > 0.0: | |
# Nucleus sampling implementation | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) | |
probs = probs.masked_fill(indices_to_remove, 0.0) | |
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) | |
elif top_k and top_k > 0: | |
# Top-k sampling implementation | |
top_probs, top_indices = torch.topk(probs, min(top_k, probs.size(-1))) | |
probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs) | |
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) | |
# Sample from the filtered distribution | |
if temperature > 0.0: | |
idx_next = torch.multinomial(probs, num_samples=1) | |
else: | |
idx_next = torch.argmax(probs, dim=-1, keepdim=True) | |
# Add the next token to the sequence | |
idx = torch.cat((idx, idx_next), dim=1) | |
# Check for end of sequence token and break if needed | |
if idx_next.item() == eos_id: | |
output_text = token_ids_to_text(idx, tokenizer) | |
if clean_the_text: | |
# Clean the output | |
# cleaned_text = clean_chat_output(output_text) | |
cleaned_text = clean_text(output_text) | |
if '<|eot_id|>' in cleaned_text: | |
cleaned_text = cleaned_text.replace('<|eot_id|>','') | |
# print("Generated text:\n", cleaned_text) | |
return cleaned_text | |
return output_text | |
# Not end of text token. terminate early since it exceeds max_new_tokens | |
output_text = token_ids_to_text(idx, tokenizer) | |
if clean_the_text: | |
# Clean the output | |
# cleaned_text = clean_chat_output(output_text) | |
cleaned_text = clean_text(output_text) | |
if '<|eot_id|>' in cleaned_text: | |
cleaned_text = cleaned_text.replace('<|eot_id|>','') | |
# print("Generated text:\n", cleaned_text) | |
return cleaned_text | |
return output_text | |
# ==================================================================================- | |
# ==================================================================================- | |
# ==================================================================================- | |
# ==================================================================================- | |
# ==================================================================================- | |
# ============================================= | |
# ============================================= | |
# ============================================= | |
# Below is the Code to initialize the model | |
# ============================================= | |
# ============================================= | |
# ============================================= | |
# import torch # already imported | |
# from previous_chapters4 import ( | |
# Llama3Model, | |
# ChatFormat, | |
# Tokenizer, | |
# generate_and_print_sample | |
# ) | |
old_context_length = 131_072 # original context length of llama3.2 model | |
new_context_length = LLAMA32_CONFIG["context_length"] # 512 our new context length | |
def rescale_theta(theta_old, context_length_old, context_length_new): | |
# original linear scaling | |
scaling_factor = context_length_new / context_length_old | |
theta_new = theta_old * scaling_factor | |
return theta_new | |
LLAMA32_CONFIG["rope_base"] = rescale_theta( | |
LLAMA32_CONFIG["rope_base"], | |
old_context_length, | |
new_context_length | |
) | |
print("New RoPE theta (i.e. LLAMA32_CONFIG[\"rope_base\"]):", LLAMA32_CONFIG["rope_base"]) | |
model = Llama3Model(LLAMA32_CONFIG) | |
# Todo: don't compile? (claude sonnet 3.7 said compiling would speed up inference speed) | |
# compile the model | |
if True: | |
print("compiling the model... (takes a ~minute)") | |
unoptimized_model = model | |
model = torch.compile(model) # requires PyTorch 2.0 | |
model.eval() # eval mode | |
# Check buffers | |
# -------------- | |
print('The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:') | |
print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask) | |
print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos) | |
print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) | |
# Display number of parameters | |
# ----------------------------- | |
total_params = sum(p.numel() for p in model.parameters()) | |
print(f"Total number of parameters: {total_params:,}") | |
# Account for weight tying | |
total_params_normalized = total_params - model.tok_emb.weight.numel() | |
print(f"\nTotal number of unique parameters: {total_params_normalized:,}") | |
# Display model_memory_size | |
# ----------------------------------------------------------------------- | |
def model_memory_size(model, input_dtype=torch.float32): | |
total_params = 0 | |
total_grads = 0 | |
for param in model.parameters(): | |
# Calculate total number of elements per parameter | |
param_size = param.numel() | |
total_params += param_size | |
# Check if gradients are stored for this parameter | |
if param.requires_grad: | |
total_grads += param_size | |
# Calculate buffer size (non-parameters that require memory) | |
total_buffers = sum(buf.numel() for buf in model.buffers()) | |
# Size in bytes = (Number of elements) * (Size of each element in bytes) | |
# We assume parameters and gradients are stored in the same type as input dtype | |
element_size = torch.tensor(0, dtype=input_dtype).element_size() | |
total_memory_bytes = (total_params + total_grads + total_buffers) * element_size | |
# Convert bytes to gigabytes | |
total_memory_gb = total_memory_bytes / (1024**3) | |
return total_memory_gb | |
print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB") | |
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB") | |
# ----------------------------------------------------------------------- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
model.to(device) | |
print(f'device: {device}') | |
latest_model_checkpoint = "parameters_300m/model_pg_398000_steps.pth" | |
checkpoint = torch.load(latest_model_checkpoint, map_location=device, weights_only=False) | |
# modified (added model loading code) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
# generate_and_print_sample(PROMPT="रामले भात", tokenizer=_tokenizer, chat_tokenizer=chat_tokenizer, model=model, device=device, context_length = LLAMA32_CONFIG["context_length"]) | |
# from previous_chapters import generate_and_print_chat | |
# generated_text = generate_and_print_chat( | |
# prompt="रामले भात", | |
# tokenizer=tokenizer, | |
# chat_tokenizer=chat_tokenizer, | |
# model=model, | |
# device=None, | |
# max_new_tokens=150, | |
# context_length=None, | |
# temperature=0.1, | |
# top_k=50, | |
# top_p=0.9, | |
# repetition_penalty=1.2, | |
# clean_the_text=True | |
# ) | |
# print(generated_text) | |
# ============================================= | |
# ============================================= | |
# ============================================= | |
def generate_text(prompt, max_new_tokens, top_k, top_p, temperature, repetition_penalty, penalize_len_below): | |
return generate_chat_optimized( | |
model=model, | |
prompt=prompt, | |
tokenizer=tokenizer, | |
chat_tokenizer=chat_tokenizer, | |
max_new_tokens=max_new_tokens, | |
context_size=LLAMA32_CONFIG['context_length'], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
eos_id=None, | |
repetition_penalty=repetition_penalty, | |
penalize_len_below=penalize_len_below, | |
device=device, | |
batch_size=1 | |
) | |
css = """ | |
#bright-textbox { | |
background-color: #ffeb3b; /* Bright yellow */ | |
color: #000000; /* Black text for contrast */ | |
border: 2px solid #fbc02d; /* Slightly darker yellow for the border */ | |
font-size: 16px; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(title="LLAMA3_Nepali_318M Text Generator", css=css) as interface: | |
gr.Markdown("# LLAMA3_Nepali_318M Text Generator") | |
gr.Markdown("Enter Nepali (नेपाली) text to generate content using the custom LLAMA3_Nepali_318M model.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="यहाँ नेपाली मा इन्पुट दिनु होस् ... (please Enter Nepali text here...)" #, | |
# value="रामले भात" | |
) | |
max_tokens = gr.Slider(minimum=1, maximum=512, value=25, step=1, label="Max New Tokens") | |
with gr.Row(): | |
with gr.Column(): | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.3, step=0.1, label="Temperature") | |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
with gr.Column(): | |
top_k = gr.Slider(minimum=0, maximum=100, value=5, step=1, label="Top K (set to 0 to use Top P)") | |
top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, label="Top P (set above 0 to use instead of Top K)") | |
min_length = gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Minimum Length Penalty") | |
generate_btn = gr.Button("Generate Text") | |
with gr.Column(): | |
output = gr.Textbox(label="Generated Text", lines=10) | |
# Add examples if you have any | |
gr.Examples( | |
examples=[ | |
# ["रामले भात", 25, 10, 0, 0.7, 1.2, 15], | |
# ["नेपाल एउटा", 25, 10, 0.9, 0.5, 1.2, 10], | |
["नेपाल का वर्तमान प्रधानमन्त्री ", 25, 10, 0.4, 0.8, 1.2, 10], | |
["भारतीय प्रधानमन्त्री ", 25, 10, 0.9, 0.5, 1.2, 15], | |
["अमिरिकी रास्ट्रपति डोनाल्ड", 25, 10, 0.9, 0.6, 1.2, 15], | |
], | |
inputs=[prompt, max_tokens, top_k, top_p, temperature, repetition_penalty, min_length], | |
outputs=output, | |
fn=generate_text, | |
cache_examples=True, | |
) | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[prompt, max_tokens, top_p, top_k, temperature, repetition_penalty, min_length], | |
outputs=output | |
) | |
interface.launch() |