Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import warnings | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from .file_utils import ModelOutput | |
from .generation_beam_search import BeamScorer, BeamSearchScorer | |
from .generation_logits_process import ( | |
EncoderNoRepeatNGramLogitsProcessor, | |
ForcedBOSTokenLogitsProcessor, | |
ForcedEOSTokenLogitsProcessor, | |
HammingDiversityLogitsProcessor, | |
InfNanRemoveLogitsProcessor, | |
LogitsProcessorList, | |
MinLengthLogitsProcessor, | |
NoBadWordsLogitsProcessor, | |
NoRepeatNGramLogitsProcessor, | |
PrefixConstrainedLogitsProcessor, | |
RepetitionPenaltyLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopKLogitsWarper, | |
TopPLogitsWarper, | |
) | |
from .generation_stopping_criteria import ( | |
MaxLengthCriteria, | |
MaxNewTokensCriteria, | |
MaxTimeCriteria, | |
StoppingCriteriaList, | |
validate_stopping_criteria, | |
) | |
from .utils import logging | |
logger = logging.get_logger(__name__) | |
class GreedySearchDecoderOnlyOutput(ModelOutput): | |
""" | |
Base class for outputs of decoder-only generation models using greedy search. | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` | |
with each tensor of shape :obj:`(batch_size, config.vocab_size)`). | |
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class GreedySearchEncoderDecoderOutput(ModelOutput): | |
""" | |
Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention | |
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the | |
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor | |
of shape :obj:`(batch_size, config.vocab_size)`). | |
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, | |
num_heads, sequence_length, sequence_length)`. | |
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class SampleDecoderOnlyOutput(ModelOutput): | |
""" | |
Base class for outputs of decoder-only generation models using sampling. | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` | |
with each tensor of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`). | |
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences*batch_size, num_heads, generated_length, | |
sequence_length)`. | |
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences*batch_size, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class SampleEncoderDecoderOutput(ModelOutput): | |
""" | |
Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of | |
the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states | |
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor | |
of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`). | |
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape | |
:obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. | |
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`. | |
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length, | |
sequence_length)`. | |
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class BeamSearchDecoderOnlyOutput(ModelOutput): | |
""" | |
Base class for outputs of decoder-only generation models using beam search. | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Final beam scores of the generated ``sequences``. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log | |
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam | |
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of | |
shape :obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). | |
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, | |
sequence_length)`. | |
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, | |
hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
sequences_scores: Optional[torch.FloatTensor] = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class BeamSearchEncoderDecoderOutput(ModelOutput): | |
""" | |
Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights | |
of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states | |
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Final beam scores of the generated ``sequences``. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log | |
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam | |
. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape | |
:obj:`(batch_size*num_beams, config.vocab_size)`). | |
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, | |
num_heads, sequence_length, sequence_length)`. | |
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. | |
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, | |
generated_length, sequence_length)`. | |
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, | |
hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
sequences_scores: Optional[torch.FloatTensor] = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class BeamSampleDecoderOnlyOutput(ModelOutput): | |
""" | |
Base class for outputs of decoder-only generation models using beam sample. | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Final beam scores of the generated ``sequences``. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log | |
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam | |
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of | |
shape :obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). | |
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, | |
sequence_length)`. | |
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
sequences_scores: Optional[torch.FloatTensor] = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
class BeamSampleEncoderDecoderOutput(ModelOutput): | |
""" | |
Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention | |
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the | |
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) | |
Args: | |
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_beams, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | |
shorter if all batches finished early due to the :obj:`eos_token_id`. | |
sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Final beam scores of the generated ``sequences``. | |
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): | |
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log | |
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam | |
. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape | |
:obj:`(batch_size*num_beams, config.vocab_size)`). | |
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, | |
num_heads, sequence_length, sequence_length)`. | |
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size*num_beams, sequence_length, hidden_size)`. | |
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, | |
sequence_length)`. | |
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. | |
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. | |
""" | |
sequences: torch.LongTensor = None | |
sequences_scores: Optional[torch.FloatTensor] = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] | |
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] | |
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] | |
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] | |
class GenerationMixin: | |
""" | |
A class containing all of the functions supporting generation, to be used as a mixin in | |
:class:`~transformers.PreTrainedModel`. | |
""" | |
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: | |
""" | |
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the | |
generate method. | |
""" | |
return {"input_ids": input_ids} | |
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: | |
""" | |
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in | |
the generate method. | |
""" | |
return logits | |
def _prepare_input_ids_for_generation( | |
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] | |
) -> torch.LongTensor: | |
if self.config.is_encoder_decoder and encoder_outputs is not None: | |
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding | |
shape = encoder_outputs.last_hidden_state.size()[:-1] | |
return torch.ones(shape, dtype=torch.long, device=self.device) * -100 | |
if bos_token_id is None: | |
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") | |
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id | |
def _prepare_attention_mask_for_generation( | |
self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int | |
) -> torch.LongTensor: | |
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) | |
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( | |
(eos_token_id is not None) and (pad_token_id != eos_token_id) | |
) | |
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: | |
return input_ids.ne(pad_token_id).long() | |
return input_ids.new_ones(input_ids.shape, dtype=torch.long) | |
def _prepare_encoder_decoder_kwargs_for_generation( | |
self, input_ids: torch.LongTensor, model_kwargs | |
) -> Dict[str, Any]: | |
if "encoder_outputs" not in model_kwargs: | |
# retrieve encoder hidden states | |
encoder = self.get_encoder() | |
encoder_kwargs = { | |
argument: value | |
for argument, value in model_kwargs.items() | |
if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) | |
} | |
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) | |
return model_kwargs | |
def _prepare_decoder_input_ids_for_generation( | |
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None | |
) -> torch.LongTensor: | |
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) | |
decoder_input_ids = ( | |
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id | |
) | |
return decoder_input_ids | |
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | |
if pad_token_id is None and eos_token_id is not None: | |
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |
pad_token_id = eos_token_id | |
return pad_token_id | |
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: | |
decoder_start_token_id = ( | |
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id | |
) | |
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | |
if decoder_start_token_id is not None: | |
return decoder_start_token_id | |
elif ( | |
hasattr(self.config, "decoder") | |
and hasattr(self.config.decoder, "decoder_start_token_id") | |
and self.config.decoder.decoder_start_token_id is not None | |
): | |
return self.config.decoder.decoder_start_token_id | |
elif bos_token_id is not None: | |
return bos_token_id | |
elif ( | |
hasattr(self.config, "decoder") | |
and hasattr(self.config.decoder, "bos_token_id") | |
and self.config.decoder.bos_token_id is not None | |
): | |
return self.config.decoder.bos_token_id | |
raise ValueError( | |
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." | |
) | |
def _expand_inputs_for_generation( | |
input_ids: torch.LongTensor, | |
expand_size: int = 1, | |
is_encoder_decoder: bool = False, | |
attention_mask: torch.LongTensor = None, | |
encoder_outputs: ModelOutput = None, | |
**model_kwargs, | |
) -> Tuple[torch.LongTensor, Dict[str, Any]]: | |
expanded_return_idx = ( | |
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) | |
) | |
input_ids = input_ids.index_select(0, expanded_return_idx) | |
if "token_type_ids" in model_kwargs: | |
token_type_ids = model_kwargs["token_type_ids"] | |
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) | |
if attention_mask is not None: | |
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) | |
if is_encoder_decoder: | |
assert encoder_outputs is not None | |
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | |
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) | |
) | |
model_kwargs["encoder_outputs"] = encoder_outputs | |
return input_ids, model_kwargs | |
def _update_model_kwargs_for_generation( | |
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False | |
) -> Dict[str, Any]: | |
# update past | |
if "past_key_values" in outputs: | |
model_kwargs["past"] = outputs.past_key_values | |
elif "mems" in outputs: | |
model_kwargs["past"] = outputs.mems | |
elif "past_buckets_states" in outputs: | |
model_kwargs["past"] = outputs.past_buckets_states | |
else: | |
model_kwargs["past"] = None | |
# update token_type_ids with last value | |
if "token_type_ids" in model_kwargs: | |
token_type_ids = model_kwargs["token_type_ids"] | |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) | |
# update attention mask | |
if not is_encoder_decoder: | |
if "attention_mask" in model_kwargs: | |
attention_mask = model_kwargs["attention_mask"] | |
model_kwargs["attention_mask"] = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | |
) | |
return model_kwargs | |
def _reorder_cache(self, past, beam_idx): | |
raise NotImplementedError( | |
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" | |
) | |
def _get_logits_warper( | |
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None | |
) -> LogitsProcessorList: | |
""" | |
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant | |
:obj:`~transformers.LogitsWarper` instances used for multinomial sampling. | |
""" | |
# init warp parameters | |
top_k = top_k if top_k is not None else self.config.top_k | |
top_p = top_p if top_p is not None else self.config.top_p | |
temperature = temperature if temperature is not None else self.config.temperature | |
# instantiate warpers list | |
warpers = LogitsProcessorList() | |
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
# all samplers can be found in `generation_utils_samplers.py` | |
if temperature is not None and temperature != 1.0: | |
warpers.append(TemperatureLogitsWarper(temperature)) | |
if top_k is not None and top_k != 0: | |
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) | |
if top_p is not None and top_p < 1.0: | |
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) | |
return warpers | |
def _get_logits_processor( | |
self, | |
repetition_penalty: float, | |
no_repeat_ngram_size: int, | |
encoder_no_repeat_ngram_size: int, | |
encoder_input_ids: torch.LongTensor, | |
bad_words_ids: List[List[int]], | |
min_length: int, | |
max_length: int, | |
eos_token_id: int, | |
forced_bos_token_id: int, | |
forced_eos_token_id: int, | |
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], | |
num_beams: int, | |
num_beam_groups: int, | |
diversity_penalty: float, | |
remove_invalid_values: bool, | |
) -> LogitsProcessorList: | |
""" | |
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant | |
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. | |
""" | |
processors = LogitsProcessorList() | |
# init warp parameters | |
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty | |
no_repeat_ngram_size = ( | |
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size | |
) | |
encoder_no_repeat_ngram_size = ( | |
encoder_no_repeat_ngram_size | |
if encoder_no_repeat_ngram_size is not None | |
else self.config.encoder_no_repeat_ngram_size | |
) | |
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids | |
min_length = min_length if min_length is not None else self.config.min_length | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty | |
forced_bos_token_id = ( | |
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id | |
) | |
forced_eos_token_id = ( | |
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id | |
) | |
remove_invalid_values = ( | |
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values | |
) | |
# instantiate processors list | |
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
# all samplers can be found in `generation_utils_samplers.py` | |
if diversity_penalty is not None and diversity_penalty > 0.0: | |
processors.append( | |
HammingDiversityLogitsProcessor( | |
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups | |
) | |
) | |
if repetition_penalty is not None and repetition_penalty != 1.0: | |
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) | |
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: | |
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) | |
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: | |
if self.config.is_encoder_decoder: | |
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) | |
else: | |
raise ValueError( | |
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" | |
) | |
if bad_words_ids is not None: | |
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) | |
if min_length is not None and eos_token_id is not None and min_length > -1: | |
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) | |
if prefix_allowed_tokens_fn is not None: | |
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) | |
if forced_bos_token_id is not None: | |
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) | |
if forced_eos_token_id is not None: | |
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) | |
if remove_invalid_values is True: | |
processors.append(InfNanRemoveLogitsProcessor()) | |
return processors | |
def _get_stopping_criteria( | |
self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int | |
) -> StoppingCriteriaList: | |
stopping_criteria = StoppingCriteriaList() | |
if max_length is not None: | |
stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | |
if max_time is not None: | |
stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) | |
if max_new_tokens is not None: | |
stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens)) | |
return stopping_criteria | |
def generate( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
max_length: Optional[int] = None, | |
min_length: Optional[int] = None, | |
do_sample: Optional[bool] = None, | |
early_stopping: Optional[bool] = None, | |
num_beams: Optional[int] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
repetition_penalty: Optional[float] = None, | |
bad_words_ids: Optional[Iterable[int]] = None, | |
bos_token_id: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
length_penalty: Optional[float] = None, | |
no_repeat_ngram_size: Optional[int] = None, | |
encoder_no_repeat_ngram_size: Optional[int] = None, | |
num_return_sequences: Optional[int] = None, | |
max_time: Optional[float] = None, | |
max_new_tokens: Optional[int] = None, | |
decoder_start_token_id: Optional[int] = None, | |
use_cache: Optional[bool] = None, | |
num_beam_groups: Optional[int] = None, | |
diversity_penalty: Optional[float] = None, | |
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
forced_bos_token_id: Optional[int] = None, | |
forced_eos_token_id: Optional[int] = None, | |
remove_invalid_values: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | |
multinomial sampling, beam-search decoding, and beam-search multinomial sampling. | |
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the | |
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values | |
indicated are the default values of those config. | |
Most of these parameters are explained in more detail in `this blog post | |
<https://huggingface.co/blog/how-to-generate>`__. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`): | |
The maximum length of the sequence to be generated. | |
max_new_tokens (:obj:`int`, `optional`, defaults to None): | |
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either | |
:obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose. | |
min_length (:obj:`int`, `optional`, defaults to 10): | |
The minimum length of the sequence to be generated. | |
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to use sampling ; use greedy decoding otherwise. | |
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | |
num_beams (:obj:`int`, `optional`, defaults to 1): | |
Number of beams for beam search. 1 means no beam search. | |
temperature (:obj:`float`, `optional`, defaults to 1.0): | |
The value used to module the next token probabilities. | |
top_k (:obj:`int`, `optional`, defaults to 50): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (:obj:`float`, `optional`, defaults to 1.0): | |
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or | |
higher are kept for generation. | |
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): | |
The parameter for repetition penalty. 1.0 means no penalty. See `this paper | |
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
bos_token_id (:obj:`int`, `optional`): | |
The id of the `beginning-of-sequence` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
length_penalty (:obj:`float`, `optional`, defaults to 1.0): | |
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the | |
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer | |
sequences. | |
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): | |
If set to int > 0, all ngrams of that size can only occur once. | |
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): | |
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the | |
``decoder_input_ids``. | |
bad_words_ids(:obj:`List[List[int]]`, `optional`): | |
List of token ids that are not allowed to be generated. In order to get the tokens of the words that | |
should not appear in the generated text, use :obj:`tokenizer(bad_word, | |
add_prefix_space=True).input_ids`. | |
num_return_sequences(:obj:`int`, `optional`, defaults to 1): | |
The number of independently computed returned sequences for each element in the batch. | |
max_time(:obj:`float`, `optional`, defaults to None): | |
The maximum amount of time you allow the computation to run for in seconds. generation will still | |
finish the current pass after allocated time has been passed. | |
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for | |
tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same | |
shape as :obj:`input_ids` that masks the pad token. `What are attention masks? | |
<../glossary.html#attention-mask>`__ | |
decoder_start_token_id (:obj:`int`, `optional`): | |
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. | |
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
Whether or not the model should use the past last key/values attentions (if applicable to the model) to | |
speed up decoding. | |
num_beam_groups (:obj:`int`, `optional`, defaults to 1): | |
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of | |
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. | |
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): | |
This value is subtracted from a beam's score if it generates a token same as any beam from other group | |
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is | |
enabled. | |
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): | |
If provided, this function constraints the beam search to allowed tokens only at each step. If not | |
provided no constraint is applied. This function takes 2 arguments: the batch ID :obj:`batch_id` and | |
:obj:`input_ids`. It has to return a list with the allowed tokens for the next generation step | |
conditioned on the batch ID :obj:`batch_id` and the previously generated tokens :obj:`inputs_ids`. This | |
argument is useful for constrained generation conditioned on the prefix, as described in | |
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
forced_bos_token_id (:obj:`int`, `optional`): | |
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`. | |
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token | |
needs to be the target language token. | |
forced_eos_token_id (:obj:`int`, `optional`): | |
The id of the token to force as the last generated token when :obj:`max_length` is reached. | |
remove_invalid_values (:obj:`bool`, `optional`): | |
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to | |
crash. Note that using ``remove_invalid_values`` can slow down generation. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the | |
model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific | |
kwargs should be prefixed with `decoder_`. | |
Return: | |
:class:`~transformers.file_utils.ModelOutput` or :obj:`torch.LongTensor`: A | |
:class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when | |
``config.return_dict_in_generate=True``) or a :obj:`torch.FloatTensor`. | |
If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the | |
possible :class:`~transformers.file_utils.ModelOutput` types are: | |
- :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, | |
- :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, | |
- :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, | |
- :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` | |
If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible | |
:class:`~transformers.file_utils.ModelOutput` types are: | |
- :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput`, | |
- :class:`~transformers.generation_utils.SampleEncoderDecoderOutput`, | |
- :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput`, | |
- :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` | |
Examples:: | |
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
>>> # do greedy decoding without providing a prompt | |
>>> outputs = model.generate(max_length=40) | |
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> document = ( | |
... "at least two people were killed in a suspected bomb attack on a passenger bus " | |
... "in the strife-torn southern philippines on monday , the military said." | |
... ) | |
>>> # encode input context | |
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids | |
>>> # generate 3 independent sequences using beam search decoding (5 beams) | |
>>> # with T5 encoder-decoder model conditioned on short news article. | |
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
>>> input_context = "The dog" | |
>>> # encode input context | |
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids | |
>>> # generate 3 candidates using sampling | |
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl") | |
>>> model = AutoModelForCausalLM.from_pretrained("ctrl") | |
>>> # "Legal" is one of the control codes for ctrl | |
>>> input_context = "Legal My neighbor is" | |
>>> # encode input context | |
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids | |
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2) | |
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> input_context = "My cute dog" | |
>>> # get tokens of words that should not be generated | |
>>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]] | |
>>> # encode input context | |
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids | |
>>> # generate sequences without allowing bad_words to be generated | |
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) | |
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
""" | |
# set init values | |
if max_length is None and max_new_tokens is None: | |
# Both are None, default | |
max_length = self.config.max_length | |
elif max_length is not None and max_new_tokens is not None: | |
# Both are set, this is odd, raise a warning | |
warnings.warn( | |
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning | |
) | |
max_length = max_length if max_length is not None else self.config.max_length | |
num_beams = num_beams if num_beams is not None else self.config.num_beams | |
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups | |
do_sample = do_sample if do_sample is not None else self.config.do_sample | |
num_return_sequences = ( | |
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences | |
) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
model_kwargs["output_attentions"] = output_attentions | |
model_kwargs["output_hidden_states"] = output_hidden_states | |
if input_ids is None and "inputs_embeds" not in model_kwargs: | |
# init `input_ids` with bos_token_id | |
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) | |
if model_kwargs.get("attention_mask", None) is None: | |
# init `attention_mask` depending on `pad_token_id` | |
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( | |
input_ids, pad_token_id, eos_token_id | |
) | |
# special case if pad_token_id is not defined | |
if pad_token_id is None and eos_token_id is not None: | |
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |
pad_token_id = eos_token_id | |
# Storing encoder_input_ids for logits_processor that could use them | |
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None | |
if self.config.is_encoder_decoder: | |
# add encoder_outputs to model_kwargs | |
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) | |
# set input_ids as decoder_input_ids | |
if "decoder_input_ids" in model_kwargs: | |
input_ids = model_kwargs.pop("decoder_input_ids") | |
else: | |
input_ids = self._prepare_decoder_input_ids_for_generation( | |
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id | |
) | |
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): | |
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") | |
if input_ids.shape[-1] >= max_length: | |
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
logger.warning( | |
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}." | |
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." | |
) | |
# determine generation mode | |
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False | |
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True | |
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False | |
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True | |
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) | |
if num_beam_groups > num_beams: | |
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") | |
if is_group_beam_gen_mode and do_sample is True: | |
raise ValueError( | |
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." | |
) | |
# set model_kwargs | |
model_kwargs["use_cache"] = use_cache | |
# get distribution pre_processing samplers | |
logits_processor = self._get_logits_processor( | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, | |
encoder_input_ids=encoder_input_ids, | |
bad_words_ids=bad_words_ids, | |
min_length=min_length, | |
max_length=max_length, | |
eos_token_id=eos_token_id, | |
forced_bos_token_id=forced_bos_token_id, | |
forced_eos_token_id=forced_eos_token_id, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
num_beams=num_beams, | |
num_beam_groups=num_beam_groups, | |
diversity_penalty=diversity_penalty, | |
remove_invalid_values=remove_invalid_values, | |
) | |
cur_len = input_ids.shape[-1] | |
stopping_criteria = self._get_stopping_criteria( | |
max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len | |
) | |
if is_greedy_gen_mode: | |
if num_return_sequences > 1: | |
raise ValueError( | |
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." | |
) | |
# greedy search | |
return self.greedy_search( | |
input_ids, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_sample_gen_mode: | |
# get probability distribution warper | |
logits_warper = self._get_logits_warper( | |
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams | |
) | |
# expand input_ids with `num_return_sequences` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids, | |
expand_size=num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# sample | |
return self.sample( | |
input_ids, | |
logits_processor=logits_processor, | |
logits_warper=logits_warper, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_gen_mode: | |
batch_size = input_ids.shape[0] | |
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | |
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | |
if num_return_sequences > num_beams: | |
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
device=self.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
num_beam_hyps_to_keep=num_return_sequences, | |
) | |
# interleave with `num_beams` | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs | |
) | |
return self.beam_search( | |
input_ids, | |
beam_scorer, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_sample_gen_mode: | |
logits_warper = self._get_logits_warper( | |
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams | |
) | |
batch_size = input_ids.shape[0] * num_return_sequences | |
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
device=self.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
) | |
# interleave with `num_beams * num_return_sequences` | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids, | |
expand_size=num_beams * num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
return self.beam_sample( | |
input_ids, | |
beam_scorer, | |
logits_processor=logits_processor, | |
logits_warper=logits_warper, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_group_beam_gen_mode: | |
batch_size = input_ids.shape[0] | |
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | |
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | |
if num_return_sequences > num_beams: | |
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | |
if num_beams % num_beam_groups != 0: | |
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
diverse_beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
max_length=stopping_criteria.max_length, | |
device=self.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
num_beam_hyps_to_keep=num_return_sequences, | |
num_beam_groups=num_beam_groups, | |
) | |
# interleave with `num_beams` | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs | |
) | |
return self.group_beam_search( | |
input_ids, | |
diverse_beam_scorer, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
def greedy_search( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[GreedySearchOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head using greedy decoding. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | |
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from | |
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of | |
generated tokens. The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the | |
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForCausalLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... ) | |
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token | |
>>> model.config.pad_token_id = model.config.eos_token_id | |
>>> input_prompt = "Today is a beautiful day, and" | |
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | |
cur_len = input_ids.shape[-1] | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_logits,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# pre-process distribution | |
next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
# argmax | |
next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
cur_len = cur_len + 1 | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id is not None: | |
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
if return_dict_in_generate: | |
if self.config.is_encoder_decoder: | |
return GreedySearchEncoderDecoderOutput( | |
sequences=input_ids, | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return GreedySearchDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return input_ids | |
def sample( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_warper: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[SampleOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head using multinomial sampling. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | |
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from | |
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. | |
logits_warper (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language | |
modeling head applied before multinomial sampling at each generation step. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of | |
generated tokens. The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.SampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.SampleDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.SampleEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForCausalLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... TopKLogitsWarper, | |
... TemperatureLogitsWarper, | |
... ) | |
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token | |
>>> model.config.pad_token_id = model.config.eos_token_id | |
>>> input_prompt = "Today is a beautiful day, and" | |
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> # instantiate logits processors | |
>>> logits_warper = LogitsProcessorList([ | |
... TopKLogitsWarper(50), | |
... TemperatureLogitsWarper(0.7), | |
... ]) | |
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | |
cur_len = input_ids.shape[-1] | |
this_peer_finished = False # used by synced_gpus only | |
# auto-regressive generation | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# sample | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
cur_len = cur_len + 1 | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id is not None: | |
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
if return_dict_in_generate: | |
if self.config.is_encoder_decoder: | |
return SampleEncoderDecoderOutput( | |
sequences=input_ids, | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return SampleDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return input_ids | |
def beam_search( | |
self, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[BeamSearchOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head using beam search decoding. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
beam_scorer (:obj:`BeamScorer`): | |
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
constructed, stored and sorted during generation. For more information, the documentation of | |
:class:`~transformers.BeamScorer` should be read. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | |
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from | |
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of | |
generated tokens. The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForSeq2SeqLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... BeamSearchScorer, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> encoder_input_str = "translate English to German: How old are you?" | |
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
>>> # lets run beam search using 3 beams | |
>>> num_beams = 3 | |
>>> # define decoder start token ids | |
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
>>> input_ids = input_ids * model.config.decoder_start_token_id | |
>>> # add encoder_outputs to model keyword arguments | |
>>> model_kwargs = { | |
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
... } | |
>>> # instantiate beam scorer | |
>>> beam_scorer = BeamSearchScorer( | |
... batch_size=1, | |
... num_beams=num_beams, | |
... device=model.device, | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
if len(stopping_criteria) == 0: | |
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
batch_beam_size, cur_len = input_ids.shape | |
assert ( | |
num_beams * batch_size == batch_beam_size | |
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores[:, 1:] = -1e9 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
) | |
next_indices = next_tokens // vocab_size | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] = None | |
if self.config.is_encoder_decoder: | |
return BeamSearchEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSearchDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] | |
def beam_sample( | |
self, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_warper: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[BeamSampleOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head using beam search with multinomial sampling. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
beam_scorer (:obj:`BeamScorer`): | |
A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
constructed, stored and sorted during generation. For more information, the documentation of | |
:class:`~transformers.BeamScorer` should be read. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | |
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from | |
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. | |
logits_warper (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language | |
modeling head applied before multinomial sampling at each generation step. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of | |
generated tokens. The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForSeq2SeqLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... TopKLogitsWarper, | |
... TemperatureLogitsWarper, | |
... BeamSearchScorer, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> encoder_input_str = "translate English to German: How old are you?" | |
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
>>> # lets run beam search using 3 beams | |
>>> num_beams = 3 | |
>>> # define decoder start token ids | |
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
>>> input_ids = input_ids * model.config.decoder_start_token_id | |
>>> # add encoder_outputs to model keyword arguments | |
>>> model_kwargs = { | |
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
... } | |
>>> # instantiate beam scorer | |
>>> beam_scorer = BeamSearchScorer( | |
... batch_size=1, | |
... max_length=model.config.max_length, | |
... num_beams=num_beams, | |
... device=model.device, | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id) | |
... ]) | |
>>> # instantiate logits processors | |
>>> logits_warper = LogitsProcessorList([ | |
... TopKLogitsWarper(50), | |
... TemperatureLogitsWarper(0.7), | |
... ]) | |
>>> outputs = model.beam_sample( | |
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs | |
... ) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
batch_beam_size, cur_len = input_ids.shape | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) | |
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) | |
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) | |
next_tokens = torch.gather(next_tokens, -1, _indices) | |
next_indices = next_tokens // vocab_size | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] = None | |
if self.config.is_encoder_decoder: | |
return BeamSampleEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSampleDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] | |
def group_beam_search( | |
self, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = None, | |
**model_kwargs, | |
): | |
r""" | |
Generates sequences for models with a language modeling head using beam search decoding. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
beam_scorer (:obj:`BeamScorer`): | |
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
constructed, stored and sorted during generation. For more information, the documentation of | |
:class:`~transformers.BeamScorer` should be read. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | |
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from | |
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of | |
generated tokens. The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
model_kwargs: | |
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForSeq2SeqLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... HammingDiversityLogitsProcessor, | |
... BeamSearchScorer, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> encoder_input_str = "translate English to German: How old are you?" | |
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
>>> # lets run diverse beam search using 6 beams | |
>>> num_beams = 6 | |
>>> # define decoder start token ids | |
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
>>> input_ids = input_ids * model.config.decoder_start_token_id | |
>>> # add encoder_outputs to model keyword arguments | |
>>> model_kwargs = { | |
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
... } | |
>>> # instantiate beam scorer | |
>>> beam_scorer = BeamSearchScorer( | |
... batch_size=1, | |
... max_length=model.config.max_length, | |
... num_beams=num_beams, | |
... device=model.device, | |
... num_beam_groups=3 | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), | |
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
num_beam_groups = beam_scorer.num_beam_groups | |
num_sub_beams = num_beams // num_beam_groups | |
device = input_ids.device | |
batch_beam_size, cur_len = input_ids.shape | |
assert ( | |
num_beams * batch_size == batch_beam_size | |
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) | |
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in | |
# the same group don't produce same tokens everytime. | |
beam_scores[:, ::num_sub_beams] = 0 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# predicted tokens in cur_len step | |
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) | |
# indices which will form the beams in the next time step | |
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) | |
# do one decoder step on all beams of all sentences in batch | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
for beam_group_idx in range(num_beam_groups): | |
group_start_idx = beam_group_idx * num_sub_beams | |
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) | |
group_size = group_end_idx - group_start_idx | |
# indices of beams of current group among all sentences in batch | |
batch_group_indices = [] | |
if output_scores: | |
processed_score = torch.zeros_like(outputs.logits[:, -1, :]) | |
for batch_idx in range(batch_size): | |
batch_group_indices.extend( | |
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] | |
) | |
group_input_ids = input_ids[batch_group_indices] | |
# select outputs of beams of current group only | |
next_token_logits = outputs.logits[batch_group_indices, -1, :] | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * group_size, vocab_size) | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = logits_processor( | |
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx | |
) | |
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( | |
next_token_scores | |
) | |
if output_scores: | |
processed_score[batch_group_indices] = next_token_scores | |
# reshape for beam search | |
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True | |
) | |
next_indices = next_tokens // vocab_size | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
group_input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
) | |
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids[batch_group_indices] = group_input_ids[beam_idx] | |
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
current_tokens[batch_group_indices] = group_input_ids[:, -1] | |
# (beam_idx // group_size) -> batch_idx | |
# (beam_idx % group_size) -> offset of idx inside the group | |
reordering_indices[batch_group_indices] = ( | |
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) | |
) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (processed_score,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] = None | |
if self.config.is_encoder_decoder: | |
return BeamSearchEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSearchDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] | |
def top_k_top_p_filtering( | |
logits: torch.FloatTensor, | |
top_k: int = 0, | |
top_p: float = 1.0, | |
filter_value: float = -float("Inf"), | |
min_tokens_to_keep: int = 1, | |
) -> torch.FloatTensor: | |
""" | |
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (batch size, vocabulary size) | |
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
Make sure we keep at least min_tokens_to_keep per batch example in the output | |
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
""" | |
if top_k > 0: | |
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( | |
None, logits | |
) | |
if 0 <= top_p <= 1.0: | |
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) | |
return logits | |