#Based on Llama from Meta (https://github.com/meta-llama/llama/blob/main/llama/model.py) import random import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from tokenizers import Tokenizer from pathlib import Path import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import torch from datasets import Dataset from torch.utils.data import DataLoader from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput import wandb from tqdm import tqdm from functools import partial import tiktoken import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM import os torch.manual_seed(1337) torch.cuda.manual_seed(1337) # import wandb # wandb.login() # from torch.utils.tensorboard import SummaryWriter from datasets import load_dataset, concatenate_datasets # data = {} # texts = [] # with open('data/input.txt', 'r') as f: # texts.append(f.readlines()) # # print(texts) # # print(len(texts[0])) # data = { # "text": texts[0] # } # fw_train = Dataset.from_dict(data) # print(fw_train) # fw_train = load_dataset("karpathy/tiny_shakespeare", split="train", trust_remote_code=True) # print(fw_train['text']) # text = fw_train['text'][0].split("\n") # print(text) # filtered_lines = [line for line in text if line != ''] # print(len(filtered_lines)) # use name="sample-10BT" to use the 10BT sample tinystories = True fw = False fw_train = None fw_test = None if(tinystories): fw_train = load_dataset("roneneldan/TinyStories", split="train") fw_test = load_dataset("roneneldan/TinyStories", split="validation") print(fw_train) print(fw_test) if(fw): fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False) fw_train = fw_train.train_test_split(test_size=0.01) print(fw_train) print(fw_train) # Select only 1000 rows from the dataset # fw_train = fw_train.select(range(1000000)) # alpaca = load_dataset("yahma/alpaca-cleaned", split='train') # dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train') # merged_dataset = concatenate_datasets([alpaca, dolly]) # dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True) # print(fw_train) # Split the dataset into training and validation sets # Split the dataset into training and validation sets # fw_train = fw_train.train_test_split(test_size=0.01) # print(fw_train) # Access the splits # train_dataset = train_val_split['train'] # val_dataset = train_val_split['test'] # train_dataset = fw_train.train_test_split(test_size=0.2) def setup(rank=None, world_size=None): # os.environ['MASTER_ADDR'] = 'localhost' # os.environ['MASTER_PORT'] = '12355' init_process_group("nccl") # torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) def cleanup(): destroy_process_group() @dataclass class ModelArgs: #Hyperparameters epochs = 4 block_size = 512 batch_size = 64 embeddings_dims = 512 attn_dropout = 0.1 no_of_heads = 8 dropout = 0.1 # epochs = 100 val_epochs = 2 max_lr = 6e-4 no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated weight_decay_optim = 0.1 beta_1 = 0.9 beta_2 = 0.95 clip = 1.0 device = 'cuda' no_kv_heads = 2 vocab_size = 50304 #powers of 2 so nice! eps = 1e-5 dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # dtype = 'bfloat16' def _save_snapshot(model, optimizer, scheduler, epoch, step): snapshot = { "MODEL_STATE": model.module.state_dict(), "OPTIMIZER_STATE": optimizer.state_dict(), # "SCHEDULER_STATE": scheduler.state_dict(), "EPOCHS_RUN": epoch, "STEP_RUN": step } torch.save(snapshot, f"snapshot_{step}.pt") print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.") def _load_snapshot(snapshot_path, model, optimizer, scheduler): snapshot = torch.load(snapshot_path) model.load_state_dict(snapshot["MODEL_STATE"]) optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"]) # scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state epoch = snapshot["EPOCHS_RUN"] step = snapshot["STEP_RUN"] print(f"Resuming from Epoch {epoch}, Step {step}") return epoch, step tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...') # tokenizer.pad_token = tokenizer.eos_token # if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # print("ADDED THE TOKENS: ", tokenizer.pad_token_id) # tokenizer.bos_token = "[INST]" # tokenizer.eos_token = "[/INST]" # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") def tokenize_function(examples): return tokenizer( examples['text'], max_length=ModelArgs.block_size, padding='max_length', truncation=True, return_tensors='pt' ) def prepare_dataset(split, device, batch_size): print("Device is: ", device) # alpaca_prompt = ''' # ### Instruction: # {} # ### Response: # {} # ''' # Load a subset of the C4 dataset with a glob pattern for specific training files # dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True) # Initialize tokenizer # tokenizer = AutoTokenizer.from_pretrained("gpt2") # generator = torch.Generator(device=device) def collate_fn(batch): # Extract text data texts = [item ["text"] for item in batch] # Set the pad token if it isn't set already # if tokenizer.pad_token is None: # tokenizer.pad_token = tokenizer.eos_token # outputs = [] # texts = [] # for item in batch: # instruction = item['prompt'] # # input = item['input'] # output = item['completion'] # # out = alpaca_prompt.format(instruction, output) # texts.append(instruction) # outputs.append(output) # Tokenize text data input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") # output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") # input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") # out = {"input": input_encodings} # input_encodings['input_ids'][: , input_encodings["attention_mask"] == 0] = -100 input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right input_encodings["labels"][:, -1] = tokenizer.eos_token_id # Let the last token be end # Return tokenized input tensors # return out return input_encodings # Create DistributedSampler for proper shuffling and partitioning across processes # dist_sampler = DistributedSampler(fw_train["text"], shuffle=True) # Create DataLoader with custom collate_fn # print(fw_dataset) dataloader = None if(tinystories): if(split == 'train'): data_loader = DataLoader( fw_train, # generator=generator, batch_size=batch_size, sampler=DistributedSampler(fw_train, shuffle=True), collate_fn=collate_fn, drop_last=True, shuffle=False ) elif(split == 'val'): data_loader = DataLoader( fw_test, batch_size=batch_size, sampler=DistributedSampler(fw_test, shuffle=True), collate_fn=collate_fn, drop_last=True, shuffle=False ) elif(fw): if(split == 'train'): data_loader = DataLoader( fw_train['train'], batch_size=batch_size, sampler=DistributedSampler(fw_train['train'], shuffle=True), collate_fn=collate_fn, drop_last=True, shuffle=False ) elif(split == 'val'): data_loader = DataLoader( fw_train['test'], batch_size=batch_size, # generator=generator, sampler=DistributedSampler(fw_train["test"]), collate_fn=collate_fn, drop_last=True, shuffle=False ) return data_loader class Normalization(nn.Module): def __init__( self, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims) def forward(self, x): x = self.rmsnorm_layer(x) return x # import numpy as np class RotaryEmbeddings(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, batch_size: int = ModelArgs.batch_size ): super().__init__() self.embeddings_dims = embeddings_dims self.block_size = block_size self.batch_size = batch_size self.theta = 0 self.device=device # self.d_model = embeddings_dims # self.i = torch.arange(0, embeddings_dims, dtype=torch.float32) # # self.pos = torch.arange(0, block_size, dtype=torch.float32) # self.exp = ((2 * self.i)) / self.d_model # self.theta = 10000 ** self.exp # # print(self.theta.shape) # self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device) # self.cos = torch.cos((self.i / self.theta)) # self.sin = torch.sin((self.i / self.theta)) # self.even = self.sin[::2] # self.odd = self.cos[1::2] # # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype) # self.x_reshaped[..., : , ::2] = self.even # self.x_reshaped[..., : , 1::2] = self.odd def apply_rope(self, seq): batch_size, seq_len, embeds_dims = seq.shape # print(seq.shape) # print(self.embeddings_dims) # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0) # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) theta = 10000 ** (-2 * (positions) / embeds_dims) angles = positions * theta angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2) cos_angles = torch.cos(angles) sin_angles = torch.sin(angles) # print(cos_angles.shape) # print(sin_angles.shape) # print(x_reshaped.shape) # indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device) out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1) out = out.view(batch_size, seq_len, embeds_dims) return out def forward(self, x): # print("X shape: ", x.shape) # print("X is: ", x) # B,T,C = x.shape # print("MATRIX:",x) # if(x > self.block_size or x < self.block_size): # matrix = self.init_matrix(x) # return matrix # else: # matrix = self.init_matrix(self.block_size) # return matrix # if(ModelArgs.inference): res = self.apply_rope(x) return res # else: # return self.x_reshaped class RotaryAttentionHead(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, no_of_heads: int = ModelArgs.no_of_heads, attn_dropout: int = ModelArgs.attn_dropout ): super().__init__() self.head_size = embeddings_dims // no_of_heads self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device) self.dropout = nn.Dropout(p = attn_dropout) self.device = device def forward(self,x): # print(x.shape) # print("X is: ", x) batch, block_size, embeddings_dims = x.shape query = self.query(x) # print(query) key = self.key(x) values = self.value(x) # matrix = self.rotary_matrix(block_size) rotary_q = self.rope(query) rotary_k = self.rope(key) # print(matrix.shape) # print(query.shape) masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) weights_masked = weights.masked_fill(masked == 0, float('-inf')) scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) scaled_weights = F.softmax(scaled_weights, dim=-1) value = scaled_weights @ values out = self.dropout(value) return out # # import numpy as np # class RotaryEmbeddings(nn.Module): # def __init__( # self, # device, # embeddings_dims: int = ModelArgs.embeddings_dims, # block_size: int = ModelArgs.block_size, # batch_size: int = ModelArgs.batch_size # ): # super().__init__() # self.embeddings_dims = embeddings_dims # self.block_size = block_size # self.batch_size = batch_size # self.theta = 0 # # def init_matrix(self, seq_len): # # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False) # # for pos in range(seq_len): # # for j in range(1, self.embeddings_dims // 2): # # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims) # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta)) # # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta)) # # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta)) # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta)) # # return self.matrix # self.device=device # def init_matrix(self, seq_len): # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) # positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1) # # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) # theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims) # angles = positions * theta # cos_angles = torch.cos(angles) # sin_angles = torch.sin(angles) # indices = torch.arange(seq_len, dtype=torch.int64, device = self.device) # # print(indices) # # print(indices.shape) # # print(indices[::2]) # even_indices = indices[::2] # odd_indices = indices[1::2] # self.matrix[:, even_indices, even_indices] = cos_angles # self.matrix[:, odd_indices, odd_indices] = sin_angles # self.matrix[:, odd_indices, even_indices] = -sin_angles # self.matrix[:, even_indices, odd_indices] = cos_angles # return self.matrix # def forward(self, x): # # B,T,C = x.shape # # print("MATRIX:",x) # if(x > self.block_size or x < self.block_size): # matrix = self.init_matrix(x) # return matrix # else: # matrix = self.init_matrix(self.block_size) # return matrix # class RotaryAttentionHead(nn.Module): # def __init__( # self, # device, # embeddings_dims: int = ModelArgs.embeddings_dims, # no_of_heads: int = ModelArgs.no_of_heads, # attn_dropout: int = ModelArgs.attn_dropout # ): # super().__init__() # self.head_size = embeddings_dims // no_of_heads # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) # self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device) # self.dropout = nn.Dropout(p = attn_dropout) # self.device = device # def forward(self,x): # # print(x.shape) # batch, block_size, embeddings_dims = x.shape # query = self.query(x) # # print(query) # key = self.key(x) # values = self.value(x) # matrix = self.rotary_matrix(block_size) # # print(matrix.shape) # # print(query.shape) # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) # weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) # weights_masked = weights.masked_fill(masked == 0, float('-inf')) # scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) # scaled_weights = F.softmax(scaled_weights, dim=-1) # value = scaled_weights @ values # out = self.dropout(value) # return out class MQA(nn.Module): def __init__( self, device, no_of_q_heads: int, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, ): super().__init__() # self.no_of_q_heads = no_of_heads // no_of_kv_heads # self.no_of_q_heads = no_of_q_heads self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads self.head_size = embeddings_dims // no_of_q_heads # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2 self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device) # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device) # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False) self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device) self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device) self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) self.device = device self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)]) def scaled_dot_product(self, q, k, v, block_size): # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) q = self.rotary(q) masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) # rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) # print("Query: ", q.shape) # print("Keys: ", k.shape) # print(q.permute(2,0,1).shape) # print(k.permute(2,0,1).transpose(-2, -1).shape) # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) # weights = q @ k.permute(2,1,0) # print(weights.shape) # print(masked.shape) weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5) masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf')) weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens weights_normalized = self.dropout(weights_normalized) out = weights_normalized @ v return out def forward(self,x): # print("MQA: ", x.shape) batch, block_size, embeddings_dims = x.shape # query = self.query(x) # matrix = self.rotary_matrix(block_size) key = self.key(x) values = self.value(x) # print("Keys: ", key.shape) # print("Values: ", values.shape) # rotary_value = self.rotary(values) rotary_key = self.rotary(key) multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1) # print("Multi query: ", multi_query_concat.shape) linear_layer= self.linear_layer(multi_query_concat) # out = self.dropout(linear_layer) return linear_layer class GQA(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, # no_of_q_heads: int = ModelArgs.no_of_heads, mqa_heads: int = ModelArgs.no_kv_heads ): super().__init__() # self.no_of_kv_heads = no_of_kv_heads self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads # self.head_dim = embeddings_dims // self.no_kv_heads self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device) self.device = device self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)]) # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size) def forward(self,x): batch, block_size, embeddings_dims = x.shape # res = self.mqa(x) grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1) linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one out = self.dropout(linear_layer) return out class Swish(nn.Module): def __init__( self, device, block_size: int = ModelArgs.block_size, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.sig = torch.nn.Sigmoid() def forward(self, x): swish = x * self.sig(x) return swish class SWiGLU(nn.Module): def __init__( self, device, block_size: int = ModelArgs.block_size, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3) self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device) self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device) self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device) self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) def forward(self, x): swish_res = self.swish(self.linear_layer1(x)) x_V = self.linear_layer2(x) res = torch.mul(swish_res, x_V) out = self.linear_layer3(res) return out class FFN(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, dropout = ModelArgs.dropout ): super().__init__() # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device) self.dropout = nn.Dropout(p = dropout) def forward(self, x): x = self.swiglue(x) # x = self.linear_layer(x) x = self.dropout(x) return x class DecoderLayer(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, dropout = ModelArgs.dropout, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, ) : super().__init__() self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device) self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device) # self.norm = Normalization(embeddings_dims=embeddings_dims) self.norm1 = Normalization(embeddings_dims=embeddings_dims) self.norm2 = Normalization(embeddings_dims=embeddings_dims) self.dropout = nn.Dropout(p = dropout) def forward(self, x): x = x + self.gqa(self.norm1(x)) x = x + self.feedforward_network(self.norm2(x)) return x class Llama(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, dropout = ModelArgs.dropout ) : super().__init__() self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device) self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)]) self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device) self.dropout = nn.Dropout(p = dropout) # self.norm = Normalization(embeddings_dims) #weight tying self.embeddings.weight = self.linear_layer.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, x): x = self.embeddings(x) x = self.dropout(x) x = self.decoder(x) # x = self.norm(x) x = self.linear_layer(x) # out = self.norm(x) return x # from andrej karapathy github def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0): input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) generated_tokens = [] ModelArgs.inference=True for _ in range(max_length): with torch.no_grad(): outputs = model.module(input_ids) logits = outputs[:, -1, :] probs = F.softmax(logits, dim=-1) # Top-k filtering top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1) # Apply temperature scaling # probs = probs / temperature # Sample from top-k next_token = torch.multinomial(top_k_probs, num_samples=1) # generated_tokens.append(next_token.item()) xcol = torch.gather(top_k_indices, -1, next_token) input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence return tokenizer.decode(input_ids[0], skip_special_tokens=True) def beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0): device = next(model.module.parameters()).device input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] beam_scores = torch.zeros(beam_width, device=device) beam_sequences = input_ids.repeat(beam_width, 1) for _ in range(max_length): outputs = model(beam_sequences) logits = outputs[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) top_probs, top_indices = torch.topk(probs, beam_width, dim=-1) # Expand beams beam_scores = beam_scores.unsqueeze(-1) + torch.log(top_probs) beam_scores = beam_scores.view(-1) top_indices = top_indices.view(-1) # Select top beams beam_scores, top_beams = torch.topk(beam_scores, beam_width) beam_sequences = torch.cat([beam_sequences[top_beams // beam_width], top_indices[top_beams].unsqueeze(-1)], dim=-1) # Return the best sequence best_sequence = beam_sequences[0] return tokenizer.decode(best_sequence, skip_special_tokens=True) # device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" # ModelArgs.device = device model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) model = model.to(ModelArgs.device) # Printing a summary of the architecture # !pip install torchinfo from torchinfo import summary # idx, targets = get_batch('test') idx = torch.randint( low=0, high=ModelArgs.vocab_size, size=(ModelArgs.batch_size, ModelArgs.block_size), dtype=torch.long ) # sample_idx = random.randint(range(len(train_dataset))) # idx, targets = train_dataset[0] idx = idx.to(ModelArgs.device) # targets = targets.to(ModelArgs.device) summary(model=model, input_data=idx, # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"]) def find_unused_parameters(model): unused = [] for name, param in model.named_parameters(): if param.grad is None: unused.append(name) return unused def greedy_decode( model, tokenizer, prompt, device, max_length=50, repetition_penalty=1.2, context_window=10, temperature=1.0, eos_token_id=None, ): # model.eval() # device = next(model.parameters()).device input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] generated_tokens = [] eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided for _ in range(max_length): with torch.no_grad(): outputs = model.module(input_ids) logits = outputs[:, -1, :] # Get logits for the last token # Apply temperature scaling # if temperature != 1.0: # logits = logits / temperature # Apply repetition penalty # if repetition_penalty != 1.0 and len(generated_tokens) > 0: # for token in set(generated_tokens[-context_window:]): # Penalize recent tokens # logits[0, token] /= repetition_penalty # Greedy selection next_token = torch.argmax(logits, dim=-1).unsqueeze(0) generated_tokens.append(next_token.item()) # Stop if EOS token is generated # if next_token.item() == eos_token_id: # break # Append the new token to the input input_ids = torch.cat([input_ids, next_token], dim=1) # Decode the generated tokens return tokenizer.decode(generated_tokens, skip_special_tokens=True) def save_to_file(text): with open('generations.txt', 'a') as f: f.writelines(text + "\n\n") #Train the model # writer = SummaryWriter(log_dir="runs/experiment") from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR # Warmup phase for 2000 steps def warmup_fn(step): if step < 2000: return step / 2000 # LR gradually increases return 1.0 from torch.optim.lr_scheduler import LambdaLR def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps): """ Trapezoidal learning rate scheduler: - Increases linearly for `warmup_steps` steps. - Remains constant for `plateau_steps` steps. - Decreases linearly for `decay_steps` steps. """ def lr_lambda(step): if step < warmup_steps: # Linear warmup return float(step) / float(max(1, warmup_steps)) elif step < warmup_steps + plateau_steps: # Constant plateau return 1.0 else: # Linear decay decay_step = step - (warmup_steps + plateau_steps) return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps))) return LambdaLR(optimizer, lr_lambda) torch.set_float32_matmul_precision('high') scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16')) save_chechpoint_iter = 50 total_iters = 10000 eval_iters = 50 eval_check = 100 warmup_iters = 700 min_lr = 0.1 * ModelArgs.max_lr lr_decay_iters = 10000 total_batch_size = 524288 micro_batch_size = ModelArgs.batch_size gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count())) # learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1) # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (ModelArgs.max_lr - min_lr) def train(): setup() device = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(int(device)) # torch.set_default_device('cuda') # train_dataloader = prepare_dataset(ModelArgs.batch_size) # rank = torch.distributed.get_rank() print(f"Start running DDP on rank {device}.") # # create model and move it to GPU with id rank # device_id = rank % torch.cuda.device_count() # CFG = ModelArgs() if(device == 0): # # Initialise run wandb.init( # entity = 'rajceo2031', project = 'Llama-DDP-Pretrain-10-billion-tokens', # config = CFG, # save_code = True, #group = 'ANN', #job_type = 'train' ) print("wand initialized") model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) # print(f"Model on device {device} is ready") print(f"Model on device {device} is ready") # Wrap model with DDP after moving to GPU # model = DDP(model, device_ids=[device]) # optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=1e-8) # # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(None, T_max=30000, eta_min=1e-6) # _load_snapshot('/kaggle/input/models/snapshot2.pt', model.module, None, None) optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps) # model = torch.compile(model) model = model.to(device) model = DDP(model, device_ids=[device]) # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot # new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps) # warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn) # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) # Cosine decay after warmup # new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) # Combine both schedulers # scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000]) # Reset learning rate to 1e-4 # for param_group in optimizer.param_groups: # param_group['lr'] = ModelArgs.max_lr # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6) # print("Old optimizer with new lr ready") # optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr) # Create DataLoader with collate_fn # train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) # val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) # print("Loader is ready") # print(train_loader) # print(next(iter(train_loader))) # for X,y in train_loader: # print(X.shape) # print(y.shape) # alpaca_prompt = ''' # ### Instruction: # {instruction} # ### Input: # {input} # ### Response: # ''' # Only create progress bar for rank 0 # eval_epoch_iterator = range(eval_iters) # train_epoch_iterator = range(total_iters) # if device == 0: # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training") # train_epoch_iterator = range(ModelArgs.epochs) # if device == 0: # Ensure tqdm only runs on rank 0 # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True) # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters) model.eval() world_size = torch.cuda.device_count() @torch.inference_mode() def estimate_loss(val_loader, val_iterator, device): out = {} # train_loader = prepare_dataset('train', ModelArgs.batch_size) # val_loader_iterator = iter(val_loader) loader = None epoch_loss = None epoch_losses = [] # print("Starting the eval...") for split in ['val']: print(f"Starting with {split} evaluation...") # losses = torch.zeros(ModelArgs.val_epochs) # if(split == 'train'): # loader = train_loader # if(split == 'val'): # loader = val_loader for step in range(eval_check): try: batch = next(val_iterator) except StopIteration: val_loader_iterator = iter(val_loader) batch = next(val_loader_iterator) total_loss = 0 # loader.sampler.set_epoch(step) total_batches = 0 # batch = next(val_loader_iterator) # for batch in loader: # Loop through DataLoader batches idx = batch['input_ids'] targets = batch['labels'] idx = idx.to(device) targets = targets.to(device) with torch.autocast(device_type=device, dtype=torch.bfloat16): logits = model(idx) batch_size, block_size, embeddings_dims = logits.shape logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens targets = targets.view(batch_size * block_size) loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) total_loss += loss.item() total_batches += 1 # Compute mean loss for this epoch epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 epoch_losses.append(epoch_loss) # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}") # Compute mean loss across all evaluation epochs out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 epoch_loss = None epoch_losses = [] model.train() return out # model = model.to(rank) model.train() count = 0 train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size) val_loader= prepare_dataset('val', device, ModelArgs.batch_size) # for step in tqdm(range(total_iters)): # for epoch in range(ModelArgs.epochs): # torch.cuda.synchronize() # train_dataloader.sampler.set_epoch(epoch) # val_loader.sampler.set_epoch(epoch) print("Loaders ready both") epochs = ModelArgs.epochs # train_step_iterator = range(len(train_dataloader)) # if device == 0: # Only create progress bar on rank 0 # train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True) # Print progress on rank 0 train_loader_length = 0 train_data_iterator = iter(train_dataloader) val_data_iterator = iter(val_loader) token_count = 0 if(device == 0): train_loader_length = len(train_dataloader) # print("Total batches: ", train_loader_length) # print("Length of : ", len(train_dataloader)) # print("Length of val: ", len(val_loader)) # for step, batch in enumerate(train_dataloader): for step in tqdm(range(total_iters)): # print("Dataloader things: ", batch) # print("Total batches: ", len(train_dataloader)) if(device == 0): # if(step % 100 == 0): # if(step == train_loader_length): # break print("Step : ", step, "/", total_iters) print('Total batches: ', len(train_dataloader)) print("Total gradient accumulation steps: ", gradient_accumulation_steps) print("Total tokens processed: ", token_count) # all_gpus_avg_train_loss = None # all_gpus_avg_val_loss = None # every once in a while evaluate the loss on train and val sets if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: losses = estimate_loss( val_loader, val_data_iterator, 'cuda') # avg_train_loss = losses['train'] avg_val_loss = losses['val'] # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # if device == 0: # Only print on main process print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}") # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}") # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # Log training loss more frequently # Aggregate average loss across all GPUs # avg_train_loss = torch.Tensor([losses['train']]).to(device) avg_val_loss = torch.Tensor([losses['val']]).to(device) # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) if device == 0: # all_gpus_avg_train_loss = avg_train_loss / world_size # print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}") all_gpus_avg_val_loss = avg_val_loss / world_size print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") # if device == 0: # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step) # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step) # writer.add_scalar("training_step_loss", losses['train'], global_step=step) # writer.add_scalar("val_step_loss", losses['val'], global_step=step) # writer.add_scalar("GPU", device, global_step=step) # writer.add_scalar("Epoch", epoch, global_step=step) wandb.log({ # "Learning Rate": optimizer.param_groups[0]['lr'], # "All_GPUs_Train_losses": all_gpus_avg_train_loss, "All_GPUs_Val_losses": all_gpus_avg_val_loss, # "training_step_loss": losses['train'], "val_step_loss": losses['val'], # "Step": step, # "Epoch": epoch }) #Loading a checkpoint # if(os.path.exists('snapshot.pt')): # model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt') # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0): # _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step) if step % save_chechpoint_iter == 0 and device == 0 and step != 0: print(f"Saving the model checkpoint for step: {step}") _save_snapshot(model, optimizer, None, None, step) accumulated_loss = 0.0 optimizer.zero_grad(set_to_none=True) for micro_step in range(gradient_accumulation_steps): try: batch = next(train_data_iterator) except StopIteration: train_data_iterator = iter(train_dataloader) batch = next(train_data_iterator) # print(batch) # batch = next(train_data_iterator) # print(batch) # batch = {k: v.to(self.local_rank) for k, v in batch.items()} idx = batch['input_ids'].to(device) # idx, targets = get_batch(split='train') # print(f"Starting the train step: {step}...") # for idx, targets in train_loader: # idx, targets = next(iter(train_loader)) # print("Idx: ", idx) # print("Targets: ", targets) # idx = idx.to(device) # print("Idx: ", idx) # print("Targets: ", targets) targets = batch['labels'].to(device) token_count += len(idx) with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16): logits = model(idx) batch_size, block_size, embeddings_dims = logits.shape # print(logits.shape) # print(targets) logits = logits.view(batch_size*block_size, embeddings_dims) # print("OK") targets = targets.view(batch_size * block_size) # print("OK2") loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch accumulated_loss += loss.detach() model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices scaler.scale(loss).backward() # Check for unused parameters unused_params = find_unused_parameters(model) if unused_params: print(f"Unused parameters: {unused_params}") # break if(device == 0): if(micro_step % 10 == 0): # if(step == train_loader_length): # break print("Micro Batch : ", micro_step) print("Step : ", step, "/", total_iters) print('Total batches: ', len(train_dataloader)) print("Total gradient accumulation steps: ", gradient_accumulation_steps) print("Total tokens processed: ", token_count) # count += 1 lr = get_lr(step) for params in optimizer.param_groups: params['lr'] = lr # Compute gradient norms before clipping if(ModelArgs.clip != 0.0): scaler.unscale_(optimizer) #To avoid underflow total_norm_before = torch.norm( torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 ) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) # Compute gradient norms after clipping total_norm_after = torch.norm( torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 ) if(device == 0 and step !=0): print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") scaler.step(optimizer) scaler.update() # optimizer.step() # new_scheduler.step() torch.cuda.synchronize() torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM) if(device == 0): wandb.log({ "Learning Rate": lr, "All_GPUs_Train_losses": accumulated_loss.item(), # "All_GPUs_Val_losses": all_gpus_avg_val_loss, # "training_step_loss": losses['train'], # "val_step_loss": losses['val'], "Step": step, # "Epoch": epoch }) # print(loss.item()) # if(step % 100 == 0): # print(f'Step : {step} | GPU: {device} Loss: {loss.item()}') # if device == 0: # print("loss: ", loss.item()) # train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"}) # print(loss.item()) # break # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) : # loss_values = estimate_loss() # print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val'])) # Add after a training step: # unused_params = find_unused_parameters(model) # print("Unused parameters:", unused_params) # break if device == 0 and step % 5 == 0: count = 3 while(count): # Only generate text on the main process # print("Generating text...") # alpaca_prompt = ''' # ### Instruction: # {} # ### Input: # {} # ### Response: # ''' # prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "") # print("Generating text") prompt = "Once upon a time" generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device) # generated_text = greedy_decode( # model, # tokenizer, # "Once upon a time", # max_length=40, # repetition_penalty=1.2, # context_window=10, # temperature=0.7, # Lower temperature for more deterministic output # device=device # ) # generated_text = beam_search(model, tokenizer, "Once upon a time ", beam_width=5, max_length=50, temperature=0.6) print(f" Step: {step} | Generated Text: {generated_text}") # model.train() # save_to_file(generated_text) count -= 1 # if step != 0: # train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"}) # break # Cleanup if device == 0: # writer.close() wandb.finish() cleanup() world_size = torch.cuda.device_count() print(f"World size: {world_size}") train()