Tutorial: Simple LSTM
=====================

In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).

This tutorial covers:

1. **Writing an Encoder and Decoder** to encode/decode the source/target
   sentence, respectively.
2. **Registering a new Model** so that it can be used with the existing
   :ref:`Command-line tools`.
3. **Training the Model** using the existing command-line tools.
4. **Making generation faster** by modifying the Decoder to use
   :ref:`Incremental decoding`.


1. Building an Encoder and Decoder
----------------------------------

In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
should implement the :class:`~fairseq.models.FairseqEncoder` interface and
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
Modules.


Encoder
~~~~~~~

Our Encoder will embed the tokens in the source sentence, feed them to a
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
save the following in a new file named :file:`fairseq/models/simple_lstm.py`::

  import torch.nn as nn
  from fairseq import utils
  from fairseq.models import FairseqEncoder

  class SimpleLSTMEncoder(FairseqEncoder):

      def __init__(
          self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
      ):
          super().__init__(dictionary)
          self.args = args

          # Our encoder will embed the inputs before feeding them to the LSTM.
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)

          # We'll use a single-layer, unidirectional LSTM for simplicity.
          self.lstm = nn.LSTM(
              input_size=embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
              batch_first=True,
          )

      def forward(self, src_tokens, src_lengths):
          # The inputs to the ``forward()`` function are determined by the
          # Task, and in particular the ``'net_input'`` key in each
          # mini-batch. We discuss Tasks in the next tutorial, but for now just
          # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
          # has shape `(batch)`.

          # Note that the source is typically padded on the left. This can be
          # configured by adding the `--left-pad-source "False"` command-line
          # argument, but here we'll make the Encoder handle either kind of
          # padding by converting everything to be right-padded.
          if self.args.left_pad_source:
              # Convert left-padding to right-padding.
              src_tokens = utils.convert_padding_direction(
                  src_tokens,
                  padding_idx=self.dictionary.pad(),
                  left_to_right=True
              )

          # Embed the source.
          x = self.embed_tokens(src_tokens)

          # Apply dropout.
          x = self.dropout(x)

          # Pack the sequence into a PackedSequence object to feed to the LSTM.
          x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)

          # Get the output from the LSTM.
          _outputs, (final_hidden, _final_cell) = self.lstm(x)

          # Return the Encoder's output. This can be any object and will be
          # passed directly to the Decoder.
          return {
              # this will have shape `(bsz, hidden_dim)`
              'final_hidden': final_hidden.squeeze(0),
          }

      # Encoders are required to implement this method so that we can rearrange
      # the order of the batch elements during inference (e.g., beam search).
      def reorder_encoder_out(self, encoder_out, new_order):
          """
          Reorder encoder output according to `new_order`.

          Args:
              encoder_out: output from the ``forward()`` method
              new_order (LongTensor): desired order

          Returns:
              `encoder_out` rearranged according to `new_order`
          """
          final_hidden = encoder_out['final_hidden']
          return {
              'final_hidden': final_hidden.index_select(0, new_order),
          }


Decoder
~~~~~~~

Our Decoder will predict the next word, conditioned on the Encoder's final
hidden state and an embedded representation of the previous target word -- which
is sometimes called *teacher forcing*. More specifically, we'll use a
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
to the size of the output vocabulary to predict each target word.

::

  import torch
  from fairseq.models import FairseqDecoder

  class SimpleLSTMDecoder(FairseqDecoder):

      def __init__(
          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
          dropout=0.1,
      ):
          super().__init__(dictionary)

          # Our decoder will embed the inputs before feeding them to the LSTM.
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)

          # We'll use a single-layer, unidirectional LSTM for simplicity.
          self.lstm = nn.LSTM(
              # For the first layer we'll concatenate the Encoder's final hidden
              # state with the embedded target tokens.
              input_size=encoder_hidden_dim + embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
          )

          # Define the output projection.
          self.output_projection = nn.Linear(hidden_dim, len(dictionary))

      # During training Decoders are expected to take the entire target sequence
      # (shifted right by one position) and produce logits over the vocabulary.
      # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
      # ``dictionary.eos()``, followed by the target sequence.
      def forward(self, prev_output_tokens, encoder_out):
          """
          Args:
              prev_output_tokens (LongTensor): previous decoder outputs of shape
                  `(batch, tgt_len)`, for teacher forcing
              encoder_out (Tensor, optional): output from the encoder, used for
                  encoder-side attention

          Returns:
              tuple:
                  - the last decoder layer's output of shape
                    `(batch, tgt_len, vocab)`
                  - the last decoder layer's attention weights of shape
                    `(batch, tgt_len, src_len)`
          """
          bsz, tgt_len = prev_output_tokens.size()

          # Extract the final hidden state from the Encoder.
          final_encoder_hidden = encoder_out['final_hidden']

          # Embed the target sequence, which has been shifted right by one
          # position and now starts with the end-of-sentence symbol.
          x = self.embed_tokens(prev_output_tokens)

          # Apply dropout.
          x = self.dropout(x)

          # Concatenate the Encoder's final hidden state to *every* embedded
          # target token.
          x = torch.cat(
              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
              dim=2,
          )

          # Using PackedSequence objects in the Decoder is harder than in the
          # Encoder, since the targets are not sorted in descending length order,
          # which is a requirement of ``pack_padded_sequence()``. Instead we'll
          # feed nn.LSTM directly.
          initial_state = (
              final_encoder_hidden.unsqueeze(0),  # hidden
              torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
          )
          output, _ = self.lstm(
              x.transpose(0, 1),  # convert to shape `(tgt_len, bsz, dim)`
              initial_state,
          )
          x = output.transpose(0, 1)  # convert to shape `(bsz, tgt_len, hidden)`

          # Project the outputs to the size of the vocabulary.
          x = self.output_projection(x)

          # Return the logits and ``None`` for the attention weights
          return x, None


