GPT2-Model / train.py
abhishek4607's picture
Upload 16 files
e97f4e2 verified
import os
import math
import numpy as np
import time
from dataclasses import dataclass
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# import code; code.interact(local=locals())
from model import GPT
from dataloader import DataLoaderLite
from hellaswag_eval import render_example, iterate_examples, get_most_likely_row
torch.set_float32_matmul_precision('high') # enable TF32 precision
# set torch compile to True (if it doesn't throws any error) to speed up training
use_torch_compile = False
class Trainer:
def __init__(
self,
model,
optimizer,
train_loader,
val_loader,
token_encoder,
eval_freq,
grad_accum_steps,
ddp,
ddp_rank,
ddp_world_size,
device,
logpath
):
self.ddp = ddp
self.ddp_rank = ddp_rank
self.master_process = ddp_rank == 0
self.ddp_world_size = ddp_world_size
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.token_encoder = token_encoder
self.eval_freq = eval_freq
self.grad_accum_steps = grad_accum_steps
self.device = device
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
self.logpath = logpath
def train(
self,
max_steps,
warmup_steps,
max_lr,
min_lr
):
for step in range(max_steps):
t0 = time.time()
self.is_last_step = (step == max_steps - 1)
# evaluate validation loss
if step % self.eval_freq == 0 or self.is_last_step:
self.evaluate_validation(step)
# evaluate model performance on HellaSwag every once in a while
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
self.evaluate_helloswag(step)
# generate sequences from the model every once in a while
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
self.generate_sequences(num_seq=5, max_tokens=32)
# training loop starts here
self.model.train() # sets model to train mode
self.optimizer.zero_grad() # resets all gradients
batch_loss = 0.0
for mini_step in range(self.grad_accum_steps):
inp, tar = self.train_loader.next_batch()
inp, tar = inp.to(self.device), tar.to(self.device)
# FORWARD PASS !!!
# autocast to bfloat16 for faster compute and memory efficiency
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(inp, tar)
# loss is scaled to account for gradient accumulation, because the gradients just add
# on each successive backward() call. Addition of gradients corresponds to SUM in the objective,
# but we want MEAN instead of a SUM
loss /= self.grad_accum_steps
batch_loss += loss.detach()
if self.ddp:
# in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes
# can use 'no_sync()' context manager alternatively.
self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1)
# each process accumulates gradients separately when 'require_backward_grad_sync'=False
# in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore
# gradients are averaged across all processes and shared among them by loss.backward()
loss.backward()
if self.ddp:
# 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to
# average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs.
# 'all_reduce' averages and deposits the result on all the processes
dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG)
# once gradients are computed, clip the global l2-norm of the gradient at 1.0
norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm'
# determine learning rate with decay
lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr)
# set learning rate for this iteration
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.optimizer.step()
if self.device_type == 'cuda':
torch.cuda.synchronize() # wait for the GPU to finish work
dt = (time.time() - t0) * 1000.0 # in ms
tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size
tokens_per_sec = tokens_processed / dt
if self.master_process:
print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} train {batch_loss.item():.6f}\n')
def evaluate_validation(self, step):
self.model.eval() # sets model to eval mode
self.val_loader.reset()
# evaluate the model on validation set
with torch.no_grad():
val_loss_accum = 0.0
val_steps = 20
for _ in range(val_steps):
inp, tar = self.val_loader.next_batch()
inp, tar = inp.to(self.device), tar.to(self.device)
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(inp, tar)
loss /= val_steps
val_loss_accum += loss.detach()
if self.ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if self.master_process:
print(f'Val loss: {val_loss_accum.item():.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} val {val_loss_accum.item():.4f}\n')
if step > 0 and (step % 10000 == 0 or self.is_last_step):
raw_model = self.model.module if self.ddp else self.model
logdir = os.path.dirname(self.logpath)
ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
checkpoint = {
'model': raw_model.state_dict(),
'config': raw_model.config,
'step': step,
'val_loss': val_loss_accum.item()
} # add optimizer.state_dict(), rng_seeds, etc. if resuming training
torch.save(checkpoint, ckpt_path)
def evaluate_helloswag(self, step):
"""
Construct a batch of 4 sequences and perform token completion using
our model.
"""
n_total = 0
n_correct_norm = 0
for i, example in enumerate(iterate_examples('val')):
# only process examples where i % ddp_world_size == ddp_rank
if i % self.ddp_world_size != self.ddp_rank:
continue
# render the example into tokens and labels
_, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N)
tokens, mask = tokens.to(self.device), mask.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(tokens)
pred_norm = get_most_likely_row(tokens, mask, logits)
n_total += 1
n_correct_norm += int(pred_norm == label)
# reduce the stats across all processes
if self.ddp:
n_total = torch.tensor(n_total, device=self.device, dtype=torch.long)
n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long)
dist.all_reduce(n_total, op=dist.ReduceOp.SUM)
dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM)
n_total = n_total.item()
n_correct_norm = n_correct_norm.item()
acc_norm = n_correct_norm / n_total
if self.master_process:
print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} hellaswag {acc_norm:.4f}\n')
def generate_sequences(self, num_seq=4, max_tokens=32):
self.model.eval()
tokens = self.token_encoder.encode("Hello, I am a language model")
tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
gen_tokens = tokens.to(self.device)
# create a different rng generator so as not to impact the global rng state used for training
sample_rng = torch.Generator(device=self.device)
# adding 'ddp_rank' in seeding to generate different tokens for different rank processes
sample_rng.manual_seed(42 + self.ddp_rank)
# generate new tokens one token at a time until the sequence length becomes 'max_tokens'
while gen_tokens.shape[-1] <= max_tokens:
with torch.no_grad():
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
logits = logits[:, -1, :] # (num_seq, vocab_size)
probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
# take top-k 50 probs
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
# sample a token from top-50 probabilities
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
# decode generated tokens and print generated text
for i in range(num_seq):
tokens = gen_tokens[i, :max_tokens].tolist()
gen_text = self.token_encoder.decode(tokens)
print(f"> rank {self.ddp_rank} sample {i}: {gen_text}")
def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
"""
Learning rate scheduler: Cosine-decay learning schedule with warmup
"""
# 1) linear warmup for 'warmup_iters' steps
if step < warmup_steps:
return max_lr * (step+1) / warmup_steps
# 2) if step > lr_decay_iters, return min lr
if step > max_steps:
return min_lr
# 3) in between, use cosine decay down to min lr
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
return min_lr + coeff * (max_lr - min_lr)
@dataclass
class GPTConfig:
context_length: int = 1024 # max context / sequence length
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
num_layers: int = 12
embd_size: int = 768 # embedding dim
num_heads: int = 12
def get_args():
import argparse
parser = argparse.ArgumentParser(description="Hyperparameter Configuration")
parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper)
parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size")
parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048)
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--embd_size", type=int, default=768)
parser.add_argument("--num_heads", type=int, default=12)
parser.add_argument("--max_lr", type=float, default=1e-3)
parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1)
parser.add_argument("--warmup_steps", type=int, default=715)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT
parser.add_argument("--eval_freq", type=int, default=250)
# parser.add_argument("--use_torch_compile", action='store_true') # default False
parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility")
parser.add_argument("--logdir", type=str, default="./logs/")
return parser.parse_args()
def main():
args = get_args()
# Print the hyperparameters
print("Hyperparameter Configuration:")
for key, value in vars(args).items():
print(f"{key}: {value}")
# create the logs directory if it doesn't exist
os.makedirs(args.logdir, exist_ok=True)
logpath = os.path.join(args.logdir, 'log.txt')
with open(logpath, 'w') as f:
pass
# set up DDP (distributed data parallel)
# 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
# RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode,
# multi GPU) settings.
ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not
if ddp:
# use of ddp requires CUDA
assert torch.cuda.is_available(), f'use of DDP requires CUDA'
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
# master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc.
master_process = ddp_rank == 0
else:
# not using ddp
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True # ddp_rank == 0
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps' # for apple macbook GPUs
print(f'using device: {device}')
device_type = 'cuda' if device.startswith('cuda') else 'cpu'
# setting seed for reproducibility
np.random.seed(args.seed)
torch.manual_seed(args.seed) # sets seed for random number generation on CPU
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU
torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs
assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size'
grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size)
if master_process:
print(f'desired batch size (number of tokens): {args.total_batch_size}')
print(f'gradient accumulation steps: {grad_accum_steps}')
print(f'GPU: {ddp_rank}, {ddp_local_rank}')
train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train')
val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val')
# create GPT model. each ddp process will create its own instance of the model but since the seed is fixed,
# they will create same identical model
gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
context_length=args.context_length,
num_layers=args.num_layers,
num_heads=args.num_heads,
embd_size=args.embd_size
)
model = GPT(config=gpt_config)
# model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2
model.to(device) # move model to device
if use_torch_compile:
# use torch compile almost always unless debugging (requires compilation time, but makes training faster)
# speedup comes from reducing python overhead and GPU read/write
model = torch.compile(model)
if ddp:
# wraps the model in DDP container (forward pass is unchanged, but after backward pass,
# gradients computed across each processes averaged by DDP using 'AllReduce' and shared across
# all processes so that each process has same gradients)
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process)
token_encoder = tiktoken.get_encoding('gpt2')
start_time = time.time()
# init the trainer object
trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps,
ddp, ddp_rank, ddp_world_size, device, logpath)
max_steps = args.steps_per_epoch * args.num_epochs
trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr)
dt = (time.time() - start_time) / (60*60)
print(f"Total training time: {dt:.4f}hr")
if ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()