@@ -0,0 +1,77 @@
1 |
# Code of Conduct
2 |
3 |
## Our Pledge
4 |
5 |
In the interest of fostering an open and welcoming environment, we as
6 |
contributors and maintainers pledge to make participation in our project and
7 |
our community a harassment-free experience for everyone, regardless of age, body
8 |
size, disability, ethnicity, sex characteristics, gender identity and expression,
9 |
level of experience, education, socio-economic status, nationality, personal
10 |
appearance, race, religion, or sexual identity and orientation.
11 |
12 |
## Our Standards
13 |
14 |
Examples of behavior that contributes to creating a positive environment
15 |
16 |
17 |
* Using welcoming and inclusive language
18 |
* Being respectful of differing viewpoints and experiences
19 |
* Gracefully accepting constructive criticism
20 |
* Focusing on what is best for the community
21 |
* Showing empathy towards other community members
22 |
23 |
Examples of unacceptable behavior by participants include:
24 |
25 |
* The use of sexualized language or imagery and unwelcome sexual attention or
26 |
27 |
* Trolling, insulting/derogatory comments, and personal or political attacks
28 |
* Public or private harassment
29 |
* Publishing others' private information, such as a physical or electronic
30 |
address, without explicit permission
31 |
* Other conduct which could reasonably be considered inappropriate in a
32 |
professional setting
33 |
34 |
## Our Responsibilities
35 |
36 |
Project maintainers are responsible for clarifying the standards of acceptable
37 |
behavior and are expected to take appropriate and fair corrective action in
38 |
response to any instances of unacceptable behavior.
39 |
40 |
Project maintainers have the right and responsibility to remove, edit, or
41 |
reject comments, commits, code, wiki edits, issues, and other contributions
42 |
that are not aligned to this Code of Conduct, or to ban temporarily or
43 |
permanently any contributor for other behaviors that they deem inappropriate,
44 |
threatening, offensive, or harmful.
45 |
46 |
## Scope
47 |
48 |
This Code of Conduct applies within all project spaces, and it also applies when
49 |
an individual is representing the project or its community in public spaces.
50 |
Examples of representing a project or community include using an official
51 |
project e-mail address, posting via an official social media account, or acting
52 |
as an appointed representative at an online or offline event. Representation of
53 |
a project may be further defined and clarified by project maintainers.
54 |
55 |
## Enforcement
56 |
57 |
Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 |
reported by contacting the project team at <[email protected]>. All
59 |
complaints will be reviewed and investigated and will result in a response that
60 |
is deemed necessary and appropriate to the circumstances. The project team is
61 |
obligated to maintain confidentiality with regard to the reporter of an incident.
62 |
Further details of specific enforcement policies may be posted separately.
63 |
64 |
Project maintainers who do not follow or enforce the Code of Conduct in good
65 |
faith may face temporary or permanent repercussions as determined by other
66 |
members of the project's leadership.
67 |
68 |
## Attribution
69 |
70 |
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 |
available at
72 |
73 |
74 |
75 |
For answers to common questions about this code of conduct, see
76 |
77 |
@@ -0,0 +1,28 @@
1 |
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2 |
We want to make contributing to this project as easy and transparent as
3 |
4 |
5 |
## Pull Requests
6 |
We actively welcome your pull requests.
7 |
8 |
1. Fork the repo and create your branch from `master`.
9 |
2. If you've added code that should be tested, add tests.
10 |
3. If you've changed APIs, update the documentation.
11 |
4. Ensure the test suite passes.
12 |
5. Make sure your code lints.
13 |
6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 |
## Contributor License Agreement ("CLA")
16 |
In order to accept your pull request, we need you to submit a CLA. You only need
17 |
to do this once to work on any of Facebook's open source projects.
18 |
19 |
Complete your CLA here: <>
20 |
21 |
## Issues
22 |
We use GitHub issues to track public bugs. Please ensure your description is
23 |
clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 |
## License
26 |
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27 |
you agree that your contributions will be licensed under the LICENSE file in
28 |
the root directory of this source tree.
@@ -0,0 +1,21 @@
1 |
MIT License
2 |
3 |
Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 |
Permission is hereby granted, free of charge, to any person obtaining a copy
6 |
of this software and associated documentation files (the "Software"), to deal
7 |
in the Software without restriction, including without limitation the rights
8 |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 |
copies of the Software, and to permit persons to whom the Software is
10 |
furnished to do so, subject to the following conditions:
11 |
12 |
The above copyright notice and this permission notice shall be included in all
13 |
copies or substantial portions of the Software.
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
@@ -0,0 +1,14 @@
1 |
import gradio as gr
2 |
3 |
4 |
5 |
description = "HuBERT: Self-Supervised Speech Representation Learning. To use it, simply add your audio or click one of the examples to load them. Read more at the links below."
6 |
article = "<p style='text-align: center'><a href=''>HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units</a> | <a href=''>Github Repo</a></p>"
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
@@ -0,0 +1,20 @@
1 |
# Minimal makefile for Sphinx documentation
2 |
3 |
4 |
# You can set these variables from the command line.
5 |
6 |
SPHINXBUILD = python -msphinx
7 |
SPHINXPROJ = fairseq
8 |
9 |
BUILDDIR = _build
10 |
11 |
# Put it first so that "make" without argument is like "make help".
12 |
13 |
14 |
15 |
.PHONY: help Makefile
16 |
17 |
# Catch-all target: route all unknown targets to Sphinx using the new
18 |
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 |
%: Makefile
20 |
@@ -0,0 +1,9 @@
1 |
.wy-table-responsive table td kbd {
2 |
white-space: nowrap;
3 |
4 |
.wy-table-responsive table td {
5 |
white-space: normal !important;
6 |
7 |
.wy-table-responsive {
8 |
overflow: visible !important;
9 |
@@ -0,0 +1,85 @@
1 |
.. _Command-line Tools:
2 |
3 |
Command-line Tools
4 |
5 |
6 |
Fairseq provides several command-line tools for training and evaluating models:
7 |
8 |
- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9 |
- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10 |
- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11 |
- :ref:`fairseq-interactive`: Translate raw text with a trained model
12 |
- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13 |
- :ref:`fairseq-eval-lm`: Language model evaluation
14 |
15 |
16 |
.. _fairseq-preprocess:
17 |
18 |
19 |
20 |
.. automodule:: fairseq_cli.preprocess
21 |
22 |
.. argparse::
23 |
:module: fairseq.options
24 |
:func: get_preprocessing_parser
25 |
:prog: fairseq-preprocess
26 |
27 |
28 |
.. _fairseq-train:
29 |
30 |
31 |
32 |
.. automodule:: fairseq_cli.train
33 |
34 |
.. argparse::
35 |
:module: fairseq.options
36 |
:func: get_training_parser
37 |
:prog: fairseq-train
38 |
39 |
40 |
.. _fairseq-generate:
41 |
42 |
43 |
44 |
.. automodule:: fairseq_cli.generate
45 |
46 |
.. argparse::
47 |
:module: fairseq.options
48 |
:func: get_generation_parser
49 |
:prog: fairseq-generate
50 |
51 |
52 |
.. _fairseq-interactive:
53 |
54 |
55 |
56 |
.. automodule:: fairseq_cli.interactive
57 |
58 |
.. argparse::
59 |
:module: fairseq.options
60 |
:func: get_interactive_generation_parser
61 |
:prog: fairseq-interactive
62 |
63 |
64 |
.. _fairseq-score:
65 |
66 |
67 |
68 |
.. automodule:: fairseq_cli.score
69 |
70 |
.. argparse::
71 |
:module: fairseq_cli.score
72 |
:func: get_parser
73 |
:prog: fairseq-score
74 |
75 |
76 |
.. _fairseq-eval-lm:
77 |
78 |
79 |
80 |
.. automodule:: fairseq_cli.eval_lm
81 |
82 |
.. argparse::
83 |
:module: fairseq.options
84 |
:func: get_eval_lm_parser
85 |
:prog: fairseq-eval-lm
@@ -0,0 +1,134 @@
1 |
#!/usr/bin/env python3
2 |
# -*- coding: utf-8 -*-
3 |
4 |
# fairseq documentation build configuration file, created by
5 |
# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
6 |
7 |
# This file is execfile()d with the current directory set to its
8 |
# containing dir.
9 |
10 |
# Note that not all possible configuration values are present in this
11 |
# autogenerated file.
12 |
13 |
# All configuration values have a default; values that are commented out
14 |
# serve to show the default.
15 |
16 |
# If extensions (or modules to document with autodoc) are in another directory,
17 |
# add these directories to sys.path here. If the directory is relative to the
18 |
# documentation root, use os.path.abspath to make it absolute, like shown here.
19 |
20 |
import os
21 |
import sys
22 |
from fairseq import __version__
23 |
24 |
25 |
# source code directory, relative to this file, for sphinx-autobuild
26 |
sys.path.insert(0, os.path.abspath(".."))
27 |
28 |
source_suffix = [".rst"]
29 |
30 |
# -- General configuration ------------------------------------------------
31 |
32 |
# If your documentation needs a minimal Sphinx version, state it here.
33 |
34 |
# needs_sphinx = '1.0'
35 |
36 |
# Add any Sphinx extension module names here, as strings. They can be
37 |
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
38 |
# ones.
39 |
extensions = [
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
# Add any paths that contain templates here, relative to this directory.
48 |
templates_path = ["_templates"]
49 |
50 |
# The master toctree document.
51 |
master_doc = "index"
52 |
53 |
# General information about the project.
54 |
project = "fairseq"
55 |
copyright = "Facebook AI Research (FAIR)"
56 |
author = "Facebook AI Research (FAIR)"
57 |
58 |
github_doc_root = ""
59 |
60 |
# The version info for the project you're documenting, acts as replacement for
61 |
# |version| and |release|, also used in various other places throughout the
62 |
# built documents.
63 |
64 |
# The short X.Y version.
65 |
version = __version__
66 |
# The full version, including alpha/beta/rc tags.
67 |
release = __version__
68 |
69 |
# The language for content autogenerated by Sphinx. Refer to documentation
70 |
# for a list of supported languages.
71 |
72 |
# This is also used if you do content translation via gettext catalogs.
73 |
# Usually you set "language" from the command line for these cases.
74 |
language = None
75 |
76 |
# List of patterns, relative to source directory, that match files and
77 |
# directories to ignore when looking for source files.
78 |
# This patterns also effect to html_static_path and html_extra_path
79 |
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
80 |
81 |
# The name of the Pygments (syntax highlighting) style to use.
82 |
pygments_style = "sphinx"
83 |
highlight_language = "python"
84 |
85 |
# If true, `todo` and `todoList` produce output, else they produce nothing.
86 |
todo_include_todos = False
87 |
88 |
89 |
# -- Options for HTML output ----------------------------------------------
90 |
91 |
# The theme to use for HTML and HTML Help pages. See the documentation for
92 |
# a list of builtin themes.
93 |
94 |
html_theme = "sphinx_rtd_theme"
95 |
96 |
# Theme options are theme-specific and customize the look and feel of a theme
97 |
# further. For a list of options available for each theme, see the
98 |
# documentation.
99 |
100 |
# html_theme_options = {}
101 |
102 |
# Add any paths that contain custom static files (such as style sheets) here,
103 |
# relative to this directory. They are copied after the builtin static files,
104 |
# so a file named "default.css" will overwrite the builtin "default.css".
105 |
html_static_path = ["_static"]
106 |
107 |
html_context = {
108 |
"css_files": [
109 |
"_static/theme_overrides.css", # override wide tables in RTD theme
110 |
111 |
112 |
113 |
# Custom sidebar templates, must be a dictionary that maps document names
114 |
# to template names.
115 |
116 |
# This is required for the alabaster theme
117 |
# refs:
118 |
# html_sidebars = {
119 |
# '**': [
120 |
# 'about.html',
121 |
# 'navigation.html',
122 |
# 'relations.html', # needs 'show_related': True theme option to display
123 |
# 'searchbox.html',
124 |
# 'donate.html',
125 |
# ]
126 |
# }
127 |
128 |
129 |
# Example configuration for intersphinx: refer to the Python standard library.
130 |
intersphinx_mapping = {
131 |
"numpy": ("", None),
132 |
"python": ("", None),
133 |
"torch": ("", None),
134 |
@@ -0,0 +1,31 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. _Criterions:
5 |
6 |
7 |
8 |
9 |
Criterions compute the loss function given the model and batch, roughly::
10 |
11 |
loss = criterion(model, batch)
12 |
13 |
.. automodule:: fairseq.criterions
14 |
15 |
16 |
.. autoclass:: fairseq.criterions.FairseqCriterion
17 |
18 |
19 |
20 |
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21 |
22 |
23 |
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24 |
25 |
26 |
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27 |
28 |
29 |
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30 |
31 |
@@ -0,0 +1,58 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. module::
5 |
6 |
Data Loading and Utilities
7 |
8 |
9 |
.. _datasets:
10 |
11 |
12 |
13 |
14 |
**Datasets** define the data format and provide helpers for creating
15 |
16 |
17 |
.. autoclass::
18 |
19 |
.. autoclass::
20 |
21 |
.. autoclass::
22 |
23 |
24 |
**Helper Datasets**
25 |
26 |
These datasets wrap other :class:`` instances and
27 |
provide additional functionality:
28 |
29 |
.. autoclass::
30 |
31 |
.. autoclass::
32 |
33 |
.. autoclass::
34 |
35 |
.. autoclass::
36 |
37 |
.. autoclass::
38 |
39 |
40 |
41 |
42 |
43 |
44 |
.. autoclass::
45 |
46 |
47 |
48 |
49 |
50 |
51 |
.. autoclass::
52 |
53 |
.. autoclass::
54 |
55 |
.. autoclass::
56 |
57 |
.. autoclass::
58 |
@@ -0,0 +1,2 @@
1 |
2 |
![]() |
![]() |
@@ -0,0 +1,216 @@
1 |
Evaluating Pre-trained Models
2 |
3 |
4 |
First, download a pre-trained model along with its vocabularies:
5 |
6 |
.. code-block:: console
7 |
8 |
> curl | tar xvjf -
9 |
10 |
This model uses a `Byte Pair Encoding (BPE)
11 |
vocabulary <>`__, so we'll have to apply
12 |
the encoding to the source text before it can be translated. This can be
13 |
done with the
14 |
`apply\ <>`__
15 |
script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
16 |
used as a continuation marker and the original text can be easily
17 |
recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
18 |
flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
19 |
using ``tokenizer.perl`` from
20 |
`mosesdecoder <>`__.
21 |
22 |
Let's use :ref:`fairseq-interactive` to generate translations interactively.
23 |
Here, we use a beam size of 5 and preprocess the input with the Moses
24 |
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
25 |
remove the BPE continuation markers and detokenize the output.
26 |
27 |
.. code-block:: console
28 |
29 |
> MODEL_DIR=wmt14.en-fr.fconv-py
30 |
> fairseq-interactive \
31 |
32 |
--beam 5 --source-lang en --target-lang fr \
33 |
--tokenizer moses \
34 |
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
35 |
| loading model(s) from wmt14.en-fr.fconv-py/
36 |
| [en] dictionary: 44206 types
37 |
| [fr] dictionary: 44463 types
38 |
| Type the input sentence and press return:
39 |
Why is it rare to discover new marine mammal species?
40 |
S-0 Why is it rare to discover new marine mam@@ mal species ?
41 |
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
42 |
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
43 |
44 |
This generation script produces three types of outputs: a line prefixed
45 |
with *O* is a copy of the original source sentence; *H* is the
46 |
hypothesis along with an average log-likelihood; and *P* is the
47 |
positional score per token position, including the
48 |
end-of-sentence marker which is omitted from the text.
49 |
50 |
Other types of output lines you might see are *D*, the detokenized hypothesis,
51 |
*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
52 |
53 |
See the `README <>`__ for a
54 |
full list of pre-trained models available.
55 |
56 |
Training a New Model
57 |
58 |
59 |
The following tutorial is for machine translation. For an example of how
60 |
to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
61 |
``examples/`` directory.
62 |
63 |
Data Pre-processing
64 |
65 |
66 |
Fairseq contains example pre-processing scripts for several translation
67 |
datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
68 |
2014 (English-German). To pre-process and binarize the IWSLT dataset:
69 |
70 |
.. code-block:: console
71 |
72 |
> cd examples/translation/
73 |
> bash
74 |
> cd ../..
75 |
> TEXT=examples/translation/
76 |
> fairseq-preprocess --source-lang de --target-lang en \
77 |
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
78 |
--destdir data-bin/
79 |
80 |
This will write binarized data that can be used for model training to
81 |
82 |
83 |
84 |
85 |
86 |
Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
87 |
well for the IWSLT 2014 dataset:
88 |
89 |
.. code-block:: console
90 |
91 |
> mkdir -p checkpoints/fconv
92 |
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/ \
93 |
--optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
94 |
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
95 |
96 |
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
97 |
``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
98 |
change the number of GPU devices that will be used.
99 |
100 |
Also note that the batch size is specified in terms of the maximum
101 |
number of tokens per batch (``--max-tokens``). You may need to use a
102 |
smaller value depending on the available GPU memory on your system.
103 |
104 |
105 |
106 |
107 |
Once your model is trained, you can generate translations using
108 |
:ref:`fairseq-generate` **(for binarized data)** or
109 |
:ref:`fairseq-interactive` **(for raw text)**:
110 |
111 |
.. code-block:: console
112 |
113 |
> fairseq-generate data-bin/ \
114 |
--path checkpoints/fconv/ \
115 |
--batch-size 128 --beam 5
116 |
| [de] dictionary: 35475 types
117 |
| [en] dictionary: 24739 types
118 |
| data-bin/ test 6750 examples
119 |
| model fconv
120 |
| loaded checkpoint trainings/fconv/
121 |
S-721 danke .
122 |
T-721 thank you .
123 |
124 |
125 |
To generate translations with only a CPU, use the ``--cpu`` flag. BPE
126 |
continuation markers can be removed with the ``--remove-bpe`` flag.
127 |
128 |
Advanced Training Options
129 |
130 |
131 |
Large mini-batch training with delayed updates
132 |
133 |
134 |
The ``--update-freq`` option can be used to accumulate gradients from
135 |
multiple mini-batches and delay updating, creating a larger effective
136 |
batch size. Delayed updates can also improve training speed by reducing
137 |
inter-GPU communication costs and by saving idle time caused by variance
138 |
in workload across GPUs. See `Ott et al.
139 |
(2018) <>`__ for more details.
140 |
141 |
To train on a single GPU with an effective batch size that is equivalent
142 |
to training on 8 GPUs:
143 |
144 |
.. code-block:: console
145 |
146 |
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
147 |
148 |
Training with half precision floating point (FP16)
149 |
150 |
151 |
.. note::
152 |
153 |
FP16 training requires a Volta GPU and CUDA 9.1 or greater
154 |
155 |
Recent GPUs enable efficient half precision floating point computation,
156 |
e.g., using `Nvidia Tensor Cores
157 |
158 |
Fairseq supports FP16 training with the ``--fp16`` flag:
159 |
160 |
.. code-block:: console
161 |
162 |
> fairseq-train --fp16 (...)
163 |
164 |
Distributed training
165 |
166 |
167 |
Distributed training in fairseq is implemented on top of ``torch.distributed``.
168 |
The easiest way to launch jobs is with the `torch.distributed.launch
169 |
<>`__ tool.
170 |
171 |
For example, to train a large English-German Transformer model on 2 nodes each
172 |
with 8 GPUs (in total 16 GPUs), run the following command on each node,
173 |
replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
174 |
sure to update ``--master_addr`` to the IP address of the first node:
175 |
176 |
.. code-block:: console
177 |
178 |
> python -m torch.distributed.launch --nproc_per_node=8 \
179 |
--nnodes=2 --node_rank=0 --master_addr="" \
180 |
--master_port=12345 \
181 |
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
182 |
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
183 |
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
184 |
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
185 |
--lr 0.0005 \
186 |
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
187 |
--max-tokens 3584 \
188 |
--max-epoch 70 \
189 |
190 |
191 |
On SLURM clusters, fairseq will automatically detect the number of nodes and
192 |
GPUs, but a port number must be provided:
193 |
194 |
.. code-block:: console
195 |
196 |
> salloc --gpus=16 --nodes 2 (...)
197 |
> srun fairseq-train --distributed-port 12345 (...).
198 |
199 |
Sharding very large datasets
200 |
201 |
202 |
It can be challenging to train over very large datasets, particularly if your
203 |
machine does not have much system RAM. Most tasks in fairseq support training
204 |
over "sharded" datasets, in which the original dataset has been preprocessed
205 |
into non-overlapping chunks (or "shards").
206 |
207 |
For example, instead of preprocessing all your data into a single "data-bin"
208 |
directory, you can split the data and create "data-bin1", "data-bin2", etc.
209 |
Then you can adapt your training command like so:
210 |
211 |
.. code-block:: console
212 |
213 |
> fairseq-train data-bin1:data-bin2:data-bin3 (...)
214 |
215 |
Training will now iterate over each shard, one by one, with each shard
216 |
corresponding to an "epoch", thus reducing system memory usage.
@@ -0,0 +1,284 @@
1 |
## Hydra
2 |
3 |
[Hydra]( is an open-source Python
4 |
framework that simplifies the development of research and other complex
5 |
applications. The key feature is the ability to dynamically create a
6 |
hierarchical configuration by composition and override it through config files
7 |
and the command line. The name Hydra comes from its ability to run multiple
8 |
similar jobs - much like a Hydra with multiple heads.
9 |
10 |
## Motivation
11 |
12 |
Until recently, all components in fairseq were configured through a shared
13 |
`args` namespace that was created at application startup. Components declared
14 |
their own `add_args` method to update the argparse parser, hoping that the names
15 |
would not clash with arguments from other components. While this model works for
16 |
smaller applications, as fairseq grew and became integrated into other
17 |
applications, this became problematic. In order to determine how to configure
18 |
each component, one needed to a) examine what args were added by this component,
19 |
and b) read the code to figure out what shared arguments it is using that were
20 |
added in other places. Reproducing models involved sharing commands that often
21 |
contained dozens of command line switches.
22 |
23 |
The model described above is still supported by fairseq for backward
24 |
compatibility, but will be deprecated some time in the future.
25 |
26 |
New components in fairseq should now create a dataclass that encapsulates all
27 |
parameters required to configure this component. The dataclass is registered
28 |
along with the component, and fairseq takes care of constructing and providing
29 |
this configuration object to the component's constructor. Note that sharing
30 |
parameters can optionally still work, but one has to explicitly point to the
31 |
"source of truth" (see inheritance example below). These changes make components
32 |
in fairseq more independent and re-usable by other applications: all that is
33 |
needed to create a component is to initialize its dataclass and overwrite some
34 |
of the defaults.
35 |
36 |
While configuring fairseq through command line (using either the legacy argparse
37 |
based or the new Hydra based entry points) is still fully supported, you can now
38 |
take advantage of configuring fairseq completely or piece-by-piece through
39 |
hierarchical YAML configuration files. These files can also be shipped as
40 |
examples that others can use to run an identically configured job.
41 |
42 |
Additionally, Hydra has a rich and growing [library of
43 |
plugins]( that
44 |
provide functionality such as hyperparameter sweeping (including using bayesian
45 |
optimization through the [Ax]( library), job
46 |
launching across various platforms, and more.
47 |
48 |
## Creating or migrating components
49 |
50 |
In general, each new (or updated) component should provide a companion
51 |
[dataclass]( These dataclass are
52 |
typically located in the same file as the component and are passed as arguments
53 |
to the `register_*()` functions. Top-level configs that should be present in
54 |
every fairseq application are placed in the
55 |
[global](fairseq/dataclass/ config file and added to the
56 |
`FairseqConfig` object.
57 |
58 |
Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
59 |
classes are decorated with a `@dataclass` decorator, and typically inherit from
60 |
`FairseqDataclass` (which adds some functionality for backward compatibility).
61 |
Each field must have a type, and generally has metadata (such as a help string)
62 |
and a default value. Only primitive types or other config objects are allowed as
63 |
data types for each field.
64 |
65 |
#### Example:
66 |
67 |
68 |
from dataclasses import dataclass, field
69 |
from fairseq.dataclass import FairseqDataclass
70 |
71 |
72 |
class InteractiveConfig(FairseqDataclass):
73 |
buffer_size: int = field(
74 |
75 |
76 |
"help": "read this many sentences into a buffer before processing them"
77 |
78 |
79 |
input: str = field(
80 |
81 |
metadata={"help": "file to read from; use - for stdin"},
82 |
83 |
84 |
85 |
### Inherting values
86 |
87 |
Some components require sharing a value. For example, a learning rate scheduler
88 |
and an optimizer may both need to know the initial learning rate value. One can
89 |
declare a field that, by default, will inherit its value from another config
90 |
node in the same hierarchy:
91 |
92 |
93 |
94 |
95 |
96 |
lr: List[float] = II("")
97 |
98 |
99 |
100 |
`II("")` is syntactic sugar for `"${}"`, which is
101 |
the value one can use in a YAML config file or through command line to achieve
102 |
the same effect. Note that this assumes that there is an "optimization" config
103 |
object in the root config and it has a field called "lr".
104 |
105 |
### Tasks and Models
106 |
107 |
Creating Tasks and Models works same as before, except that legacy
108 |
implementations now inherit from `LegacyFairseq*` base classes, while new
109 |
components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
110 |
to the `register_*()` functions.
111 |
112 |
#### Task example:
113 |
114 |
115 |
116 |
class LanguageModelingConfig(FairseqDataclass):
117 |
data: Optional[str] = field(
118 |
default=None, metadata={"help": "path to data directory"}
119 |
120 |
121 |
122 |
@register_task("language_modeling", dataclass=LanguageModelingConfig)
123 |
class LanguageModelingTask(FairseqTask):
124 |
125 |
126 |
def setup_task(cls, cfg: LanguageModelingConfig):
127 |
128 |
129 |
130 |
#### Model example:
131 |
132 |
133 |
134 |
class TransformerLanguageModelConfig(FairseqDataclass):
135 |
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
136 |
default="relu", metadata={"help": "activation function to use"}
137 |
138 |
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
139 |
140 |
141 |
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
142 |
class TransformerLanguageModel(FairseqLanguageModel):
143 |
144 |
145 |
def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
146 |
147 |
148 |
149 |
### Other components
150 |
151 |
Other components work as before, but they now take their configuration dataclass
152 |
as the only constructor argument:
153 |
154 |
155 |
156 |
class MosesTokenizerConfig(FairseqDataclass):
157 |
source_lang: str = field(default="en", metadata={"help": "source language"})
158 |
159 |
160 |
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
161 |
class MosesTokenizer(object):
162 |
def __init__(self, cfg: MosesTokenizerConfig):
163 |
164 |
165 |
166 |
Note that if you are adding a new registry for a new set of components, you need
167 |
to add it to the `FairseqConfig` object in `fairseq/dataclass/`:
168 |
169 |
170 |
171 |
class FairseqConfig(object):
172 |
173 |
my_new_registry: Any = None
174 |
175 |
176 |
## Training with `fairseq-hydra-train`
177 |
178 |
To fully take advantage of configuration flexibility offered by Hydra, you may
179 |
want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
180 |
tools such as `fairseq-train` will remain supported for the foreseeable future
181 |
but will be deprecated eventually.
182 |
183 |
On startup, Hydra will create a configuration object that contains a hierarchy
184 |
of all the necessary dataclasses populated with their default values in the
185 |
code. The default values are overwritten by values found in YAML files in
186 |
`fairseq/config` directory (which currently sets minimal defaults) and then
187 |
further overwritten by values provided through command line arguments.
188 |
189 |
Some of the most common use cases are shown below:
190 |
191 |
### 1. Override default values through command line:
192 |
193 |
```shell script
194 |
$ fairseq-hydra-train \
195 |
distributed_training.distributed_world_size=1 \
196 |
dataset.batch_size=2 \
197 |
+ \
198 |
model=transformer_lm/transformer_lm_gpt \
199 |
task=language_modeling \
200 |
201 |
202 |
203 |
Note that along with explicitly providing values for parameters such as
204 |
`dataset.batch_size`, this also tells Hydra to overlay configuration found in
205 |
`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
206 |
values in the dataclass. If you want to train a model without specifying a
207 |
particular architecture you can simply specify `model=transformer_lm`. This only
208 |
works for migrated tasks and models.
209 |
210 |
### 2. Replace bundled configs with an external config:
211 |
212 |
```shell script
213 |
$ fairseq-hydra-train \
214 |
--config-dir /path/to/external/configs \
215 |
--config-name wiki103
216 |
217 |
218 |
where `/path/to/external/configs/wiki103.yaml` contains:
219 |
220 |
221 |
# @package _group_
222 |
223 |
224 |
_name: transformer_lm
225 |
226 |
distributed_world_size: 1
227 |
228 |
batch_size: 2
229 |
230 |
_name: language_modeling
231 |
data: /path/to/data
232 |
add_bos_token: false
233 |
max_target_positions: 1024
234 |
235 |
max_update: 50000
236 |
lr: [ 0.25 ]
237 |
criterion: cross_entropy
238 |
optimizer: adam
239 |
240 |
_name: cosine
241 |
242 |
243 |
Note that here bundled configs from `fairseq/config` directory are not used,
244 |
however the defaults from each dataclass will still be used (unless overwritten
245 |
by your external config).
246 |
247 |
Additionally you can choose to break up your configs by creating a directory
248 |
structure in the same location as your main config file, with the names of the
249 |
top-level fields (such as "model", "dataset", etc), and placing config files
250 |
with meaningful names that would populate that specific section of your
251 |
top-level config file (for example, you might have
252 |
`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
253 |
can then specify the correct configuration via command line, defaults in the
254 |
main config, or even launch all of them as a sweep (see Hydra documentation on
255 |
how to do this).
256 |
257 |
### 3. Add an external config directory to Hydra search path:
258 |
259 |
This allows combining default configuration (including using any bundled config
260 |
files), while specifying your own config files for some parts of the
261 |
262 |
263 |
```shell script
264 |
$ fairseq-hydra-train \
265 |
distributed_training.distributed_world_size=1 \
266 |
dataset.batch_size=2 \
267 |
+ \
268 |
model=transformer_lm/2_layers \
269 |
task=language_modeling \
270 |
optimization.max_update=5000 \
271 |
--config-dir /path/to/external/configs
272 |
273 |
274 |
where `/path/to/external/configs` has the following structure:
275 |
276 |
277 |
+-- model
278 |
| +-- transformer_lm
279 |
| | +-- 2_layers.yaml
280 |
281 |
282 |
and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
283 |
`decoder_layers` set to 2. You can add other configs to configure other
284 |
components as well.
@@ -0,0 +1,49 @@
1 |
.. fairseq documentation master file, created by
2 |
sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3 |
You can adapt this file completely to your liking, but it should at least
4 |
contain the root `toctree` directive.
5 |
6 |
7 |
8 |
9 |
fairseq documentation
10 |
11 |
12 |
Fairseq is a sequence modeling toolkit written in `PyTorch
13 |
<>`_ that allows researchers and developers to
14 |
train custom models for translation, summarization, language modeling and other
15 |
text generation tasks.
16 |
17 |
.. toctree::
18 |
:maxdepth: 1
19 |
:caption: Getting Started
20 |
21 |
22 |
23 |
24 |
.. toctree::
25 |
:maxdepth: 1
26 |
:caption: Extending Fairseq
27 |
28 |
29 |
30 |
31 |
32 |
.. toctree::
33 |
:maxdepth: 2
34 |
:caption: Library Reference
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
Indices and tables
46 |
47 |
48 |
* :ref:`genindex`
49 |
* :ref:`search`
@@ -0,0 +1,34 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. _Learning Rate Schedulers:
5 |
6 |
Learning Rate Schedulers
7 |
8 |
9 |
Learning Rate Schedulers update the learning rate over the course of training.
10 |
Learning rates can be updated after each update via :func:`step_update` or at
11 |
epoch boundaries via :func:`step`.
12 |
13 |
.. automodule:: fairseq.optim.lr_scheduler
14 |
15 |
16 |
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17 |
18 |
19 |
20 |
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21 |
22 |
23 |
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24 |
25 |
26 |
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27 |
28 |
29 |
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30 |
31 |
32 |
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33 |
34 |
@@ -0,0 +1,36 @@
1 |
2 |
3 |
pushd %~dp0
4 |
5 |
REM Command file for Sphinx documentation
6 |
7 |
if "%SPHINXBUILD%" == "" (
8 |
set SPHINXBUILD=python -msphinx
9 |
10 |
11 |
set BUILDDIR=_build
12 |
set SPHINXPROJ=fairseq
13 |
14 |
if "%1" == "" goto help
15 |
16 |
17 |
if errorlevel 9009 (
18 |
19 |
echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 |
echo.then set the SPHINXBUILD environment variable to point to the full
21 |
echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 |
echo.Sphinx directory to PATH.
23 |
24 |
echo.If you don't have Sphinx installed, grab it from
25 |
26 |
exit /b 1
27 |
28 |
29 |
30 |
goto end
31 |
32 |
33 |
34 |
35 |
36 |
@@ -0,0 +1,104 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. module:: fairseq.models
5 |
6 |
.. _Models:
7 |
8 |
9 |
10 |
11 |
A Model defines the neural network's ``forward()`` method and encapsulates all
12 |
of the learnable parameters in the network. Each model also provides a set of
13 |
named *architectures* that define the precise network configuration (e.g.,
14 |
embedding dimension, number of layers, etc.).
15 |
16 |
Both the model type and architecture are selected via the ``--arch``
17 |
command-line argument. Once selected, a model may expose additional command-line
18 |
arguments for further configuration.
19 |
20 |
.. note::
21 |
22 |
All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23 |
:class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24 |
stand-alone Module in other PyTorch code.
25 |
26 |
27 |
Convolutional Neural Networks (CNN)
28 |
29 |
30 |
.. module:: fairseq.models.fconv
31 |
.. autoclass:: fairseq.models.fconv.FConvModel
32 |
33 |
.. autoclass:: fairseq.models.fconv.FConvEncoder
34 |
35 |
36 |
.. autoclass:: fairseq.models.fconv.FConvDecoder
37 |
38 |
39 |
40 |
Long Short-Term Memory (LSTM) networks
41 |
42 |
43 |
.. module:: fairseq.models.lstm
44 |
.. autoclass:: fairseq.models.lstm.LSTMModel
45 |
46 |
.. autoclass:: fairseq.models.lstm.LSTMEncoder
47 |
48 |
.. autoclass:: fairseq.models.lstm.LSTMDecoder
49 |
50 |
51 |
52 |
Transformer (self-attention) networks
53 |
54 |
55 |
.. module:: fairseq.models.transformer
56 |
.. autoclass:: fairseq.models.transformer.TransformerModel
57 |
58 |
.. autoclass:: fairseq.models.transformer.TransformerEncoder
59 |
60 |
.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61 |
62 |
.. autoclass:: fairseq.models.transformer.TransformerDecoder
63 |
64 |
.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65 |
66 |
67 |
68 |
Adding new models
69 |
70 |
71 |
.. currentmodule:: fairseq.models
72 |
.. autofunction:: fairseq.models.register_model
73 |
.. autofunction:: fairseq.models.register_model_architecture
74 |
.. autoclass:: fairseq.models.BaseFairseqModel
75 |
76 |
77 |
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78 |
79 |
80 |
.. autoclass:: fairseq.models.FairseqEncoderModel
81 |
82 |
83 |
.. autoclass:: fairseq.models.FairseqLanguageModel
84 |
85 |
86 |
.. autoclass:: fairseq.models.FairseqMultiModel
87 |
88 |
89 |
.. autoclass:: fairseq.models.FairseqEncoder
90 |
91 |
.. autoclass:: fairseq.models.CompositeEncoder
92 |
93 |
.. autoclass:: fairseq.models.FairseqDecoder
94 |
95 |
96 |
97 |
.. _Incremental decoding:
98 |
99 |
Incremental decoding
100 |
101 |
102 |
.. autoclass:: fairseq.models.FairseqIncrementalDecoder
103 |
104 |
@@ -0,0 +1,9 @@
1 |
2 |
3 |
4 |
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5 |
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6 |
7 |
.. automodule:: fairseq.modules
8 |
9 |
@@ -0,0 +1,38 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. _optimizers:
5 |
6 |
7 |
8 |
9 |
Optimizers update the Model parameters based on the gradients.
10 |
11 |
.. automodule:: fairseq.optim
12 |
13 |
14 |
.. autoclass:: fairseq.optim.FairseqOptimizer
15 |
16 |
17 |
18 |
.. autoclass:: fairseq.optim.adadelta.Adadelta
19 |
20 |
21 |
.. autoclass:: fairseq.optim.adagrad.Adagrad
22 |
23 |
24 |
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25 |
26 |
27 |
.. autoclass:: fairseq.optim.adam.FairseqAdam
28 |
29 |
30 |
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31 |
32 |
33 |
.. autoclass:: fairseq.optim.nag.FairseqNAG
34 |
35 |
36 |
.. autoclass:: fairseq.optim.sgd.SGD
37 |
38 |
@@ -0,0 +1,74 @@
1 |
2 |
3 |
4 |
Fairseq can be extended through user-supplied `plug-ins
5 |
<>`_. We support five kinds of
6 |
7 |
8 |
- :ref:`Models` define the neural network architecture and encapsulate all of the
9 |
learnable parameters.
10 |
- :ref:`Criterions` compute the loss function given the model outputs and targets.
11 |
- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
12 |
Datasets, initializing the Model/Criterion and calculating the loss.
13 |
- :ref:`Optimizers` update the Model parameters based on the gradients.
14 |
- :ref:`Learning Rate Schedulers` update the learning rate over the course of
15 |
16 |
17 |
**Training Flow**
18 |
19 |
Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
20 |
fairseq implements the following high-level training flow::
21 |
22 |
for epoch in range(num_epochs):
23 |
itr = task.get_batch_iterator(task.dataset('train'))
24 |
for num_updates, batch in enumerate(itr):
25 |
task.train_step(batch, model, criterion, optimizer)
26 |
27 |
28 |
29 |
30 |
31 |
where the default implementation for ``task.train_step`` is roughly::
32 |
33 |
def train_step(self, batch, model, criterion, optimizer, **unused):
34 |
loss = criterion(model, batch)
35 |
36 |
return loss
37 |
38 |
**Registering new plug-ins**
39 |
40 |
New plug-ins are *registered* through a set of ``@register`` function
41 |
decorators, for example::
42 |
43 |
44 |
class MyLSTM(FairseqEncoderDecoderModel):
45 |
46 |
47 |
Once registered, new plug-ins can be used with the existing :ref:`Command-line
48 |
Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
49 |
new plug-ins.
50 |
51 |
**Loading plug-ins from another directory**
52 |
53 |
New plug-ins can be defined in a custom module stored in the user system. In
54 |
order to import the module, and make the plugin available to *fairseq*, the
55 |
command line supports the ``--user-dir`` flag that can be used to specify a
56 |
custom location for additional modules to load into *fairseq*.
57 |
58 |
For example, assuming this directory tree::
59 |
60 |
61 |
62 |
63 |
with ````::
64 |
65 |
from fairseq.models import register_model_architecture
66 |
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
67 |
68 |
@register_model_architecture('transformer', 'my_transformer')
69 |
def transformer_mmt_big(args):
70 |
71 |
72 |
it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
73 |
74 |
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
@@ -0,0 +1,2 @@
1 |
2 |
@@ -0,0 +1,61 @@
1 |
.. role:: hidden
2 |
:class: hidden-section
3 |
4 |
.. module:: fairseq.tasks
5 |
6 |
.. _Tasks:
7 |
8 |
9 |
10 |
11 |
Tasks store dictionaries and provide helpers for loading/iterating over
12 |
Datasets, initializing the Model/Criterion and calculating the loss.
13 |
14 |
Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15 |
task may expose additional command-line arguments for further configuration.
16 |
17 |
Example usage::
18 |
19 |
# setup the task (e.g., load dictionaries)
20 |
task = fairseq.tasks.setup_task(args)
21 |
22 |
# build model and criterion
23 |
model = task.build_model(args)
24 |
criterion = task.build_criterion(args)
25 |
26 |
# load datasets
27 |
28 |
29 |
30 |
# iterate over mini-batches of data
31 |
batch_itr = task.get_batch_iterator(
32 |
task.dataset('train'), max_tokens=4096,
33 |
34 |
for batch in batch_itr:
35 |
# compute the loss
36 |
loss, sample_size, logging_output = task.get_loss(
37 |
model, criterion, batch,
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
.. autoclass:: fairseq.tasks.translation.TranslationTask
46 |
47 |
.. _language modeling:
48 |
49 |
Language Modeling
50 |
51 |
52 |
.. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53 |
54 |
55 |
Adding new tasks
56 |
57 |
58 |
.. autofunction:: fairseq.tasks.register_task
59 |
.. autoclass:: fairseq.tasks.FairseqTask
60 |
61 |
@@ -0,0 +1,415 @@
1 |
Tutorial: Classifying Names with a Character-Level RNN
2 |
3 |
4 |
In this tutorial we will extend fairseq to support *classification* tasks. In
5 |
particular we will re-implement the PyTorch tutorial for `Classifying Names with
6 |
a Character-Level RNN <>`_
7 |
in fairseq. It is recommended to quickly skim that tutorial before beginning
8 |
this one.
9 |
10 |
This tutorial covers:
11 |
12 |
1. **Preprocessing the data** to create dictionaries.
13 |
2. **Registering a new Model** that encodes an input sentence with a simple RNN
14 |
and predicts the output label.
15 |
3. **Registering a new Task** that loads our dictionaries and dataset.
16 |
4. **Training the Model** using the existing command-line tools.
17 |
5. **Writing an evaluation script** that imports fairseq and allows us to
18 |
interactively evaluate our model on new inputs.
19 |
20 |
21 |
1. Preprocessing the data
22 |
23 |
24 |
The original tutorial provides raw data, but we'll work with a modified version
25 |
of the data that is already tokenized into characters and split into separate
26 |
train, valid and test sets.
27 |
28 |
Download and extract the data from here:
29 |
`tutorial_names.tar.gz <>`_
30 |
31 |
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
32 |
command-line tool to create the dictionaries. While this tool is primarily
33 |
intended for sequence-to-sequence problems, we're able to reuse it here by
34 |
treating the label as a "target" sequence of length 1. We'll also output the
35 |
preprocessed files in "raw" format using the ``--dataset-impl`` option to
36 |
enhance readability:
37 |
38 |
.. code-block:: console
39 |
40 |
> fairseq-preprocess \
41 |
--trainpref names/train --validpref names/valid --testpref names/test \
42 |
--source-lang input --target-lang label \
43 |
--destdir names-bin --dataset-impl raw
44 |
45 |
After running the above command you should see a new directory,
46 |
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
47 |
48 |
49 |
2. Registering a new Model
50 |
51 |
52 |
Next we'll register a new model in fairseq that will encode an input sentence
53 |
with a simple RNN and predict the output label. Compared to the original PyTorch
54 |
tutorial, our version will also work with batches of data and GPU Tensors.
55 |
56 |
First let's copy the simple RNN module implemented in the `PyTorch tutorial
57 |
58 |
Create a new file named :file:`fairseq/models/` with the
59 |
following contents::
60 |
61 |
import torch
62 |
import torch.nn as nn
63 |
64 |
class RNN(nn.Module):
65 |
66 |
def __init__(self, input_size, hidden_size, output_size):
67 |
super(RNN, self).__init__()
68 |
69 |
self.hidden_size = hidden_size
70 |
71 |
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
72 |
self.i2o = nn.Linear(input_size + hidden_size, output_size)
73 |
self.softmax = nn.LogSoftmax(dim=1)
74 |
75 |
def forward(self, input, hidden):
76 |
combined =, hidden), 1)
77 |
hidden = self.i2h(combined)
78 |
output = self.i2o(combined)
79 |
output = self.softmax(output)
80 |
return output, hidden
81 |
82 |
def initHidden(self):
83 |
return torch.zeros(1, self.hidden_size)
84 |
85 |
We must also *register* this model with fairseq using the
86 |
:func:`~fairseq.models.register_model` function decorator. Once the model is
87 |
registered we'll be able to use it with the existing :ref:`Command-line Tools`.
88 |
89 |
All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
90 |
interface, so we'll create a small wrapper class in the same file and register
91 |
it in fairseq with the name ``'rnn_classifier'``::
92 |
93 |
from fairseq.models import BaseFairseqModel, register_model
94 |
95 |
# Note: the register_model "decorator" should immediately precede the
96 |
# definition of the Model class.
97 |
98 |
99 |
class FairseqRNNClassifier(BaseFairseqModel):
100 |
101 |
102 |
def add_args(parser):
103 |
# Models can override this method to add new command-line arguments.
104 |
# Here we'll add a new command-line argument to configure the
105 |
# dimensionality of the hidden state.
106 |
107 |
'--hidden-dim', type=int, metavar='N',
108 |
help='dimensionality of the hidden state',
109 |
110 |
111 |
112 |
def build_model(cls, args, task):
113 |
# Fairseq initializes models by calling the ``build_model()``
114 |
# function. This provides more flexibility, since the returned model
115 |
# instance can be of a different type than the one that was called.
116 |
# In this case we'll just return a FairseqRNNClassifier instance.
117 |
118 |
# Initialize our RNN module
119 |
rnn = RNN(
120 |
# We'll define the Task in the next section, but for now just
121 |
# notice that the task holds the dictionaries for the "source"
122 |
# (i.e., the input sentence) and "target" (i.e., the label).
123 |
124 |
125 |
126 |
127 |
128 |
# Return the wrapped version of the module
129 |
return FairseqRNNClassifier(
130 |
131 |
132 |
133 |
134 |
def __init__(self, rnn, input_vocab):
135 |
super(FairseqRNNClassifier, self).__init__()
136 |
137 |
self.rnn = rnn
138 |
self.input_vocab = input_vocab
139 |
140 |
# The RNN module in the tutorial expects one-hot inputs, so we can
141 |
# precompute the identity matrix to help convert from indices to
142 |
# one-hot vectors. We register it as a buffer so that it is moved to
143 |
# the GPU when ``cuda()`` is called.
144 |
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
145 |
146 |
def forward(self, src_tokens, src_lengths):
147 |
# The inputs to the ``forward()`` function are determined by the
148 |
# Task, and in particular the ``'net_input'`` key in each
149 |
# mini-batch. We'll define the Task in the next section, but for
150 |
# now just know that *src_tokens* has shape `(batch, src_len)` and
151 |
# *src_lengths* has shape `(batch)`.
152 |
bsz, max_src_len = src_tokens.size()
153 |
154 |
# Initialize the RNN hidden state. Compared to the original PyTorch
155 |
# tutorial we'll also handle batched inputs and work on the GPU.
156 |
hidden = self.rnn.initHidden()
157 |
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
158 |
hidden = # move to GPU
159 |
160 |
for i in range(max_src_len):
161 |
# WARNING: The inputs have padding, so we should mask those
162 |
# elements here so that padding doesn't affect the results.
163 |
# This is left as an exercise for the reader. The padding symbol
164 |
# is given by ``self.input_vocab.pad()`` and the unpadded length
165 |
# of each input is given by *src_lengths*.
166 |
167 |
# One-hot encode a batch of input characters.
168 |
input = self.one_hot_inputs[src_tokens[:, i].long()]
169 |
170 |
# Feed the input to our RNN.
171 |
output, hidden = self.rnn(input, hidden)
172 |
173 |
# Return the final output state for making a prediction
174 |
return output
175 |
176 |
Finally let's define a *named architecture* with the configuration for our
177 |
model. This is done with the :func:`~fairseq.models.register_model_architecture`
178 |
function decorator. Thereafter this named architecture can be used with the
179 |
``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
180 |
181 |
from fairseq.models import register_model_architecture
182 |
183 |
# The first argument to ``register_model_architecture()`` should be the name
184 |
# of the model we registered above (i.e., 'rnn_classifier'). The function we
185 |
# register here should take a single argument *args* and modify it in-place
186 |
# to match the desired architecture.
187 |
188 |
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
189 |
def pytorch_tutorial_rnn(args):
190 |
# We use ``getattr()`` to prioritize arguments that are explicitly given
191 |
# on the command-line, so that the defaults defined below are only used
192 |
# when no other value has been specified.
193 |
args.hidden_dim = getattr(args, 'hidden_dim', 128)
194 |
195 |
196 |
3. Registering a new Task
197 |
198 |
199 |
Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
200 |
dictionaries and dataset. Tasks can also control how the data is batched into
201 |
mini-batches, but in this tutorial we'll reuse the batching provided by
202 |
203 |
204 |
Create a new file named :file:`fairseq/tasks/` with the
205 |
following contents::
206 |
207 |
import os
208 |
import torch
209 |
210 |
from import Dictionary, LanguagePairDataset
211 |
from fairseq.tasks import FairseqTask, register_task
212 |
213 |
214 |
215 |
class SimpleClassificationTask(LegacyFairseqTask):
216 |
217 |
218 |
def add_args(parser):
219 |
# Add some command-line arguments for specifying where the data is
220 |
# located and the maximum supported input length.
221 |
parser.add_argument('data', metavar='FILE',
222 |
help='file prefix for data')
223 |
parser.add_argument('--max-positions', default=1024, type=int,
224 |
help='max input length')
225 |
226 |
227 |
def setup_task(cls, args, **kwargs):
228 |
# Here we can perform any setup required for the task. This may include
229 |
# loading Dictionaries, initializing shared Embedding layers, etc.
230 |
# In this case we'll just load the Dictionaries.
231 |
input_vocab = Dictionary.load(os.path.join(, 'dict.input.txt'))
232 |
label_vocab = Dictionary.load(os.path.join(, 'dict.label.txt'))
233 |
print('| [input] dictionary: {} types'.format(len(input_vocab)))
234 |
print('| [label] dictionary: {} types'.format(len(label_vocab)))
235 |
236 |
return SimpleClassificationTask(args, input_vocab, label_vocab)
237 |
238 |
def __init__(self, args, input_vocab, label_vocab):
239 |
240 |
self.input_vocab = input_vocab
241 |
self.label_vocab = label_vocab
242 |
243 |
def load_dataset(self, split, **kwargs):
244 |
"""Load a given dataset split (e.g., train, valid, test)."""
245 |
246 |
prefix = os.path.join(, '{}.input-label'.format(split))
247 |
248 |
# Read input sentences.
249 |
sentences, lengths = [], []
250 |
with open(prefix + '.input', encoding='utf-8') as file:
251 |
for line in file:
252 |
sentence = line.strip()
253 |
254 |
# Tokenize the sentence, splitting on spaces
255 |
tokens = self.input_vocab.encode_line(
256 |
sentence, add_if_not_exist=False,
257 |
258 |
259 |
260 |
261 |
262 |
# Read labels.
263 |
labels = []
264 |
with open(prefix + '.label', encoding='utf-8') as file:
265 |
for line in file:
266 |
label = line.strip()
267 |
268 |
# Convert label to a numeric ID.
269 |
270 |
271 |
272 |
assert len(sentences) == len(labels)
273 |
print('| {} {} {} examples'.format(, split, len(sentences)))
274 |
275 |
# We reuse LanguagePairDataset since classification can be modeled as a
276 |
# sequence-to-sequence task where the target sequence has length 1.
277 |
self.datasets[split] = LanguagePairDataset(
278 |
279 |
280 |
281 |
282 |
tgt_sizes=torch.ones(len(labels)), # targets have length 1
283 |
284 |
285 |
# Since our target is a single class label, there's no need for
286 |
# teacher forcing. If we set this to ``True`` then our Model's
287 |
# ``forward()`` method would receive an additional argument called
288 |
# *prev_output_tokens* that would contain a shifted version of the
289 |
# target sequence.
290 |
291 |
292 |
293 |
def max_positions(self):
294 |
"""Return the max input length allowed by the task."""
295 |
# The source should be less than *args.max_positions* and the "target"
296 |
# has max length 1.
297 |
return (self.args.max_positions, 1)
298 |
299 |
300 |
def source_dictionary(self):
301 |
"""Return the source :class:``."""
302 |
return self.input_vocab
303 |
304 |
305 |
def target_dictionary(self):
306 |
"""Return the target :class:``."""
307 |
return self.label_vocab
308 |
309 |
# We could override this method if we wanted more control over how batches
310 |
# are constructed, but it's not necessary for this tutorial since we can
311 |
# reuse the batching provided by LanguagePairDataset.
312 |
313 |
# def get_batch_iterator(
314 |
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
315 |
# ignore_invalid_inputs=False, required_batch_size_multiple=1,
316 |
# seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
317 |
# data_buffer_size=0, disable_iterator_cache=False,
318 |
# ):
319 |
# (...)
320 |
321 |
322 |
4. Training the Model
323 |
324 |
325 |
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
326 |
command-line tool for this, making sure to specify our new Task (``--task
327 |
simple_classification``) and Model architecture (``--arch
328 |
329 |
330 |
.. note::
331 |
332 |
You can also configure the dimensionality of the hidden state by passing the
333 |
``--hidden-dim`` argument to :ref:`fairseq-train`.
334 |
335 |
.. code-block:: console
336 |
337 |
> fairseq-train names-bin \
338 |
--task simple_classification \
339 |
--arch pytorch_tutorial_rnn \
340 |
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
341 |
--max-tokens 1000
342 |
343 |
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
344 |
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
345 |
| done training in 31.6 seconds
346 |
347 |
The model files should appear in the :file:`checkpoints/` directory.
348 |
349 |
350 |
5. Writing an evaluation script
351 |
352 |
353 |
Finally we can write a short script to evaluate our model on new inputs. Create
354 |
a new file named :file:`` with the following contents::
355 |
356 |
from fairseq import checkpoint_utils, data, options, tasks
357 |
358 |
# Parse command-line arguments for generation
359 |
parser = options.get_generation_parser(default_task='simple_classification')
360 |
args = options.parse_args_and_arch(parser)
361 |
362 |
# Setup task
363 |
task = tasks.setup_task(args)
364 |
365 |
# Load model
366 |
print('| loading model from {}'.format(args.path))
367 |
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
368 |
model = models[0]
369 |
370 |
while True:
371 |
sentence = input('\nInput: ')
372 |
373 |
# Tokenize into characters
374 |
chars = ' '.join(list(sentence.strip()))
375 |
tokens = task.source_dictionary.encode_line(
376 |
chars, add_if_not_exist=False,
377 |
378 |
379 |
# Build mini-batch to feed to the model
380 |
batch = data.language_pair_dataset.collate(
381 |
samples=[{'id': -1, 'source': tokens}], # bsz = 1
382 |
383 |
384 |
385 |
386 |
387 |
388 |
# Feed batch to the model and get predictions
389 |
preds = model(**batch['net_input'])
390 |
391 |
# Print top 3 predictions and their log-probabilities
392 |
top_scores, top_labels = preds[0].topk(k=3)
393 |
for score, label_idx in zip(top_scores, top_labels):
394 |
label_name = task.target_dictionary.string([label_idx])
395 |
print('({:.2f})\t{}'.format(score, label_name))
396 |
397 |
Now we can evaluate our model interactively. Note that we have included the
398 |
original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
399 |
400 |
.. code-block:: console
401 |
402 |
> python names-bin --path checkpoints/
403 |
| [input] dictionary: 64 types
404 |
| [label] dictionary: 24 types
405 |
| loading model from checkpoints/
406 |
407 |
Input: Satoshi
408 |
(-0.61) Japanese
409 |
(-1.20) Arabic
410 |
(-2.86) Italian
411 |
412 |
Input: Sinbad
413 |
(-0.30) Arabic
414 |
(-1.76) English
415 |
(-4.08) Russian
@@ -0,0 +1,518 @@
1 |
Tutorial: Simple LSTM
2 |
3 |
4 |
In this tutorial we will extend fairseq by adding a new
5 |
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
6 |
sentence with an LSTM and then passes the final hidden state to a second LSTM
7 |
that decodes the target sentence (without attention).
8 |
9 |
This tutorial covers:
10 |
11 |
1. **Writing an Encoder and Decoder** to encode/decode the source/target
12 |
sentence, respectively.
13 |
2. **Registering a new Model** so that it can be used with the existing
14 |
:ref:`Command-line tools`.
15 |
3. **Training the Model** using the existing command-line tools.
16 |
4. **Making generation faster** by modifying the Decoder to use
17 |
:ref:`Incremental decoding`.
18 |
19 |
20 |
1. Building an Encoder and Decoder
21 |
22 |
23 |
In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
24 |
should implement the :class:`~fairseq.models.FairseqEncoder` interface and
25 |
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
26 |
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
27 |
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
28 |
29 |
30 |
31 |
32 |
33 |
34 |
Our Encoder will embed the tokens in the source sentence, feed them to a
35 |
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
36 |
save the following in a new file named :file:`fairseq/models/`::
37 |
38 |
import torch.nn as nn
39 |
from fairseq import utils
40 |
from fairseq.models import FairseqEncoder
41 |
42 |
class SimpleLSTMEncoder(FairseqEncoder):
43 |
44 |
def __init__(
45 |
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
46 |
47 |
48 |
self.args = args
49 |
50 |
# Our encoder will embed the inputs before feeding them to the LSTM.
51 |
self.embed_tokens = nn.Embedding(
52 |
53 |
54 |
55 |
56 |
self.dropout = nn.Dropout(p=dropout)
57 |
58 |
# We'll use a single-layer, unidirectional LSTM for simplicity.
59 |
self.lstm = nn.LSTM(
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
def forward(self, src_tokens, src_lengths):
68 |
# The inputs to the ``forward()`` function are determined by the
69 |
# Task, and in particular the ``'net_input'`` key in each
70 |
# mini-batch. We discuss Tasks in the next tutorial, but for now just
71 |
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
72 |
# has shape `(batch)`.
73 |
74 |
# Note that the source is typically padded on the left. This can be
75 |
# configured by adding the `--left-pad-source "False"` command-line
76 |
# argument, but here we'll make the Encoder handle either kind of
77 |
# padding by converting everything to be right-padded.
78 |
if self.args.left_pad_source:
79 |
# Convert left-padding to right-padding.
80 |
src_tokens = utils.convert_padding_direction(
81 |
82 |
83 |
84 |
85 |
86 |
# Embed the source.
87 |
x = self.embed_tokens(src_tokens)
88 |
89 |
# Apply dropout.
90 |
x = self.dropout(x)
91 |
92 |
# Pack the sequence into a PackedSequence object to feed to the LSTM.
93 |
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
94 |
95 |
# Get the output from the LSTM.
96 |
_outputs, (final_hidden, _final_cell) = self.lstm(x)
97 |
98 |
# Return the Encoder's output. This can be any object and will be
99 |
# passed directly to the Decoder.
100 |
return {
101 |
# this will have shape `(bsz, hidden_dim)`
102 |
'final_hidden': final_hidden.squeeze(0),
103 |
104 |
105 |
# Encoders are required to implement this method so that we can rearrange
106 |
# the order of the batch elements during inference (e.g., beam search).
107 |
def reorder_encoder_out(self, encoder_out, new_order):
108 |
109 |
Reorder encoder output according to `new_order`.
110 |
111 |
112 |
encoder_out: output from the ``forward()`` method
113 |
new_order (LongTensor): desired order
114 |
115 |
116 |
`encoder_out` rearranged according to `new_order`
117 |
118 |
final_hidden = encoder_out['final_hidden']
119 |
return {
120 |
'final_hidden': final_hidden.index_select(0, new_order),
121 |
122 |
123 |
124 |
125 |
126 |
127 |
Our Decoder will predict the next word, conditioned on the Encoder's final
128 |
hidden state and an embedded representation of the previous target word -- which
129 |
is sometimes called *teacher forcing*. More specifically, we'll use a
130 |
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
131 |
to the size of the output vocabulary to predict each target word.
132 |
133 |
134 |
135 |
import torch
136 |
from fairseq.models import FairseqDecoder
137 |
138 |
class SimpleLSTMDecoder(FairseqDecoder):
139 |
140 |
def __init__(
141 |
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
142 |
143 |
144 |
145 |
146 |
# Our decoder will embed the inputs before feeding them to the LSTM.
147 |
self.embed_tokens = nn.Embedding(
148 |
149 |
150 |
151 |
152 |
self.dropout = nn.Dropout(p=dropout)
153 |
154 |
# We'll use a single-layer, unidirectional LSTM for simplicity.
155 |
self.lstm = nn.LSTM(
156 |
# For the first layer we'll concatenate the Encoder's final hidden
157 |
# state with the embedded target tokens.
158 |
input_size=encoder_hidden_dim + embed_dim,
159 |
160 |
161 |
162 |
163 |
164 |
# Define the output projection.
165 |
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
166 |
167 |
# During training Decoders are expected to take the entire target sequence
168 |
# (shifted right by one position) and produce logits over the vocabulary.
169 |
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
170 |
# ``dictionary.eos()``, followed by the target sequence.
171 |
def forward(self, prev_output_tokens, encoder_out):
172 |
173 |
174 |
prev_output_tokens (LongTensor): previous decoder outputs of shape
175 |
`(batch, tgt_len)`, for teacher forcing
176 |
encoder_out (Tensor, optional): output from the encoder, used for
177 |
encoder-side attention
178 |
179 |
180 |
181 |
- the last decoder layer's output of shape
182 |
`(batch, tgt_len, vocab)`
183 |
- the last decoder layer's attention weights of shape
184 |
`(batch, tgt_len, src_len)`
185 |
186 |
bsz, tgt_len = prev_output_tokens.size()
187 |
188 |
# Extract the final hidden state from the Encoder.
189 |
final_encoder_hidden = encoder_out['final_hidden']
190 |
191 |
# Embed the target sequence, which has been shifted right by one
192 |
# position and now starts with the end-of-sentence symbol.
193 |
x = self.embed_tokens(prev_output_tokens)
194 |
195 |
# Apply dropout.
196 |
x = self.dropout(x)
197 |
198 |
# Concatenate the Encoder's final hidden state to *every* embedded
199 |
# target token.
200 |
x =
201 |
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
202 |
203 |
204 |
205 |
# Using PackedSequence objects in the Decoder is harder than in the
206 |
# Encoder, since the targets are not sorted in descending length order,
207 |
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
208 |
# feed nn.LSTM directly.
209 |
initial_state = (
210 |
final_encoder_hidden.unsqueeze(0), # hidden
211 |
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
212 |
213 |
output, _ = self.lstm(
214 |
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
215 |
216 |
217 |
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
218 |
219 |
# Project the outputs to the size of the vocabulary.
220 |
x = self.output_projection(x)
221 |
222 |
# Return the logits and ``None`` for the attention weights
223 |
return x, None
224 |
225 |
226 |
2. Registering the Model
227 |
228 |
229 |
Now that we've defined our Encoder and Decoder we must *register* our model with
230 |
fairseq using the :func:`~fairseq.models.register_model` function decorator.
231 |
Once the model is registered we'll be able to use it with the existing
232 |
:ref:`Command-line Tools`.
233 |
234 |
All registered models must implement the
235 |
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
236 |
models (i.e., any model with a single Encoder and Decoder), we can instead
237 |
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
238 |
239 |
Create a small wrapper class in the same file and register it in fairseq with
240 |
the name ``'simple_lstm'``::
241 |
242 |
from fairseq.models import FairseqEncoderDecoderModel, register_model
243 |
244 |
# Note: the register_model "decorator" should immediately precede the
245 |
# definition of the Model class.
246 |
247 |
248 |
class SimpleLSTMModel(FairseqEncoderDecoderModel):
249 |
250 |
251 |
def add_args(parser):
252 |
# Models can override this method to add new command-line arguments.
253 |
# Here we'll add some new command-line arguments to configure dropout
254 |
# and the dimensionality of the embeddings and hidden states.
255 |
256 |
'--encoder-embed-dim', type=int, metavar='N',
257 |
help='dimensionality of the encoder embeddings',
258 |
259 |
260 |
'--encoder-hidden-dim', type=int, metavar='N',
261 |
help='dimensionality of the encoder hidden state',
262 |
263 |
264 |
'--encoder-dropout', type=float, default=0.1,
265 |
help='encoder dropout probability',
266 |
267 |
268 |
'--decoder-embed-dim', type=int, metavar='N',
269 |
help='dimensionality of the decoder embeddings',
270 |
271 |
272 |
'--decoder-hidden-dim', type=int, metavar='N',
273 |
help='dimensionality of the decoder hidden state',
274 |
275 |
276 |
'--decoder-dropout', type=float, default=0.1,
277 |
help='decoder dropout probability',
278 |
279 |
280 |
281 |
def build_model(cls, args, task):
282 |
# Fairseq initializes models by calling the ``build_model()``
283 |
# function. This provides more flexibility, since the returned model
284 |
# instance can be of a different type than the one that was called.
285 |
# In this case we'll just return a SimpleLSTMModel instance.
286 |
287 |
# Initialize our Encoder and Decoder.
288 |
encoder = SimpleLSTMEncoder(
289 |
290 |
291 |
292 |
293 |
294 |
295 |
decoder = SimpleLSTMDecoder(
296 |
297 |
298 |
299 |
300 |
301 |
302 |
model = SimpleLSTMModel(encoder, decoder)
303 |
304 |
# Print the model architecture.
305 |
306 |
307 |
return model
308 |
309 |
# We could override the ``forward()`` if we wanted more control over how
310 |
# the encoder and decoder interact, but it's not necessary for this
311 |
# tutorial since we can inherit the default implementation provided by
312 |
# the FairseqEncoderDecoderModel base class, which looks like:
313 |
314 |
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
315 |
# encoder_out = self.encoder(src_tokens, src_lengths)
316 |
# decoder_out = self.decoder(prev_output_tokens, encoder_out)
317 |
# return decoder_out
318 |
319 |
Finally let's define a *named architecture* with the configuration for our
320 |
model. This is done with the :func:`~fairseq.models.register_model_architecture`
321 |
function decorator. Thereafter this named architecture can be used with the
322 |
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
323 |
324 |
from fairseq.models import register_model_architecture
325 |
326 |
# The first argument to ``register_model_architecture()`` should be the name
327 |
# of the model we registered above (i.e., 'simple_lstm'). The function we
328 |
# register here should take a single argument *args* and modify it in-place
329 |
# to match the desired architecture.
330 |
331 |
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
332 |
def tutorial_simple_lstm(args):
333 |
# We use ``getattr()`` to prioritize arguments that are explicitly given
334 |
# on the command-line, so that the defaults defined below are only used
335 |
# when no other value has been specified.
336 |
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
337 |
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
338 |
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
339 |
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
340 |
341 |
342 |
3. Training the Model
343 |
344 |
345 |
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
346 |
command-line tool for this, making sure to specify our new Model architecture
347 |
(``--arch tutorial_simple_lstm``).
348 |
349 |
.. note::
350 |
351 |
Make sure you've already preprocessed the data from the IWSLT example in the
352 |
:file:`examples/translation/` directory.
353 |
354 |
.. code-block:: console
355 |
356 |
> fairseq-train data-bin/ \
357 |
--arch tutorial_simple_lstm \
358 |
--encoder-dropout 0.2 --decoder-dropout 0.2 \
359 |
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
360 |
--max-tokens 12000
361 |
362 |
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
363 |
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
364 |
365 |
The model files should appear in the :file:`checkpoints/` directory. While this
366 |
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
367 |
generate translations and compute our BLEU score over the test set:
368 |
369 |
.. code-block:: console
370 |
371 |
> fairseq-generate data-bin/ \
372 |
--path checkpoints/ \
373 |
--beam 5 \
374 |
375 |
376 |
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
377 |
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
378 |
379 |
380 |
4. Making generation faster
381 |
382 |
383 |
While autoregressive generation from sequence-to-sequence models is inherently
384 |
slow, our implementation above is especially slow because it recomputes the
385 |
entire sequence of Decoder hidden states for every output token (i.e., it is
386 |
``O(n^2)``). We can make this significantly faster by instead caching the
387 |
previous hidden states.
388 |
389 |
In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
390 |
special mode at inference time where the Model only receives a single timestep
391 |
of input corresponding to the immediately previous output token (for teacher
392 |
forcing) and must produce the next output incrementally. Thus the model must
393 |
cache any long-term state that is needed about the sequence, e.g., hidden
394 |
states, convolutional states, etc.
395 |
396 |
To implement incremental decoding we will modify our model to implement the
397 |
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
398 |
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
399 |
decoder interface allows ``forward()`` methods to take an extra keyword argument
400 |
(*incremental_state*) that can be used to cache state across time-steps.
401 |
402 |
Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
403 |
404 |
import torch
405 |
from fairseq.models import FairseqIncrementalDecoder
406 |
407 |
class SimpleLSTMDecoder(FairseqIncrementalDecoder):
408 |
409 |
def __init__(
410 |
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
411 |
412 |
413 |
# This remains the same as before.
414 |
415 |
self.embed_tokens = nn.Embedding(
416 |
417 |
418 |
419 |
420 |
self.dropout = nn.Dropout(p=dropout)
421 |
self.lstm = nn.LSTM(
422 |
input_size=encoder_hidden_dim + embed_dim,
423 |
424 |
425 |
426 |
427 |
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
428 |
429 |
# We now take an additional kwarg (*incremental_state*) for caching the
430 |
# previous hidden and cell states.
431 |
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
432 |
if incremental_state is not None:
433 |
# If the *incremental_state* argument is not ``None`` then we are
434 |
# in incremental inference mode. While *prev_output_tokens* will
435 |
# still contain the entire decoded prefix, we will only use the
436 |
# last step and assume that the rest of the state is cached.
437 |
prev_output_tokens = prev_output_tokens[:, -1:]
438 |
439 |
# This remains the same as before.
440 |
bsz, tgt_len = prev_output_tokens.size()
441 |
final_encoder_hidden = encoder_out['final_hidden']
442 |
x = self.embed_tokens(prev_output_tokens)
443 |
x = self.dropout(x)
444 |
x =
445 |
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
446 |
447 |
448 |
449 |
# We will now check the cache and load the cached previous hidden and
450 |
# cell states, if they exist, otherwise we will initialize them to
451 |
# zeros (as before). We will use the ``utils.get_incremental_state()``
452 |
# and ``utils.set_incremental_state()`` helpers.
453 |
initial_state = utils.get_incremental_state(
454 |
self, incremental_state, 'prev_state',
455 |
456 |
if initial_state is None:
457 |
# first time initialization, same as the original version
458 |
initial_state = (
459 |
final_encoder_hidden.unsqueeze(0), # hidden
460 |
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
461 |
462 |
463 |
# Run one step of our LSTM.
464 |
output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
465 |
466 |
# Update the cache with the latest hidden and cell states.
467 |
468 |
self, incremental_state, 'prev_state', latest_state,
469 |
470 |
471 |
# This remains the same as before
472 |
x = output.transpose(0, 1)
473 |
x = self.output_projection(x)
474 |
return x, None
475 |
476 |
# The ``FairseqIncrementalDecoder`` interface also requires implementing a
477 |
# ``reorder_incremental_state()`` method, which is used during beam search
478 |
# to select and reorder the incremental state.
479 |
def reorder_incremental_state(self, incremental_state, new_order):
480 |
# Load the cached state.
481 |
prev_state = utils.get_incremental_state(
482 |
self, incremental_state, 'prev_state',
483 |
484 |
485 |
# Reorder batches according to *new_order*.
486 |
reordered_state = (
487 |
prev_state[0].index_select(1, new_order), # hidden
488 |
prev_state[1].index_select(1, new_order), # cell
489 |
490 |
491 |
# Update the cached state.
492 |
493 |
self, incremental_state, 'prev_state', reordered_state,
494 |
495 |
496 |
Finally, we can rerun generation and observe the speedup:
497 |
498 |
.. code-block:: console
499 |
500 |
# Before
501 |
502 |
> fairseq-generate data-bin/ \
503 |
--path checkpoints/ \
504 |
--beam 5 \
505 |
506 |
507 |
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
508 |
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
509 |
510 |
# After
511 |
512 |
> fairseq-generate data-bin/ \
513 |
--path checkpoints/ \
514 |
--beam 5 \
515 |
516 |
517 |
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
518 |
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
@@ -0,0 +1,2 @@
1 |
2 |
@@ -0,0 +1,9 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
7 |
from fairseq.version import __version__ # noqa
8 |
except ImportError:
9 |
@@ -0,0 +1,90 @@
1 |
# Adaptive Span
2 |
3 |
Adaptive Span is a novel self-attention mechanism that can learn its optimal
4 |
attention span. This allows us to extend significantly the maximum context size
5 |
used in Transformer, while maintaining control over their memory footprint
6 |
and computational time. It uses the Truncated BPTT technique for training,
7 |
as in [transformerXL](
8 |
9 |
Adaptive Span was introduced by paper:
10 |
[Adaptive Attention Span in Transformers](,
11 |
which achieved state-of-the-art language modeling results at the time of publication.
12 |
13 |
We manage to reproduce their result in fairseq and keep most of the
14 |
[original implementation]( untouched.
15 |
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16 |
17 |
##### 0. Setup
18 |
19 |
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20 |
from [adaptive span paper](
21 |
You can download the dataset, and then run:
22 |
23 |
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24 |
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25 |
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26 |
27 |
28 |
##### 1. Train a Adaptive Span model on Enwik8
29 |
30 |
We will train a 12-layer Adaptive Span model following the [hyperparameters
31 |
used in the original
32 |
33 |
34 |
The following command assumes 4 GPUs, so that the total batch size is 64
35 |
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36 |
37 |
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38 |
--user-dir examples/adaptive_span \
39 |
--data ~/data/enwik8/data-bin/ \
40 |
--fp16 --fp16-no-flatten-grads --max-update 600000 \
41 |
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42 |
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43 |
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44 |
--validate-interval-updates 1000 \
45 |
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46 |
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47 |
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48 |
49 |
This should land around 1.05 on validation, 1.03 on test. You can lower the
50 |
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51 |
improvement to the transformerXL baseline here.
52 |
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53 |
and simulate training on 4 GPUs.
54 |
You can also reproduce the transformerXL result on enwik8 using this code base.
55 |
It should land around 1.06 on test,matching the [original paper](
56 |
You can try by
57 |
58 |
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59 |
--user-dir examples/truncated_bptt \
60 |
~/data/enwik8/data-bin/ \
61 |
--task truncated_bptt_lm --fp16 --max-update 400000 \
62 |
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63 |
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64 |
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65 |
--lr-scheduler cosine --warmup-updates 0 \
66 |
--lr 0.0 --lr 0.00025 --batch-size 15 \
67 |
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
68 |
69 |
70 |
71 |
##### 2. Evaluate
72 |
For Adaptive Span:
73 |
74 |
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/ \
75 |
--user-dir examples/adaptive_span \
76 |
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77 |
78 |
For Transformer-XL evaluation:
79 |
80 |
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/ \
81 |
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82 |
--tokens-per-sample 80 \
83 |
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84 |
--gen-subset valid
85 |
86 |
87 |
*Note:* During training the model saw 512 tokens of context
88 |
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89 |
settings from [the original
90 |
@@ -0,0 +1,19 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
import importlib
7 |
import os
8 |
9 |
# automatically import any Python files in the current directory
10 |
cur_dir = os.path.dirname(__file__)
11 |
for file in os.listdir(cur_dir):
12 |
path = os.path.join(cur_dir, file)
13 |
if (
14 |
not file.startswith("_")
15 |
and not file.startswith(".")
16 |
and (file.endswith(".py") or os.path.isdir(path))
17 |
18 |
mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19 |
module = importlib.import_module(__name__ + "." + mod_name)
@@ -0,0 +1,128 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
from torch.optim import Adagrad
7 |
8 |
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9 |
10 |
11 |
12 |
class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13 |
def __init__(self, args, params):
14 |
15 |
self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16 |
17 |
18 |
def add_args(parser):
19 |
"""Add optimizer-specific arguments to the parser."""
20 |
# fmt: off
21 |
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22 |
help='weight decay')
23 |
parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24 |
help='internal grad clip')
25 |
# fmt: on
26 |
27 |
28 |
def optimizer_config(self):
29 |
30 |
Return a kwarg dictionary that will be used to override optimizer
31 |
args stored in checkpoints. This allows us to load a checkpoint and
32 |
resume training using a different set of optimizer args, e.g., with a
33 |
different learning rate.
34 |
35 |
return {
36 |
37 |
"weight_decay": self.args.weight_decay,
38 |
"grad_clip": self.args.adagrad_clip,
39 |
40 |
41 |
42 |
def supports_flat_params(self):
43 |
return False
44 |
45 |
46 |
def _clip_grad(clr, grad, group_grad_clip):
47 |
if group_grad_clip > 0:
48 |
norm = grad.norm(2).item()
49 |
if norm > group_grad_clip:
50 |
clr *= group_grad_clip / (norm + 1e-10)
51 |
return clr
52 |
53 |
54 |
class AdagradWithGradClip(Adagrad):
55 |
"""Adagrad algorithm with custom gradient clipping"""
56 |
57 |
def __init__(
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
self.defaults["grad_clip"] = grad_clip
75 |
self.param_groups[0].setdefault("grad_clip", grad_clip)
76 |
77 |
def step(self, closure=None):
78 |
loss = None
79 |
if closure is not None:
80 |
loss = closure()
81 |
82 |
for group in self.param_groups:
83 |
for p in group["params"]:
84 |
if p.grad is None:
85 |
86 |
87 |
grad =
88 |
state = self.state[p]
89 |
90 |
state["step"] += 1
91 |
92 |
if group["weight_decay"] != 0:
93 |
94 |
raise RuntimeError(
95 |
"weight_decay option is "
96 |
"not compatible with sparse "
97 |
98 |
99 |
grad = grad.add(group["weight_decay"],
100 |
101 |
clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102 |
103 |
# clip
104 |
clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105 |
106 |
if grad.is_sparse:
107 |
# the update is non-linear so indices must be unique
108 |
grad = grad.coalesce()
109 |
grad_indices = grad._indices()
110 |
grad_values = grad._values()
111 |
size = grad.size()
112 |
113 |
def make_sparse(values):
114 |
constructor =
115 |
if grad_indices.dim() == 0 or values.dim() == 0:
116 |
return constructor().resize_as_(grad)
117 |
return constructor(grad_indices, values, size)
118 |
119 |
120 |
std = state["sum"]._sparse_mask(grad)
121 |
std_values = std._values().sqrt_().add_(1e-10)
122 |
+, make_sparse(grad_values / std_values))
123 |
124 |
state["sum"].addcmul_(1, grad, grad)
125 |
std = state["sum"].sqrt().add_(1e-10)
126 |
+, grad, std)
127 |
128 |
return loss
@@ -0,0 +1,160 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
import math
6 |
7 |
import torch
8 |
import torch.nn as nn
9 |
import torch.nn.functional as F
10 |
11 |
12 |
class AdaptiveMask(nn.Module):
13 |
"""Soft masking function for adaptive size.
14 |
It masks out the last K values of an input. The masking value
15 |
goes from 1 to 0 gradually, so K can be learned with
16 |
17 |
18 |
max_size: maximum size (i.e. input dimension)
19 |
ramp_size: size of the ramp going from 0 to 1
20 |
init_val: initial size proportion not to be masked out
21 |
shape: learn multiple sizes independent of each other
22 |
23 |
24 |
def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25 |
26 |
self._max_size = max_size
27 |
self._ramp_size = ramp_size
28 |
self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29 |
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30 |
self.register_buffer("mask_template", mask_template)
31 |
32 |
def forward(self, x):
33 |
mask = self.mask_template.float() + self.current_val.float() * self._max_size
34 |
mask = mask / self._ramp_size + 1
35 |
mask = mask.clamp(0, 1)
36 |
if x.size(-1) < self._max_size:
37 |
# the input could have been trimmed beforehand to save computation
38 |
mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39 |
x = (x * mask).type_as(x)
40 |
return x
41 |
42 |
def get_current_max_size(self, include_ramp=True):
43 |
current_size = math.ceil(self.current_val.max().item() * self._max_size)
44 |
if include_ramp:
45 |
current_size += self._ramp_size
46 |
current_size = max(0, min(self._max_size, current_size))
47 |
return current_size
48 |
49 |
def get_current_avg_size(self, include_ramp=True):
50 |
current_size = math.ceil(
51 |
self.current_val.float().mean().item() * self._max_size
52 |
53 |
if include_ramp:
54 |
current_size += self._ramp_size
55 |
current_size = max(0, min(self._max_size, current_size))
56 |
return current_size
57 |
58 |
def clamp_param(self):
59 |
"""this need to be called after each update"""
60 |
+, 1)
61 |
62 |
63 |
class AdaptiveSpan(nn.Module):
64 |
"""Adaptive attention span for Transformerself.
65 |
This module learns an attention span length from data for each
66 |
self-attention head.
67 |
68 |
attn_span: maximum attention span
69 |
adapt_span_loss: loss coefficient for the span length
70 |
adapt_span_ramp: length of the masking ramp
71 |
adapt_span_init: initial size ratio
72 |
adapt_span_cache: adapt cache size to reduce memory usage
73 |
74 |
75 |
def __init__(
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
self._max_span = attn_span
86 |
self._n_head = n_head
87 |
self._adapt_span_layer = adapt_span_layer
88 |
if self._adapt_span_layer:
89 |
self._mask = AdaptiveMask(
90 |
91 |
92 |
93 |
94 |
95 |
self._mask = AdaptiveMask(
96 |
97 |
98 |
99 |
shape=(n_head, 1, 1),
100 |
101 |
102 |
def forward(self, attn, normalize=True):
103 |
"""mask attention with the right span"""
104 |
# batch and head dimensions are merged together, so separate them first
105 |
106 |
if self._adapt_span_layer:
107 |
attn = self._mask(attn)
108 |
109 |
B = attn.size(0) # batch size
110 |
M = attn.size(1) # block size
111 |
attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112 |
attn = self._mask(attn)
113 |
attn = attn.view(B, M, -1)
114 |
return attn
115 |
116 |
def get_trim_len(self):
117 |
"""how much of memory can be trimmed to reduce computation"""
118 |
L = self._max_span
119 |
trim_len = min(L - 1, L - self._mask.get_current_max_size())
120 |
# too fine granularity might be bad for the memory management
121 |
trim_len = math.floor(trim_len / 64) * 64
122 |
return trim_len
123 |
124 |
def trim_memory(self, query, key, value, key_pe):
125 |
"""trim out unnecessary memory beforehand to reduce computation"""
126 |
trim_len = self.get_trim_len()
127 |
cache_size = key.size(1) - query.size(1)
128 |
trim_len_cache = trim_len - (self._max_span - cache_size)
129 |
if trim_len_cache > 0:
130 |
key = key[:, trim_len_cache:, :]
131 |
value = value[:, trim_len_cache:, :]
132 |
elif trim_len_cache < 0:
133 |
# cache is too short! this happens when validation resumes
134 |
# after a lot of updates.
135 |
key = F.pad(key, [0, 0, -trim_len_cache, 0])
136 |
value = F.pad(value, [0, 0, -trim_len_cache, 0])
137 |
if trim_len > 0:
138 |
if key_pe is not None:
139 |
key_pe = key_pe[:, :, trim_len:]
140 |
return key, value, key_pe
141 |
142 |
def get_cache_size(self):
143 |
"""determine how long the cache should be"""
144 |
trim_len = self.get_trim_len()
145 |
# give a buffer of 64 steps since a span might increase
146 |
# in future updates
147 |
return min(self._max_span, self._max_span - trim_len + 64)
148 |
149 |
def get_loss(self):
150 |
"""a loss term for regularizing the span length"""
151 |
return self._max_span * self._mask.current_val.float().mean()
152 |
153 |
def get_current_max_span(self):
154 |
return self._mask.get_current_max_size()
155 |
156 |
def get_current_avg_span(self):
157 |
return self._mask.get_current_avg_size()
158 |
159 |
def clamp_param(self):
160 |
@@ -0,0 +1,106 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
import math
7 |
from dataclasses import dataclass
8 |
9 |
import torch.nn.functional as F
10 |
from fairseq import metrics, utils
11 |
from fairseq.criterions import register_criterion
12 |
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13 |
from fairseq.dataclass import FairseqDataclass
14 |
from omegaconf import II
15 |
16 |
17 |
18 |
class AdaptiveSpanCriterionConfig(FairseqDataclass):
19 |
sentence_avg: bool = II("optimization.sentence_avg")
20 |
21 |
22 |
@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
23 |
class AdaptiveSpanCriterion(CrossEntropyCriterion):
24 |
def __init__(self, task, sentence_avg):
25 |
super().__init__(task, sentence_avg)
26 |
27 |
def forward(self, model, sample, reduce=True):
28 |
"""Compute the loss for the given sample.
29 |
30 |
Returns a tuple with three elements:
31 |
1) the loss here is summed, different from the adaptive span code
32 |
2) the sample size, which is used as the denominator for the gradient
33 |
3) logging outputs to display while training
34 |
35 |
net_output = model(**sample["net_input"])
36 |
loss, aux_loss, avg_span, max_span = self.compute_loss(
37 |
model, net_output, sample, reduce=reduce
38 |
39 |
sample_size = (
40 |
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
41 |
42 |
loss /= sample_size
43 |
total_loss = loss + aux_loss
44 |
sample_size = 1
45 |
46 |
logging_output = {
47 |
48 |
"ntokens": sample["ntokens"],
49 |
"nsentences": sample["target"].size(0),
50 |
"sample_size": sample_size,
51 |
52 |
"avg_span": avg_span * sample_size,
53 |
"max_span": max_span * sample_size,
54 |
55 |
return total_loss, sample_size, logging_output
56 |
57 |
def compute_loss(self, model, net_output, sample, reduce=True):
58 |
loss, _ = super().compute_loss(model, net_output, sample, reduce)
59 |
aux_loss = model.get_aux_loss()
60 |
avg_span = model.get_current_avg_span()
61 |
max_span = model.get_current_max_span()
62 |
return loss, aux_loss, avg_span, max_span
63 |
64 |
65 |
def reduce_metrics(logging_outputs) -> None:
66 |
"""Aggregate logging outputs from data parallel training."""
67 |
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68 |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
69 |
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
70 |
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
71 |
avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
72 |
max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
73 |
74 |
# we divide by log(2) to convert the loss from base e to base 2
75 |
76 |
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
77 |
78 |
metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
79 |
metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
80 |
# total loss contains the L1 norm on adaptive-span
81 |
82 |
83 |
total_loss_sum / sample_size / math.log(2),
84 |
85 |
86 |
87 |
if sample_size != ntokens:
88 |
89 |
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
90 |
91 |
92 |
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
93 |
94 |
95 |
96 |
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
97 |
98 |
99 |
100 |
def logging_outputs_can_be_summed() -> bool:
101 |
102 |
Whether the logging outputs returned by `forward` can be summed
103 |
across workers prior to calling `reduce_metrics`. Setting this
104 |
to True will improves distributed training speed.
105 |
106 |
return True
@@ -0,0 +1,263 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import math
8 |
9 |
import torch
10 |
import torch.nn as nn
11 |
import torch.nn.functional as F
12 |
13 |
from fairseq.modules.layer_norm import LayerNorm
14 |
15 |
from .adaptive_span_attention import AdaptiveSpan
16 |
17 |
# Size notations:
18 |
# B = batch_size, H = d_model, M = block_size, L = attn_span
19 |
20 |
21 |
def _skew(X, pad_value):
22 |
"""shift every row 1 step to right"""
23 |
# X = B x M x L
24 |
B, M, L = X.size()
25 |
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26 |
X = X.view(B, -1) # B x ML+MM+M
27 |
X = X[:, :-M] # B x ML+MM
28 |
X = X.view(B, M, M + L) # B x M x L+M
29 |
return X
30 |
31 |
32 |
def _unskew(X):
33 |
"""reverse _skew operation"""
34 |
# X = B x M x L+M
35 |
B, M, L = X.size()
36 |
L -= M
37 |
X = X.view(B, -1) # B x ML+MM
38 |
X = F.pad(X, (0, M)) # B x ML+MM+M
39 |
X = X.view(B, M, M + L + 1) # B x M x L+M+1
40 |
X = X[:, :, :L] # B x M x L
41 |
return X
42 |
43 |
44 |
class SeqAttention(nn.Module):
45 |
"""Sequential self-attention layer.
46 |
Each token will attend to its previous fixed number of steps.
47 |
Note that attention doesn't include the current step itself.
48 |
49 |
50 |
def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51 |
52 |
self.dropout = nn.Dropout(dropout)
53 |
self.d_model = d_model # size of a single head
54 |
self.attn_span = attn_span
55 |
self.adaptive_span = AdaptiveSpan(
56 |
57 |
58 |
59 |
60 |
61 |
62 |
def forward(self, query, key, value, key_pe):
63 |
# query size = B x M x H
64 |
# key, value sizes = B x (M+L) x H
65 |
66 |
key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67 |
68 |
# compute attention from context
69 |
# B x M (dest) x (M+L) (src)
70 |
attn_cont = torch.matmul(query, key.transpose(-1, -2))
71 |
attn_cont = _unskew(attn_cont) # B x M x L
72 |
73 |
# compute the effect of position embedding
74 |
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75 |
attn = attn_cont + attn_pos
76 |
77 |
attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78 |
79 |
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80 |
81 |
# trim attention lengths according to the learned span
82 |
attn = self.adaptive_span(attn)
83 |
84 |
attn = self.dropout(attn) # B x M X L_pos
85 |
86 |
attn_cont = _skew(attn, 0) # B x M X (L+M)
87 |
out = torch.matmul(attn_cont, value) # B x M x H
88 |
return out
89 |
90 |
def get_cache_size(self):
91 |
return self.adaptive_span.get_cache_size()
92 |
93 |
94 |
class MultiHeadSeqAttention(nn.Module):
95 |
def __init__(self, d_model, n_head, **kargs):
96 |
97 |
assert d_model % n_head == 0
98 |
self.n_head = n_head
99 |
self.head_dim = d_model // n_head
100 |
self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101 |
self.proj_query = nn.Linear(d_model, d_model, bias=False)
102 |
103 |
self.proj_out = nn.Linear(d_model, d_model, bias=False)
104 |
105 |
self.proj_val = nn.Linear(d_model, d_model, bias=False)
106 |
107 |
self.proj_key = nn.Linear(d_model, d_model, bias=False)
108 |
109 |
110 |
def head_reshape(self, x):
111 |
K = self.n_head
112 |
D = self.head_dim
113 |
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114 |
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115 |
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116 |
return x
117 |
118 |
def forward(self, query, key, value, key_pe):
119 |
B = query.size(0)
120 |
K = self.n_head
121 |
D = self.head_dim
122 |
M = query.size(1)
123 |
124 |
query = self.proj_query(query)
125 |
query = self.head_reshape(query)
126 |
value = self.proj_val(value)
127 |
value = self.head_reshape(value)
128 |
key = self.proj_key(key)
129 |
key = self.head_reshape(key)
130 |
131 |
out = self.attn(query, key, value, key_pe) # B_K x M x D
132 |
out = out.view(B, K, M, D) # B x K x M x D
133 |
out = out.transpose(1, 2).contiguous() # B x M x K x D
134 |
out = out.view(B, M, -1) # B x M x K_D
135 |
out = self.proj_out(out)
136 |
return out
137 |
138 |
139 |
class FeedForwardLayer(nn.Module):
140 |
def __init__(self, d_model, d_inner, dropout, **kargs):
141 |
142 |
self.fc1 = nn.Linear(d_model, d_inner)
143 |
self.fc2 = nn.Linear(d_inner, d_model)
144 |
145 |
146 |
self.dropout = nn.Dropout(dropout)
147 |
148 |
def forward(self, h):
149 |
h1 = F.relu(self.fc1(h))
150 |
h1 = self.dropout(h1)
151 |
h2 = self.fc2(h1)
152 |
return h2
153 |
154 |
155 |
class TransformerSeqLayer(nn.Module):
156 |
def __init__(self, d_model, **kargs):
157 |
158 |
self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159 |
self.norm1 = LayerNorm(d_model)
160 |
self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161 |
self.norm2 = LayerNorm(d_model)
162 |
163 |
def forward(self, h, h_cache, key_pe):
164 |
# h = B x M x H
165 |
# h_cache = B x L x H
166 |
h_all =[h_cache, h], dim=1) # B x (M+L) x H
167 |
attn_out = self.attn(h, h_all, h_all, key_pe)
168 |
h = self.norm1(h + attn_out) # B x M x H
169 |
if self.ff is not None:
170 |
ff_out = self.ff(h)
171 |
out = self.norm2(h + ff_out) # B x M x H
172 |
173 |
out = h
174 |
return out
175 |
176 |
def get_cache_size(self):
177 |
return self.attn.attn.get_cache_size()
178 |
179 |
180 |
class TransformerSeq(nn.Module):
181 |
def __init__(
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
# token embeddings
195 |
self.in_emb = nn.Embedding(vocab_size, d_model)
196 |
nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197 |
self.out_emb = nn.Linear(d_model, vocab_size)
198 |
self.aux_loss_scaler = aux_loss_scaler
199 |
if emb_dropout > 0:
200 |
self.emb_dropout = nn.Dropout(emb_dropout)
201 |
202 |
self.emb_dropout = None
203 |
# position embeddings
204 |
self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205 |
206 |
self.layers = nn.ModuleList()
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
for _ in range(n_layer)
216 |
217 |
218 |
def forward(self, x, h_cache, target=None):
219 |
# x size = B x M
220 |
block_size = x.size(1)
221 |
h = self.in_emb(x) # B x M x H
222 |
if self.emb_dropout is not None:
223 |
h = self.emb_dropout(h)
224 |
225 |
h_cache_next = []
226 |
for l, layer in enumerate(self.layers):
227 |
cache_size = layer.attn.attn.get_cache_size()
228 |
if cache_size > block_size:
229 |
h_cache_next_l =
230 |
[h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231 |
232 |
233 |
h_cache_next_l = h[:, -cache_size:, :].detach()
234 |
235 |
h = layer(h, h_cache[l], self.key_pe) # B x M x H
236 |
237 |
if self.emb_dropout is not None:
238 |
h = self.emb_dropout(h)
239 |
240 |
out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241 |
dummy_loss = None
242 |
243 |
return out, h_cache_next, dummy_loss
244 |
245 |
def get_aux_loss(self):
246 |
loss = 0.0
247 |
for layer in self.layers:
248 |
loss += layer.attn.attn.adaptive_span.get_loss()
249 |
return self.aux_loss_scaler * loss
250 |
251 |
def get_current_max_span(self):
252 |
max_span = 0.0
253 |
for layer in self.layers:
254 |
max_span = max(
255 |
max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256 |
257 |
return max_span
258 |
259 |
def get_current_avg_span(self):
260 |
avg_span = 0.0
261 |
for layer in self.layers:
262 |
avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263 |
return avg_span / len(self.layers)
@@ -0,0 +1,145 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
import logging
7 |
from dataclasses import dataclass
8 |
from typing import Dict, List, Optional
9 |
10 |
import torch
11 |
from fairseq.dataclass import FairseqDataclass
12 |
from fairseq.models import (
13 |
14 |
15 |
16 |
17 |
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18 |
19 |
20 |
logger = logging.getLogger(__name__)
21 |
22 |
23 |
24 |
class AdaptiveSpanSmallConfig(FairseqDataclass):
25 |
# defaults come from
26 |
vocab_size: int = 50
27 |
d_model: int = 256
28 |
n_head: int = 4
29 |
d_inner: int = 1024
30 |
n_layer: int = 8
31 |
attn_span: int = 1024
32 |
dropout: float = 0.0
33 |
emb_dropout: float = 0.0
34 |
adapt_span_ramp: int = 32
35 |
adapt_span_init: float = 0.0
36 |
aux_loss_scaler: float = 0.000002
37 |
adapt_span_layer: bool = False
38 |
39 |
40 |
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41 |
class AdaptiveSpanTransformer(FairseqLanguageModel):
42 |
43 |
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44 |
return cls(AdaptiveSpanDecoder(cfg, task))
45 |
46 |
def get_aux_loss(self):
47 |
return self.decoder.get_aux_loss()
48 |
49 |
def get_current_max_span(self):
50 |
return self.decoder.get_current_max_span()
51 |
52 |
def get_current_avg_span(self):
53 |
return self.decoder.get_current_avg_span()
54 |
55 |
56 |
class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57 |
def __init__(self, cfg, task):
58 |
59 |
60 |
61 |
self.config = cfg
62 |
config = AdaptiveSpanSmallConfig(
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78 |
79 |
self._mems = None
80 |
81 |
def forward(
82 |
83 |
84 |
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85 |
86 |
87 |
bsz = src_tokens.size(0)
88 |
if incremental_state is not None: # used during inference
89 |
mems = self.get_incremental_state("mems")
90 |
src_tokens = src_tokens[:, -1:] # only keep the most recent token
91 |
92 |
mems = self._mems
93 |
94 |
if mems is None:
95 |
# first time init
96 |
mems = self.init_hid_cache(bsz)
97 |
output = self.model(x=src_tokens, h_cache=mems,)
98 |
if incremental_state is not None:
99 |
self.set_incremental_state(incremental_state, "mems", output[1])
100 |
101 |
self._mems = output[1]
102 |
return (output[0],)
103 |
104 |
def max_positions(self):
105 |
return self.config.attn_span
106 |
107 |
def init_hid_cache(self, batch_sz):
108 |
hid = []
109 |
for layer in self.model.layers:
110 |
param = next(self.model.parameters())
111 |
h = torch.zeros(
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
return hid
120 |
121 |
def get_aux_loss(self):
122 |
return self.model.get_aux_loss()
123 |
124 |
def get_current_max_span(self):
125 |
return self.model.get_current_max_span()
126 |
127 |
def get_current_avg_span(self):
128 |
return self.model.get_current_avg_span()
129 |
130 |
def reorder_incremental_state(
131 |
132 |
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133 |
new_order: torch.Tensor,
134 |
135 |
"""Reorder incremental state.
136 |
137 |
This will be called when the order of the input has changed from the
138 |
previous time step. A typical use case is beam search, where the input
139 |
order changes between time steps based on the selection of beams.
140 |
141 |
raise NotImplementedError("This is required for generation/beam search")
142 |
# mems = self.get_incremental_state(incremental_state, "mems")
143 |
# if mems is not None:
144 |
# new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145 |
# self.set_incremental_state(incremental_state, "mems", new_mems)
@@ -0,0 +1 @@
1 |
@@ -0,0 +1,297 @@
1 |
# Understanding Back-Translation at Scale (Edunov et al., 2018)
2 |
3 |
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](
4 |
5 |
## Pre-trained models
6 |
7 |
Model | Description | Dataset | Download
8 |
9 |
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018]( <br> WMT'18 winner | [WMT'18 English-German]( | [download (.tar.gz)]( <br> See NOTE in the archive
10 |
11 |
## Example usage (torch.hub)
12 |
13 |
We require a few additional Python dependencies for preprocessing:
14 |
15 |
pip install subword_nmt sacremoses
16 |
17 |
18 |
Then to generate translations from the full model ensemble:
19 |
20 |
import torch
21 |
22 |
# List available models
23 |
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24 |
25 |
# Load the WMT'18 En-De ensemble
26 |
en2de_ensemble = torch.hub.load(
27 |
'pytorch/fairseq', 'transformer.wmt18.en-de',
28 |
29 |
tokenizer='moses', bpe='subword_nmt')
30 |
31 |
# The ensemble contains 5 models
32 |
33 |
# 5
34 |
35 |
# Translate
36 |
en2de_ensemble.translate('Hello world!')
37 |
# 'Hallo Welt!'
38 |
39 |
40 |
## Training your own model (WMT'18 English-German)
41 |
42 |
The following instructions can be adapted to reproduce the models from the paper.
43 |
44 |
45 |
#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46 |
47 |
First download and preprocess the data:
48 |
49 |
# Download and prepare the data
50 |
cd examples/backtranslation/
51 |
52 |
cd ../..
53 |
54 |
# Binarize the data
55 |
56 |
fairseq-preprocess \
57 |
--joined-dictionary \
58 |
--source-lang en --target-lang de \
59 |
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60 |
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61 |
--workers 20
62 |
63 |
# Copy the BPE code into the data-bin directory for future use
64 |
cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65 |
66 |
67 |
(Optionally) Train a baseline model (English-German) using just the parallel data:
68 |
69 |
70 |
fairseq-train --fp16 \
71 |
data-bin/wmt18_en_de \
72 |
--source-lang en --target-lang de \
73 |
--arch transformer_wmt_en_de_big --share-all-embeddings \
74 |
--dropout 0.3 --weight-decay 0.0 \
75 |
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76 |
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77 |
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78 |
--max-tokens 3584 --update-freq 16 \
79 |
--max-update 30000 \
80 |
--save-dir $CHECKPOINT_DIR
81 |
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82 |
# different number of GPUs.
83 |
84 |
85 |
Average the last 10 checkpoints:
86 |
87 |
python scripts/ \
88 |
--inputs $CHECKPOINT_DIR \
89 |
--num-epoch-checkpoints 10 \
90 |
91 |
92 |
93 |
Evaluate BLEU:
94 |
95 |
# tokenized BLEU on newstest2017:
96 |
bash examples/backtranslation/ \
97 |
wmt17 \
98 |
en-de \
99 |
data-bin/wmt18_en_de \
100 |
data-bin/wmt18_en_de/code \
101 |
102 |
# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103 |
# compare to 29.46 in Table 1, which is also for tokenized BLEU
104 |
105 |
# generally it's better to report (detokenized) sacrebleu though:
106 |
bash examples/backtranslation/ \
107 |
wmt17 \
108 |
en-de \
109 |
data-bin/wmt18_en_de \
110 |
data-bin/wmt18_en_de/code \
111 |
112 |
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113 |
114 |
115 |
116 |
#### Step 2. Back-translate monolingual German data
117 |
118 |
Train a reverse model (German-English) to do the back-translation:
119 |
120 |
121 |
fairseq-train --fp16 \
122 |
data-bin/wmt18_en_de \
123 |
--source-lang de --target-lang en \
124 |
--arch transformer_wmt_en_de_big --share-all-embeddings \
125 |
--dropout 0.3 --weight-decay 0.0 \
126 |
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127 |
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128 |
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129 |
--max-tokens 3584 --update-freq 16 \
130 |
--max-update 30000 \
131 |
--save-dir $CHECKPOINT_DIR
132 |
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133 |
# different number of GPUs.
134 |
135 |
136 |
Let's evaluate the back-translation (BT) model to make sure it is well trained:
137 |
138 |
bash examples/backtranslation/ \
139 |
wmt17 \
140 |
de-en \
141 |
data-bin/wmt18_en_de \
142 |
data-bin/wmt18_en_de/code \
143 |
144 |
# = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145 |
# compare to the best system from WMT'17 which scored 35.1:
146 |
147 |
148 |
Next prepare the monolingual data:
149 |
150 |
# Download and prepare the monolingual data
151 |
# By default the script samples 25M monolingual sentences, which after
152 |
# deduplication should be just over 24M sentences. These are split into 25
153 |
# shards, each with 1M sentences (except for the last shard).
154 |
cd examples/backtranslation/
155 |
156 |
cd ../..
157 |
158 |
# Binarize each shard of the monolingual data
159 |
160 |
for SHARD in $(seq -f "%02g" 0 24); do \
161 |
fairseq-preprocess \
162 |
--only-source \
163 |
--source-lang de --target-lang en \
164 |
--joined-dictionary \
165 |
--srcdict data-bin/wmt18_en_de/ \
166 |
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167 |
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
168 |
--workers 20; \
169 |
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170 |
171 |
172 |
173 |
Now we're ready to perform back-translation over the monolingual data. The
174 |
following command generates via sampling, but it's possible to use greedy
175 |
decoding (`--beam 1`), beam search (`--beam 5`),
176 |
top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177 |
178 |
mkdir backtranslation_output
179 |
for SHARD in $(seq -f "%02g" 0 24); do \
180 |
fairseq-generate --fp16 \
181 |
data-bin/wmt18_de_mono/shard${SHARD} \
182 |
183 |
--skip-invalid-size-inputs-valid-test \
184 |
--max-tokens 4096 \
185 |
--sampling --beam 1 \
186 |
> backtranslation_output/sampling.shard${SHARD}.out; \
187 |
188 |
189 |
190 |
After BT, use the `` script to re-combine the shards, extract
191 |
the back-translations and apply length ratio filters:
192 |
193 |
python examples/backtranslation/ \
194 |
--minlen 1 --maxlen 250 --ratio 1.5 \
195 |
--output backtranslation_output/bt_data --srclang en --tgtlang de \
196 |
197 |
198 |
# Ensure lengths are the same:
199 |
# wc -l backtranslation_output/bt_data.{en,de}
200 |
# 21795614 backtranslation_output/bt_data.en
201 |
# 21795614 backtranslation_output/
202 |
# 43591228 total
203 |
204 |
205 |
Binarize the filtered BT data and combine it with the parallel data:
206 |
207 |
208 |
fairseq-preprocess \
209 |
--source-lang en --target-lang de \
210 |
--joined-dictionary \
211 |
--srcdict data-bin/wmt18_en_de/dict.en.txt \
212 |
--trainpref $TEXT/bt_data \
213 |
--destdir data-bin/wmt18_en_de_bt \
214 |
--workers 20
215 |
216 |
# We want to train on the combined data, so we'll symlink the parallel + BT data
217 |
# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218 |
# and the BT data as "train1", so that fairseq will combine them automatically
219 |
# and so that we can use the `--upsample-primary` option to upsample the
220 |
# parallel data (if desired).
221 |
PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222 |
BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223 |
224 |
mkdir -p $COMB_DATA
225 |
for LANG in en de; do \
226 |
ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227 |
for EXT in bin idx; do \
228 |
ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229 |
ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230 |
ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231 |
ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232 |
done; \
233 |
234 |
235 |
236 |
237 |
#### 3. Train an English-German model over the combined parallel + BT data
238 |
239 |
Finally we can train a model over the parallel + BT data:
240 |
241 |
242 |
fairseq-train --fp16 \
243 |
data-bin/wmt18_en_de_para_plus_bt \
244 |
--upsample-primary 16 \
245 |
--source-lang en --target-lang de \
246 |
--arch transformer_wmt_en_de_big --share-all-embeddings \
247 |
--dropout 0.3 --weight-decay 0.0 \
248 |
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249 |
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250 |
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251 |
--max-tokens 3584 --update-freq 16 \
252 |
--max-update 100000 \
253 |
--save-dir $CHECKPOINT_DIR
254 |
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255 |
# different number of GPUs.
256 |
257 |
258 |
Average the last 10 checkpoints:
259 |
260 |
python scripts/ \
261 |
--inputs $CHECKPOINT_DIR \
262 |
--num-epoch-checkpoints 10 \
263 |
264 |
265 |
266 |
Evaluate BLEU:
267 |
268 |
# tokenized BLEU on newstest2017:
269 |
bash examples/backtranslation/ \
270 |
wmt17 \
271 |
en-de \
272 |
data-bin/wmt18_en_de \
273 |
data-bin/wmt18_en_de/code \
274 |
275 |
# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276 |
# compare to 32.35 in Table 1, which is also for tokenized BLEU
277 |
278 |
# generally it's better to report (detokenized) sacrebleu:
279 |
bash examples/backtranslation/ \
280 |
wmt17 \
281 |
en-de \
282 |
data-bin/wmt18_en_de \
283 |
data-bin/wmt18_en_de/code \
284 |
285 |
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286 |
287 |
288 |
289 |
## Citation
290 |
291 |
292 |
title = {Understanding Back-Translation at Scale},
293 |
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294 |
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295 |
year = 2018,
296 |
297 |
@@ -0,0 +1,41 @@
1 |
2 |
# Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 |
# This source code is licensed under the MIT license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import argparse
8 |
import fileinput
9 |
import hashlib
10 |
import sys
11 |
from multiprocessing import Pool
12 |
13 |
14 |
def get_hashes_and_lines(raw_line):
15 |
hash = hashlib.md5(raw_line).hexdigest()
16 |
return hash, raw_line
17 |
18 |
19 |
def main():
20 |
parser = argparse.ArgumentParser()
21 |
parser.add_argument("--workers", type=int, default=10)
22 |
parser.add_argument("files", nargs="*", help="input files")
23 |
args = parser.parse_args()
24 |
25 |
seen = set()
26 |
with fileinput.input(args.files, mode="rb") as h:
27 |
pool = Pool(args.workers)
28 |
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29 |
for i, (hash, raw_line) in enumerate(results):
30 |
if hash not in seen:
31 |
32 |
33 |
if i % 1000000 == 0:
34 |
print(i, file=sys.stderr, end="", flush=True)
35 |
elif i % 100000 == 0:
36 |
print(".", file=sys.stderr, end="", flush=True)
37 |
print(file=sys.stderr, flush=True)
38 |
39 |
40 |
if __name__ == "__main__":
41 |
@@ -0,0 +1,72 @@
1 |
#!/usr/bin/env python
2 |
# Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 |
# This source code is licensed under the MIT license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import argparse
8 |
import fileinput
9 |
10 |
from tqdm import tqdm
11 |
12 |
13 |
def main():
14 |
parser = argparse.ArgumentParser(
15 |
16 |
"Extract back-translations from the stdout of fairseq-generate. "
17 |
"If there are multiply hypotheses for a source, we only keep the first one. "
18 |
19 |
20 |
parser.add_argument("--output", required=True, help="output prefix")
21 |
22 |
"--srclang", required=True, help="source language (extracted from H-* lines)"
23 |
24 |
25 |
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
26 |
27 |
parser.add_argument("--minlen", type=int, help="min length filter")
28 |
parser.add_argument("--maxlen", type=int, help="max length filter")
29 |
parser.add_argument("--ratio", type=float, help="ratio filter")
30 |
parser.add_argument("files", nargs="*", help="input files")
31 |
args = parser.parse_args()
32 |
33 |
def validate(src, tgt):
34 |
srclen = len(src.split(" ")) if src != "" else 0
35 |
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36 |
if (
37 |
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38 |
or (
39 |
args.maxlen is not None
40 |
and (srclen > args.maxlen or tgtlen > args.maxlen)
41 |
42 |
or (
43 |
args.ratio is not None
44 |
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45 |
46 |
47 |
return False
48 |
return True
49 |
50 |
def safe_index(toks, index, default):
51 |
52 |
return toks[index]
53 |
except IndexError:
54 |
return default
55 |
56 |
with open(args.output + "." + args.srclang, "w") as src_h, open(
57 |
args.output + "." + args.tgtlang, "w"
58 |
) as tgt_h:
59 |
for line in tqdm(fileinput.input(args.files)):
60 |
if line.startswith("S-"):
61 |
tgt = safe_index(line.rstrip().split("\t"), 1, "")
62 |
elif line.startswith("H-"):
63 |
if tgt is not None:
64 |
src = safe_index(line.rstrip().split("\t"), 2, "")
65 |
if validate(src, tgt):
66 |
print(src, file=src_h)
67 |
print(tgt, file=tgt_h)
68 |
tgt = None
69 |
70 |
71 |
if __name__ == "__main__":
72 |
@@ -0,0 +1,98 @@
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
mkdir -p $OUTDIR $tmp
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
cd $orig
50 |
for ((i=0;i<${#URLS[@]};++i)); do
51 |
52 |
if [ -f $file ]; then
53 |
echo "$file already exists, skipping download"
54 |
55 |
56 |
wget "$url"
57 |
58 |
59 |
cd ..
60 |
61 |
62 |
if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63 |
echo "found monolingual sample, skipping shuffle/sample/tokenize"
64 |
65 |
gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66 |
| shuf -n $SUBSAMPLE_SIZE \
67 |
| perl $NORM_PUNC $LANG \
68 |
69 |
| perl $TOKENIZER -threads 8 -a -l $LANG \
70 |
> $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71 |
72 |
73 |
74 |
if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75 |
echo "found BPE monolingual sample, skipping BPE step"
76 |
77 |
python $BPEROOT/ -c $BPE_CODE \
78 |
< $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79 |
> $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80 |
81 |
82 |
83 |
if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84 |
echo "found deduplicated monolingual sample, skipping deduplication step"
85 |
86 |
python $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87 |
> $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88 |
89 |
90 |
91 |
if [ -f $OUTDIR/ ]; then
92 |
echo "found sharded data, skipping sharding step"
93 |
94 |
split --lines 1000000 --numeric-suffixes \
95 |
--additional-suffix .${LANG} \
96 |
$tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97 |
98 |
@@ -0,0 +1,135 @@
1 |
2 |
# Adapted from
3 |
4 |
echo 'Cloning Moses github repository (for tokenization scripts)...'
5 |
git clone
6 |
7 |
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8 |
git clone
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
if [ ! -d "$SCRIPTS" ]; then
42 |
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
43 |
exit 1
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
mkdir -p $orig $tmp $prep
56 |
57 |
cd $orig
58 |
59 |
for ((i=0;i<${#URLS[@]};++i)); do
60 |
61 |
if [ -f $file ]; then
62 |
echo "$file already exists, skipping download"
63 |
64 |
65 |
wget "$url"
66 |
if [ -f $file ]; then
67 |
echo "$url successfully downloaded."
68 |
69 |
echo "$url not successfully downloaded."
70 |
exit 1
71 |
72 |
if [ ${file: -4} == ".tgz" ]; then
73 |
tar zxvf $file
74 |
elif [ ${file: -4} == ".tar" ]; then
75 |
tar xvf $file
76 |
77 |
78 |
79 |
cd ..
80 |
81 |
echo "pre-processing train data..."
82 |
for l in $src $tgt; do
83 |
rm $tmp/train.tags.$lang.tok.$l
84 |
for f in "${CORPORA[@]}"; do
85 |
cat $orig/$f.$l | \
86 |
perl $NORM_PUNC $l | \
87 |
88 |
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
89 |
90 |
91 |
92 |
echo "pre-processing test data..."
93 |
for l in $src $tgt; do
94 |
if [ "$l" == "$src" ]; then
95 |
96 |
97 |
98 |
99 |
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
100 |
sed -e 's/<seg id="[0-9]*">\s*//g' | \
101 |
sed -e 's/\s*<\/seg>\s*//g' | \
102 |
sed -e "s/\’/\'/g" | \
103 |
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
104 |
echo ""
105 |
106 |
107 |
echo "splitting train and valid..."
108 |
for l in $src $tgt; do
109 |
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
110 |
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
111 |
112 |
113 |
114 |
115 |
rm -f $TRAIN
116 |
for l in $src $tgt; do
117 |
cat $tmp/train.$l >> $TRAIN
118 |
119 |
120 |
echo " on ${TRAIN}..."
121 |
122 |
123 |
for L in $src $tgt; do
124 |
for f in train.$L valid.$L test.$L; do
125 |
echo " to ${f}..."
126 |
python $BPEROOT/ -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
127 |
128 |
129 |
130 |
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
131 |
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
132 |
133 |
for L in $src $tgt; do
134 |
cp $tmp/bpe.test.$L $prep/test.$L
135 |
@@ -0,0 +1,37 @@
1 |
2 |
3 |
if [ $# -ne 5 ]; then
4 |
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16 |
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17 |
18 |
19 |
20 |
if [ ! -e $BPEROOT ]; then
21 |
22 |
if [ ! -e $BPEROOT ]; then
23 |
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24 |
git clone
25 |
26 |
27 |
28 |
29 |
sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30 |
| sacremoses tokenize -a -l $SRCLANG -q \
31 |
| python $BPEROOT/ -c $BPECODE \
32 |
| fairseq-interactive $DATABIN --path $MODEL \
33 |
34 |
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35 |
| grep ^H- | cut -f 3- \
36 |
| sacremoses detokenize -l $TGTLANG -q \
37 |
| sacrebleu -t $DATASET -l $LANGPAIR
@@ -0,0 +1,46 @@
1 |
2 |
3 |
if [ $# -ne 5 ]; then
4 |
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16 |
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17 |
18 |
19 |
20 |
if [ ! -e $BPEROOT ]; then
21 |
22 |
if [ ! -e $BPEROOT ]; then
23 |
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24 |
git clone
25 |
26 |
27 |
28 |
29 |
30 |
31 |
sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32 |
| sacremoses normalize -l $TGTLANG -q \
33 |
| sacremoses tokenize -a -l $TGTLANG -q \
34 |
35 |
36 |
sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37 |
| sacremoses normalize -l $SRCLANG -q \
38 |
| sacremoses tokenize -a -l $SRCLANG -q \
39 |
| python $BPEROOT/ -c $BPECODE \
40 |
| fairseq-interactive $DATABIN --path $MODEL \
41 |
42 |
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43 |
| grep ^H- | cut -f 3- \
44 |
| fairseq-score --ref $TMP_REF
45 |
46 |
rm -f $TMP_REF
@@ -0,0 +1,99 @@
1 |
# Fine-tuning BART on GLUE tasks
2 |
3 |
### 1) Download the data from GLUE website ( using following commands:
4 |
5 |
6 |
python --data_dir glue_data --tasks all
7 |
8 |
9 |
### 2) Preprocess GLUE task data (same as RoBERTa):
10 |
11 |
./examples/roberta/ glue_data <glue_task_name>
12 |
13 |
`glue_task_name` is one of the following:
14 |
15 |
Use `ALL` for preprocessing all the glue tasks.
16 |
17 |
### 3) Fine-tuning on GLUE task:
18 |
Example fine-tuning cmd for `RTE` task
19 |
20 |
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21 |
WARMUP_UPDATES=61 # 6 percent of the number of updates
22 |
LR=1e-05 # Peak LR for polynomial LR scheduler.
23 |
24 |
MAX_SENTENCES=16 # Batch size.
25 |
26 |
27 |
CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28 |
--restore-file $BART_PATH \
29 |
--batch-size $MAX_SENTENCES \
30 |
--max-tokens 4400 \
31 |
--task sentence_prediction \
32 |
--add-prev-output-tokens \
33 |
--layernorm-embedding \
34 |
--share-all-embeddings \
35 |
--share-decoder-input-output-embed \
36 |
--reset-optimizer --reset-dataloader --reset-meters \
37 |
--required-batch-size-multiple 1 \
38 |
--init-token 0 \
39 |
--arch bart_large \
40 |
--criterion sentence_prediction \
41 |
--num-classes $NUM_CLASSES \
42 |
--dropout 0.1 --attention-dropout 0.1 \
43 |
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44 |
--clip-norm 0.0 \
45 |
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46 |
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47 |
--max-epoch 10 \
48 |
--find-unused-parameters \
49 |
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50 |
51 |
52 |
For each of the GLUE task, you will need to use following cmd-line arguments:
53 |
54 |
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55 |
56 |
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57 |
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58 |
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59 |
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60 |
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61 |
62 |
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63 |
64 |
65 |
66 |
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67 |
68 |
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69 |
70 |
### Inference on GLUE task
71 |
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72 |
73 |
74 |
from fairseq.models.bart import BARTModel
75 |
76 |
bart = BARTModel.from_pretrained(
77 |
78 |
79 |
80 |
81 |
82 |
label_fn = lambda label: bart.task.label_dictionary.string(
83 |
[label + bart.task.label_dictionary.nspecial]
84 |
85 |
ncorrect, nsamples = 0, 0
86 |
87 |
88 |
with open('glue_data/RTE/dev.tsv') as fin:
89 |
90 |
for index, line in enumerate(fin):
91 |
tokens = line.strip().split('\t')
92 |
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93 |
tokens = bart.encode(sent1, sent2)
94 |
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95 |
prediction_label = label_fn(prediction)
96 |
ncorrect += int(prediction_label == target)
97 |
nsamples += 1
98 |
print('| Accuracy: ', float(ncorrect)/float(nsamples))
99 |
@@ -0,0 +1,228 @@
1 |
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2 |
3 |
4 |
5 |
## Introduction
6 |
7 |
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8 |
9 |
## Pre-trained models
10 |
11 |
Model | Description | # params | Download
12 |
13 |
`bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](
14 |
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](
15 |
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](
16 |
`bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](
17 |
`bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](
18 |
19 |
## Results
20 |
21 |
**[GLUE (Wang et al., 2019)](**
22 |
_(dev set, single model, single-task finetuning)_
23 |
24 |
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25 |
26 |
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27 |
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28 |
29 |
**[SQuAD (Rajpurkar et al., 2018)](**
30 |
_(dev set, no additional data used)_
31 |
32 |
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33 |
34 |
`roberta.large` | 88.9/94.6 | 86.5/89.4
35 |
`bart.large` | 88.8/94.6 | 86.1/89.2
36 |
37 |
**[CNN/Daily Mail](**
38 |
_(test set, no additional data used)_
39 |
40 |
Model | R1 | R2 | RL
41 |
42 |
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43 |
`bart.large` | 44.16 | 21.28 | 40.90
44 |
45 |
## Example usage
46 |
47 |
##### Load BART from torch.hub (PyTorch >= 1.1):
48 |
49 |
import torch
50 |
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51 |
bart.eval() # disable dropout (or leave in train mode to finetune)
52 |
53 |
54 |
##### Load BART (for PyTorch 1.0 or custom models):
55 |
56 |
# Download bart.large model
57 |
58 |
tar -xzvf bart.large.tar.gz
59 |
60 |
# Load the model in fairseq
61 |
from fairseq.models.bart import BARTModel
62 |
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='')
63 |
bart.eval() # disable dropout (or leave in train mode to finetune)
64 |
65 |
66 |
##### Apply Byte-Pair Encoding (BPE) to input text:
67 |
68 |
tokens = bart.encode('Hello world!')
69 |
assert tokens.tolist() == [0, 31414, 232, 328, 2]
70 |
bart.decode(tokens) # 'Hello world!'
71 |
72 |
73 |
##### Extract features from BART:
74 |
75 |
# Extract the last layer's features
76 |
last_layer_features = bart.extract_features(tokens)
77 |
assert last_layer_features.size() == torch.Size([1, 5, 1024])
78 |
79 |
# Extract all layer's features from decoder (layer 0 is the embedding layer)
80 |
all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81 |
assert len(all_layers) == 13
82 |
assert torch.all(all_layers[-1] == last_layer_features)
83 |
84 |
85 |
##### Use BART for sentence-pair classification tasks:
86 |
87 |
# Download BART already finetuned for MNLI
88 |
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89 |
bart.eval() # disable dropout for evaluation
90 |
91 |
# Encode a pair of sentences and make a prediction
92 |
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93 |
bart.predict('mnli', tokens).argmax() # 0: contradiction
94 |
95 |
# Encode another pair of sentences
96 |
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97 |
bart.predict('mnli', tokens).argmax() # 2: entailment
98 |
99 |
100 |
##### Register a new (randomly initialized) classification head:
101 |
102 |
bart.register_classification_head('new_task', num_classes=3)
103 |
logprobs = bart.predict('new_task', tokens)
104 |
105 |
106 |
##### Batched prediction:
107 |
108 |
import torch
109 |
from import collate_tokens
110 |
111 |
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112 |
113 |
114 |
batch_of_pairs = [
115 |
['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116 |
['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117 |
118 |
119 |
batch = collate_tokens(
120 |
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121 |
122 |
123 |
logprobs = bart.predict('mnli', batch)
124 |
125 |
# tensor([0, 2])
126 |
127 |
128 |
##### Using the GPU:
129 |
130 |
131 |
bart.predict('new_task', tokens)
132 |
133 |
134 |
#### Filling masks:
135 |
136 |
BART can be used to fill multiple `<mask>` tokens in the input.
137 |
138 |
bart = torch.hub.load('pytorch/fairseq', 'bart.base')
139 |
140 |
bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
141 |
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
142 |
143 |
144 |
Note that by default we enforce the output length to match the input length.
145 |
This can be disabled by setting ``match_source_len=False``:
146 |
147 |
bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
148 |
# [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
149 |
150 |
151 |
Example code to fill masks for a batch of sentences using GPU
152 |
153 |
154 |
bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
155 |
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
156 |
('The dog was asleep on the couch', tensor(-0.6796))]]
157 |
158 |
159 |
#### Evaluating the `bart.large.mnli` model:
160 |
161 |
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
162 |
163 |
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
164 |
ncorrect, nsamples = 0, 0
165 |
166 |
167 |
with open('glue_data/MNLI/dev_matched.tsv') as fin:
168 |
169 |
for index, line in enumerate(fin):
170 |
tokens = line.strip().split('\t')
171 |
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
172 |
tokens = bart.encode(sent1, sent2)
173 |
prediction = bart.predict('mnli', tokens).argmax().item()
174 |
prediction_label = label_map[prediction]
175 |
ncorrect += int(prediction_label == target)
176 |
nsamples += 1
177 |
print('| Accuracy: ', float(ncorrect)/float(nsamples))
178 |
# Expected output: 0.9010
179 |
180 |
181 |
#### Evaluating the `bart.large.cnn` model:
182 |
- Follow instructions [here]( to download and process into data-files such that `test.source` and `` has one line for each non-tokenized sample.
183 |
- For simpler preprocessing, you can also `wget`, although there is no guarantee of identical scores
184 |
- `huggingface/transformers` has a simpler interface that supports [single-gpu]( and [multi-gpu]( beam search.
185 |
In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
186 |
187 |
In `fairseq`, summaries can be generated using:
188 |
189 |
190 |
cp data-bin/cnn_dm/dict.source.txt checkpoints/
191 |
python examples/bart/ \
192 |
--model-dir pytorch/fairseq \
193 |
--model-file bart.large.cnn \
194 |
--src cnn_dm/test.source \
195 |
--out cnn_dm/test.hypo
196 |
197 |
198 |
For calculating rouge, install `files2rouge` from [here](
199 |
200 |
201 |
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
202 |
203 |
# Tokenize hypothesis and target files.
204 |
cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
205 |
cat | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines >
206 |
files2rouge test.hypo.tokenized
207 |
# Expected output: (ROUGE-2 Average_F: 0.21238)
208 |
209 |
210 |
211 |
## Finetuning
212 |
213 |
- [Finetuning on GLUE](
214 |
- [Finetuning on CNN-DM](
215 |
216 |
## Citation
217 |
218 |
219 |
220 |
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
221 |
Language Generation, Translation, and Comprehension},
222 |
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
223 |
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
224 |
and Luke Zettlemoyer },
225 |
journal={arXiv preprint arXiv:1910.13461},
226 |
year = {2019},
227 |
228 |
@@ -0,0 +1,102 @@
1 |
# Fine-tuning BART on CNN-Dailymail summarization task
2 |
3 |
### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4 |
5 |
Follow the instructions [here]( to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue]( or check out the code [here](
6 |
7 |
Follow the instructions [here]( to download the original Extreme Summarization datasets, or check out the code [here](, Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8 |
9 |
### 2) BPE preprocess:
10 |
11 |
12 |
wget -N ''
13 |
wget -N ''
14 |
wget -N ''
15 |
16 |
17 |
for SPLIT in train val
18 |
19 |
for LANG in source target
20 |
21 |
python -m examples.roberta.multiprocessing_bpe_encoder \
22 |
--encoder-json encoder.json \
23 |
--vocab-bpe vocab.bpe \
24 |
--inputs "$TASK/$SPLIT.$LANG" \
25 |
--outputs "$TASK/$SPLIT.bpe.$LANG" \
26 |
--workers 60 \
27 |
28 |
29 |
30 |
31 |
32 |
### 3) Binarize dataset:
33 |
34 |
fairseq-preprocess \
35 |
--source-lang "source" \
36 |
--target-lang "target" \
37 |
--trainpref "${TASK}/train.bpe" \
38 |
--validpref "${TASK}/val.bpe" \
39 |
--destdir "${TASK}-bin/" \
40 |
--workers 60 \
41 |
--srcdict dict.txt \
42 |
--tgtdict dict.txt;
43 |
44 |
45 |
### 4) Fine-tuning on CNN-DM summarization task:
46 |
Example fine-tuning CNN-DM
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56 |
--restore-file $BART_PATH \
57 |
--max-tokens $MAX_TOKENS \
58 |
--task translation \
59 |
--source-lang source --target-lang target \
60 |
--truncate-source \
61 |
--layernorm-embedding \
62 |
--share-all-embeddings \
63 |
--share-decoder-input-output-embed \
64 |
--reset-optimizer --reset-dataloader --reset-meters \
65 |
--required-batch-size-multiple 1 \
66 |
--arch bart_large \
67 |
--criterion label_smoothed_cross_entropy \
68 |
--label-smoothing 0.1 \
69 |
--dropout 0.1 --attention-dropout 0.1 \
70 |
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71 |
--clip-norm 0.1 \
72 |
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73 |
--fp16 --update-freq $UPDATE_FREQ \
74 |
--skip-invalid-size-inputs-valid-test \
75 |
76 |
77 |
Above is expected to run on `1` node with `8 32gb-V100`.
78 |
Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79 |
80 |
Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81 |
82 |
### Inference for CNN-DM test data using above trained checkpoint.
83 |
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using ``, for example
84 |
85 |
86 |
cp data-bin/cnn_dm/dict.source.txt checkpoints/
87 |
python examples/bart/ \
88 |
--model-dir checkpoints \
89 |
--model-file \
90 |
--src cnn_dm/test.source \
91 |
--out cnn_dm/test.hypo
92 |
93 |
For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
94 |
95 |
cp data-bin/cnn_dm/dict.source.txt checkpoints/
96 |
python examples/bart/ \
97 |
--model-dir checkpoints \
98 |
--model-file \
99 |
--src cnn_dm/test.source \
100 |
--out cnn_dm/test.hypo \
101 |
102 |
@@ -0,0 +1,100 @@
1 |
# Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 |
# This source code is licensed under the MIT license found in the
4 |
# LICENSE file in the root directory of this source tree.
5 |
6 |
import torch
7 |
from fairseq.models.bart import BARTModel
8 |
import argparse
9 |
10 |
XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
11 |
CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
12 |
13 |
14 |
15 |
def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
16 |
count = 1
17 |
18 |
# if n_obs is not None: bsz = min(bsz, n_obs)
19 |
20 |
with open(infile) as source, open(outfile, "w") as fout:
21 |
sline = source.readline().strip()
22 |
slines = [sline]
23 |
for sline in source:
24 |
if n_obs is not None and count > n_obs:
25 |
26 |
if count % bsz == 0:
27 |
hypotheses_batch = bart.sample(slines, **eval_kwargs)
28 |
for hypothesis in hypotheses_batch:
29 |
fout.write(hypothesis + "\n")
30 |
31 |
slines = []
32 |
33 |
34 |
count += 1
35 |
36 |
if slines != []:
37 |
hypotheses_batch = bart.sample(slines, **eval_kwargs)
38 |
for hypothesis in hypotheses_batch:
39 |
fout.write(hypothesis + "\n")
40 |
41 |
42 |
43 |
def main():
44 |
45 |
46 |
47 |
python examples/bart/ \
48 |
--model-dir $HOME/bart.large.cnn \
49 |
--model-file \
50 |
--src $HOME/data-bin/cnn_dm/test.source
51 |
52 |
parser = argparse.ArgumentParser()
53 |
54 |
55 |
56 |
57 |
58 |
help="path containing model file and src_dict.txt",
59 |
60 |
61 |
62 |
63 |
help="where in model_dir are weights saved",
64 |
65 |
66 |
"--src", default="test.source", help="text to summarize", type=str
67 |
68 |
69 |
"--out", default="test.hypo", help="where to save summaries", type=str
70 |
71 |
parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
72 |
73 |
"--n", default=None, help="how many examples to summarize", type=int
74 |
75 |
76 |
77 |
78 |
79 |
help="if true use XSUM_KWARGS else CNN_KWARGS",
80 |
81 |
args = parser.parse_args()
82 |
eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
83 |
if args.model_dir == "pytorch/fairseq":
84 |
bart = torch.hub.load("pytorch/fairseq", args.model_file)
85 |
86 |
bart = BARTModel.from_pretrained(
87 |
88 |
89 |
90 |
91 |
bart = bart.eval()
92 |
if torch.cuda.is_available():
93 |
bart = bart.cuda().half()
94 |
95 |
bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
96 |
97 |
98 |
99 |
if __name__ == "__main__":
100 |
@@ -0,0 +1,88 @@
1 |
# Neural Machine Translation with Byte-Level Subwords
2 |
3 |
4 |
5 |
We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
6 |
7 |
8 |
## Data
9 |
Get data and generate fairseq binary dataset:
10 |
11 |
bash ./
12 |
13 |
14 |
## Model Training
15 |
Train Transformer model with Bi-GRU embedding contextualization (implemented in ``):
16 |
17 |
# VOCAB=bytes
18 |
# VOCAB=chars
19 |
20 |
# VOCAB=bpe2048
21 |
# VOCAB=bbpe4096
22 |
# VOCAB=bpe4096
23 |
# VOCAB=bpe16384
24 |
25 |
26 |
fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
27 |
--arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
28 |
--optimizer adam --adam-betas '(0.9, 0.98)' \
29 |
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
30 |
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
31 |
--log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
32 |
--batch-size 100 --max-update 100000 --update-freq 2
33 |
34 |
35 |
## Generation
36 |
`fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
37 |
38 |
# BPE=--bpe bytes
39 |
# BPE=--bpe characters
40 |
BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
41 |
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
42 |
# BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
43 |
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
44 |
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
45 |
46 |
47 |
48 |
fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
49 |
--source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/" \
50 |
--tokenizer moses --moses-target-lang en ${BPE}
51 |
52 |
When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
53 |
54 |
fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
55 |
--path "checkpoints/${VOCAB}/" --input data/ --tokenizer moses --moses-source-lang fr \
56 |
--moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
57 |
58 |
59 |
## Results
60 |
| Vocabulary | Model | BLEU |
61 |
62 |
| Joint BPE 16k ([Kudo, 2018]( | 512d LSTM 2+2 | 33.81 |
63 |
| Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
64 |
| Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
65 |
| Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
66 |
| Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
67 |
| Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
68 |
| Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
69 |
| Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
70 |
71 |
72 |
## Citation
73 |
74 |
75 |
title={Neural Machine Translation with Byte-Level Subwords},
76 |
author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
## Contact
86 |
Changhan Wang ([[email protected]](mailto:[email protected])),
87 |
Kyunghyun Cho ([[email protected]](mailto:[email protected])),
88 |
Jiatao Gu ([[email protected]](mailto:[email protected]))