File size: 1,327 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
"""
Training loop helpers
"""
import torch
import numpy as np
from transformers.tokenization_utils import PreTrainedTokenizer
def replace_padding_tokens(token_ids: torch.Tensor,
pad_token_id: int,
ignore_token_id: int = -100) -> any:
"""
Replace ignore_token_id tokens with pad_token_id,
e.g., for printing inputs during training
"""
if isinstance(token_ids, list):
return [np.where(t != ignore_token_id, t, pad_token_id)[0] for t in token_ids]
else:
return np.where(token_ids != ignore_token_id, token_ids, pad_token_id)
def decode_samples(outputs: torch.Tensor,
targets: torch.Tensor,
tokenizer: PreTrainedTokenizer,
sample_idx: int = None) -> None:
"""
Print first element of samples for debugging
"""
print('=' * 20)
print(f'*** TARGETS (sample {sample_idx})***')
tokens = tokenizer.decode(
replace_padding_tokens(targets[0], tokenizer.pad_token_id)
)
print(tokens)
print('-' * 20)
print(f'*** PREDICTIONS (sample {sample_idx}) ***')
pred_logits = outputs.argmax(dim=-1).cpu()
pred_tokens = tokenizer.decode(
replace_padding_tokens(pred_logits[0], tokenizer.pad_token_id)
)
print(pred_tokens)
|