2. Registering the Model
------------------------

Now that we've defined our Encoder and Decoder we must *register* our model with
fairseq using the :func:`~fairseq.models.register_model` function decorator.
Once the model is registered we'll be able to use it with the existing
:ref:`Command-line Tools`.

All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.

Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``::

  from fairseq.models import FairseqEncoderDecoderModel, register_model

  # Note: the register_model "decorator" should immediately precede the
  # definition of the Model class.

  @register_model('simple_lstm')
  class SimpleLSTMModel(FairseqEncoderDecoderModel):

      @staticmethod
      def add_args(parser):
          # Models can override this method to add new command-line arguments.
          # Here we'll add some new command-line arguments to configure dropout
          # and the dimensionality of the embeddings and hidden states.
          parser.add_argument(
              '--encoder-embed-dim', type=int, metavar='N',
              help='dimensionality of the encoder embeddings',
          )
          parser.add_argument(
              '--encoder-hidden-dim', type=int, metavar='N',
              help='dimensionality of the encoder hidden state',
          )
          parser.add_argument(
              '--encoder-dropout', type=float, default=0.1,
              help='encoder dropout probability',
          )
          parser.add_argument(
              '--decoder-embed-dim', type=int, metavar='N',
              help='dimensionality of the decoder embeddings',
          )
          parser.add_argument(
              '--decoder-hidden-dim', type=int, metavar='N',
              help='dimensionality of the decoder hidden state',
          )
          parser.add_argument(
              '--decoder-dropout', type=float, default=0.1,
              help='decoder dropout probability',
          )

      @classmethod
      def build_model(cls, args, task):
          # Fairseq initializes models by calling the ``build_model()``
          # function. This provides more flexibility, since the returned model
          # instance can be of a different type than the one that was called.
          # In this case we'll just return a SimpleLSTMModel instance.

          # Initialize our Encoder and Decoder.
          encoder = SimpleLSTMEncoder(
              args=args,
              dictionary=task.source_dictionary,
              embed_dim=args.encoder_embed_dim,
              hidden_dim=args.encoder_hidden_dim,
              dropout=args.encoder_dropout,
          )
          decoder = SimpleLSTMDecoder(
              dictionary=task.target_dictionary,
              encoder_hidden_dim=args.encoder_hidden_dim,
              embed_dim=args.decoder_embed_dim,
              hidden_dim=args.decoder_hidden_dim,
              dropout=args.decoder_dropout,
          )
          model = SimpleLSTMModel(encoder, decoder)

          # Print the model architecture.
          print(model)

          return model

      # We could override the ``forward()`` if we wanted more control over how
      # the encoder and decoder interact, but it's not necessary for this
      # tutorial since we can inherit the default implementation provided by
      # the FairseqEncoderDecoderModel base class, which looks like:
      #
      # def forward(self, src_tokens, src_lengths, prev_output_tokens):
      #     encoder_out = self.encoder(src_tokens, src_lengths)
      #     decoder_out = self.decoder(prev_output_tokens, encoder_out)
      #     return decoder_out

Finally let's define a *named architecture* with the configuration for our
model. This is done with the :func:`~fairseq.models.register_model_architecture`
function decorator. Thereafter this named architecture can be used with the
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::

  from fairseq.models import register_model_architecture

  # The first argument to ``register_model_architecture()`` should be the name
  # of the model we registered above (i.e., 'simple_lstm'). The function we
  # register here should take a single argument *args* and modify it in-place
  # to match the desired architecture.

  @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
  def tutorial_simple_lstm(args):
      # We use ``getattr()`` to prioritize arguments that are explicitly given
      # on the command-line, so that the defaults defined below are only used
      # when no other value has been specified.
      args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
      args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
      args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
      args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)


