|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
import math |
|
from abc import ABC |
|
from typing import Callable, Iterable, List |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .file_utils import add_start_docstrings |
|
from .utils.logging import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using :class:`~transformers.BertTokenizer`. See |
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for |
|
details. |
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__ |
|
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): |
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam |
|
search or log softmax for each vocabulary token when using beam search |
|
kwargs: |
|
Additional logits processor specific kwargs. |
|
|
|
Return: |
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. |
|
|
|
""" |
|
|
|
|
|
class LogitsProcessor(ABC): |
|
"""Abstract base class for all logit processors that can be applied during generation.""" |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
"""Torch method for processing logits.""" |
|
raise NotImplementedError( |
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." |
|
) |
|
|
|
|
|
class LogitsWarper(ABC): |
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
"""Torch method for warping logits.""" |
|
raise NotImplementedError( |
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." |
|
) |
|
|
|
|
|
class LogitsProcessorList(list): |
|
""" |
|
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or |
|
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from |
|
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or |
|
:class:`~transformers.LogitsWarper` to the inputs. |
|
""" |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: |
|
for processor in self: |
|
function_args = inspect.signature(processor.__call__).parameters |
|
if len(function_args) > 2: |
|
assert all( |
|
arg in kwargs for arg in list(function_args.keys())[2:] |
|
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." |
|
scores = processor(input_ids, scores, **kwargs) |
|
else: |
|
scores = processor(input_ids, scores) |
|
return scores |
|
|
|
|
|
class MinLengthLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. |
|
|
|
Args: |
|
min_length (:obj:`int`): |
|
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. |
|
eos_token_id (:obj:`int`): |
|
The id of the `end-of-sequence` token. |
|
""" |
|
|
|
def __init__(self, min_length: int, eos_token_id: int): |
|
if not isinstance(min_length, int) or min_length < 0: |
|
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") |
|
|
|
if not isinstance(eos_token_id, int) or eos_token_id < 0: |
|
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") |
|
|
|
self.min_length = min_length |
|
self.eos_token_id = eos_token_id |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
cur_len = input_ids.shape[-1] |
|
if cur_len < self.min_length: |
|
scores[:, self.eos_token_id] = -float("inf") |
|
return scores |
|
|
|
|
|
class TemperatureLogitsWarper(LogitsWarper): |
|
r""" |
|
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). |
|
|
|
Args: |
|
temperature (:obj:`float`): |
|
The value used to module the logits distribution. |
|
""" |
|
|
|
def __init__(self, temperature: float): |
|
if not isinstance(temperature, float) or not (temperature > 0): |
|
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") |
|
|
|
self.temperature = temperature |
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: |
|
scores = scores / self.temperature |
|
return scores |
|
|
|
|
|
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. |
|
|
|
Args: |
|
repetition_penalty (:obj:`float`): |
|
The parameter for repetition penalty. 1.0 means no penalty. See `this paper |
|
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. |
|
""" |
|
|
|
def __init__(self, penalty: float): |
|
if not isinstance(penalty, float) or not (penalty > 0): |
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") |
|
|
|
self.penalty = penalty |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
score = torch.gather(scores, 1, input_ids) |
|
|
|
|
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty) |
|
|
|
scores.scatter_(1, input_ids, score) |
|
return scores |
|
|
|
|
|
class TopPLogitsWarper(LogitsWarper): |
|
""" |
|
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= |
|
prob_cut_off. |
|
|
|
Args: |
|
top_p (:obj:`float`): |
|
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are |
|
kept for generation. |
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): |
|
All filtered values will be set to this float value. |
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): |
|
Minimum number of tokens that cannot be filtered. |
|
""" |
|
|
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
|
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): |
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") |
|
|
|
self.top_p = top_p |
|
self.filter_value = filter_value |
|
self.min_tokens_to_keep = min_tokens_to_keep |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True) |
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > self.top_p |
|
if self.min_tokens_to_keep > 1: |
|
|
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
scores = scores.masked_fill(indices_to_remove, self.filter_value) |
|
return scores |
|
|
|
|
|
class TopKLogitsWarper(LogitsWarper): |
|
r""" |
|
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. |
|
|
|
Args: |
|
top_k (:obj:`int`): |
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. |
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): |
|
All filtered values will be set to this float value. |
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): |
|
Minimum number of tokens that cannot be filtered. |
|
""" |
|
|
|
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
|
if not isinstance(top_k, int) or top_k <= 0: |
|
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") |
|
|
|
self.top_k = top_k |
|
self.filter_value = filter_value |
|
self.min_tokens_to_keep = min_tokens_to_keep |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) |
|
|
|
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] |
|
scores = scores.masked_fill(indices_to_remove, self.filter_value) |
|
return scores |
|
|
|
|
|
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): |
|
generated_ngrams = [{} for _ in range(num_hypos)] |
|
for idx in range(num_hypos): |
|
gen_tokens = prev_input_ids[idx].tolist() |
|
generated_ngram = generated_ngrams[idx] |
|
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): |
|
prev_ngram_tuple = tuple(ngram[:-1]) |
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] |
|
return generated_ngrams |
|
|
|
|
|
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): |
|
|
|
start_idx = cur_len + 1 - ngram_size |
|
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) |
|
return banned_ngrams.get(ngram_idx, []) |
|
|
|
|
|
def _calc_banned_ngram_tokens( |
|
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int |
|
) -> List[Iterable[int]]: |
|
"""Copied from fairseq for no_repeat_ngram in beam_search""" |
|
if cur_len + 1 < ngram_size: |
|
|
|
return [[] for _ in range(num_hypos)] |
|
|
|
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) |
|
|
|
banned_tokens = [ |
|
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) |
|
for hypo_idx in range(num_hypos) |
|
] |
|
return banned_tokens |
|
|
|
|
|
class NoRepeatNGramLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq |
|
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. |
|
|
|
Args: |
|
ngram_size (:obj:`int`): |
|
All ngrams of size :obj:`ngram_size` can only occur once. |
|
""" |
|
|
|
def __init__(self, ngram_size: int): |
|
if not isinstance(ngram_size, int) or ngram_size <= 0: |
|
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") |
|
self.ngram_size = ngram_size |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
num_batch_hypotheses = scores.shape[0] |
|
cur_len = input_ids.shape[-1] |
|
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) |
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens): |
|
scores[i, banned_tokens] = -float("inf") |
|
|
|
return scores |
|
|
|
|
|
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids. |
|
See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__. |
|
|
|
Args: |
|
encoder_ngram_size (:obj:`int`): |
|
All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids. |
|
encoder_input_ids (:obj:`int`): |
|
The encoder_input_ids that should not be repeated within the decoder ids. |
|
""" |
|
|
|
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): |
|
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: |
|
raise ValueError( |
|
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" |
|
) |
|
self.ngram_size = encoder_ngram_size |
|
if len(encoder_input_ids.shape) == 1: |
|
encoder_input_ids = encoder_input_ids.unsqueeze(0) |
|
self.batch_size = encoder_input_ids.shape[0] |
|
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
num_hypos = scores.shape[0] |
|
num_beams = num_hypos // self.batch_size |
|
cur_len = input_ids.shape[-1] |
|
banned_batch_tokens = [ |
|
_get_generated_ngrams( |
|
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len |
|
) |
|
for hypo_idx in range(num_hypos) |
|
] |
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens): |
|
scores[i, banned_tokens] = -float("inf") |
|
|
|
return scores |
|
|
|
|
|
class NoBadWordsLogitsProcessor(LogitsProcessor): |
|
""" |
|
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled. |
|
|
|
Args: |
|
bad_words_ids (:obj:`List[List[int]]`): |
|
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words |
|
that should not appear in the generated text, use :obj:`tokenizer(bad_word, |
|
add_prefix_space=True).input_ids`. |
|
eos_token_id (:obj:`int`): |
|
The id of the `end-of-sequence` token. |
|
""" |
|
|
|
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int): |
|
|
|
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: |
|
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.") |
|
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): |
|
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") |
|
if any( |
|
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) |
|
for bad_word_ids in bad_words_ids |
|
): |
|
raise ValueError( |
|
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." |
|
) |
|
|
|
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) |
|
|
|
for banned_token_seq in self.bad_words_ids: |
|
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list" |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
banned_tokens = self._calc_banned_bad_words_ids(input_ids) |
|
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens) |
|
|
|
return scores |
|
|
|
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: |
|
if len(tokens) == 0: |
|
|
|
return True |
|
elif len(tokens) > len(prev_tokens): |
|
|
|
return False |
|
elif prev_tokens[-len(tokens) :].tolist() == tokens: |
|
|
|
return True |
|
else: |
|
return False |
|
|
|
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]: |
|
banned_tokens = [] |
|
for prev_input_ids_slice in prev_input_ids: |
|
banned_tokens_slice = [] |
|
for banned_token_seq in self.bad_words_ids: |
|
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: |
|
|
|
continue |
|
|
|
banned_tokens_slice.append(banned_token_seq[-1]) |
|
|
|
banned_tokens.append(banned_tokens_slice) |
|
|
|
return banned_tokens |
|
|
|
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: |
|
""" |
|
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a |
|
list of list of banned tokens to ban in the format [[batch index, vocabulary position],... |
|
|
|
Args: |
|
scores: logits distribution of shape (batch size, vocabulary size) |
|
banned_tokens: list of list of tokens to ban of length (batch_size) |
|
""" |
|
banned_mask_list = [] |
|
for idx, batch_banned_tokens in enumerate(banned_tokens): |
|
for token in batch_banned_tokens: |
|
|
|
if token <= scores.shape[1]: |
|
banned_mask_list.append([idx, token]) |
|
else: |
|
logger.error( |
|
f"An invalid bad word ID is defined: {token}. This ID is not contained in the" |
|
f"vocabulary, and is therefore ignored." |
|
) |
|
if not banned_mask_list: |
|
return scores |
|
|
|
banned_mask = torch.LongTensor(banned_mask_list) |
|
indices = torch.ones(len(banned_mask)) |
|
|
|
|
|
|
|
|
|
|
|
banned_mask = ( |
|
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() |
|
) |
|
scores = scores.masked_fill(banned_mask, -float("inf")) |
|
return scores |
|
|
|
|
|
class PrefixConstrainedLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` that enforces constrained generation and is useful for prefix-conditioned |
|
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more |
|
information. |
|
|
|
Args: |
|
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`): |
|
This function constraints the beam search to allowed tokens only at each step. This function takes 2 |
|
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed |
|
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and |
|
the batch ID :obj:`batch_id`. |
|
""" |
|
|
|
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): |
|
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn |
|
self._num_beams = num_beams |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
mask = torch.full_like(scores, -math.inf) |
|
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): |
|
for beam_id, sent in enumerate(beam_sent): |
|
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 |
|
|
|
return scores + mask |
|
|
|
|
|
class HammingDiversityLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only |
|
effective for :meth:`transformers.PreTrainedModel.group_beam_search`. See `Diverse Beam Search: Decoding Diverse |
|
Solutions from Neural Sequence Models <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. |
|
|
|
Args: |
|
diversity_penalty (:obj:`float`): |
|
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a |
|
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. |
|
num_beams (:obj:`int`): |
|
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for |
|
more details. |
|
num_beam_groups (:obj:`int`): |
|
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of |
|
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. |
|
""" |
|
|
|
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): |
|
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): |
|
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") |
|
self._diversity_penalty = diversity_penalty |
|
if not isinstance(num_beams, int) or num_beams < 2: |
|
raise ValueError("`num_beams` should be an integer strictly larger than 1.") |
|
self._num_beams = num_beams |
|
if not isinstance(num_beam_groups, int) or num_beam_groups < 2: |
|
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") |
|
if num_beam_groups > num_beams: |
|
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") |
|
self._num_sub_beams = num_beams // num_beam_groups |
|
|
|
def __call__( |
|
self, |
|
input_ids: torch.LongTensor, |
|
scores: torch.FloatTensor, |
|
current_tokens: torch.LongTensor, |
|
beam_group_idx: int, |
|
) -> torch.FloatTensor: |
|
|
|
|
|
batch_size = current_tokens.shape[0] // self._num_beams |
|
group_start_idx = beam_group_idx * self._num_sub_beams |
|
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) |
|
group_size = group_end_idx - group_start_idx |
|
vocab_size = scores.shape[-1] |
|
|
|
if group_start_idx == 0: |
|
return scores |
|
|
|
for batch_idx in range(batch_size): |
|
|
|
previous_group_tokens = current_tokens[ |
|
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx |
|
] |
|
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) |
|
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency |
|
|
|
return scores |
|
|
|
|
|
class ForcedBOSTokenLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token. |
|
|
|
Args: |
|
bos_token_id (:obj:`int`): |
|
The id of the token to force as the first generated token. |
|
""" |
|
|
|
def __init__(self, bos_token_id: int): |
|
self.bos_token_id = bos_token_id |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
cur_len = input_ids.shape[-1] |
|
if cur_len == 1: |
|
num_tokens = scores.shape[1] |
|
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") |
|
scores[:, self.bos_token_id] = 0 |
|
return scores |
|
|
|
|
|
class ForcedEOSTokenLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when |
|
:obj:`max_length` is reached. |
|
|
|
Args: |
|
max_length (:obj:`int`): |
|
The maximum length of the sequence to be generated. |
|
eos_token_id (:obj:`int`): |
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. |
|
""" |
|
|
|
def __init__(self, max_length: int, eos_token_id: int): |
|
self.max_length = max_length |
|
self.eos_token_id = eos_token_id |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
cur_len = input_ids.shape[-1] |
|
if cur_len == self.max_length - 1: |
|
num_tokens = scores.shape[1] |
|
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") |
|
scores[:, self.eos_token_id] = 0 |
|
return scores |
|
|
|
|
|
class InfNanRemoveLogitsProcessor(LogitsProcessor): |
|
r""" |
|
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation |
|
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the |
|
generation method. :obj:`max_length` is reached. |
|
""" |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
scores[scores != scores] = 0.0 |
|
|
|
|
|
scores[scores == float("inf")] = torch.finfo(scores.dtype).max |
|
|
|
return scores |
|
|