Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2020 The HuggingFace Inc. team | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
import math | |
from abc import ABC | |
from typing import Callable, Iterable, List | |
import numpy as np | |
import torch | |
from .file_utils import add_start_docstrings | |
from .utils.logging import get_logger | |
logger = get_logger(__name__) | |
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using :class:`~transformers.BertTokenizer`. See | |
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam | |
search or log softmax for each vocabulary token when using beam search | |
kwargs: | |
Additional logits processor specific kwargs. | |
Return: | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. | |
""" | |
class LogitsProcessor(ABC): | |
"""Abstract base class for all logit processors that can be applied during generation.""" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
"""Torch method for processing logits.""" | |
raise NotImplementedError( | |
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." | |
) | |
class LogitsWarper(ABC): | |
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
"""Torch method for warping logits.""" | |
raise NotImplementedError( | |
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." | |
) | |
class LogitsProcessorList(list): | |
""" | |
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or | |
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from | |
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or | |
:class:`~transformers.LogitsWarper` to the inputs. | |
""" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: | |
for processor in self: | |
function_args = inspect.signature(processor.__call__).parameters | |
if len(function_args) > 2: | |
assert all( | |
arg in kwargs for arg in list(function_args.keys())[2:] | |
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." | |
scores = processor(input_ids, scores, **kwargs) | |
else: | |
scores = processor(input_ids, scores) | |
return scores | |
class MinLengthLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. | |
Args: | |
min_length (:obj:`int`): | |
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. | |
eos_token_id (:obj:`int`): | |
The id of the `end-of-sequence` token. | |
""" | |
def __init__(self, min_length: int, eos_token_id: int): | |
if not isinstance(min_length, int) or min_length < 0: | |
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") | |
if not isinstance(eos_token_id, int) or eos_token_id < 0: | |
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") | |
self.min_length = min_length | |
self.eos_token_id = eos_token_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
cur_len = input_ids.shape[-1] | |
if cur_len < self.min_length: | |
scores[:, self.eos_token_id] = -float("inf") | |
return scores | |
class TemperatureLogitsWarper(LogitsWarper): | |
r""" | |
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). | |
Args: | |
temperature (:obj:`float`): | |
The value used to module the logits distribution. | |
""" | |
def __init__(self, temperature: float): | |
if not isinstance(temperature, float) or not (temperature > 0): | |
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") | |
self.temperature = temperature | |
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: | |
scores = scores / self.temperature | |
return scores | |
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. | |
Args: | |
repetition_penalty (:obj:`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See `this paper | |
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. | |
""" | |
def __init__(self, penalty: float): | |
if not isinstance(penalty, float) or not (penalty > 0): | |
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") | |
self.penalty = penalty | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
score = torch.gather(scores, 1, input_ids) | |
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability | |
score = torch.where(score < 0, score * self.penalty, score / self.penalty) | |
scores.scatter_(1, input_ids, score) | |
return scores | |
class TopPLogitsWarper(LogitsWarper): | |
""" | |
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= | |
prob_cut_off. | |
Args: | |
top_p (:obj:`float`): | |
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are | |
kept for generation. | |
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): | |
All filtered values will be set to this float value. | |
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): | |
Minimum number of tokens that cannot be filtered. | |
""" | |
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): | |
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): | |
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") | |
self.top_p = top_p | |
self.filter_value = filter_value | |
self.min_tokens_to_keep = min_tokens_to_keep | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
sorted_logits, sorted_indices = torch.sort(scores, descending=True) | |
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs > self.top_p | |
if self.min_tokens_to_keep > 1: | |
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
scores = scores.masked_fill(indices_to_remove, self.filter_value) | |
return scores | |
class TopKLogitsWarper(LogitsWarper): | |
r""" | |
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. | |
Args: | |
top_k (:obj:`int`): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): | |
All filtered values will be set to this float value. | |
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): | |
Minimum number of tokens that cannot be filtered. | |
""" | |
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): | |
if not isinstance(top_k, int) or top_k <= 0: | |
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") | |
self.top_k = top_k | |
self.filter_value = filter_value | |
self.min_tokens_to_keep = min_tokens_to_keep | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] | |
scores = scores.masked_fill(indices_to_remove, self.filter_value) | |
return scores | |
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): | |
generated_ngrams = [{} for _ in range(num_hypos)] | |
for idx in range(num_hypos): | |
gen_tokens = prev_input_ids[idx].tolist() | |
generated_ngram = generated_ngrams[idx] | |
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): | |
prev_ngram_tuple = tuple(ngram[:-1]) | |
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] | |
return generated_ngrams | |
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): | |
# Before decoding the next token, prevent decoding of ngrams that have already appeared | |
start_idx = cur_len + 1 - ngram_size | |
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) | |
return banned_ngrams.get(ngram_idx, []) | |
def _calc_banned_ngram_tokens( | |
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int | |
) -> List[Iterable[int]]: | |
"""Copied from fairseq for no_repeat_ngram in beam_search""" | |
if cur_len + 1 < ngram_size: | |
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |
return [[] for _ in range(num_hypos)] | |
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) | |
banned_tokens = [ | |
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) | |
for hypo_idx in range(num_hypos) | |
] | |
return banned_tokens | |
class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq | |
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. | |
Args: | |
ngram_size (:obj:`int`): | |
All ngrams of size :obj:`ngram_size` can only occur once. | |
""" | |
def __init__(self, ngram_size: int): | |
if not isinstance(ngram_size, int) or ngram_size <= 0: | |
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") | |
self.ngram_size = ngram_size | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
num_batch_hypotheses = scores.shape[0] | |
cur_len = input_ids.shape[-1] | |
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) | |
for i, banned_tokens in enumerate(banned_batch_tokens): | |
scores[i, banned_tokens] = -float("inf") | |
return scores | |
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids. | |
See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__. | |
Args: | |
encoder_ngram_size (:obj:`int`): | |
All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids. | |
encoder_input_ids (:obj:`int`): | |
The encoder_input_ids that should not be repeated within the decoder ids. | |
""" | |
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): | |
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: | |
raise ValueError( | |
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" | |
) | |
self.ngram_size = encoder_ngram_size | |
if len(encoder_input_ids.shape) == 1: | |
encoder_input_ids = encoder_input_ids.unsqueeze(0) | |
self.batch_size = encoder_input_ids.shape[0] | |
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
# B x num_beams | |
num_hypos = scores.shape[0] | |
num_beams = num_hypos // self.batch_size | |
cur_len = input_ids.shape[-1] | |
banned_batch_tokens = [ | |
_get_generated_ngrams( | |
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len | |
) | |
for hypo_idx in range(num_hypos) | |
] | |
for i, banned_tokens in enumerate(banned_batch_tokens): | |
scores[i, banned_tokens] = -float("inf") | |
return scores | |
class NoBadWordsLogitsProcessor(LogitsProcessor): | |
""" | |
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled. | |
Args: | |
bad_words_ids (:obj:`List[List[int]]`): | |
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words | |
that should not appear in the generated text, use :obj:`tokenizer(bad_word, | |
add_prefix_space=True).input_ids`. | |
eos_token_id (:obj:`int`): | |
The id of the `end-of-sequence` token. | |
""" | |
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int): | |
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: | |
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.") | |
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): | |
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") | |
if any( | |
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) | |
for bad_word_ids in bad_words_ids | |
): | |
raise ValueError( | |
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." | |
) | |
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) | |
for banned_token_seq in self.bad_words_ids: | |
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
banned_tokens = self._calc_banned_bad_words_ids(input_ids) | |
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens) | |
return scores | |
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: | |
if len(tokens) == 0: | |
# if bad word tokens is just one token always ban it | |
return True | |
elif len(tokens) > len(prev_tokens): | |
# if bad word tokens are longer then prev input_ids they can't be equal | |
return False | |
elif prev_tokens[-len(tokens) :].tolist() == tokens: | |
# if tokens match | |
return True | |
else: | |
return False | |
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]: | |
banned_tokens = [] | |
for prev_input_ids_slice in prev_input_ids: | |
banned_tokens_slice = [] | |
for banned_token_seq in self.bad_words_ids: | |
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: | |
# if tokens do not match continue | |
continue | |
banned_tokens_slice.append(banned_token_seq[-1]) | |
banned_tokens.append(banned_tokens_slice) | |
return banned_tokens | |
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: | |
""" | |
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a | |
list of list of banned tokens to ban in the format [[batch index, vocabulary position],... | |
Args: | |
scores: logits distribution of shape (batch size, vocabulary size) | |
banned_tokens: list of list of tokens to ban of length (batch_size) | |
""" | |
banned_mask_list = [] | |
for idx, batch_banned_tokens in enumerate(banned_tokens): | |
for token in batch_banned_tokens: | |
# Eliminates invalid bad word IDs that are over the vocabulary size. | |
if token <= scores.shape[1]: | |
banned_mask_list.append([idx, token]) | |
else: | |
logger.error( | |
f"An invalid bad word ID is defined: {token}. This ID is not contained in the" | |
f"vocabulary, and is therefore ignored." | |
) | |
if not banned_mask_list: | |
return scores | |
banned_mask = torch.LongTensor(banned_mask_list) | |
indices = torch.ones(len(banned_mask)) | |
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: | |
# [ 0 1 1 ] | |
# [ 0 0 0 ] | |
# [ 1 0 0 ] | |
banned_mask = ( | |
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() | |
) | |
scores = scores.masked_fill(banned_mask, -float("inf")) | |
return scores | |
class PrefixConstrainedLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` that enforces constrained generation and is useful for prefix-conditioned | |
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more | |
information. | |
Args: | |
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`): | |
This function constraints the beam search to allowed tokens only at each step. This function takes 2 | |
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed | |
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and | |
the batch ID :obj:`batch_id`. | |
""" | |
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): | |
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn | |
self._num_beams = num_beams | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
mask = torch.full_like(scores, -math.inf) | |
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): | |
for beam_id, sent in enumerate(beam_sent): | |
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 | |
return scores + mask | |
class HammingDiversityLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only | |
effective for :meth:`transformers.PreTrainedModel.group_beam_search`. See `Diverse Beam Search: Decoding Diverse | |
Solutions from Neural Sequence Models <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. | |
Args: | |
diversity_penalty (:obj:`float`): | |
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a | |
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. | |
num_beams (:obj:`int`): | |
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for | |
more details. | |
num_beam_groups (:obj:`int`): | |
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of | |
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. | |
""" | |
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): | |
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): | |
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") | |
self._diversity_penalty = diversity_penalty | |
if not isinstance(num_beams, int) or num_beams < 2: | |
raise ValueError("`num_beams` should be an integer strictly larger than 1.") | |
self._num_beams = num_beams | |
if not isinstance(num_beam_groups, int) or num_beam_groups < 2: | |
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") | |
if num_beam_groups > num_beams: | |
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") | |
self._num_sub_beams = num_beams // num_beam_groups | |
def __call__( | |
self, | |
input_ids: torch.LongTensor, | |
scores: torch.FloatTensor, | |
current_tokens: torch.LongTensor, | |
beam_group_idx: int, | |
) -> torch.FloatTensor: | |
# hamming diversity: penalise using same token in current group which was used in previous groups at | |
# the same time step | |
batch_size = current_tokens.shape[0] // self._num_beams | |
group_start_idx = beam_group_idx * self._num_sub_beams | |
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) | |
group_size = group_end_idx - group_start_idx | |
vocab_size = scores.shape[-1] | |
if group_start_idx == 0: | |
return scores | |
for batch_idx in range(batch_size): | |
# predicted tokens of last time step of previous groups | |
previous_group_tokens = current_tokens[ | |
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx | |
] | |
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) | |
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency | |
return scores | |
class ForcedBOSTokenLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token. | |
Args: | |
bos_token_id (:obj:`int`): | |
The id of the token to force as the first generated token. | |
""" | |
def __init__(self, bos_token_id: int): | |
self.bos_token_id = bos_token_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
cur_len = input_ids.shape[-1] | |
if cur_len == 1: | |
num_tokens = scores.shape[1] | |
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") | |
scores[:, self.bos_token_id] = 0 | |
return scores | |
class ForcedEOSTokenLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when | |
:obj:`max_length` is reached. | |
Args: | |
max_length (:obj:`int`): | |
The maximum length of the sequence to be generated. | |
eos_token_id (:obj:`int`): | |
The id of the token to force as the last generated token when :obj:`max_length` is reached. | |
""" | |
def __init__(self, max_length: int, eos_token_id: int): | |
self.max_length = max_length | |
self.eos_token_id = eos_token_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
cur_len = input_ids.shape[-1] | |
if cur_len == self.max_length - 1: | |
num_tokens = scores.shape[1] | |
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") | |
scores[:, self.eos_token_id] = 0 | |
return scores | |
class InfNanRemoveLogitsProcessor(LogitsProcessor): | |
r""" | |
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation | |
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the | |
generation method. :obj:`max_length` is reached. | |
""" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
# set all nan values to 0.0 | |
scores[scores != scores] = 0.0 | |
# set all inf values to max possible value | |
scores[scores == float("inf")] = torch.finfo(scores.dtype).max | |
return scores | |