File size: 18,498 Bytes
e97f4e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 |
import os
import math
import numpy as np
import time
from dataclasses import dataclass
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# import code; code.interact(local=locals())
from model import GPT
from dataloader import DataLoaderLite
from hellaswag_eval import render_example, iterate_examples, get_most_likely_row
torch.set_float32_matmul_precision('high') # enable TF32 precision
# set torch compile to True (if it doesn't throws any error) to speed up training
use_torch_compile = False
class Trainer:
def __init__(
self,
model,
optimizer,
train_loader,
val_loader,
token_encoder,
eval_freq,
grad_accum_steps,
ddp,
ddp_rank,
ddp_world_size,
device,
logpath
):
self.ddp = ddp
self.ddp_rank = ddp_rank
self.master_process = ddp_rank == 0
self.ddp_world_size = ddp_world_size
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.token_encoder = token_encoder
self.eval_freq = eval_freq
self.grad_accum_steps = grad_accum_steps
self.device = device
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
self.logpath = logpath
def train(
self,
max_steps,
warmup_steps,
max_lr,
min_lr
):
for step in range(max_steps):
t0 = time.time()
self.is_last_step = (step == max_steps - 1)
# evaluate validation loss
if step % self.eval_freq == 0 or self.is_last_step:
self.evaluate_validation(step)
# evaluate model performance on HellaSwag every once in a while
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
self.evaluate_helloswag(step)
# generate sequences from the model every once in a while
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
self.generate_sequences(num_seq=5, max_tokens=32)
# training loop starts here
self.model.train() # sets model to train mode
self.optimizer.zero_grad() # resets all gradients
batch_loss = 0.0
for mini_step in range(self.grad_accum_steps):
inp, tar = self.train_loader.next_batch()
inp, tar = inp.to(self.device), tar.to(self.device)
# FORWARD PASS !!!
# autocast to bfloat16 for faster compute and memory efficiency
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(inp, tar)
# loss is scaled to account for gradient accumulation, because the gradients just add
# on each successive backward() call. Addition of gradients corresponds to SUM in the objective,
# but we want MEAN instead of a SUM
loss /= self.grad_accum_steps
batch_loss += loss.detach()
if self.ddp:
# in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes
# can use 'no_sync()' context manager alternatively.
self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1)
# each process accumulates gradients separately when 'require_backward_grad_sync'=False
# in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore
# gradients are averaged across all processes and shared among them by loss.backward()
loss.backward()
if self.ddp:
# 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to
# average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs.
# 'all_reduce' averages and deposits the result on all the processes
dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG)
# once gradients are computed, clip the global l2-norm of the gradient at 1.0
norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm'
# determine learning rate with decay
lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr)
# set learning rate for this iteration
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.optimizer.step()
if self.device_type == 'cuda':
torch.cuda.synchronize() # wait for the GPU to finish work
dt = (time.time() - t0) * 1000.0 # in ms
tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size
tokens_per_sec = tokens_processed / dt
if self.master_process:
print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} train {batch_loss.item():.6f}\n')
def evaluate_validation(self, step):
self.model.eval() # sets model to eval mode
self.val_loader.reset()
# evaluate the model on validation set
with torch.no_grad():
val_loss_accum = 0.0
val_steps = 20
for _ in range(val_steps):
inp, tar = self.val_loader.next_batch()
inp, tar = inp.to(self.device), tar.to(self.device)
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(inp, tar)
loss /= val_steps
val_loss_accum += loss.detach()
if self.ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if self.master_process:
print(f'Val loss: {val_loss_accum.item():.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} val {val_loss_accum.item():.4f}\n')
if step > 0 and (step % 10000 == 0 or self.is_last_step):
raw_model = self.model.module if self.ddp else self.model
logdir = os.path.dirname(self.logpath)
ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
checkpoint = {
'model': raw_model.state_dict(),
'config': raw_model.config,
'step': step,
'val_loss': val_loss_accum.item()
} # add optimizer.state_dict(), rng_seeds, etc. if resuming training
torch.save(checkpoint, ckpt_path)
def evaluate_helloswag(self, step):
"""
Construct a batch of 4 sequences and perform token completion using
our model.
"""
n_total = 0
n_correct_norm = 0
for i, example in enumerate(iterate_examples('val')):
# only process examples where i % ddp_world_size == ddp_rank
if i % self.ddp_world_size != self.ddp_rank:
continue
# render the example into tokens and labels
_, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N)
tokens, mask = tokens.to(self.device), mask.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(tokens)
pred_norm = get_most_likely_row(tokens, mask, logits)
n_total += 1
n_correct_norm += int(pred_norm == label)
# reduce the stats across all processes
if self.ddp:
n_total = torch.tensor(n_total, device=self.device, dtype=torch.long)
n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long)
dist.all_reduce(n_total, op=dist.ReduceOp.SUM)
dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM)
n_total = n_total.item()
n_correct_norm = n_correct_norm.item()
acc_norm = n_correct_norm / n_total
if self.master_process:
print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}')
with open(self.logpath, 'a') as f:
f.write(f'{step} hellaswag {acc_norm:.4f}\n')
def generate_sequences(self, num_seq=4, max_tokens=32):
self.model.eval()
tokens = self.token_encoder.encode("Hello, I am a language model")
tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
gen_tokens = tokens.to(self.device)
# create a different rng generator so as not to impact the global rng state used for training
sample_rng = torch.Generator(device=self.device)
# adding 'ddp_rank' in seeding to generate different tokens for different rank processes
sample_rng.manual_seed(42 + self.ddp_rank)
# generate new tokens one token at a time until the sequence length becomes 'max_tokens'
while gen_tokens.shape[-1] <= max_tokens:
with torch.no_grad():
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
logits = logits[:, -1, :] # (num_seq, vocab_size)
probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
# take top-k 50 probs
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
# sample a token from top-50 probabilities
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
# decode generated tokens and print generated text
for i in range(num_seq):
tokens = gen_tokens[i, :max_tokens].tolist()
gen_text = self.token_encoder.decode(tokens)
print(f"> rank {self.ddp_rank} sample {i}: {gen_text}")
def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
"""
Learning rate scheduler: Cosine-decay learning schedule with warmup
"""
# 1) linear warmup for 'warmup_iters' steps
if step < warmup_steps:
return max_lr * (step+1) / warmup_steps
# 2) if step > lr_decay_iters, return min lr
if step > max_steps:
return min_lr
# 3) in between, use cosine decay down to min lr
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
return min_lr + coeff * (max_lr - min_lr)
@dataclass
class GPTConfig:
context_length: int = 1024 # max context / sequence length
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
num_layers: int = 12
embd_size: int = 768 # embedding dim
num_heads: int = 12
def get_args():
import argparse
parser = argparse.ArgumentParser(description="Hyperparameter Configuration")
parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper)
parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size")
parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048)
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--embd_size", type=int, default=768)
parser.add_argument("--num_heads", type=int, default=12)
parser.add_argument("--max_lr", type=float, default=1e-3)
parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1)
parser.add_argument("--warmup_steps", type=int, default=715)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT
parser.add_argument("--eval_freq", type=int, default=250)
# parser.add_argument("--use_torch_compile", action='store_true') # default False
parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility")
parser.add_argument("--logdir", type=str, default="./logs/")
return parser.parse_args()
def main():
args = get_args()
# Print the hyperparameters
print("Hyperparameter Configuration:")
for key, value in vars(args).items():
print(f"{key}: {value}")
# create the logs directory if it doesn't exist
os.makedirs(args.logdir, exist_ok=True)
logpath = os.path.join(args.logdir, 'log.txt')
with open(logpath, 'w') as f:
pass
# set up DDP (distributed data parallel)
# 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
# RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode,
# multi GPU) settings.
ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not
if ddp:
# use of ddp requires CUDA
assert torch.cuda.is_available(), f'use of DDP requires CUDA'
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
# master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc.
master_process = ddp_rank == 0
else:
# not using ddp
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True # ddp_rank == 0
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps' # for apple macbook GPUs
print(f'using device: {device}')
device_type = 'cuda' if device.startswith('cuda') else 'cpu'
# setting seed for reproducibility
np.random.seed(args.seed)
torch.manual_seed(args.seed) # sets seed for random number generation on CPU
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU
torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs
assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size'
grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size)
if master_process:
print(f'desired batch size (number of tokens): {args.total_batch_size}')
print(f'gradient accumulation steps: {grad_accum_steps}')
print(f'GPU: {ddp_rank}, {ddp_local_rank}')
train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train')
val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val')
# create GPT model. each ddp process will create its own instance of the model but since the seed is fixed,
# they will create same identical model
gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
context_length=args.context_length,
num_layers=args.num_layers,
num_heads=args.num_heads,
embd_size=args.embd_size
)
model = GPT(config=gpt_config)
# model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2
model.to(device) # move model to device
if use_torch_compile:
# use torch compile almost always unless debugging (requires compilation time, but makes training faster)
# speedup comes from reducing python overhead and GPU read/write
model = torch.compile(model)
if ddp:
# wraps the model in DDP container (forward pass is unchanged, but after backward pass,
# gradients computed across each processes averaged by DDP using 'AllReduce' and shared across
# all processes so that each process has same gradients)
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process)
token_encoder = tiktoken.get_encoding('gpt2')
start_time = time.time()
# init the trainer object
trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps,
ddp, ddp_rank, ddp_world_size, device, logpath)
max_steps = args.steps_per_epoch * args.num_epochs
trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr)
dt = (time.time() - start_time) / (60*60)
print(f"Total training time: {dt:.4f}hr")
if ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()
|