decoder-only-transformer-learning / shakespeare_inference.py
vmal's picture
decoder only transformer learning
ca9f11d
from shakespeare_config import (get_config,
latest_weights_file_path,
get_gpt2_tokenizer,
causal_mask,
current_directory)
import torch
import warnings
import heapq
from train import build_transformer
def predict_with_greedy_search(start_str:str)-> None:
config:dict=get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = get_gpt2_tokenizer(config=config)
model = build_transformer(vocab_size=config['vocab_size'],
seq_len=config['seq_len'],
d_model=config['d_model']).to(device)
# load the pretrained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
model.eval()
output = start_str
with torch.no_grad():
start_tokens = tokenizer.encode(start_str)
print(start_tokens)
input = torch.tensor(data=start_tokens, dtype=torch.int64).unsqueeze(dim=0).to(device)
# print(input)
while input.size(1) <= config['seq_len']:
# use mask otheriwse model may generate repetitive words in prediction
mask = causal_mask(input.size(1)).to(device)
out = model.decode(input,mask)
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
input = torch.cat(
[
input,
torch.empty(1,1).type_as(input).fill_(next_word.item()).to(device)
],
dim=1
)
output += tokenizer.decode(next_word.item())
print(f'Model output: {output}')
def predict_with_beam_search(start_str: str,
beam_width: int = 3) -> None:
config: dict = get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = get_gpt2_tokenizer(config=config)
model = build_transformer(vocab_size=config['vocab_size'],
seq_len=config['seq_len'],
d_model=config['d_model']).to(device)
# Load the pretrained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
model.eval()
# Initial input
start_tokens = tokenizer.encode(start_str)
input = torch.tensor(data=start_tokens, dtype=torch.int64).unsqueeze(dim=0).to(device) # (1, seq_len)
# Beam search variables
beams = [(0, input, [])] # Each beam is a tuple of (score, sequence, tokens_generated)
for _ in range(config['seq_len']):
all_candidates = []
# Process each beam
for score, seq, tokens in beams:
# use mask otheriwse model may generate repetitive words in prediction
mask = causal_mask(seq.size(1)).to(device)
out = model.decode(seq, mask)
prob = model.project(out[:, -1])
# Get the top k predictions
top_k_probabilities, top_k_indices = torch.topk(prob, beam_width, dim=1)
# Generate new beams for each of the top k tokens
for i in range(beam_width):
new_token = top_k_indices[0, i].item()
new_score = score - torch.log(top_k_probabilities[0, i]).item() # We negate because we want to maximize
new_seq = torch.cat([seq, torch.tensor([[new_token]], device=device)], dim=1)
new_tokens = tokens + [new_token]
all_candidates.append((new_score, new_seq, new_tokens))
# Sort all candidates based on their score and keep the top `beam_width` beams
beams = heapq.nsmallest(beam_width, all_candidates, key=lambda x: x[0])
# Optionally, stop early if all beams end with an EOS token
if all(beam[1].shape[1] >= config['seq_len'] for beam in beams):
break
# Retrieve the best beam (with the highest score)
best_beam = beams[0]
best_tokens = best_beam[2]
# Decode the final sequence
output = tokenizer.decode(best_tokens, skip_special_tokens=True)
print(f'Model output: {output}')
if __name__ == '__main__':
warnings.filterwarnings("ignore")
start_str = 'Now sadder, that you come so'
predict_with_greedy_search(start_str=start_str)
print('--'*100)
predict_with_beam_search(start_str=start_str)