Spaces:
Paused
Paused
# 185860 | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from dataclasses import dataclass | |
# from torchtune.modules import RMSNorm | |
from tokenizers import Tokenizer | |
from pathlib import Path | |
import torch.multiprocessing as mp | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed import init_process_group, destroy_process_group | |
import torch | |
from datasets import Dataset | |
from torch.utils.data import DataLoader | |
from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput | |
import wandb | |
from tqdm import tqdm | |
from functools import partial | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
# Load model directly | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
# import wandb | |
# wandb.login() | |
# from torch.utils.tensorboard import SummaryWriter | |
from datasets import load_dataset, concatenate_datasets | |
# use name="sample-10BT" to use the 10BT sample | |
# fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False) | |
# print(fw_train) | |
# Select only 1000 rows from the dataset | |
# fw_train = fw_train.select(range(1000000)) | |
# alpaca = load_dataset("yahma/alpaca-cleaned", split='train') | |
# dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train') | |
# merged_dataset = concatenate_datasets([alpaca, dolly]) | |
dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True) | |
# print(fw_train) | |
# Split the dataset into training and validation sets | |
merged_dataset = dataset.train_test_split(test_size=0.1) | |
print(merged_dataset) | |
# fw_train = fw_train.train_test_split(test_size=0.2) | |
# print(fw_train) | |
# Access the splits | |
# train_dataset = train_val_split['train'] | |
# val_dataset = train_val_split['test'] | |
# train_dataset = fw_train.train_test_split(test_size=0.2) | |
def setup(rank=None, world_size=None): | |
# os.environ['MASTER_ADDR'] = 'localhost' | |
# os.environ['MASTER_PORT'] = '12355' | |
init_process_group("nccl") | |
# torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) | |
def cleanup(): | |
destroy_process_group() | |
class ModelArgs: | |
#Hyperparameters | |
epochs = 5 | |
block_size = 128 | |
batch_size = 64 | |
embeddings_dims = 786 | |
attn_dropout = 0.1 | |
no_of_heads = 6 #IMP needs to be thoroughly calculated | |
dropout = 0.1 | |
# epochs = 100 | |
val_epochs = 2 | |
max_lr = 2e-4 | |
no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated | |
weight_decay_optim = 0.1 | |
beta_1 = 0.9 | |
beta_2 = 0.95 | |
clip = 1.0 | |
device = 'cuda' | |
no_kv_heads = 2 | |
vocab_size = 50258 | |
from pathlib import Path | |
data_path = Path('data') | |
data_path.mkdir(exist_ok=True) | |
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt | |
# !cp input.txt data/input.txt | |
#Datasets | |
# Using tinyshakespeare | |
# with open('data/input.txt', 'r', encoding='utf-8') as f: | |
# text = f.read() | |
# Load the tokenizer | |
# tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json") | |
# Encode and decode functions | |
# encode = lambda s: tokenizer.encode(s).ids | |
# decode = lambda l: tokenizer.decode(l) | |
def _save_snapshot(model, optimizer, scheduler, epoch, step): | |
snapshot = { | |
"MODEL_STATE": model.module.state_dict(), | |
"OPTIMIZER_STATE": optimizer.state_dict(), | |
"SCHEDULER_STATE": scheduler.state_dict(), # NEW: Save scheduler state | |
"EPOCHS_RUN": epoch, | |
"STEP_RUN": step | |
} | |
torch.save(snapshot, "/kaggle/working/snapshot_fine_tuned_model_with_gradient_clipping_3.pt") | |
print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.") | |
def _load_snapshot(snapshot_path, model, optimizer, scheduler): | |
snapshot = torch.load(snapshot_path) | |
model.load_state_dict(snapshot["MODEL_STATE"]) | |
# optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"]) | |
# scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state | |
epoch = snapshot["EPOCHS_RUN"] | |
step = snapshot["STEP_RUN"] | |
print(f"Resuming from Epoch {epoch}, Step {step}") | |
return epoch, step | |
#Subword level tokenization | |
#Loading custom trained BPE | |
# Load the tokenizer | |
# tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json") | |
# vocab_size = tokenizer.get_vocab_size() | |
# Encode and decode functions | |
# encode = lambda s: tokenizer.encode(s).ids | |
# decode = lambda l: tokenizer.decode(l) | |
############################################################################### | |
#Character level tokenization | |
# # here are all the unique characters that occur in this text | |
# chars = sorted(list(set(text))) | |
# vocab_size = len(chars) | |
# # create a mapping from characters to integers | |
# stoi = { ch: i for i,ch in enumerate(chars) } | |
# itos = { i:ch for i,ch in enumerate(chars) } | |
# encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers | |
# decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string | |
# Convert the dataset to Hugging Face Dataset format | |
# train_hf_dataset = Dataset.from_dict({"text": train_dataset['train']['text']}) | |
# val_hf_dataset = Dataset.from_dict({"text": train_dataset['test']['text']}) | |
# Tokenize the dataset using the `map` function | |
# from google.colab import userdata | |
# HF_TOKEN = userdata.get('HF_TOKEN') | |
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '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') | |
# tokenizer.pad_token = tokenizer.eos_token | |
# if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
# print("ADDED THE TOKENS: ", tokenizer.pad_token_id) | |
# tokenizer.bos_token = "[INST]" | |
# tokenizer.eos_token = "[/INST]" | |
# model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
def tokenize_function(examples): | |
return tokenizer( | |
examples['text'], | |
max_length=ModelArgs.block_size, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
## Load the tokenizer | |
# tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json") | |
# # Tokenization functions | |
# def encode_train(examples): | |
# tokens = [] | |
# for example in examples['text']: | |
# out = tokenizer.encode(example).ids | |
# tokens.append(out) # Append the tokenized sequence (do not flatten) | |
# return {"tokens": tokens} | |
# def encode_val(examples): | |
# tokens = [] | |
# for example in examples['text']: | |
# out = tokenizer.encode(example).ids | |
# tokens.append(out) # Append the tokenized sequence (do not flatten) | |
# return {"tokens": tokens} | |
# Apply tokenization with batching | |
# train_data = train_dataset['train'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8) | |
# val_data = train_dataset['test'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8) | |
# # # Extract tokens from the processed datasets | |
# # train_tokens = train_data['tokens'] | |
# # val_tokens = val_data['tokens'] | |
# # Flatten the tokenized data | |
# # train_tokens = [token_id for seq in train_data['input_ids'] for token_id in seq] | |
# # val_tokens = [token_id for seq in val_data['input_ids'] for token_id in seq] | |
# try: | |
# train_tensors = [torch.tensor(seq) for seq in tqdm(train_data['input_ids'], desc="Converting train_data to tensors")] | |
# train_data_tensor = torch.cat(train_tensors) | |
# except Exception as e: | |
# print(f"Error during tensor conversion: {e}") | |
# try: | |
# train_tensors = [torch.tensor(seq) for seq in tqdm(val_data['input_ids'], desc="Converting train_data to tensors")] | |
# val_data_tensor = torch.cat(train_tensors) | |
# except Exception as e: | |
# print(f"Error during tensor conversion: {e}") | |
# print("Train tokens count: ", train_data_tensor) | |
# print("Val tokens count: ", val_data_tensor) | |
def prepare_dataset(split, batch_size): | |
# alpaca_prompt = ''' | |
# ### Instruction: | |
# {} | |
# ### Response: | |
# {} | |
# ''' | |
# Load a subset of the C4 dataset with a glob pattern for specific training files | |
# dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True) | |
# Initialize tokenizer | |
# tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
def collate_fn(batch): | |
# Extract text data | |
# texts = [item ["text"] for item in batch] | |
# Set the pad token if it isn't set already | |
# if tokenizer.pad_token is None: | |
# tokenizer.pad_token = tokenizer.eos_token | |
outputs = [] | |
texts = [] | |
for item in batch: | |
instruction = item['prompt'] | |
# input = item['input'] | |
output = item['completion'] | |
# out = alpaca_prompt.format(instruction, output) | |
texts.append(instruction) | |
outputs.append(output) | |
# Tokenize text data | |
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
# output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
# out = {"input": input_encodings} | |
# input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels | |
# input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right | |
# input_encodings["labels"][:, -1] = tokenizer.pad_token_id # Ignore the last token (no target for it) | |
# Return tokenized input tensors | |
# return out | |
return input_encodings | |
# Create DistributedSampler for proper shuffling and partitioning across processes | |
# dist_sampler = DistributedSampler(fw_train["text"], shuffle=True) | |
# Create DataLoader with custom collate_fn | |
# print(fw_dataset) | |
dataloader = None | |
if(split == 'train'): | |
data_loader = DataLoader( | |
merged_dataset['train'], | |
batch_size=batch_size, | |
sampler=DistributedSampler(merged_dataset['train'], shuffle=True), | |
collate_fn=collate_fn, | |
drop_last=True, | |
shuffle=False | |
) | |
elif(split == 'val'): | |
data_loader = DataLoader( | |
merged_dataset['test'], | |
batch_size=batch_size, | |
sampler=DistributedSampler(merged_dataset["test"], shuffle=True), | |
collate_fn=collate_fn, | |
drop_last=True, | |
shuffle=False | |
) | |
return data_loader | |
# Convert to tensors | |
# train_data_tensor = torch.tensor(train_tokens, dtype=torch.long) | |
# val_data_tensor = torch.tensor(val_tokens, dtype=torch.long) | |
# # Debug output | |
# print("Number of train tokens:", len(train_data_tensor)) | |
# print("Number of validation tokens:", len(val_data_tensor)) | |
# def create_sequences(data, block_size): | |
# sequences = [] | |
# for seq in data: | |
# if len(seq) < block_size: | |
# # while(len(sequence) < block_size): | |
# # sequence = data[i:i + block_size + 1] | |
# # Pad the sequence if it's shorter than block_size | |
# padding_length = block_size - len(seq) | |
# seq = torch.cat([seq, torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long)]) | |
# sequences.append(seq) | |
# out = torch.tensor(sequences, dtype=torch.long) | |
# return out | |
# train_data = create_sequences(train_data['input_ids'], ModelArgs.block_size) | |
# val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size) | |
def get_batch(split): | |
# generate a small batch of data of inputs x and targets y | |
data = train_data if split == 'train' else val_data | |
ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,)) | |
x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix]) | |
y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix]) | |
x, y = x.to(ModelArgs.device), y.to(ModelArgs.device) | |
return x, y | |
from torch.utils.data import Dataset | |
class TokenDataset(Dataset): | |
def __init__(self, data, block_size): | |
self.data = data | |
self.block_size = block_size | |
def __len__(self): | |
return len(self.data) - self.block_size # Ensure valid indexing | |
def __getitem__(self, idx): | |
x = self.data[idx:idx + self.block_size] | |
y = self.data[idx + 1:idx + self.block_size + 1] | |
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) | |
# train_rows = 11895089 | |
# encoded_data = torch.tensor(encode(fw_train['text']), dtype=torch.long) | |
# train_data = train_data[:train_rows] | |
# val_data = val_data[train_rows:] | |
# train_dataset = TokenDataset(train_data_tensor, ModelArgs.block_size) | |
# val_dataset = TokenDataset(val_data_tensor, ModelArgs.block_size) | |
# encoded_data = torch.tensor(encode(text), dtype=torch.long) | |
# print(train_data) | |
# print(val_data) | |
# train_dataset = TextDataset(train_data, ModelArgs.block_size) | |
# val_dataset = TextDataset(val_data, ModelArgs.block_size) | |
# print(train_dataset) | |
# print(val_dataset) | |
# # Convert the tokenized data into a list of sequences | |
# train_sequences = [train_data[i:i + ModelArgs.block_size] for i in range(0, len(train_data) - ModelArgs.block_size)] | |
# val_sequences = [val_data[i:i + ModelArgs.block_size] for i in range(0, len(val_data) - ModelArgs.block_size)] | |
# Define collate_fn | |
# def collate_fn(batch): | |
# block_size = ModelArgs.block_size | |
# batch_size = len(batch) | |
# x = torch.zeros((batch_size, block_size), dtype=torch.long) | |
# y = torch.zeros((batch_size, block_size), dtype=torch.long) | |
# for i, sequence in enumerate(batch): | |
# print("Shape x: ", sequence[:-1].shape) | |
# print("Shape of y: ", len(sequence[1:])) | |
# x[i] = sequence[:-1] # Input is all tokens except the last one | |
# y[i] = sequence[1:] # Target is all tokens except the first one | |
# return x, y | |
def create_sequences(data, block_size): | |
sequences = [] | |
for seq in data: | |
len(seq) | |
if len(seq) < block_size: | |
# while(len(sequence) < block_size): | |
# sequence = data[i:i + block_size + 1] | |
# Pad the sequence if it's shorter than block_size | |
padding_length = block_size - len(seq) | |
seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)]) | |
else: | |
if len(seq) > block_size: | |
seq = seq[:block_size] | |
# while(len(sequence) < block_size): | |
# sequence = data[i:i + block_size + 1] | |
# Pad the sequence if it's shorter than block_size | |
# padding_length = block_size - len(seq) | |
# seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)]) | |
sequences.append(seq) | |
out = torch.tensor(sequences, dtype=torch.long) | |
return out | |
# train_data = create_sequences(train_data_flat['input_ids'], ModelArgs.block_size) | |
# val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size) | |
# Define collate_fn | |
def collate_fn(split , batch): | |
block_size = ModelArgs.block_size | |
batch_size = len(batch) | |
if(split == 'train'): | |
data = train_data_tensor | |
elif(split == 'test'): | |
data = val_data_tensor | |
ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,)) | |
x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix]) | |
y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix]) | |
# print("Shape of x: ", len(x)) | |
# print("Length of y: ", len(y)) | |
# x, y = x.to(ModelArgs.device), y.to(ModelArgs.device) | |
# x = torch.zeros((batch_size, block_size), dtype=torch.long) | |
# y = torch.zeros((batch_size, block_size), dtype=torch.long) | |
# for i, sequence in enumerate(batch): | |
# print("Seq: ", sequence) | |
# print("Shape x: ", sequence[:-1].shape) | |
# print("Shape of y: ", len(sequence[1:])) | |
# x[i] = sequence[:-1] # Input is all tokens except the last one | |
# y[i] = sequence[1:] # Target is all tokens except the first one | |
return x, y | |
class Normalization(nn.Module): | |
def __init__( | |
self, | |
embeddings_dims: int = ModelArgs.embeddings_dims | |
): | |
super().__init__() | |
self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims) | |
def forward(self, x): | |
x = self.rmsnorm_layer(x) | |
return x | |
# import numpy as np | |
class RotaryEmbeddings(nn.Module): | |
def __init__( | |
self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
block_size: int = ModelArgs.block_size, | |
batch_size: int = ModelArgs.batch_size | |
): | |
super().__init__() | |
self.embeddings_dims = embeddings_dims | |
self.block_size = block_size | |
self.batch_size = batch_size | |
self.theta = 0 | |
# def init_matrix(self, seq_len): | |
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False) | |
# for pos in range(seq_len): | |
# for j in range(1, self.embeddings_dims // 2): | |
# self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims) | |
# self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta)) | |
# self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta)) | |
# self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta)) | |
# self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta)) | |
# return self.matrix | |
self.device=device | |
def init_matrix(self, seq_len): | |
self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) | |
positions = torch.arange(seq_len, dtype=torch.float32, device = self.device).unsqueeze(1) | |
# dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) | |
theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims) | |
angles = positions * theta | |
cos_angles = torch.cos(angles) | |
sin_angles = torch.sin(angles) | |
indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device) | |
# print(indices) | |
# print(indices.shape) | |
# print(indices[::2]) | |
even_indices = indices[::2] | |
odd_indices = indices[1::2] | |
self.matrix[:, even_indices, even_indices] = cos_angles | |
self.matrix[:, odd_indices, odd_indices] = sin_angles | |
self.matrix[:, odd_indices, even_indices] = -sin_angles | |
self.matrix[:, even_indices, odd_indices] = cos_angles | |
return self.matrix | |
def forward(self, x): | |
# B,T,C = x.shape | |
# print("MATRIX:",x) | |
if(x > self.block_size or x < self.block_size): | |
matrix = self.init_matrix(x) | |
return matrix | |
else: | |
matrix = self.init_matrix(self.block_size) | |
return matrix | |
class RotaryAttentionHead(nn.Module): | |
def __init__( | |
self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
no_of_heads: int = ModelArgs.no_of_heads, | |
attn_dropout: int = ModelArgs.attn_dropout | |
): | |
super().__init__() | |
self.head_size = embeddings_dims // no_of_heads | |
self.query = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device) | |
self.dropout = nn.Dropout(p = attn_dropout) | |
self.device = device | |
def forward(self,x): | |
# print(x.shape) | |
batch, block_size, embeddings_dims = x.shape | |
query = self.query(x) | |
# print(query) | |
key = self.key(x) | |
values = self.value(x) | |
matrix = self.rotary_matrix(block_size) | |
# print(matrix.shape) | |
# print(query.shape) | |
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) | |
weights_masked = weights.masked_fill(masked == 0, float('-inf')) | |
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) | |
scaled_weights = F.softmax(scaled_weights, dim=-1) | |
value = scaled_weights @ values | |
out = self.dropout(value) | |
return out | |
class MQA(nn.Module): | |
def __init__( | |
self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
block_size: int = ModelArgs.block_size, | |
no_of_kv_heads: int = ModelArgs.no_of_heads, | |
no_of_heads: int = ModelArgs.no_of_heads, | |
): | |
super().__init__() | |
self.no_of_kv_heads = no_of_kv_heads | |
self.no_of_q_heads = no_of_heads // no_of_kv_heads | |
self.head_size = embeddings_dims // self.no_of_q_heads | |
self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device) | |
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False) | |
self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) | |
self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) | |
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) | |
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) | |
self.device = device | |
self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, device = self.device) for _ in range(self.no_of_q_heads)]) | |
def scaled_dot_product(self, q, k, v, block_size, matrix): | |
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) | |
weights_masked = weights.masked_fill(masked == 0, float('-inf')) | |
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(k.shape[-1]))) | |
scaled_weights = F.softmax(scaled_weights, dim=-1) | |
value = scaled_weights @ v | |
out = self.dropout(value) | |
return value | |
def forward(self,x): | |
# print("MQA: ", x.shape) | |
batch, block_size, embeddings_dims = x.shape | |
# query = self.query(x) | |
matrix = self.rotary_matrix(block_size) | |
key = self.key(x) | |
values = self.value(x) | |
multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size, matrix) for query in self.multi_query], dim=-1) | |
linear_layer= self.linear_layer(multi_query_concat) | |
out = self.dropout(linear_layer) | |
return out | |
class GQA(nn.Module): | |
def __init__( | |
self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
block_size: int = ModelArgs.block_size, | |
no_of_q_heads: int = ModelArgs.no_of_heads, | |
no_of_kv_heads: int = ModelArgs.no_kv_heads | |
): | |
super().__init__() | |
self.no_of_kv_heads = no_of_kv_heads | |
self.no_of_q_heads = no_of_q_heads | |
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) | |
self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device) | |
self.device = device | |
self.mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_kv_heads)]) | |
def forward(self,x): | |
batch, block_size, embeddings_dims = x.shape | |
grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1) | |
linear_layer= self.linear_layer(grouped_query_concat) | |
out = self.dropout(linear_layer) | |
return out | |
class Swish(nn.Module): | |
def __init__( | |
self, | |
device, | |
block_size: int = ModelArgs.block_size, | |
embeddings_dims: int = ModelArgs.embeddings_dims | |
): | |
super().__init__() | |
self.sig = torch.nn.Sigmoid() | |
def forward(self, x): | |
swish = x * self.sig(x) | |
return swish | |
class SWiGLU(nn.Module): | |
def __init__( | |
self, | |
device, | |
block_size: int = ModelArgs.block_size, | |
embeddings_dims: int = ModelArgs.embeddings_dims | |
): | |
super().__init__() | |
self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device) | |
self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
def forward(self, x): | |
swish_res = self.swish(self.linear_layer1(x)) | |
x_V = self.linear_layer2(x) | |
res = torch.mul(swish_res, x_V) | |
out = self.linear_layer3(res) | |
return out | |
class FFN(nn.Module): | |
def __init__(self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
block_size: int = ModelArgs.block_size, | |
vocab_size: int = ModelArgs.vocab_size, | |
dropout = ModelArgs.dropout | |
): | |
super().__init__() | |
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) | |
self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device) | |
self.dropout = nn.Dropout(p = dropout) | |
def forward(self, x): | |
x = self.swiglue(x) | |
x = self.linear_layer(x) | |
x = self.dropout(x) | |
return x | |
class DecoderLayer(nn.Module): | |
def __init__(self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
dropout = ModelArgs.dropout, | |
block_size: int = ModelArgs.block_size, | |
vocab_size: int = ModelArgs.vocab_size, | |
) : | |
super().__init__() | |
self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device) | |
self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, no_of_kv_heads=ModelArgs.no_kv_heads, no_of_q_heads=ModelArgs.no_of_heads, device = device) | |
# self.norm = Normalization(embeddings_dims=embeddings_dims) | |
self.norm1 = Normalization(embeddings_dims=embeddings_dims) | |
self.norm2 = Normalization(embeddings_dims=embeddings_dims) | |
self.dropout = nn.Dropout(p = dropout) | |
def forward(self, x): | |
x = self.norm1(x + self.gqa(x)) | |
x = self.norm2(x + self.feedforward_network(x)) | |
return x | |
class Llama(nn.Module): | |
def __init__(self, | |
device, | |
embeddings_dims: int = ModelArgs.embeddings_dims, | |
no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers, | |
block_size: int = ModelArgs.block_size, | |
vocab_size: int = ModelArgs.vocab_size, | |
dropout = ModelArgs.dropout | |
) : | |
super().__init__() | |
self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device) | |
self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)]) | |
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device) | |
self.dropout = nn.Dropout(p = dropout) | |
# self.norm = Normalization(embeddings_dims) | |
def forward(self, x): | |
x = self.embeddings(x) | |
x = self.dropout(x) | |
x = self.decoder(x) | |
# x = self.norm(x) | |
x = self.linear_layer(x) | |
# out = self.norm(x) | |
return x | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # device = "cpu" | |
# ModelArgs.device = device | |
# model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) | |
# model = model.to(ModelArgs.device) | |
#Printing a summary of the architecture | |
# !pip install torchinfo | |
# from torchinfo import summary | |
# # idx, targets = get_batch('test') | |
# idx = torch.randint( | |
# low=0, | |
# high=ModelArgs.vocab_size, | |
# size=(ModelArgs.batch_size, ModelArgs.block_size), | |
# dtype=torch.long | |
# ) | |
# # sample_idx = random.randint(range(len(train_dataset))) | |
# # idx, targets = train_dataset[0] | |
# idx = idx.to(ModelArgs.device) | |
# # targets = targets.to(ModelArgs.device) | |
# summary(model=model, | |
# input_data=idx, | |
# # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims), | |
# col_names=["input_size", "output_size", "num_params", "trainable"], | |
# col_width=20, | |
# row_settings=["var_names"]) | |
def find_unused_parameters(model): | |
unused = [] | |
for name, param in model.named_parameters(): | |
if param.grad is None: | |
unused.append(name) | |
return unused | |
def greedy_decode( | |
model, | |
tokenizer, | |
prompt, | |
max_length=50, | |
repetition_penalty=1.2, | |
context_window=10, | |
temperature=1.0, | |
eos_token_id=None | |
): | |
device = next(model.parameters()).device | |
input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] | |
generated_tokens = [] | |
eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided | |
for _ in range(max_length): | |
outputs = model(input_ids) | |
logits = outputs[:, -1, :] # Get logits for the last token | |
# Apply temperature scaling | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Apply repetition penalty | |
if repetition_penalty != 1.0 and len(generated_tokens) > 0: | |
for token in set(generated_tokens[-context_window:]): # Penalize recent tokens | |
logits[0, token] /= repetition_penalty | |
# Greedy selection | |
next_token = torch.argmax(logits, dim=-1).unsqueeze(0) | |
generated_tokens.append(next_token.item()) | |
# Stop if EOS token is generated | |
if next_token.item() == eos_token_id: | |
break | |
# Append the new token to the input | |
input_ids = torch.cat([input_ids, next_token], dim=1) | |
# Decode the generated tokens | |
return tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
def save_to_file(text): | |
with open('generations.txt', 'a') as f: | |
f.writelines(text + "\n\n") | |
#Train the model | |
# writer = SummaryWriter(log_dir="runs/experiment") | |
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR | |
# Warmup phase for 2000 steps | |
def warmup_fn(step): | |
if step < 2000: | |
return step / 2000 # LR gradually increases | |
return 1.0 | |
from torch.optim.lr_scheduler import LambdaLR | |
def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps): | |
""" | |
Trapezoidal learning rate scheduler: | |
- Increases linearly for `warmup_steps` steps. | |
- Remains constant for `plateau_steps` steps. | |
- Decreases linearly for `decay_steps` steps. | |
""" | |
def lr_lambda(step): | |
if step < warmup_steps: | |
# Linear warmup | |
return float(step) / float(max(1, warmup_steps)) | |
elif step < warmup_steps + plateau_steps: | |
# Constant plateau | |
return 1.0 | |
else: | |
# Linear decay | |
decay_step = step - (warmup_steps + plateau_steps) | |
return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps))) | |
return LambdaLR(optimizer, lr_lambda) | |
torch.set_float32_matmul_precision('high') | |
def train(): | |
setup() | |
device = int(os.environ["LOCAL_RANK"]) | |
torch.cuda.set_device(int(device)) | |
# train_dataloader = prepare_dataset(ModelArgs.batch_size) | |
# rank = torch.distributed.get_rank() | |
print(f"Start running DDP on rank {device}.") | |
# # create model and move it to GPU with id rank | |
# device_id = rank % torch.cuda.device_count() | |
# CFG = ModelArgs() | |
if(device == 0): | |
# # Initialise run | |
wandb.init( | |
# entity = 'rajceo2031', | |
project = 'Llama-DDP-Pretrain-10-billion-tokens', | |
# config = CFG, | |
# save_code = True, | |
#group = 'ANN', | |
#job_type = 'train' | |
) | |
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) | |
# Optimizer setup and scheduler steup | |
model = model.to(device) | |
print(f"Model on device {device} is ready") | |
# Wrap model with DDP after moving to GPU | |
# model = DDP(model, device_ids=[device]) | |
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim) | |
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30000, eta_min=1e-6) | |
_load_snapshot('/kaggle/input/models/snapshot2.pt', model, optimizer, scheduler) | |
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim) | |
# model = torch.compile(model) | |
# Define the trapezoidal learning rate scheduler | |
total_steps = 100000 # Total steps (40k + 20k + 40k) | |
warmup_steps = 40000 # Steps for warmup (increase) | |
plateau_steps = 20000 # Steps for plateau (constant) | |
decay_steps = 40000 # Steps for decay (decrease) | |
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot | |
new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps) | |
# warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn) | |
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) | |
# Cosine decay after warmup | |
# new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) | |
# Combine both schedulers | |
# scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000]) | |
# Reset learning rate to 1e-4 | |
# for param_group in optimizer.param_groups: | |
# param_group['lr'] = ModelArgs.max_lr | |
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6) | |
# print("Old optimizer with new lr ready") | |
model = DDP(model, device_ids=[device]) | |
print(f"Model on device {device} is ready") | |
# optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr) | |
# Create DataLoader with collate_fn | |
# train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) | |
# val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) | |
# print("Loader is ready") | |
# print(train_loader) | |
# print(next(iter(train_loader))) | |
save_chechpoint_iter = 1000 | |
total_iters = 20000 | |
eval_iters = 200 | |
eval_check = 100 | |
# for X,y in train_loader: | |
# print(X.shape) | |
# print(y.shape) | |
# alpaca_prompt = ''' | |
# ### Instruction: | |
# {instruction} | |
# ### Input: | |
# {input} | |
# ### Response: | |
# ''' | |
# Only create progress bar for rank 0 | |
# eval_epoch_iterator = range(eval_iters) | |
# train_epoch_iterator = range(total_iters) | |
# if device == 0: | |
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training") | |
# train_epoch_iterator = range(ModelArgs.epochs) | |
# if device == 0: # Ensure tqdm only runs on rank 0 | |
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True) | |
# lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters) | |
world_size = torch.cuda.device_count() | |
def estimate_loss(val_loader, train_loader=None): | |
out = {} | |
# train_loader = prepare_dataset('train', ModelArgs.batch_size) | |
model.eval() | |
loader = None | |
epoch_loss = None | |
epoch_losses = [] | |
# print("Starting the eval...") | |
for split in ['train', 'val']: | |
print(f"Starting with {split} evaluation...") | |
# losses = torch.zeros(ModelArgs.val_epochs) | |
if(split == 'train'): | |
loader = train_loader | |
if(split == 'val'): | |
loader = val_loader | |
for step in range(eval_check): | |
total_loss = 0 | |
# loader.sampler.set_epoch(step) | |
total_batches = 0 | |
batch = next(iter(loader)) | |
# for batch in loader: # Loop through DataLoader batches | |
idx = batch['input_ids'] | |
targets = batch['labels']['input_ids'] | |
idx = idx.to(device) | |
targets = targets.to(device) | |
logits = model(idx) | |
batch_size, block_size, embeddings_dims = logits.shape | |
logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens | |
targets = targets.view(batch_size * block_size) | |
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
total_loss += loss.item() | |
total_batches += 1 | |
# Compute mean loss for this epoch | |
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 | |
epoch_losses.append(epoch_loss) | |
# print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}") | |
# Compute mean loss across all evaluation epochs | |
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 | |
epoch_loss = None | |
epoch_losses = [] | |
model.train() | |
return out | |
# model = model.to(rank) | |
model.train() | |
train_dataloader = prepare_dataset('train', ModelArgs.batch_size) | |
val_loader= prepare_dataset('val', ModelArgs.batch_size) | |
# for step in tqdm(range(total_iters)): | |
for epoch in range(ModelArgs.epochs): | |
# torch.cuda.synchronize() | |
train_dataloader.sampler.set_epoch(epoch) | |
val_loader.sampler.set_epoch(epoch) | |
print("Loaders ready both") | |
epochs = ModelArgs.epochs | |
# train_step_iterator = range(len(train_dataloader)) | |
# if device == 0: # Only create progress bar on rank 0 | |
# train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True) | |
# Print progress on rank 0 | |
train_loader_length = 0 | |
if(device == 0): | |
train_loader_length = len(train_dataloader) | |
print("Total batches: ", train_loader_length) | |
# print("Length of : ", len(train_dataloader)) | |
# print("Length of val: ", len(val_loader)) | |
for step, batch in enumerate(train_dataloader): | |
# print("Dataloader things: ", batch) | |
# print("Total batches: ", len(train_dataloader)) | |
if(device == 0): | |
if(step % 100 == 0): | |
# if(step == train_loader_length): | |
# break | |
print("Batch : ", step, "/", len(train_dataloader)) | |
# all_gpus_avg_train_loss = None | |
# all_gpus_avg_val_loss = None | |
# every once in a while evaluate the loss on train and val sets | |
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: | |
losses = estimate_loss( val_loader, train_dataloader) | |
avg_train_loss = losses['train'] | |
avg_val_loss = losses['val'] | |
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
# if device == 0: # Only print on main process | |
print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}") | |
# print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}") | |
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
# Log training loss more frequently | |
# Aggregate average loss across all GPUs | |
avg_train_loss = torch.Tensor([losses['train']]).to(device) | |
avg_val_loss = torch.Tensor([losses['val']]).to(device) | |
torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
if device == 0: | |
all_gpus_avg_train_loss = avg_train_loss / world_size | |
print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}") | |
all_gpus_avg_val_loss = avg_val_loss / world_size | |
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") | |
# if device == 0: | |
# writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step) | |
# writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step) | |
# writer.add_scalar("training_step_loss", losses['train'], global_step=step) | |
# writer.add_scalar("val_step_loss", losses['val'], global_step=step) | |
# writer.add_scalar("GPU", device, global_step=step) | |
# writer.add_scalar("Epoch", epoch, global_step=step) | |
wandb.log({ | |
"Learning Rate": new_scheduler.get_last_lr()[0] , | |
"All_GPUs_Train_losses": all_gpus_avg_train_loss, | |
"All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
"training_step_loss": losses['train'], | |
"val_step_loss": losses['val'], | |
"Step": step, | |
"Epoch": epoch | |
}) | |
#Loading a checkpoint | |
# if(os.path.exists('snapshot.pt')): | |
# model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt') | |
# if(step % save_chechpoint_iter == 0 and device == 0 and step != 0): | |
# _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step) | |
if step % save_chechpoint_iter == 0 and device == 0 and step != 0: | |
print(f"Saving the model checkpoint for step: {step}") | |
_save_snapshot(model, optimizer, scheduler, epoch, step) | |
# batch = {k: v.to(self.local_rank) for k, v in batch.items()} | |
idx = batch['input_ids'].to(device) | |
# idx, targets = get_batch(split='train') | |
# print(f"Starting the train step: {step}...") | |
# for idx, targets in train_loader: | |
# idx, targets = next(iter(train_loader)) | |
# print("Idx: ", idx) | |
# print("Targets: ", targets) | |
# idx = idx.to(device) | |
# print("Idx: ", idx) | |
# print("Targets: ", targets) | |
targets = batch['labels']['input_ids'].to(device) | |
# with torch.autocast(device_type=device, dtype=torch.bfloat16()): | |
logits = model(idx) | |
batch_size, block_size, embeddings_dims = logits.shape | |
# print(logits.shape) | |
# print(targets) | |
logits = logits.view(batch_size*block_size, embeddings_dims) | |
# print("OK") | |
targets = targets.view(batch_size * block_size) | |
# print("OK2") | |
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
# Compute gradient norms before clipping | |
total_norm_before = torch.norm( | |
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) | |
# Compute gradient norms after clipping | |
total_norm_after = torch.norm( | |
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
) | |
if(device == 0 and step !=0 and step % 100 == 0): | |
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") | |
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") | |
optimizer.step() | |
new_scheduler.step() | |
# torch.cuda.synchronize() | |
# print(loss.item()) | |
# if(step % 100 == 0): | |
# print(f'Step : {step} | GPU: {device} Loss: {loss.item()}') | |
# if device == 0: | |
# print("loss: ", loss.item()) | |
# train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"}) | |
# print(loss.item()) | |
# break | |
# if step != 0 and (step % eval_iters == 0 or step == total_steps -1) : | |
# loss_values = estimate_loss() | |
# print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val'])) | |
# Add after a training step: | |
# unused_params = find_unused_parameters(model) | |
# print("Unused parameters:", unused_params) | |
# break | |
# if device == 0 and step % 200 == 0 and step != 0: | |
# count = 5 | |
# while(count): # Only generate text on the main process | |
# print("Generating text...") | |
# alpaca_prompt = ''' | |
# ### Instruction: | |
# {} | |
# ### Input: | |
# {} | |
# ### Response: | |
# ''' | |
# prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "") | |
# generated_text = greedy_decode( | |
# model, | |
# tokenizer, | |
# prompt, | |
# max_length=60, | |
# repetition_penalty=1.2, | |
# context_window=10, | |
# temperature=0.7 # Lower temperature for more deterministic output | |
# ) | |
# # generated_text = beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0) | |
# print(f" Step: {step} | Generated Text: {generated_text}") | |
# save_to_file(generated_text) | |
# count -= 1 | |
# if step != 0: | |
# train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"}) | |
# break | |
# Cleanup | |
if device == 0: | |
# writer.close() | |
wandb.finish() | |
cleanup() | |
world_size = torch.cuda.device_count() | |
print(f"World size: {world_size}") | |
train() | |