Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyright 2021 The HuggingFace Inc. team | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from abc import ABC | |
| import jax | |
| import jax.lax as lax | |
| import jax.numpy as jnp | |
| import jaxlib.xla_extension as jax_xla | |
| from .file_utils import add_start_docstrings | |
| from .utils.logging import get_logger | |
| logger = get_logger(__name__) | |
| LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" | |
| Args: | |
| input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See | |
| :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
| details. | |
| `What are input IDs? <../glossary.html#input-ids>`__ | |
| scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`): | |
| Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam | |
| search or log softmax for each vocabulary token when using beam search | |
| kwargs: | |
| Additional logits processor specific kwargs. | |
| Return: | |
| :obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. | |
| """ | |
| class FlaxLogitsProcessor(ABC): | |
| """Abstract base class for all logit processors that can be applied during generation.""" | |
| def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: | |
| """Flax method for processing logits.""" | |
| raise NotImplementedError( | |
| f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." | |
| ) | |
| class FlaxLogitsWarper(ABC): | |
| """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" | |
| def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: | |
| """Flax method for warping logits.""" | |
| raise NotImplementedError( | |
| f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." | |
| ) | |
| class FlaxLogitsProcessorList(list): | |
| """ | |
| This class can be used to create a list of :class:`~transformers.FlaxLogitsProcessor` or | |
| :class:`~transformers.FlaxLogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits | |
| from list and adds a specific `__call__` method to apply each :class:`~transformers.FlaxLogitsProcessor` or | |
| :class:`~transformers.FlaxLogitsWarper` to the inputs. | |
| """ | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs | |
| ) -> jax_xla.DeviceArray: | |
| for processor in self: | |
| function_args = inspect.signature(processor.__call__).parameters | |
| if len(function_args) > 3: | |
| assert all( | |
| arg in kwargs for arg in list(function_args.keys())[2:] | |
| ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." | |
| scores = processor(input_ids, scores, cur_len, **kwargs) | |
| else: | |
| scores = processor(input_ids, scores, cur_len) | |
| return scores | |
| class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): | |
| r""" | |
| :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). | |
| Args: | |
| temperature (:obj:`float`): | |
| The value used to module the logits distribution. | |
| """ | |
| def __init__(self, temperature: float): | |
| if not isinstance(temperature, float) or not (temperature > 0): | |
| raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") | |
| self.temperature = temperature | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| scores = scores / self.temperature | |
| return scores | |
| class FlaxTopPLogitsWarper(FlaxLogitsWarper): | |
| """ | |
| :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= | |
| prob_cut_off. | |
| Args: | |
| top_p (:obj:`float`): | |
| If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are | |
| kept for generation. | |
| filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): | |
| All filtered values will be set to this float value. | |
| min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): | |
| Minimum number of tokens that cannot be filtered. | |
| """ | |
| def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): | |
| if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): | |
| raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") | |
| self.top_p = top_p | |
| self.filter_value = filter_value | |
| self.min_tokens_to_keep = min_tokens_to_keep | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) | |
| mask_scores = jnp.full_like(scores, self.filter_value) | |
| cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1) | |
| score_mask = cumulative_probs < self.top_p | |
| # include the token that is higher than top_p as well | |
| score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True) | |
| # min tokens to keep | |
| score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True) | |
| topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores) | |
| next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1] | |
| return next_scores | |
| class FlaxTopKLogitsWarper(FlaxLogitsWarper): | |
| r""" | |
| :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. | |
| Args: | |
| top_k (:obj:`int`): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): | |
| All filtered values will be set to this float value. | |
| min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): | |
| Minimum number of tokens that cannot be filtered. | |
| """ | |
| def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): | |
| if not isinstance(top_k, int) or top_k <= 0: | |
| raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") | |
| self.top_k = top_k | |
| self.filter_value = filter_value | |
| self.min_tokens_to_keep = min_tokens_to_keep | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| batch_size, vocab_size = scores.shape | |
| next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) | |
| topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check | |
| topk_scores, topk_indices = lax.top_k(scores, topk) | |
| shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten() | |
| topk_scores_flat = topk_scores.flatten() | |
| topk_indices_flat = topk_indices.flatten() + shift | |
| next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat) | |
| next_scores = next_scores_flat.reshape(batch_size, vocab_size) | |
| return next_scores | |
| class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): | |
| r""" | |
| :class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token. | |
| Args: | |
| bos_token_id (:obj:`int`): | |
| The id of the token to force as the first generated token. | |
| """ | |
| def __init__(self, bos_token_id: int): | |
| self.bos_token_id = bos_token_id | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| new_scores = jnp.full(scores.shape, -float("inf")) | |
| apply_penalty = 1 - jnp.bool_(cur_len - 1) | |
| scores = jnp.where( | |
| apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores | |
| ) | |
| return scores | |
| class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): | |
| r""" | |
| :class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when | |
| :obj:`max_length` is reached. | |
| Args: | |
| max_length (:obj:`int`): | |
| The maximum length of the sequence to be generated. | |
| eos_token_id (:obj:`int`): | |
| The id of the token to force as the last generated token when :obj:`max_length` is reached. | |
| """ | |
| def __init__(self, max_length: int, eos_token_id: int): | |
| self.max_length = max_length | |
| self.eos_token_id = eos_token_id | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| new_scores = jnp.full(scores.shape, -float("inf")) | |
| apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) | |
| scores = jnp.where( | |
| apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores | |
| ) | |
| return scores | |
| class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): | |
| r""" | |
| :class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0. | |
| Args: | |
| min_length (:obj:`int`): | |
| The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. | |
| eos_token_id (:obj:`int`): | |
| The id of the `end-of-sequence` token. | |
| """ | |
| def __init__(self, min_length: int, eos_token_id: int): | |
| if not isinstance(min_length, int) or min_length < 0: | |
| raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") | |
| if not isinstance(eos_token_id, int) or eos_token_id < 0: | |
| raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") | |
| self.min_length = min_length | |
| self.eos_token_id = eos_token_id | |
| def __call__( | |
| self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int | |
| ) -> jax_xla.DeviceArray: | |
| # create boolean flag to decide if min length penalty should be applied | |
| apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) | |
| scores = jnp.where( | |
| apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores | |
| ) | |
| return scores | |