Spaces:
Runtime error
Runtime error
| """Default Recurrent Neural Network Languge Model in `lm_train.py`.""" | |
| from typing import Any | |
| from typing import List | |
| from typing import Tuple | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from espnet.nets.lm_interface import LMInterface | |
| from espnet.nets.pytorch_backend.e2e_asr import to_device | |
| from espnet.nets.scorer_interface import BatchScorerInterface | |
| from espnet.utils.cli_utils import strtobool | |
| class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module): | |
| """Default RNNLM for `LMInterface` Implementation. | |
| Note: | |
| PyTorch seems to have memory leak when one GPU compute this after data parallel. | |
| If parallel GPUs compute this, it seems to be fine. | |
| See also https://github.com/espnet/espnet/issues/1075 | |
| """ | |
| def add_arguments(parser): | |
| """Add arguments to command line argument parser.""" | |
| parser.add_argument( | |
| "--type", | |
| type=str, | |
| default="lstm", | |
| nargs="?", | |
| choices=["lstm", "gru"], | |
| help="Which type of RNN to use", | |
| ) | |
| parser.add_argument( | |
| "--layer", "-l", type=int, default=2, help="Number of hidden layers" | |
| ) | |
| parser.add_argument( | |
| "--unit", "-u", type=int, default=650, help="Number of hidden units" | |
| ) | |
| parser.add_argument( | |
| "--embed-unit", | |
| default=None, | |
| type=int, | |
| help="Number of hidden units in embedding layer, " | |
| "if it is not specified, it keeps the same number with hidden units.", | |
| ) | |
| parser.add_argument( | |
| "--dropout-rate", type=float, default=0.5, help="dropout probability" | |
| ) | |
| parser.add_argument( | |
| "--emb-dropout-rate", | |
| type=float, | |
| default=0.0, | |
| help="emb dropout probability", | |
| ) | |
| parser.add_argument( | |
| "--tie-weights", | |
| type=strtobool, | |
| default=False, | |
| help="Tie input and output embeddings", | |
| ) | |
| return parser | |
| def __init__(self, n_vocab, args): | |
| """Initialize class. | |
| Args: | |
| n_vocab (int): The size of the vocabulary | |
| args (argparse.Namespace): configurations. see py:method:`add_arguments` | |
| """ | |
| nn.Module.__init__(self) | |
| # NOTE: for a compatibility with less than 0.5.0 version models | |
| dropout_rate = getattr(args, "dropout_rate", 0.0) | |
| # NOTE: for a compatibility with less than 0.6.1 version models | |
| embed_unit = getattr(args, "embed_unit", None) | |
| # NOTE: for a compatibility with less than 0.9.7 version models | |
| emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) | |
| # NOTE: for a compatibility with less than 0.9.7 version models | |
| tie_weights = getattr(args, "tie_weights", False) | |
| self.model = ClassifierWithState( | |
| RNNLM( | |
| n_vocab, | |
| args.layer, | |
| args.unit, | |
| embed_unit, | |
| args.type, | |
| dropout_rate, | |
| emb_dropout_rate, | |
| tie_weights, | |
| ) | |
| ) | |
| def state_dict(self): | |
| """Dump state dict.""" | |
| return self.model.state_dict() | |
| def load_state_dict(self, d): | |
| """Load state dict.""" | |
| self.model.load_state_dict(d) | |
| def forward(self, x, t): | |
| """Compute LM loss value from buffer sequences. | |
| Args: | |
| x (torch.Tensor): Input ids. (batch, len) | |
| t (torch.Tensor): Target ids. (batch, len) | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of | |
| loss to backward (scalar), | |
| negative log-likelihood of t: -log p(t) (scalar) and | |
| the number of elements in x (scalar) | |
| Notes: | |
| The last two return values are used | |
| in perplexity: p(t)^{-n} = exp(-log p(t) / n) | |
| """ | |
| loss = 0 | |
| logp = 0 | |
| count = torch.tensor(0).long() | |
| state = None | |
| batch_size, sequence_length = x.shape | |
| for i in range(sequence_length): | |
| # Compute the loss at this time step and accumulate it | |
| state, loss_batch = self.model(state, x[:, i], t[:, i]) | |
| non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype) | |
| loss += loss_batch.mean() * non_zeros | |
| logp += torch.sum(loss_batch * non_zeros) | |
| count += int(non_zeros) | |
| return loss / batch_size, loss, count.to(loss.device) | |
| def score(self, y, state, x): | |
| """Score new token. | |
| Args: | |
| y (torch.Tensor): 1D torch.int64 prefix tokens. | |
| state: Scorer state for prefix tokens | |
| x (torch.Tensor): 2D encoder feature that generates ys. | |
| Returns: | |
| tuple[torch.Tensor, Any]: Tuple of | |
| torch.float32 scores for next token (n_vocab) | |
| and next state for ys | |
| """ | |
| new_state, scores = self.model.predict(state, y[-1].unsqueeze(0)) | |
| return scores.squeeze(0), new_state | |
| def final_score(self, state): | |
| """Score eos. | |
| Args: | |
| state: Scorer state for prefix tokens | |
| Returns: | |
| float: final score | |
| """ | |
| return self.model.final(state) | |
| # batch beam search API (see BatchScorerInterface) | |
| def batch_score( | |
| self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | |
| ) -> Tuple[torch.Tensor, List[Any]]: | |
| """Score new token batch. | |
| Args: | |
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
| states (List[Any]): Scorer states for prefix tokens. | |
| xs (torch.Tensor): | |
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |
| Returns: | |
| tuple[torch.Tensor, List[Any]]: Tuple of | |
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
| and next state list for ys. | |
| """ | |
| # merge states | |
| n_batch = len(ys) | |
| n_layers = self.model.predictor.n_layers | |
| if self.model.predictor.typ == "lstm": | |
| keys = ("c", "h") | |
| else: | |
| keys = ("h",) | |
| if states[0] is None: | |
| states = None | |
| else: | |
| # transpose state of [batch, key, layer] into [key, layer, batch] | |
| states = { | |
| k: [ | |
| torch.stack([states[b][k][i] for b in range(n_batch)]) | |
| for i in range(n_layers) | |
| ] | |
| for k in keys | |
| } | |
| states, logp = self.model.predict(states, ys[:, -1]) | |
| # transpose state of [key, layer, batch] into [batch, key, layer] | |
| return ( | |
| logp, | |
| [ | |
| {k: [states[k][i][b] for i in range(n_layers)] for k in keys} | |
| for b in range(n_batch) | |
| ], | |
| ) | |
| class ClassifierWithState(nn.Module): | |
| """A wrapper for pytorch RNNLM.""" | |
| def __init__( | |
| self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1 | |
| ): | |
| """Initialize class. | |
| :param torch.nn.Module predictor : The RNNLM | |
| :param function lossfun : The loss function to use | |
| :param int/str label_key : | |
| """ | |
| if not (isinstance(label_key, (int, str))): | |
| raise TypeError("label_key must be int or str, but is %s" % type(label_key)) | |
| super(ClassifierWithState, self).__init__() | |
| self.lossfun = lossfun | |
| self.y = None | |
| self.loss = None | |
| self.label_key = label_key | |
| self.predictor = predictor | |
| def forward(self, state, *args, **kwargs): | |
| """Compute the loss value for an input and label pair. | |
| Notes: | |
| It also computes accuracy and stores it to the attribute. | |
| When ``label_key`` is ``int``, the corresponding element in ``args`` | |
| is treated as ground truth labels. And when it is ``str``, the | |
| element in ``kwargs`` is used. | |
| The all elements of ``args`` and ``kwargs`` except the groundtruth | |
| labels are features. | |
| It feeds features to the predictor and compare the result | |
| with ground truth labels. | |
| :param torch.Tensor state : the LM state | |
| :param list[torch.Tensor] args : Input minibatch | |
| :param dict[torch.Tensor] kwargs : Input minibatch | |
| :return loss value | |
| :rtype torch.Tensor | |
| """ | |
| if isinstance(self.label_key, int): | |
| if not (-len(args) <= self.label_key < len(args)): | |
| msg = "Label key %d is out of bounds" % self.label_key | |
| raise ValueError(msg) | |
| t = args[self.label_key] | |
| if self.label_key == -1: | |
| args = args[:-1] | |
| else: | |
| args = args[: self.label_key] + args[self.label_key + 1 :] | |
| elif isinstance(self.label_key, str): | |
| if self.label_key not in kwargs: | |
| msg = 'Label key "%s" is not found' % self.label_key | |
| raise ValueError(msg) | |
| t = kwargs[self.label_key] | |
| del kwargs[self.label_key] | |
| self.y = None | |
| self.loss = None | |
| state, self.y = self.predictor(state, *args, **kwargs) | |
| self.loss = self.lossfun(self.y, t) | |
| return state, self.loss | |
| def predict(self, state, x): | |
| """Predict log probabilities for given state and input x using the predictor. | |
| :param torch.Tensor state : The current state | |
| :param torch.Tensor x : The input | |
| :return a tuple (new state, log prob vector) | |
| :rtype (torch.Tensor, torch.Tensor) | |
| """ | |
| if hasattr(self.predictor, "normalized") and self.predictor.normalized: | |
| return self.predictor(state, x) | |
| else: | |
| state, z = self.predictor(state, x) | |
| return state, F.log_softmax(z, dim=1) | |
| def buff_predict(self, state, x, n): | |
| """Predict new tokens from buffered inputs.""" | |
| if self.predictor.__class__.__name__ == "RNNLM": | |
| return self.predict(state, x) | |
| new_state = [] | |
| new_log_y = [] | |
| for i in range(n): | |
| state_i = None if state is None else state[i] | |
| state_i, log_y = self.predict(state_i, x[i].unsqueeze(0)) | |
| new_state.append(state_i) | |
| new_log_y.append(log_y) | |
| return new_state, torch.cat(new_log_y) | |
| def final(self, state, index=None): | |
| """Predict final log probabilities for given state using the predictor. | |
| :param state: The state | |
| :return The final log probabilities | |
| :rtype torch.Tensor | |
| """ | |
| if hasattr(self.predictor, "final"): | |
| if index is not None: | |
| return self.predictor.final(state[index]) | |
| else: | |
| return self.predictor.final(state) | |
| else: | |
| return 0.0 | |
| # Definition of a recurrent net for language modeling | |
| class RNNLM(nn.Module): | |
| """A pytorch RNNLM.""" | |
| def __init__( | |
| self, | |
| n_vocab, | |
| n_layers, | |
| n_units, | |
| n_embed=None, | |
| typ="lstm", | |
| dropout_rate=0.5, | |
| emb_dropout_rate=0.0, | |
| tie_weights=False, | |
| ): | |
| """Initialize class. | |
| :param int n_vocab: The size of the vocabulary | |
| :param int n_layers: The number of layers to create | |
| :param int n_units: The number of units per layer | |
| :param str typ: The RNN type | |
| """ | |
| super(RNNLM, self).__init__() | |
| if n_embed is None: | |
| n_embed = n_units | |
| self.embed = nn.Embedding(n_vocab, n_embed) | |
| if emb_dropout_rate == 0.0: | |
| self.embed_drop = None | |
| else: | |
| self.embed_drop = nn.Dropout(emb_dropout_rate) | |
| if typ == "lstm": | |
| self.rnn = nn.ModuleList( | |
| [nn.LSTMCell(n_embed, n_units)] | |
| + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)] | |
| ) | |
| else: | |
| self.rnn = nn.ModuleList( | |
| [nn.GRUCell(n_embed, n_units)] | |
| + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)] | |
| ) | |
| self.dropout = nn.ModuleList( | |
| [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)] | |
| ) | |
| self.lo = nn.Linear(n_units, n_vocab) | |
| self.n_layers = n_layers | |
| self.n_units = n_units | |
| self.typ = typ | |
| logging.info("Tie weights set to {}".format(tie_weights)) | |
| logging.info("Dropout set to {}".format(dropout_rate)) | |
| logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) | |
| if tie_weights: | |
| assert ( | |
| n_embed == n_units | |
| ), "Tie Weights: True need embedding and final dimensions to match" | |
| self.lo.weight = self.embed.weight | |
| # initialize parameters from uniform distribution | |
| for param in self.parameters(): | |
| param.data.uniform_(-0.1, 0.1) | |
| def zero_state(self, batchsize): | |
| """Initialize state.""" | |
| p = next(self.parameters()) | |
| return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype) | |
| def forward(self, state, x): | |
| """Forward neural networks.""" | |
| if state is None: | |
| h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)] | |
| state = {"h": h} | |
| if self.typ == "lstm": | |
| c = [ | |
| to_device(x, self.zero_state(x.size(0))) | |
| for n in range(self.n_layers) | |
| ] | |
| state = {"c": c, "h": h} | |
| h = [None] * self.n_layers | |
| if self.embed_drop is not None: | |
| emb = self.embed_drop(self.embed(x)) | |
| else: | |
| emb = self.embed(x) | |
| if self.typ == "lstm": | |
| c = [None] * self.n_layers | |
| h[0], c[0] = self.rnn[0]( | |
| self.dropout[0](emb), (state["h"][0], state["c"][0]) | |
| ) | |
| for n in range(1, self.n_layers): | |
| h[n], c[n] = self.rnn[n]( | |
| self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n]) | |
| ) | |
| state = {"c": c, "h": h} | |
| else: | |
| h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0]) | |
| for n in range(1, self.n_layers): | |
| h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n]) | |
| state = {"h": h} | |
| y = self.lo(self.dropout[-1](h[-1])) | |
| return state, y | |