|
""" |
|
Training loop helpers |
|
""" |
|
import torch |
|
import numpy as np |
|
|
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
|
|
def replace_padding_tokens(token_ids: torch.Tensor, |
|
pad_token_id: int, |
|
ignore_token_id: int = -100) -> any: |
|
""" |
|
Replace ignore_token_id tokens with pad_token_id, |
|
e.g., for printing inputs during training |
|
""" |
|
if isinstance(token_ids, list): |
|
return [np.where(t != ignore_token_id, t, pad_token_id)[0] for t in token_ids] |
|
else: |
|
return np.where(token_ids != ignore_token_id, token_ids, pad_token_id) |
|
|
|
|
|
def decode_samples(outputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
tokenizer: PreTrainedTokenizer, |
|
sample_idx: int = None) -> None: |
|
""" |
|
Print first element of samples for debugging |
|
""" |
|
print('=' * 20) |
|
print(f'*** TARGETS (sample {sample_idx})***') |
|
tokens = tokenizer.decode( |
|
replace_padding_tokens(targets[0], tokenizer.pad_token_id) |
|
) |
|
print(tokens) |
|
print('-' * 20) |
|
print(f'*** PREDICTIONS (sample {sample_idx}) ***') |
|
pred_logits = outputs.argmax(dim=-1).cpu() |
|
pred_tokens = tokenizer.decode( |
|
replace_padding_tokens(pred_logits[0], tokenizer.pad_token_id) |
|
) |
|
print(pred_tokens) |
|
|