""" Training loop helpers """ import torch import numpy as np from transformers.tokenization_utils import PreTrainedTokenizer def replace_padding_tokens(token_ids: torch.Tensor, pad_token_id: int, ignore_token_id: int = -100) -> any: """ Replace ignore_token_id tokens with pad_token_id, e.g., for printing inputs during training """ if isinstance(token_ids, list): return [np.where(t != ignore_token_id, t, pad_token_id)[0] for t in token_ids] else: return np.where(token_ids != ignore_token_id, token_ids, pad_token_id) def decode_samples(outputs: torch.Tensor, targets: torch.Tensor, tokenizer: PreTrainedTokenizer, sample_idx: int = None) -> None: """ Print first element of samples for debugging """ print('=' * 20) print(f'*** TARGETS (sample {sample_idx})***') tokens = tokenizer.decode( replace_padding_tokens(targets[0], tokenizer.pad_token_id) ) print(tokens) print('-' * 20) print(f'*** PREDICTIONS (sample {sample_idx}) ***') pred_logits = outputs.argmax(dim=-1).cpu() pred_tokens = tokenizer.decode( replace_padding_tokens(pred_logits[0], tokenizer.pad_token_id) ) print(pred_tokens)