|
import torchmetrics.classification |
|
from torchmetrics.text import ROUGEScore |
|
from model import build_transformer |
|
from shakespeare_config import (get_config, |
|
get_data_folder_path, |
|
get_weights_file_path, |
|
latest_weights_file_path, |
|
current_directory, |
|
causal_mask, |
|
get_gpt2_tokenizer) |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.tensorboard import SummaryWriter |
|
import torchmetrics |
|
import numpy as np |
|
import warnings |
|
import os |
|
from pathlib import Path |
|
|
|
def get_model(config): |
|
model = build_transformer(vocab_size=config['vocab_size'], |
|
seq_len=config['seq_len'], |
|
d_model=config['d_model']) |
|
return model |
|
|
|
def get_batch(split, data_dir, block_size, batch_size, device='gpu', device_type='cuda'): |
|
|
|
|
|
if split == 'train': |
|
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') |
|
else: |
|
data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r') |
|
ix = torch.randint(len(data) - block_size, (batch_size,)) |
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) |
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
|
|
|
|
|
|
|
|
|
|
|
return x, y |
|
|
|
def greedy_decode(model, |
|
input, |
|
mask, |
|
tokenizer, |
|
max_len, |
|
device): |
|
while True: |
|
if input.size(1) == max_len: |
|
break |
|
|
|
out = model.decode(input, mask) |
|
prob = model.project(out[:, -1]) |
|
_, next_word = torch.max(prob, dim=1) |
|
input = torch.cat( |
|
[input, torch.empty(1,1).type_as(input).fill_(next_word.item()).to(device)], |
|
dim=1 |
|
) |
|
if next_word == tokenizer.eos_token_id: |
|
break |
|
return input.squeeze(0) |
|
|
|
def run_validation(model, |
|
x, |
|
y, |
|
tokenizer, |
|
max_len, |
|
device, |
|
print_msg, |
|
global_step, |
|
writer, |
|
rouge:ROUGEScore): |
|
|
|
model.eval() |
|
source_texts = [] |
|
expected = [] |
|
predicted = [] |
|
|
|
with torch.no_grad(): |
|
decoder_input = x.to(device) |
|
mask = causal_mask(x.size(1)).to(device) |
|
|
|
|
|
assert decoder_input.size(0)==1, "batch size must be 1 for validation" |
|
|
|
model_out = greedy_decode(model, |
|
decoder_input, |
|
mask, |
|
tokenizer, |
|
max_len, |
|
device) |
|
|
|
source_text = tokenizer.decode(x[0]) |
|
target_text = tokenizer.decode(y[0]) |
|
model_out_text = tokenizer.decode(model_out.detach().cpu().numpy()) |
|
|
|
source_texts.append(source_text) |
|
expected.append(target_text) |
|
predicted.append(model_out_text) |
|
|
|
|
|
print_msg('-'*100) |
|
print_msg(f"{f'SOURCE: ':>12}{source_text}") |
|
print_msg(f"{f'TARGET: ':>12}{target_text}") |
|
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}") |
|
|
|
rouge_score = rouge(predicted, expected) |
|
print_msg(f"{f'ROUGE-1 Score: ':>12}{rouge_score['rouge1_fmeasure'].item()}") |
|
print_msg(f"{f'ROUGE-2 Score: ':>12}{rouge_score['rouge2_fmeasure'].item()}") |
|
print_msg(f"{f'ROUGE-L Score: ':>12}{rouge_score['rougeL_fmeasure'].item()}") |
|
print_msg('-'*100) |
|
|
|
if writer: |
|
writer.add_scalar('validation ROUGE/ROUGE-1', rouge_score["rouge1_fmeasure"].item(), global_step) |
|
writer.add_scalar('validation ROUGE/ROUGE-2', rouge_score["rouge2_fmeasure"].item(), global_step) |
|
writer.add_scalar('validation ROUGE/ROUGE-L', rouge_score["rougeL_fmeasure"].item(), global_step) |
|
writer.add_scalar('validation ROUGE/ROUGE-L', rouge_score["rougeLsum_fmeasure"].item(), global_step) |
|
writer.flush() |
|
|
|
def train_model(config): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_ms or torch.backends.mps.is_available else "cpu" |
|
print("Using device:", device) |
|
|
|
if (device == 'cuda'): |
|
print(f"Device name: {torch.cuda.get_device_name(device=device.index)}") |
|
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB") |
|
elif (device == 'cpu'): |
|
print(f"device name: <mps>") |
|
else: |
|
print("It's cpu") |
|
|
|
device = torch.device(device) |
|
|
|
|
|
Path(f"{current_directory}/{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True) |
|
|
|
tokenizer = get_gpt2_tokenizer(config=config) |
|
model = get_model(config).to(device) |
|
|
|
writer = SummaryWriter(f"{current_directory}/{config['experiment_name']}") |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), |
|
lr=config['lr'], |
|
eps=1e-9) |
|
rouge:ROUGEScore = ROUGEScore() |
|
|
|
|
|
initial_epoch = 0 |
|
global_step = 0 |
|
preload = config['preload'] |
|
model_filename = (latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None) |
|
if model_filename: |
|
print(f'Preloading model {model_filename}') |
|
state = torch.load(model_filename) |
|
model.load_state_dict(state['model_state_dict']) |
|
initial_epoch = state['epoch'] + 1 |
|
optimizer.load_state_dict(state['optimizer_state_dict']) |
|
global_step = state['global_step'] |
|
else: |
|
print('No model to preload, starting from scratch') |
|
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids('[PAD]'), label_smoothing=0.1).to(device) |
|
for epoch in range(initial_epoch, config['num_epochs']): |
|
torch.cuda.empty_cache() |
|
model.train() |
|
|
|
X, y = get_batch(split='train', |
|
data_dir=get_data_folder_path(config=config), |
|
block_size=config['seq_len'], |
|
batch_size=config['batch_size']) |
|
print(f'length of the batch: {len(X)}, type:{X.shape}') |
|
|
|
decoder_input = X.to(device) |
|
decoder_mask = causal_mask(config['seq_len']).to(device) |
|
|
|
|
|
decoder_output = model.decode(decoder_input, decoder_mask) |
|
proj_output = model.project(decoder_output) |
|
|
|
|
|
label = y.to(device) |
|
|
|
|
|
loss = loss_fn(proj_output.view(-1, config['vocab_size']), |
|
label.view(-1)) |
|
|
|
print(f"loss: {loss.item():6.3f}") |
|
|
|
|
|
writer.add_scalar('train loss', loss.item(), global_step) |
|
writer.flush() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
global_step += 1 |
|
|
|
|
|
X_val, y_val = get_batch(split='val', |
|
data_dir=get_data_folder_path(config=config), |
|
block_size=config['seq_len'], |
|
batch_size=1) |
|
run_validation(model, |
|
X_val, |
|
y_val, |
|
tokenizer, |
|
config['seq_len'], |
|
device, |
|
lambda msg: print(msg), |
|
global_step, |
|
writer, |
|
rouge) |
|
|
|
if epoch%1000==0 or epoch >= (config['num_epochs']-1): |
|
|
|
model_filename = get_weights_file_path(config, f"{epoch:02d}") |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'global_step': global_step |
|
}, model_filename) |
|
|
|
if __name__ == '__main__': |
|
warnings.filterwarnings("ignore") |
|
config = get_config() |
|
train_model(config) |