|
import os |
|
import math |
|
import numpy as np |
|
import time |
|
from dataclasses import dataclass |
|
import tiktoken |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
from model import GPT |
|
from dataloader import DataLoaderLite |
|
from hellaswag_eval import render_example, iterate_examples, get_most_likely_row |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
use_torch_compile = False |
|
|
|
|
|
class Trainer: |
|
def __init__( |
|
self, |
|
model, |
|
optimizer, |
|
train_loader, |
|
val_loader, |
|
token_encoder, |
|
eval_freq, |
|
grad_accum_steps, |
|
ddp, |
|
ddp_rank, |
|
ddp_world_size, |
|
device, |
|
logpath |
|
): |
|
self.ddp = ddp |
|
self.ddp_rank = ddp_rank |
|
self.master_process = ddp_rank == 0 |
|
self.ddp_world_size = ddp_world_size |
|
|
|
self.model = model |
|
self.optimizer = optimizer |
|
self.train_loader = train_loader |
|
self.val_loader = val_loader |
|
self.token_encoder = token_encoder |
|
|
|
self.eval_freq = eval_freq |
|
self.grad_accum_steps = grad_accum_steps |
|
self.device = device |
|
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
self.logpath = logpath |
|
|
|
|
|
def train( |
|
self, |
|
max_steps, |
|
warmup_steps, |
|
max_lr, |
|
min_lr |
|
): |
|
for step in range(max_steps): |
|
t0 = time.time() |
|
self.is_last_step = (step == max_steps - 1) |
|
|
|
|
|
if step % self.eval_freq == 0 or self.is_last_step: |
|
self.evaluate_validation(step) |
|
|
|
|
|
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): |
|
self.evaluate_helloswag(step) |
|
|
|
|
|
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): |
|
self.generate_sequences(num_seq=5, max_tokens=32) |
|
|
|
|
|
self.model.train() |
|
self.optimizer.zero_grad() |
|
batch_loss = 0.0 |
|
|
|
for mini_step in range(self.grad_accum_steps): |
|
inp, tar = self.train_loader.next_batch() |
|
inp, tar = inp.to(self.device), tar.to(self.device) |
|
|
|
|
|
|
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
logits, loss = self.model(inp, tar) |
|
|
|
|
|
|
|
|
|
loss /= self.grad_accum_steps |
|
batch_loss += loss.detach() |
|
|
|
if self.ddp: |
|
|
|
|
|
self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1) |
|
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
if self.ddp: |
|
|
|
|
|
|
|
dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG) |
|
|
|
|
|
norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
|
|
|
lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr) |
|
|
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
self.optimizer.step() |
|
if self.device_type == 'cuda': |
|
torch.cuda.synchronize() |
|
|
|
dt = (time.time() - t0) * 1000.0 |
|
tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size |
|
tokens_per_sec = tokens_processed / dt |
|
|
|
if self.master_process: |
|
print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}') |
|
with open(self.logpath, 'a') as f: |
|
f.write(f'{step} train {batch_loss.item():.6f}\n') |
|
|
|
|
|
def evaluate_validation(self, step): |
|
self.model.eval() |
|
self.val_loader.reset() |
|
|
|
with torch.no_grad(): |
|
val_loss_accum = 0.0 |
|
val_steps = 20 |
|
for _ in range(val_steps): |
|
inp, tar = self.val_loader.next_batch() |
|
inp, tar = inp.to(self.device), tar.to(self.device) |
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
logits, loss = self.model(inp, tar) |
|
loss /= val_steps |
|
val_loss_accum += loss.detach() |
|
|
|
if self.ddp: |
|
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) |
|
if self.master_process: |
|
print(f'Val loss: {val_loss_accum.item():.4f}') |
|
with open(self.logpath, 'a') as f: |
|
f.write(f'{step} val {val_loss_accum.item():.4f}\n') |
|
|
|
if step > 0 and (step % 10000 == 0 or self.is_last_step): |
|
raw_model = self.model.module if self.ddp else self.model |
|
logdir = os.path.dirname(self.logpath) |
|
ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt') |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'config': raw_model.config, |
|
'step': step, |
|
'val_loss': val_loss_accum.item() |
|
} |
|
torch.save(checkpoint, ckpt_path) |
|
|
|
|
|
def evaluate_helloswag(self, step): |
|
""" |
|
Construct a batch of 4 sequences and perform token completion using |
|
our model. |
|
""" |
|
n_total = 0 |
|
n_correct_norm = 0 |
|
for i, example in enumerate(iterate_examples('val')): |
|
|
|
if i % self.ddp_world_size != self.ddp_rank: |
|
continue |
|
|
|
_, tokens, mask, label = render_example(example) |
|
tokens, mask = tokens.to(self.device), mask.to(self.device) |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
logits, loss = self.model(tokens) |
|
pred_norm = get_most_likely_row(tokens, mask, logits) |
|
n_total += 1 |
|
n_correct_norm += int(pred_norm == label) |
|
|
|
if self.ddp: |
|
n_total = torch.tensor(n_total, device=self.device, dtype=torch.long) |
|
n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long) |
|
dist.all_reduce(n_total, op=dist.ReduceOp.SUM) |
|
dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM) |
|
n_total = n_total.item() |
|
n_correct_norm = n_correct_norm.item() |
|
acc_norm = n_correct_norm / n_total |
|
if self.master_process: |
|
print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}') |
|
with open(self.logpath, 'a') as f: |
|
f.write(f'{step} hellaswag {acc_norm:.4f}\n') |
|
|
|
|
|
def generate_sequences(self, num_seq=4, max_tokens=32): |
|
self.model.eval() |
|
tokens = self.token_encoder.encode("Hello, I am a language model") |
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) |
|
gen_tokens = tokens.to(self.device) |
|
|
|
sample_rng = torch.Generator(device=self.device) |
|
|
|
sample_rng.manual_seed(42 + self.ddp_rank) |
|
|
|
while gen_tokens.shape[-1] <= max_tokens: |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
logits, loss = self.model(gen_tokens) |
|
logits = logits[:, -1, :] |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
|
|
|
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) |
|
next_tok = torch.gather(topk_indices, -1, ix) |
|
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) |
|
|
|
for i in range(num_seq): |
|
tokens = gen_tokens[i, :max_tokens].tolist() |
|
gen_text = self.token_encoder.decode(tokens) |
|
print(f"> rank {self.ddp_rank} sample {i}: {gen_text}") |
|
|
|
|
|
def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr): |
|
""" |
|
Learning rate scheduler: Cosine-decay learning schedule with warmup |
|
""" |
|
|
|
if step < warmup_steps: |
|
return max_lr * (step+1) / warmup_steps |
|
|
|
if step > max_steps: |
|
return min_lr |
|
|
|
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
|
@dataclass |
|
class GPTConfig: |
|
context_length: int = 1024 |
|
vocab_size: int = 50257 |
|
num_layers: int = 12 |
|
embd_size: int = 768 |
|
num_heads: int = 12 |
|
|
|
|
|
def get_args(): |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Hyperparameter Configuration") |
|
parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") |
|
parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size") |
|
parser.add_argument("--context_length", type=int, default=1024) |
|
parser.add_argument("--num_layers", type=int, default=12) |
|
parser.add_argument("--embd_size", type=int, default=768) |
|
parser.add_argument("--num_heads", type=int, default=12) |
|
parser.add_argument("--max_lr", type=float, default=1e-3) |
|
parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1) |
|
parser.add_argument("--warmup_steps", type=int, default=715) |
|
parser.add_argument("--weight_decay", type=float, default=0.1) |
|
parser.add_argument("--num_epochs", type=int, default=5) |
|
parser.add_argument("--steps_per_epoch", type=int, default=19073) |
|
parser.add_argument("--eval_freq", type=int, default=250) |
|
|
|
parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility") |
|
parser.add_argument("--logdir", type=str, default="./logs/") |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
print("Hyperparameter Configuration:") |
|
for key, value in vars(args).items(): |
|
print(f"{key}: {value}") |
|
|
|
|
|
os.makedirs(args.logdir, exist_ok=True) |
|
logpath = os.path.join(args.logdir, 'log.txt') |
|
with open(logpath, 'w') as f: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
if ddp: |
|
|
|
assert torch.cuda.is_available(), f'use of DDP requires CUDA' |
|
dist.init_process_group(backend='nccl') |
|
ddp_rank = int(os.environ['RANK']) |
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
device = f'cuda:{ddp_local_rank}' |
|
torch.cuda.set_device(device) |
|
|
|
master_process = ddp_rank == 0 |
|
else: |
|
|
|
ddp_rank = 0 |
|
ddp_local_rank = 0 |
|
ddp_world_size = 1 |
|
master_process = True |
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
device = 'mps' |
|
print(f'using device: {device}') |
|
|
|
device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
|
|
|
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(args.seed) |
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size' |
|
grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size) |
|
if master_process: |
|
print(f'desired batch size (number of tokens): {args.total_batch_size}') |
|
print(f'gradient accumulation steps: {grad_accum_steps}') |
|
print(f'GPU: {ddp_rank}, {ddp_local_rank}') |
|
|
|
train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train') |
|
val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val') |
|
|
|
|
|
|
|
gpt_config = GPTConfig(vocab_size=50304, |
|
context_length=args.context_length, |
|
num_layers=args.num_layers, |
|
num_heads=args.num_heads, |
|
embd_size=args.embd_size |
|
) |
|
model = GPT(config=gpt_config) |
|
|
|
model.to(device) |
|
if use_torch_compile: |
|
|
|
|
|
model = torch.compile(model) |
|
|
|
if ddp: |
|
|
|
|
|
|
|
model = DDP(model, device_ids=[ddp_local_rank]) |
|
|
|
raw_model = model.module if ddp else model |
|
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process) |
|
token_encoder = tiktoken.get_encoding('gpt2') |
|
|
|
start_time = time.time() |
|
|
|
trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps, |
|
ddp, ddp_rank, ddp_world_size, device, logpath) |
|
|
|
max_steps = args.steps_per_epoch * args.num_epochs |
|
trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr) |
|
|
|
dt = (time.time() - start_time) / (60*60) |
|
print(f"Total training time: {dt:.4f}hr") |
|
|
|
if ddp: |
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|