Spaces:
Runtime error
Runtime error
| """Scorer interface module.""" | |
| from typing import Any | |
| from typing import List | |
| from typing import Tuple | |
| import torch | |
| import warnings | |
| class ScorerInterface: | |
| """Scorer interface for beam search. | |
| The scorer performs scoring of the all tokens in vocabulary. | |
| Examples: | |
| * Search heuristics | |
| * :class:`espnet.nets.scorers.length_bonus.LengthBonus` | |
| * Decoder networks of the sequence-to-sequence models | |
| * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` | |
| * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` | |
| * Neural language models | |
| * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` | |
| * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` | |
| * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` | |
| """ | |
| def init_state(self, x: torch.Tensor) -> Any: | |
| """Get an initial state for decoding (optional). | |
| Args: | |
| x (torch.Tensor): The encoded feature tensor | |
| Returns: initial state | |
| """ | |
| return None | |
| def select_state(self, state: Any, i: int, new_id: int = None) -> Any: | |
| """Select state with relative ids in the main beam search. | |
| Args: | |
| state: Decoder state for prefix tokens | |
| i (int): Index to select a state in the main beam search | |
| new_id (int): New label index to select a state if necessary | |
| Returns: | |
| state: pruned state | |
| """ | |
| return None if state is None else state[i] | |
| def score( | |
| self, y: torch.Tensor, state: Any, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Any]: | |
| """Score new token (required). | |
| Args: | |
| y (torch.Tensor): 1D torch.int64 prefix tokens. | |
| state: Scorer state for prefix tokens | |
| x (torch.Tensor): The encoder feature that generates ys. | |
| Returns: | |
| tuple[torch.Tensor, Any]: Tuple of | |
| scores for next token that has a shape of `(n_vocab)` | |
| and next state for ys | |
| """ | |
| raise NotImplementedError | |
| def final_score(self, state: Any) -> float: | |
| """Score eos (optional). | |
| Args: | |
| state: Scorer state for prefix tokens | |
| Returns: | |
| float: final score | |
| """ | |
| return 0.0 | |
| class BatchScorerInterface(ScorerInterface): | |
| """Batch scorer interface.""" | |
| def batch_init_state(self, x: torch.Tensor) -> Any: | |
| """Get an initial state for decoding (optional). | |
| Args: | |
| x (torch.Tensor): The encoded feature tensor | |
| Returns: initial state | |
| """ | |
| return self.init_state(x) | |
| def batch_score( | |
| self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | |
| ) -> Tuple[torch.Tensor, List[Any]]: | |
| """Score new token batch (required). | |
| Args: | |
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
| states (List[Any]): Scorer states for prefix tokens. | |
| xs (torch.Tensor): | |
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |
| Returns: | |
| tuple[torch.Tensor, List[Any]]: Tuple of | |
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
| and next state list for ys. | |
| """ | |
| warnings.warn( | |
| "{} batch score is implemented through for loop not parallelized".format( | |
| self.__class__.__name__ | |
| ) | |
| ) | |
| scores = list() | |
| outstates = list() | |
| for i, (y, state, x) in enumerate(zip(ys, states, xs)): | |
| score, outstate = self.score(y, state, x) | |
| outstates.append(outstate) | |
| scores.append(score) | |
| scores = torch.cat(scores, 0).view(ys.shape[0], -1) | |
| return scores, outstates | |
| class PartialScorerInterface(ScorerInterface): | |
| """Partial scorer interface for beam search. | |
| The partial scorer performs scoring when non-partial scorer finished scoring, | |
| and receives pre-pruned next tokens to score because it is too heavy to score | |
| all the tokens. | |
| Examples: | |
| * Prefix search for connectionist-temporal-classification models | |
| * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` | |
| """ | |
| def score_partial( | |
| self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Any]: | |
| """Score new token (required). | |
| Args: | |
| y (torch.Tensor): 1D prefix token | |
| next_tokens (torch.Tensor): torch.int64 next token to score | |
| state: decoder state for prefix tokens | |
| x (torch.Tensor): The encoder feature that generates ys | |
| Returns: | |
| tuple[torch.Tensor, Any]: | |
| Tuple of a score tensor for y that has a shape `(len(next_tokens),)` | |
| and next state for ys | |
| """ | |
| raise NotImplementedError | |
| class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): | |
| """Batch partial scorer interface for beam search.""" | |
| def batch_score_partial( | |
| self, | |
| ys: torch.Tensor, | |
| next_tokens: torch.Tensor, | |
| states: List[Any], | |
| xs: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Any]: | |
| """Score new token (required). | |
| Args: | |
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
| next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). | |
| states (List[Any]): Scorer states for prefix tokens. | |
| xs (torch.Tensor): | |
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |
| Returns: | |
| tuple[torch.Tensor, Any]: | |
| Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` | |
| and next states for ys | |
| """ | |
| raise NotImplementedError | |