# 185860 import random import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass # from torchtune.modules import RMSNorm from tokenizers import Tokenizer from pathlib import Path import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import torch from datasets import Dataset from torch.utils.data import DataLoader from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput import wandb from tqdm import tqdm from functools import partial import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM import os # import wandb # wandb.login() # from torch.utils.tensorboard import SummaryWriter from datasets import load_dataset, concatenate_datasets # use name="sample-10BT" to use the 10BT sample # fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False) # print(fw_train) # Select only 1000 rows from the dataset # fw_train = fw_train.select(range(1000000)) # alpaca = load_dataset("yahma/alpaca-cleaned", split='train') # dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train') # merged_dataset = concatenate_datasets([alpaca, dolly]) dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True) # print(fw_train) # Split the dataset into training and validation sets merged_dataset = dataset.train_test_split(test_size=0.1) print(merged_dataset) # fw_train = fw_train.train_test_split(test_size=0.2) # print(fw_train) # Access the splits # train_dataset = train_val_split['train'] # val_dataset = train_val_split['test'] # train_dataset = fw_train.train_test_split(test_size=0.2) def setup(rank=None, world_size=None): # os.environ['MASTER_ADDR'] = 'localhost' # os.environ['MASTER_PORT'] = '12355' init_process_group("nccl") # torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) def cleanup(): destroy_process_group() @dataclass class ModelArgs: #Hyperparameters epochs = 5 block_size = 128 batch_size = 64 embeddings_dims = 786 attn_dropout = 0.1 no_of_heads = 6 #IMP needs to be thoroughly calculated dropout = 0.1 # epochs = 100 val_epochs = 2 max_lr = 2e-4 no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated weight_decay_optim = 0.1 beta_1 = 0.9 beta_2 = 0.95 clip = 1.0 device = 'cuda' no_kv_heads = 2 vocab_size = 50258 from pathlib import Path data_path = Path('data') data_path.mkdir(exist_ok=True) # !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt # !cp input.txt data/input.txt #Datasets # Using tinyshakespeare # with open('data/input.txt', 'r', encoding='utf-8') as f: # text = f.read() # Load the tokenizer # tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json") # Encode and decode functions # encode = lambda s: tokenizer.encode(s).ids # decode = lambda l: tokenizer.decode(l) def _save_snapshot(model, optimizer, scheduler, epoch, step): snapshot = { "MODEL_STATE": model.module.state_dict(), "OPTIMIZER_STATE": optimizer.state_dict(), "SCHEDULER_STATE": scheduler.state_dict(), # NEW: Save scheduler state "EPOCHS_RUN": epoch, "STEP_RUN": step } torch.save(snapshot, "/kaggle/working/snapshot_fine_tuned_model_with_gradient_clipping_3.pt") print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.") def _load_snapshot(snapshot_path, model, optimizer, scheduler): snapshot = torch.load(snapshot_path) model.load_state_dict(snapshot["MODEL_STATE"]) # optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"]) # scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state epoch = snapshot["EPOCHS_RUN"] step = snapshot["STEP_RUN"] print(f"Resuming from Epoch {epoch}, Step {step}") return epoch, step #Subword level tokenization #Loading custom trained BPE # Load the tokenizer # tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json") # vocab_size = tokenizer.get_vocab_size() # Encode and decode functions # encode = lambda s: tokenizer.encode(s).ids # decode = lambda l: tokenizer.decode(l) ############################################################################### #Character level tokenization # # here are all the unique characters that occur in this text # chars = sorted(list(set(text))) # vocab_size = len(chars) # # create a mapping from characters to integers # stoi = { ch: i for i,ch in enumerate(chars) } # itos = { i:ch for i,ch in enumerate(chars) } # encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers # decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string # Convert the dataset to Hugging Face Dataset format # train_hf_dataset = Dataset.from_dict({"text": train_dataset['train']['text']}) # val_hf_dataset = Dataset.from_dict({"text": train_dataset['test']['text']}) # Tokenize the dataset using the `map` function # from google.colab import userdata # HF_TOKEN = userdata.get('HF_TOKEN') tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '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') # tokenizer.pad_token = tokenizer.eos_token # if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # print("ADDED THE TOKENS: ", tokenizer.pad_token_id) # tokenizer.bos_token = "[INST]" # tokenizer.eos_token = "[/INST]" # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") def tokenize_function(examples): return tokenizer( examples['text'], max_length=ModelArgs.block_size, padding='max_length', truncation=True, return_tensors='pt' ) ## Load the tokenizer # tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json") # # Tokenization functions # def encode_train(examples): # tokens = [] # for example in examples['text']: # out = tokenizer.encode(example).ids # tokens.append(out) # Append the tokenized sequence (do not flatten) # return {"tokens": tokens} # def encode_val(examples): # tokens = [] # for example in examples['text']: # out = tokenizer.encode(example).ids # tokens.append(out) # Append the tokenized sequence (do not flatten) # return {"tokens": tokens} # Apply tokenization with batching # train_data = train_dataset['train'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8) # val_data = train_dataset['test'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8) # # # Extract tokens from the processed datasets # # train_tokens = train_data['tokens'] # # val_tokens = val_data['tokens'] # # Flatten the tokenized data # # train_tokens = [token_id for seq in train_data['input_ids'] for token_id in seq] # # val_tokens = [token_id for seq in val_data['input_ids'] for token_id in seq] # try: # train_tensors = [torch.tensor(seq) for seq in tqdm(train_data['input_ids'], desc="Converting train_data to tensors")] # train_data_tensor = torch.cat(train_tensors) # except Exception as e: # print(f"Error during tensor conversion: {e}") # try: # train_tensors = [torch.tensor(seq) for seq in tqdm(val_data['input_ids'], desc="Converting train_data to tensors")] # val_data_tensor = torch.cat(train_tensors) # except Exception as e: # print(f"Error during tensor conversion: {e}") # print("Train tokens count: ", train_data_tensor) # print("Val tokens count: ", val_data_tensor) def prepare_dataset(split, batch_size): # alpaca_prompt = ''' # ### Instruction: # {} # ### Response: # {} # ''' # Load a subset of the C4 dataset with a glob pattern for specific training files # dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True) # Initialize tokenizer # tokenizer = AutoTokenizer.from_pretrained("gpt2") def collate_fn(batch): # Extract text data # texts = [item ["text"] for item in batch] # Set the pad token if it isn't set already # if tokenizer.pad_token is None: # tokenizer.pad_token = tokenizer.eos_token outputs = [] texts = [] for item in batch: instruction = item['prompt'] # input = item['input'] output = item['completion'] # out = alpaca_prompt.format(instruction, output) texts.append(instruction) outputs.append(output) # Tokenize text data input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") # output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") # out = {"input": input_encodings} # input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels # input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right # input_encodings["labels"][:, -1] = tokenizer.pad_token_id # Ignore the last token (no target for it) # Return tokenized input tensors # return out return input_encodings # Create DistributedSampler for proper shuffling and partitioning across processes # dist_sampler = DistributedSampler(fw_train["text"], shuffle=True) # Create DataLoader with custom collate_fn # print(fw_dataset) dataloader = None if(split == 'train'): data_loader = DataLoader( merged_dataset['train'], batch_size=batch_size, sampler=DistributedSampler(merged_dataset['train'], shuffle=True), collate_fn=collate_fn, drop_last=True, shuffle=False ) elif(split == 'val'): data_loader = DataLoader( merged_dataset['test'], batch_size=batch_size, sampler=DistributedSampler(merged_dataset["test"], shuffle=True), collate_fn=collate_fn, drop_last=True, shuffle=False ) return data_loader # Convert to tensors # train_data_tensor = torch.tensor(train_tokens, dtype=torch.long) # val_data_tensor = torch.tensor(val_tokens, dtype=torch.long) # # Debug output # print("Number of train tokens:", len(train_data_tensor)) # print("Number of validation tokens:", len(val_data_tensor)) # def create_sequences(data, block_size): # sequences = [] # for seq in data: # if len(seq) < block_size: # # while(len(sequence) < block_size): # # sequence = data[i:i + block_size + 1] # # Pad the sequence if it's shorter than block_size # padding_length = block_size - len(seq) # seq = torch.cat([seq, torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long)]) # sequences.append(seq) # out = torch.tensor(sequences, dtype=torch.long) # return out # train_data = create_sequences(train_data['input_ids'], ModelArgs.block_size) # val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size) def get_batch(split): # generate a small batch of data of inputs x and targets y data = train_data if split == 'train' else val_data ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,)) x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix]) y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix]) x, y = x.to(ModelArgs.device), y.to(ModelArgs.device) return x, y from torch.utils.data import Dataset class TokenDataset(Dataset): def __init__(self, data, block_size): self.data = data self.block_size = block_size def __len__(self): return len(self.data) - self.block_size # Ensure valid indexing def __getitem__(self, idx): x = self.data[idx:idx + self.block_size] y = self.data[idx + 1:idx + self.block_size + 1] return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) # train_rows = 11895089 # encoded_data = torch.tensor(encode(fw_train['text']), dtype=torch.long) # train_data = train_data[:train_rows] # val_data = val_data[train_rows:] # train_dataset = TokenDataset(train_data_tensor, ModelArgs.block_size) # val_dataset = TokenDataset(val_data_tensor, ModelArgs.block_size) # encoded_data = torch.tensor(encode(text), dtype=torch.long) # print(train_data) # print(val_data) # train_dataset = TextDataset(train_data, ModelArgs.block_size) # val_dataset = TextDataset(val_data, ModelArgs.block_size) # print(train_dataset) # print(val_dataset) # # Convert the tokenized data into a list of sequences # train_sequences = [train_data[i:i + ModelArgs.block_size] for i in range(0, len(train_data) - ModelArgs.block_size)] # val_sequences = [val_data[i:i + ModelArgs.block_size] for i in range(0, len(val_data) - ModelArgs.block_size)] # Define collate_fn # def collate_fn(batch): # block_size = ModelArgs.block_size # batch_size = len(batch) # x = torch.zeros((batch_size, block_size), dtype=torch.long) # y = torch.zeros((batch_size, block_size), dtype=torch.long) # for i, sequence in enumerate(batch): # print("Shape x: ", sequence[:-1].shape) # print("Shape of y: ", len(sequence[1:])) # x[i] = sequence[:-1] # Input is all tokens except the last one # y[i] = sequence[1:] # Target is all tokens except the first one # return x, y def create_sequences(data, block_size): sequences = [] for seq in data: len(seq) if len(seq) < block_size: # while(len(sequence) < block_size): # sequence = data[i:i + block_size + 1] # Pad the sequence if it's shorter than block_size padding_length = block_size - len(seq) seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)]) else: if len(seq) > block_size: seq = seq[:block_size] # while(len(sequence) < block_size): # sequence = data[i:i + block_size + 1] # Pad the sequence if it's shorter than block_size # padding_length = block_size - len(seq) # seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)]) sequences.append(seq) out = torch.tensor(sequences, dtype=torch.long) return out # train_data = create_sequences(train_data_flat['input_ids'], ModelArgs.block_size) # val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size) # Define collate_fn def collate_fn(split , batch): block_size = ModelArgs.block_size batch_size = len(batch) if(split == 'train'): data = train_data_tensor elif(split == 'test'): data = val_data_tensor ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,)) x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix]) y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix]) # print("Shape of x: ", len(x)) # print("Length of y: ", len(y)) # x, y = x.to(ModelArgs.device), y.to(ModelArgs.device) # x = torch.zeros((batch_size, block_size), dtype=torch.long) # y = torch.zeros((batch_size, block_size), dtype=torch.long) # for i, sequence in enumerate(batch): # print("Seq: ", sequence) # print("Shape x: ", sequence[:-1].shape) # print("Shape of y: ", len(sequence[1:])) # x[i] = sequence[:-1] # Input is all tokens except the last one # y[i] = sequence[1:] # Target is all tokens except the first one return x, y class Normalization(nn.Module): def __init__( self, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims) def forward(self, x): x = self.rmsnorm_layer(x) return x # import numpy as np class RotaryEmbeddings(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, batch_size: int = ModelArgs.batch_size ): super().__init__() self.embeddings_dims = embeddings_dims self.block_size = block_size self.batch_size = batch_size self.theta = 0 # def init_matrix(self, seq_len): # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False) # for pos in range(seq_len): # for j in range(1, self.embeddings_dims // 2): # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims) # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta)) # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta)) # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta)) # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta)) # return self.matrix self.device=device def init_matrix(self, seq_len): self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) positions = torch.arange(seq_len, dtype=torch.float32, device = self.device).unsqueeze(1) # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims) angles = positions * theta cos_angles = torch.cos(angles) sin_angles = torch.sin(angles) indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device) # print(indices) # print(indices.shape) # print(indices[::2]) even_indices = indices[::2] odd_indices = indices[1::2] self.matrix[:, even_indices, even_indices] = cos_angles self.matrix[:, odd_indices, odd_indices] = sin_angles self.matrix[:, odd_indices, even_indices] = -sin_angles self.matrix[:, even_indices, odd_indices] = cos_angles return self.matrix def forward(self, x): # B,T,C = x.shape # print("MATRIX:",x) if(x > self.block_size or x < self.block_size): matrix = self.init_matrix(x) return matrix else: matrix = self.init_matrix(self.block_size) return matrix class RotaryAttentionHead(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, no_of_heads: int = ModelArgs.no_of_heads, attn_dropout: int = ModelArgs.attn_dropout ): super().__init__() self.head_size = embeddings_dims // no_of_heads self.query = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device) self.dropout = nn.Dropout(p = attn_dropout) self.device = device def forward(self,x): # print(x.shape) batch, block_size, embeddings_dims = x.shape query = self.query(x) # print(query) key = self.key(x) values = self.value(x) matrix = self.rotary_matrix(block_size) # print(matrix.shape) # print(query.shape) masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) weights_masked = weights.masked_fill(masked == 0, float('-inf')) scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) scaled_weights = F.softmax(scaled_weights, dim=-1) value = scaled_weights @ values out = self.dropout(value) return out class MQA(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, no_of_kv_heads: int = ModelArgs.no_of_heads, no_of_heads: int = ModelArgs.no_of_heads, ): super().__init__() self.no_of_kv_heads = no_of_kv_heads self.no_of_q_heads = no_of_heads // no_of_kv_heads self.head_size = embeddings_dims // self.no_of_q_heads self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device) # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False) self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) self.device = device self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, device = self.device) for _ in range(self.no_of_q_heads)]) def scaled_dot_product(self, q, k, v, block_size, matrix): # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) weights_masked = weights.masked_fill(masked == 0, float('-inf')) scaled_weights = weights_masked / (torch.sqrt(torch.tensor(k.shape[-1]))) scaled_weights = F.softmax(scaled_weights, dim=-1) value = scaled_weights @ v out = self.dropout(value) return value def forward(self,x): # print("MQA: ", x.shape) batch, block_size, embeddings_dims = x.shape # query = self.query(x) matrix = self.rotary_matrix(block_size) key = self.key(x) values = self.value(x) multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size, matrix) for query in self.multi_query], dim=-1) linear_layer= self.linear_layer(multi_query_concat) out = self.dropout(linear_layer) return out class GQA(nn.Module): def __init__( self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, no_of_q_heads: int = ModelArgs.no_of_heads, no_of_kv_heads: int = ModelArgs.no_kv_heads ): super().__init__() self.no_of_kv_heads = no_of_kv_heads self.no_of_q_heads = no_of_q_heads self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device) self.device = device self.mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_kv_heads)]) def forward(self,x): batch, block_size, embeddings_dims = x.shape grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1) linear_layer= self.linear_layer(grouped_query_concat) out = self.dropout(linear_layer) return out class Swish(nn.Module): def __init__( self, device, block_size: int = ModelArgs.block_size, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.sig = torch.nn.Sigmoid() def forward(self, x): swish = x * self.sig(x) return swish class SWiGLU(nn.Module): def __init__( self, device, block_size: int = ModelArgs.block_size, embeddings_dims: int = ModelArgs.embeddings_dims ): super().__init__() self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device) self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) def forward(self, x): swish_res = self.swish(self.linear_layer1(x)) x_V = self.linear_layer2(x) res = torch.mul(swish_res, x_V) out = self.linear_layer3(res) return out class FFN(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, dropout = ModelArgs.dropout ): super().__init__() self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device) self.dropout = nn.Dropout(p = dropout) def forward(self, x): x = self.swiglue(x) x = self.linear_layer(x) x = self.dropout(x) return x class DecoderLayer(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, dropout = ModelArgs.dropout, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, ) : super().__init__() self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device) self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, no_of_kv_heads=ModelArgs.no_kv_heads, no_of_q_heads=ModelArgs.no_of_heads, device = device) # self.norm = Normalization(embeddings_dims=embeddings_dims) self.norm1 = Normalization(embeddings_dims=embeddings_dims) self.norm2 = Normalization(embeddings_dims=embeddings_dims) self.dropout = nn.Dropout(p = dropout) def forward(self, x): x = self.norm1(x + self.gqa(x)) x = self.norm2(x + self.feedforward_network(x)) return x class Llama(nn.Module): def __init__(self, device, embeddings_dims: int = ModelArgs.embeddings_dims, no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers, block_size: int = ModelArgs.block_size, vocab_size: int = ModelArgs.vocab_size, dropout = ModelArgs.dropout ) : super().__init__() self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device) self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)]) self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device) self.dropout = nn.Dropout(p = dropout) # self.norm = Normalization(embeddings_dims) def forward(self, x): x = self.embeddings(x) x = self.dropout(x) x = self.decoder(x) # x = self.norm(x) x = self.linear_layer(x) # out = self.norm(x) return x # device = "cuda" if torch.cuda.is_available() else "cpu" # # device = "cpu" # ModelArgs.device = device # model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) # model = model.to(ModelArgs.device) #Printing a summary of the architecture # !pip install torchinfo # from torchinfo import summary # # idx, targets = get_batch('test') # idx = torch.randint( # low=0, # high=ModelArgs.vocab_size, # size=(ModelArgs.batch_size, ModelArgs.block_size), # dtype=torch.long # ) # # sample_idx = random.randint(range(len(train_dataset))) # # idx, targets = train_dataset[0] # idx = idx.to(ModelArgs.device) # # targets = targets.to(ModelArgs.device) # summary(model=model, # input_data=idx, # # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims), # col_names=["input_size", "output_size", "num_params", "trainable"], # col_width=20, # row_settings=["var_names"]) def find_unused_parameters(model): unused = [] for name, param in model.named_parameters(): if param.grad is None: unused.append(name) return unused def greedy_decode( model, tokenizer, prompt, max_length=50, repetition_penalty=1.2, context_window=10, temperature=1.0, eos_token_id=None ): device = next(model.parameters()).device input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] generated_tokens = [] eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided for _ in range(max_length): outputs = model(input_ids) logits = outputs[:, -1, :] # Get logits for the last token # Apply temperature scaling if temperature != 1.0: logits = logits / temperature # Apply repetition penalty if repetition_penalty != 1.0 and len(generated_tokens) > 0: for token in set(generated_tokens[-context_window:]): # Penalize recent tokens logits[0, token] /= repetition_penalty # Greedy selection next_token = torch.argmax(logits, dim=-1).unsqueeze(0) generated_tokens.append(next_token.item()) # Stop if EOS token is generated if next_token.item() == eos_token_id: break # Append the new token to the input input_ids = torch.cat([input_ids, next_token], dim=1) # Decode the generated tokens return tokenizer.decode(generated_tokens, skip_special_tokens=True) def save_to_file(text): with open('generations.txt', 'a') as f: f.writelines(text + "\n\n") #Train the model # writer = SummaryWriter(log_dir="runs/experiment") from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR # Warmup phase for 2000 steps def warmup_fn(step): if step < 2000: return step / 2000 # LR gradually increases return 1.0 from torch.optim.lr_scheduler import LambdaLR def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps): """ Trapezoidal learning rate scheduler: - Increases linearly for `warmup_steps` steps. - Remains constant for `plateau_steps` steps. - Decreases linearly for `decay_steps` steps. """ def lr_lambda(step): if step < warmup_steps: # Linear warmup return float(step) / float(max(1, warmup_steps)) elif step < warmup_steps + plateau_steps: # Constant plateau return 1.0 else: # Linear decay decay_step = step - (warmup_steps + plateau_steps) return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps))) return LambdaLR(optimizer, lr_lambda) torch.set_float32_matmul_precision('high') def train(): setup() device = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(int(device)) # train_dataloader = prepare_dataset(ModelArgs.batch_size) # rank = torch.distributed.get_rank() print(f"Start running DDP on rank {device}.") # # create model and move it to GPU with id rank # device_id = rank % torch.cuda.device_count() # CFG = ModelArgs() if(device == 0): # # Initialise run wandb.init( # entity = 'rajceo2031', project = 'Llama-DDP-Pretrain-10-billion-tokens', # config = CFG, # save_code = True, #group = 'ANN', #job_type = 'train' ) model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) # Optimizer setup and scheduler steup model = model.to(device) print(f"Model on device {device} is ready") # Wrap model with DDP after moving to GPU # model = DDP(model, device_ids=[device]) optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim) # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30000, eta_min=1e-6) _load_snapshot('/kaggle/input/models/snapshot2.pt', model, optimizer, scheduler) optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim) # model = torch.compile(model) # Define the trapezoidal learning rate scheduler total_steps = 100000 # Total steps (40k + 20k + 40k) warmup_steps = 40000 # Steps for warmup (increase) plateau_steps = 20000 # Steps for plateau (constant) decay_steps = 40000 # Steps for decay (decrease) # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps) # warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn) # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) # Cosine decay after warmup # new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) # Combine both schedulers # scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000]) # Reset learning rate to 1e-4 # for param_group in optimizer.param_groups: # param_group['lr'] = ModelArgs.max_lr # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6) # print("Old optimizer with new lr ready") model = DDP(model, device_ids=[device]) print(f"Model on device {device} is ready") # optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr) # Create DataLoader with collate_fn # train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) # val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) # print("Loader is ready") # print(train_loader) # print(next(iter(train_loader))) save_chechpoint_iter = 1000 total_iters = 20000 eval_iters = 200 eval_check = 100 # for X,y in train_loader: # print(X.shape) # print(y.shape) # alpaca_prompt = ''' # ### Instruction: # {instruction} # ### Input: # {input} # ### Response: # ''' # Only create progress bar for rank 0 # eval_epoch_iterator = range(eval_iters) # train_epoch_iterator = range(total_iters) # if device == 0: # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training") # train_epoch_iterator = range(ModelArgs.epochs) # if device == 0: # Ensure tqdm only runs on rank 0 # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True) # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters) world_size = torch.cuda.device_count() @torch.inference_mode() def estimate_loss(val_loader, train_loader=None): out = {} # train_loader = prepare_dataset('train', ModelArgs.batch_size) model.eval() loader = None epoch_loss = None epoch_losses = [] # print("Starting the eval...") for split in ['train', 'val']: print(f"Starting with {split} evaluation...") # losses = torch.zeros(ModelArgs.val_epochs) if(split == 'train'): loader = train_loader if(split == 'val'): loader = val_loader for step in range(eval_check): total_loss = 0 # loader.sampler.set_epoch(step) total_batches = 0 batch = next(iter(loader)) # for batch in loader: # Loop through DataLoader batches idx = batch['input_ids'] targets = batch['labels']['input_ids'] idx = idx.to(device) targets = targets.to(device) logits = model(idx) batch_size, block_size, embeddings_dims = logits.shape logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens targets = targets.view(batch_size * block_size) loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) total_loss += loss.item() total_batches += 1 # Compute mean loss for this epoch epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 epoch_losses.append(epoch_loss) # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}") # Compute mean loss across all evaluation epochs out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 epoch_loss = None epoch_losses = [] model.train() return out # model = model.to(rank) model.train() train_dataloader = prepare_dataset('train', ModelArgs.batch_size) val_loader= prepare_dataset('val', ModelArgs.batch_size) # for step in tqdm(range(total_iters)): for epoch in range(ModelArgs.epochs): # torch.cuda.synchronize() train_dataloader.sampler.set_epoch(epoch) val_loader.sampler.set_epoch(epoch) print("Loaders ready both") epochs = ModelArgs.epochs # train_step_iterator = range(len(train_dataloader)) # if device == 0: # Only create progress bar on rank 0 # train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True) # Print progress on rank 0 train_loader_length = 0 if(device == 0): train_loader_length = len(train_dataloader) print("Total batches: ", train_loader_length) # print("Length of : ", len(train_dataloader)) # print("Length of val: ", len(val_loader)) for step, batch in enumerate(train_dataloader): # print("Dataloader things: ", batch) # print("Total batches: ", len(train_dataloader)) if(device == 0): if(step % 100 == 0): # if(step == train_loader_length): # break print("Batch : ", step, "/", len(train_dataloader)) # all_gpus_avg_train_loss = None # all_gpus_avg_val_loss = None # every once in a while evaluate the loss on train and val sets if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: losses = estimate_loss( val_loader, train_dataloader) avg_train_loss = losses['train'] avg_val_loss = losses['val'] # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # if device == 0: # Only print on main process print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}") # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}") # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # Log training loss more frequently # Aggregate average loss across all GPUs avg_train_loss = torch.Tensor([losses['train']]).to(device) avg_val_loss = torch.Tensor([losses['val']]).to(device) torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) if device == 0: all_gpus_avg_train_loss = avg_train_loss / world_size print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}") all_gpus_avg_val_loss = avg_val_loss / world_size print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") # if device == 0: # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step) # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step) # writer.add_scalar("training_step_loss", losses['train'], global_step=step) # writer.add_scalar("val_step_loss", losses['val'], global_step=step) # writer.add_scalar("GPU", device, global_step=step) # writer.add_scalar("Epoch", epoch, global_step=step) wandb.log({ "Learning Rate": new_scheduler.get_last_lr()[0] , "All_GPUs_Train_losses": all_gpus_avg_train_loss, "All_GPUs_Val_losses": all_gpus_avg_val_loss, "training_step_loss": losses['train'], "val_step_loss": losses['val'], "Step": step, "Epoch": epoch }) #Loading a checkpoint # if(os.path.exists('snapshot.pt')): # model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt') # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0): # _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step) if step % save_chechpoint_iter == 0 and device == 0 and step != 0: print(f"Saving the model checkpoint for step: {step}") _save_snapshot(model, optimizer, scheduler, epoch, step) # batch = {k: v.to(self.local_rank) for k, v in batch.items()} idx = batch['input_ids'].to(device) # idx, targets = get_batch(split='train') # print(f"Starting the train step: {step}...") # for idx, targets in train_loader: # idx, targets = next(iter(train_loader)) # print("Idx: ", idx) # print("Targets: ", targets) # idx = idx.to(device) # print("Idx: ", idx) # print("Targets: ", targets) targets = batch['labels']['input_ids'].to(device) # with torch.autocast(device_type=device, dtype=torch.bfloat16()): logits = model(idx) batch_size, block_size, embeddings_dims = logits.shape # print(logits.shape) # print(targets) logits = logits.view(batch_size*block_size, embeddings_dims) # print("OK") targets = targets.view(batch_size * block_size) # print("OK2") loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) optimizer.zero_grad(set_to_none=True) loss.backward() # Compute gradient norms before clipping total_norm_before = torch.norm( torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 ) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) # Compute gradient norms after clipping total_norm_after = torch.norm( torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 ) if(device == 0 and step !=0 and step % 100 == 0): print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") optimizer.step() new_scheduler.step() # torch.cuda.synchronize() # print(loss.item()) # if(step % 100 == 0): # print(f'Step : {step} | GPU: {device} Loss: {loss.item()}') # if device == 0: # print("loss: ", loss.item()) # train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"}) # print(loss.item()) # break # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) : # loss_values = estimate_loss() # print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val'])) # Add after a training step: # unused_params = find_unused_parameters(model) # print("Unused parameters:", unused_params) # break # if device == 0 and step % 200 == 0 and step != 0: # count = 5 # while(count): # Only generate text on the main process # print("Generating text...") # alpaca_prompt = ''' # ### Instruction: # {} # ### Input: # {} # ### Response: # ''' # prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "") # generated_text = greedy_decode( # model, # tokenizer, # prompt, # max_length=60, # repetition_penalty=1.2, # context_window=10, # temperature=0.7 # Lower temperature for more deterministic output # ) # # generated_text = beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0) # print(f" Step: {step} | Generated Text: {generated_text}") # save_to_file(generated_text) # count -= 1 # if step != 0: # train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"}) # break # Cleanup if device == 0: # writer.close() wandb.finish() cleanup() world_size = torch.cuda.device_count() print(f"World size: {world_size}") train()