3. Training the Model
---------------------

Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
command-line tool for this, making sure to specify our new Model architecture
(``--arch tutorial_simple_lstm``).

.. note::

  Make sure you've already preprocessed the data from the IWSLT example in the
  :file:`examples/translation/` directory.

.. code-block:: console

  > fairseq-train data-bin/iwslt14.tokenized.de-en \
    --arch tutorial_simple_lstm \
    --encoder-dropout 0.2 --decoder-dropout 0.2 \
    --optimizer adam --lr 0.005 --lr-shrink 0.5 \
    --max-tokens 12000
  (...)
  | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
  | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954

The model files should appear in the :file:`checkpoints/` directory. While this
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
generate translations and compute our BLEU score over the test set:

.. code-block:: console

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)


4. Making generation faster
---------------------------

While autoregressive generation from sequence-to-sequence models is inherently
slow, our implementation above is especially slow because it recomputes the
entire sequence of Decoder hidden states for every output token (i.e., it is
``O(n^2)``). We can make this significantly faster by instead caching the
previous hidden states.

In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
special mode at inference time where the Model only receives a single timestep
of input corresponding to the immediately previous output token (for teacher
forcing) and must produce the next output incrementally. Thus the model must
cache any long-term state that is needed about the sequence, e.g., hidden
states, convolutional states, etc.

To implement incremental decoding we will modify our model to implement the
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
decoder interface allows ``forward()`` methods to take an extra keyword argument
(*incremental_state*) that can be used to cache state across time-steps.

Let's replace our ``SimpleLSTMDecoder`` with an incremental one::

  import torch
  from fairseq.models import FairseqIncrementalDecoder

  class SimpleLSTMDecoder(FairseqIncrementalDecoder):

      def __init__(
          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
          dropout=0.1,
      ):
          # This remains the same as before.
          super().__init__(dictionary)
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)
          self.lstm = nn.LSTM(
              input_size=encoder_hidden_dim + embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
          )
          self.output_projection = nn.Linear(hidden_dim, len(dictionary))

      # We now take an additional kwarg (*incremental_state*) for caching the
      # previous hidden and cell states.
      def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
          if incremental_state is not None:
              # If the *incremental_state* argument is not ``None`` then we are
              # in incremental inference mode. While *prev_output_tokens* will
              # still contain the entire decoded prefix, we will only use the
              # last step and assume that the rest of the state is cached.
              prev_output_tokens = prev_output_tokens[:, -1:]

          # This remains the same as before.
          bsz, tgt_len = prev_output_tokens.size()
          final_encoder_hidden = encoder_out['final_hidden']
          x = self.embed_tokens(prev_output_tokens)
          x = self.dropout(x)
          x = torch.cat(
              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
              dim=2,
          )

          # We will now check the cache and load the cached previous hidden and
          # cell states, if they exist, otherwise we will initialize them to
          # zeros (as before). We will use the ``utils.get_incremental_state()``
          # and ``utils.set_incremental_state()`` helpers.
          initial_state = utils.get_incremental_state(
              self, incremental_state, 'prev_state',
          )
          if initial_state is None:
              # first time initialization, same as the original version
              initial_state = (
                  final_encoder_hidden.unsqueeze(0),  # hidden
                  torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
              )

          # Run one step of our LSTM.
          output, latest_state = self.lstm(x.transpose(0, 1), initial_state)

          # Update the cache with the latest hidden and cell states.
          utils.set_incremental_state(
              self, incremental_state, 'prev_state', latest_state,
          )

          # This remains the same as before
          x = output.transpose(0, 1)
          x = self.output_projection(x)
          return x, None

      # The ``FairseqIncrementalDecoder`` interface also requires implementing a
      # ``reorder_incremental_state()`` method, which is used during beam search
      # to select and reorder the incremental state.
      def reorder_incremental_state(self, incremental_state, new_order):
          # Load the cached state.
          prev_state = utils.get_incremental_state(
              self, incremental_state, 'prev_state',
          )

          # Reorder batches according to *new_order*.
          reordered_state = (
              prev_state[0].index_select(1, new_order),  # hidden
              prev_state[1].index_select(1, new_order),  # cell
          )

          # Update the cached state.
          utils.set_incremental_state(
              self, incremental_state, 'prev_state', reordered_state,
          )

Finally, we can rerun generation and observe the speedup:

.. code-block:: console

  # Before

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

  # After

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)