From b41c74dc5be15918d5fd21f199b66b78a601192c Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 25 Jan 2019 15:35:01 -0800 Subject: [PATCH] Add code for "Pay Less Attention with Lightweight and Dynamic Convolutions" (#473) Summary: Changelog: - `e330f56`: Add code for the "Pay Less Attention with Lightweight and Dynamic Convolutions" paper - `5e3b98c`: Add scripts for computing tokenized BLEU with compound splitting and sacrebleu - update READMEs - misc fixes Pull Request resolved: https://github.com/pytorch/fairseq/pull/473 Differential Revision: D13819717 Pulled By: myleott fbshipit-source-id: f2dc12ea89a436b950cafec3593ed1b04af808e9 --- README.md | 80 +- docs/getting_started.rst | 2 +- docs/tutorial_classifying_names.rst | 2 +- examples/backtranslation/README.md | 19 + examples/conv_lm/README.md | 26 + examples/conv_seq2seq/README.md | 25 + examples/language_model/README.md | 13 +- examples/pay_less_attention_paper/README.md | 116 +++ examples/scaling_nmt/README.md | 62 ++ examples/stories/README.md | 26 +- examples/translation/README.md | 81 +- fairseq/models/lightconv.py | 931 ++++++++++++++++++++ fairseq/modules/__init__.py | 6 + fairseq/modules/dynamic_convolution.py | 227 +++++ fairseq/modules/lightweight_convolution.py | 233 +++++ fairseq/modules/unfold1d.py | 19 + scripts/compound_split_bleu.sh | 20 + scripts/sacrebleu_pregen.sh | 28 + tests/test_binaries.py | 22 + tests/test_train.py | 1 + train.py | 7 +- 21 files changed, 1832 insertions(+), 114 deletions(-) create mode 100644 examples/backtranslation/README.md create mode 100644 examples/conv_lm/README.md create mode 100644 examples/conv_seq2seq/README.md create mode 100644 examples/pay_less_attention_paper/README.md create mode 100644 examples/scaling_nmt/README.md create mode 100644 fairseq/models/lightconv.py create mode 100644 fairseq/modules/dynamic_convolution.py create mode 100644 fairseq/modules/lightweight_convolution.py create mode 100644 fairseq/modules/unfold1d.py create mode 100644 scripts/compound_split_bleu.sh create mode 100755 scripts/sacrebleu_pregen.sh diff --git a/README.md b/README.md index c9b761643..07e430d41 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,16 @@ of various sequence-to-sequence models, including: - [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083) - [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) - [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956) - - **_New_** [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833) + - [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833) +- **LightConv and DynamicConv models** + - **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX) - **Long Short-Term Memory (LSTM) networks** - [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025) - [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960) - **Transformer (self-attention) networks** - [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762) - - **_New_** [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187) - - **_New_** [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381) + - [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187) + - [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381) Fairseq features: - multi-GPU (distributed) training on one machine or across multiple machines @@ -27,7 +29,7 @@ Fairseq features: - fast half-precision floating point (FP16) training - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers -We also provide [pre-trained models](#pre-trained-models) for several benchmark +We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark translation and language modeling datasets. ![Model](fairseq.gif) @@ -55,73 +57,27 @@ The [full documentation](https://fairseq.readthedocs.io/) contains instructions for getting started, training new models and extending fairseq with new model types and tasks. -# Pre-trained Models +# Pre-trained models and examples -We provide the following pre-trained models and pre-processed, binarized test sets: +We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, +as well as example training and evaluation commands. -### Translation +- [Translation](examples/translation/README.md): convolutional and transformer models are available +- [Language Modeling](examples/language_model/README.md): convolutional models are available -Description | Dataset | Model | Test set(s) ----|---|---|--- -Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2) -Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-de.newstest2014.tar.bz2) -Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt17.v2.en-de.newstest2014.tar.bz2) -Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) -Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) -Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive - -### Language models - -Description | Dataset | Model | Test set(s) ----|---|---|--- -Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/gbw_test_lm.tar.bz2) -Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wiki103_test_lm.tar.bz2) - -### Stories - -Description | Dataset | Model | Test set(s) ----|---|---|--- -Stories with Convolutional Model
([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/stories_test.tar.bz2) - - -### Usage - -Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: -``` -$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin -$ curl https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin -$ python generate.py data-bin/wmt14.en-fr.newstest2014 \ - --path data-bin/wmt14.en-fr.fconv-py/model.pt \ - --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out -... -| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s) -| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) - -# Scoring with score.py: -$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys -$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref -$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref -BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) -``` +We also have more detailed READMEs to reproduce results from specific papers: +- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) +- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md) +- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md) +- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md) +- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md) # Join the fairseq community * Facebook page: https://www.facebook.com/groups/fairseq.users * Google group: https://groups.google.com/forum/#!forum/fairseq-users -# Citation - -If you use the code in your paper, then please cite it as: - -``` -@inproceedings{gehring2017convs2s, - author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, - title = "{Convolutional Sequence to Sequence Learning}", - booktitle = {Proc. of ICML}, - year = 2017, -} -``` - # License fairseq(-py) is BSD-licensed. The license applies to the pre-trained models as well. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 1658d4109..9912d1830 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -5,7 +5,7 @@ First, download a pre-trained model along with its vocabularies: .. code-block:: console - > curl https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - + > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - This model uses a `Byte Pair Encoding (BPE) vocabulary `__, so we'll have to apply diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index 262c90ef4..461675041 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -26,7 +26,7 @@ of the data that is already tokenized into characters and split into separate train, valid and test sets. Download and extract the data from here: -`tutorial_names.tar.gz `_ +`tutorial_names.tar.gz `_ Once extracted, let's preprocess the data using the :ref:`preprocess.py` command-line tool to create the dictionaries. While this tool is primarily diff --git a/examples/backtranslation/README.md b/examples/backtranslation/README.md new file mode 100644 index 000000000..c499ecccc --- /dev/null +++ b/examples/backtranslation/README.md @@ -0,0 +1,19 @@ +# Understanding Back-Translation at Scale (Edunov et al., 2018) + +This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381). + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive + +## Citation +```bibtex +@inproceedings{edunov2018backtranslation, + title = {Understanding Back-Translation at Scale}, + author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David}, + booktitle = {Conference of the Association for Computational Linguistics (ACL)}, + year = 2018, +} +``` diff --git a/examples/conv_lm/README.md b/examples/conv_lm/README.md new file mode 100644 index 000000000..a4e42a2cf --- /dev/null +++ b/examples/conv_lm/README.md @@ -0,0 +1,26 @@ +# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017) + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2) +Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2) + +## Example usage + +See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103 +using the `fconv_lm_dauphin_wikitext103` model architecture. + +## Citation + +```bibtex +@inproceedings{dauphin2017language, + title={Language Modeling with Gated Convolutional Networks}, + author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David}, + booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70}, + pages={933--941}, + year={2017}, + organization={JMLR} +} +``` diff --git a/examples/conv_seq2seq/README.md b/examples/conv_seq2seq/README.md new file mode 100644 index 000000000..95fe7e790 --- /dev/null +++ b/examples/conv_seq2seq/README.md @@ -0,0 +1,25 @@ +# Convolutional Sequence to Sequence Learning (Gehring et al., 2017) + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) + +## Example usage + +See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and +WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures. + +## Citation + +```bibtex +@inproceedings{gehring2017convs2s, + title = {Convolutional Sequence to Sequence Learning}, + author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, + booktitle = {Proc. of ICML}, + year = 2017, +} +``` diff --git a/examples/language_model/README.md b/examples/language_model/README.md index e5107df6b..eeef21f08 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -1,8 +1,17 @@ -Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit +# Neural Language Modeling + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2) +Convolutional
([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2) + +## Example usage These scripts provide an example of pre-processing data for the Language Modeling task. -# prepare-wikitext-103.sh +### prepare-wikitext-103.sh Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/): diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md new file mode 100644 index 000000000..89b61d61b --- /dev/null +++ b/examples/pay_less_attention_paper/README.md @@ -0,0 +1,116 @@ +# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019) +This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://openreview.net/pdf?id=SkVhlh09tX) + +## Citation: +```bibtex +@inproceedings{wu2018pay, + title = {Pay Less Attention with Lightweight and Dynamic Convolutions}, + author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli}, + booktitle = {International Conference on Learning Representations}, + year = {2019}, + url = {https://openreview.net/forum?id=SkVhlh09tX}, +} +``` + +## Translation + +### Pre-trained models +For some datasets we release models without GLUs which are faster at inference. + +Description | Dataset | Model | Test set(s) +---|---|---|--- +LightConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.bz2) | IWSLT14 test:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2) +DynamicConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.bz2) | IWSLT14 test:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2) +LightConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +DynamicConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +LightConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +DynamicConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +LightConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +DynamicConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +LightConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.bz2) | newstest2017:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2) +DynamicConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.bz2) | newstest2017:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2) + +### Preprocessing the training datasets + +Please follow the instructions in [`examples/translation/README.md`](../translation/README.md) to preprocess the data. + +### Training and evaluation options: +To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`. +For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv. +For best BLEU results, lenpen may need to be manually tuned. + +### IWSLT14 De-En +Training and evaluating DynamicConv (without GLU) on a GPU: +```sh +# Training +SAVE="save/dynamic_conv_iwslt" +mkdir -p $SAVE +CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \ + --clip-norm 0 --optimizer adam --lr 0.0005 \ + --source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \ + --log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --lr-scheduler inverse_sqrt \ + --ddp-backend=no_c10d \ + --max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \ + --adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \ + -a lightconv_iwslt_de_en --save-dir $SAVE \ + --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 0 --decoder-glu 0 +python scripts/average_checkpoints.py --inputs $SAVE \ + --num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt" + +# Evaluation +CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet +``` + +### WMT16 En-De +Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine scheduler on one machine with 8 V100 GPUs: +```sh +# Training +SAVE="save/dynamic_conv_wmt16en2de" +mkdir -p $SAVE +python -m torch.distributed.launch --nproc_per_node 8 train.py \ + data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \ + --max-update 30000 --share-all-embeddings --optimizer adam \ + --adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \ + --clip-norm 0.0 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --ddp-backend=no_c10d --max-tokens 3584 \ + --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ + --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --t-mult 1 --lr-period-updates 20000 \ + --arch lightconv_wmt_en_de_big --save-dir $SAVE \ + --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 1 --decoder-glu 1 + +# Evaluation +CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt +bash scripts/compound_split_bleu.sh wmt16_gen.txt +``` + +### WMT14 En-Fr +Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one machine with 8 V100 GPUs: +```sh +# Training +SAVE="save/dynamic_conv_wmt14en2fr" +mkdir -p $SAVE +python -m torch.distributed.launch --nproc_per_node 8 train.py \ + data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \ + --max-update 30000 --share-all-embeddings --optimizer adam \ + --adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \ + --clip-norm 0.0 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --ddp-backend=no_c10d --max-tokens 3584 \ + --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ + --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --t-mult 1 --lr-period-updates 70000 \ + --arch lightconv_wmt_en_fr_big --save-dir $SAVE \ + --dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 1 --decoder-glu 1 + +# Evaluation +CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test +``` diff --git a/examples/scaling_nmt/README.md b/examples/scaling_nmt/README.md new file mode 100644 index 000000000..a48f8ac17 --- /dev/null +++ b/examples/scaling_nmt/README.md @@ -0,0 +1,62 @@ +# Scaling Neural Machine Translation (Ott et al., 2018) + +This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187). + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) + +## Training a new model on WMT'16 En-De + +Please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8). +Then: + +1. Extract the WMT'16 En-De data: +``` +$ TEXT=wmt16_en_de_bpe32k +$ mkdir $TEXT +$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT +``` + +2. Preprocess the dataset with a joined dictionary: +``` +$ python preprocess.py --source-lang en --target-lang de \ + --trainpref $TEXT/train.tok.clean.bpe.32000 \ + --validpref $TEXT/newstest2013.tok.bpe.32000 \ + --testpref $TEXT/newstest2014.tok.bpe.32000 \ + --destdir data-bin/wmt16_en_de_bpe32k \ + --nwordssrc 32768 --nwordstgt 32768 \ + --joined-dictionary +``` + +3. Train a model: +``` +$ python train.py data-bin/wmt16_en_de_bpe32k \ + --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ + --lr 0.0005 --min-lr 1e-09 \ + --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --max-tokens 3584 \ + --fp16 +``` + +Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU. + +If you want to train the above model with big batches (assuming your machine has 8 GPUs): +- add `--update-freq 16` to simulate training on 8*16=128 GPUs +- increase the learning rate; 0.001 works well for big batches + +## Citation + +```bibtex +@inproceedings{ott2018scaling, + title = {Scaling Neural Machine Translation}, + author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael}, + booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)}, + year = 2018, +} +``` diff --git a/examples/stories/README.md b/examples/stories/README.md index 6f476c56a..a633870a5 100644 --- a/examples/stories/README.md +++ b/examples/stories/README.md @@ -1,18 +1,28 @@ -FAIR Sequence-to-Sequence Toolkit for Story Generation +# Hierarchical Neural Story Generation (Fan et al., 2018) The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset. +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Stories with Convolutional Model
([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2) + + +## Dataset + The dataset can be downloaded like this: ``` cd examples/stories -curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvzf - +curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf - ``` and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token. -Example usage: +## Example usage + ``` # Preprocess the dataset: # Note that the dataset release is the full data, but the paper models the first 1000 words of each story @@ -44,3 +54,13 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip- $ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}" ``` + +## Citation +```bibtex +@inproceedings{fan2018hierarchical, + title = {Hierarchical Neural Story Generation}, + author = {Fan, Angela and Lewis, Mike and Dauphin, Yann}, + booktitle = {Conference of the Association for Computational Linguistics (ACL)}, + year = 2018, +} +``` diff --git a/examples/translation/README.md b/examples/translation/README.md index 3d84c52a3..7cbda849f 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -1,10 +1,41 @@ -# Example usage for Neural Machine Translation +# Neural Machine Translation -These scripts provide an example of pre-processing data for the NMT task -and instructions for how to replicate the results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187). +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) +Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive + +## Example usage + +Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: +``` +$ mkdir -p data-bin +$ curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin +$ curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin +$ python generate.py data-bin/wmt14.en-fr.newstest2014 \ + --path data-bin/wmt14.en-fr.fconv-py/model.pt \ + --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out +... +| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s) +| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) + +# Scoring with score.py: +$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys +$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref +$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref +BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) +``` ## Preprocessing +These scripts provide an example of pre-processing data for the NMT task. + ### prepare-iwslt14.sh Provides an example of pre-processing for IWSLT'14 German to English translation task: ["Report on the 11th IWSLT evaluation campaign" by Cettolo et al.](http://workshop2014.iwslt.org/downloads/proceeding.pdf) @@ -64,9 +95,10 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \ ### prepare-wmt14en2de.sh -Provides an example of pre-processing for the WMT'14 English to German translation task. By default it will produce a dataset that was modeled after ["Attention Is All You Need" by Vaswani et al.](https://arxiv.org/abs/1706.03762) that includes news-commentary-v12 data. +The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script. +By default it will produce a dataset that was modeled after ["Attention Is All You Need" (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with news-commentary-v12 data from WMT'17. -To use only data available in WMT'14 or to replicate results obtained in the original paper ["Convolutional Sequence to Sequence Learning" by Gehring et al.](https://arxiv.org/abs/1705.03122) run it with --icml17 instead: +To use only data available in WMT'14 or to replicate results obtained in the original ["Convolutional Sequence to Sequence Learning" (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option. ``` $ bash prepare-wmt14en2de.sh --icml17 @@ -131,42 +163,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \ --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe ``` - -## Replicating results from "Scaling Neural Machine Translation" - -To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187), -please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8). - -1. Extract the WMT'16 En-De data: -``` -$ TEXT=wmt16_en_de_bpe32k -$ mkdir $TEXT -$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT -``` -2. Preprocess the dataset with a joined dictionary: -``` -$ python preprocess.py --source-lang en --target-lang de \ - --trainpref $TEXT/train.tok.clean.bpe.32000 \ - --validpref $TEXT/newstest2013.tok.bpe.32000 \ - --testpref $TEXT/newstest2014.tok.bpe.32000 \ - --destdir data-bin/wmt16_en_de_bpe32k \ - --nwordssrc 32768 --nwordstgt 32768 \ - --joined-dictionary -``` -3. Train a model: -``` -$ python train.py data-bin/wmt16_en_de_bpe32k \ - --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ - --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ - --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ - --lr 0.0005 --min-lr 1e-09 \ - --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --max-tokens 3584 \ - --fp16 -``` - -Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU. - -If you want to train the above model with big batches (assuming your machine has 8 GPUs): -- add `--update-freq 16` to simulate training on 8*16=128 GPUs -- increase the learning rate; 0.001 works well for big batches diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py new file mode 100644 index 000000000..750fcb5ef --- /dev/null +++ b/fairseq/models/lightconv.py @@ -0,0 +1,931 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options +from fairseq import utils + +from fairseq.modules import ( + AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention, + SinusoidalPositionalEmbedding, DynamicConv1dTBC, LightweightConv1dTBC +) + +from . import ( + FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model, + register_model_architecture, +) + + +@register_model('lightconv') +class LightConvModel(FairseqModel): + """ + LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019) + `_. + To use LightConv please set --encoder-conv-type lightweight --decoder-conv-type lightweight + To use DynamicConv please set --encoder-conv-type dynamic --decoder-conv-type dynamic + + Args: + encoder (LightConvEncoder): the encoder + decoder (LightConvDecoder): the decoder + + The LightConv model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.lightconv_parser + :prog: + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--relu-dropout', type=float, metavar='D', + help='dropout probability after ReLU in FFN') + parser.add_argument('--input-dropout', type=float, metavar='D', + help='dropout probability of the inputs') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-conv-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads or LightConv/DynamicConv heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-conv-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads or LightConv/DynamicConv heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + + """LightConv and DynamicConv arguments""" + parser.add_argument('--encoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31,31]")') + parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")') + parser.add_argument('--encoder-glu', type=options.eval_bool, + help='glu after in proj') + parser.add_argument('--decoder-glu', type=options.eval_bool, + help='glu after in proj') + parser.add_argument('--encoder-conv-type', default='dynamic', type=str, + choices=['dynamic', 'lightweight'], + help='type of convolution') + parser.add_argument('--decoder-conv-type', default='dynamic', type=str, + choices=['dynamic', 'lightweight'], + help='type of convolution') + parser.add_argument('--weight-softmax', default=True, type=options.eval_bool) + parser.add_argument('--weight-dropout', type=float, metavar='D', + help='dropout probability for conv weights') + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, 'max_source_positions'): + args.max_source_positions = 1024 + if not hasattr(args, 'max_target_positions'): + args.max_target_positions = 1024 + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise RuntimeError('--share-all-embeddings requires a joined dictionary') + if args.encoder_embed_dim != args.decoder_embed_dim: + raise RuntimeError( + '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path): + raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path') + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = build_embedding( + tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens) + decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) + return LightConvModel(encoder, decoder) + + +@register_model('lightconv_lm') +class LightConvLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument('--dropout', default=0.1, type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--relu-dropout', default=0., type=float, metavar='D', + help='dropout probability after ReLU in FFN') + parser.add_argument('--input-dropout', type=float, metavar='D', + help='dropout probability of the inputs') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-output-dim', type=int, metavar='N', + help='decoder output dimension') + parser.add_argument('--decoder-input-dim', type=int, metavar='N', + help='decoder input dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads or LightConv/DynamicConv heads') + parser.add_argument('--decoder-normalize-before', default=False, action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N', + help='adaptive input factor') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--character-embeddings', default=False, action='store_true', + help='if set, uses character embedding convolutions to produce token embeddings') + parser.add_argument('--character-filters', type=str, metavar='LIST', + default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]', + help='size of character embeddings') + parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4, + help='size of character embeddings') + parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2, + help='number of highway layers for character token embeddder') + parser.add_argument('--adaptive-input', default=False, action='store_true', + help='if set, uses adaptive input') + parser.add_argument('--adaptive-input-factor', type=float, metavar='N', + help='adaptive input factor') + parser.add_argument('--adaptive-input-cutoff', metavar='EXPR', + help='comma separated list of adaptive input cutoff points.') + parser.add_argument('--tie-adaptive-weights', action='store_true', + help='if set, ties the weights of adaptive softmax and adaptive input') + parser.add_argument('--tie-adaptive-proj', action='store_true', + help='if set, ties the projection weights of adaptive softmax and adaptive input') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + + """LightConv and DynamicConv arguments""" + parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")') + parser.add_argument('--decoder-glu', type=options.eval_bool, + help='glu after in proj') + parser.add_argument('--decoder-conv-type', default='dynamic', type=str, + choices=['dynamic', 'lightweight'], + help='type of convolution') + parser.add_argument('--weight-softmax', default=True, type=options.eval_bool) + parser.add_argument('--weight-dropout', type=float, metavar='D', + help='dropout probability for conv weights') + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_lm_architecture(args) + + if not hasattr(args, 'max_source_positions'): + args.max_source_positions = args.tokens_per_sample + if not hasattr(args, 'max_target_positions'): + args.max_target_positions = args.tokens_per_sample + + if args.character_embeddings: + embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) + elif args.adaptive_input: + embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, + args.adaptive_input_factor, args.decoder_embed_dim, + options.eval_str_list(args.adaptive_input_cutoff, type=int)) + else: + embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) + + if args.tie_adaptive_weights: + assert args.adaptive_input + assert args.adaptive_input_factor == args.adaptive_softmax_factor + assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) + assert args.decoder_input_dim == args.decoder_output_dim + + decoder = LightConvDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) + return LightConvLanguageModel(decoder) + + +class LightConvEncoder(FairseqEncoder): + """ + LightConv encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`LightConvEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + left_pad (bool, optional): whether the input is left-padded. Default: + ``True`` + """ + + def __init__(self, args, dictionary, embed_tokens, left_pad=True): + super().__init__(dictionary) + self.dropout = args.dropout + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = args.max_source_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, embed_dim, self.padding_idx, + left_pad=left_pad, + learned=args.encoder_learned_pos, + ) if not args.no_token_positional_embeddings else None + + self.layers = nn.ModuleList([]) + self.layers.extend([ + LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) + for i in range(args.encoder_layers) + ]) + self.register_buffer('version', torch.Tensor([2])) + self.normalize = args.encoder_normalize_before + if self.normalize: + self.layer_norm = LayerNorm(embed_dim) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + """ + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x += self.embed_positions(src_tokens) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + # encoder layers + for layer in self.layers: + x = layer(x, encoder_padding_mask) + + if self.normalize: + x = self.layer_norm(x) + + return { + 'encoder_out': x, # T x B x C + 'encoder_padding_mask': encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out['encoder_out'] is not None: + encoder_out['encoder_out'] = \ + encoder_out['encoder_out'].index_select(1, new_order) + if encoder_out['encoder_padding_mask'] is not None: + encoder_out['encoder_padding_mask'] = \ + encoder_out['encoder_padding_mask'].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions()) + + +class LightConvDecoder(FairseqIncrementalDecoder): + """ + LightConv decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`LightConvDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + left_pad (bool, optional): whether the input is left-padded. Default: + ``False`` + """ + + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): + super().__init__(dictionary) + self.dropout = args.dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + output_embed_dim = args.decoder_output_dim + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None + + self.embed_positions = PositionalEmbedding( + args.max_target_positions, embed_dim, padding_idx, + left_pad=left_pad, + learned=args.decoder_learned_pos, + ) if not args.no_token_positional_embeddings else None + + self.layers = nn.ModuleList([]) + self.layers.extend([ + LightConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i]) + for i in range(args.decoder_layers) + ]) + + self.adaptive_softmax = None + + self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ + if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None + + if args.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + output_embed_dim, + options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif not self.share_input_output_embed: + self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) + nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5) + self.register_buffer('version', torch.Tensor([2])) + self.normalize = args.decoder_normalize_before and final_norm + if self.normalize: + self.layer_norm = LayerNorm(embed_dim) + + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the last decoder layer's output of shape `(batch, tgt_len, + vocab)` + - the last decoder layer's attention weights of shape `(batch, + tgt_len, src_len)` + """ + # embed positions + positions = self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) if self.embed_positions is not None else None + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + for layer in self.layers: + x, attn = layer( + x, + encoder_out['encoder_out'] if encoder_out is not None else None, + encoder_out['encoder_padding_mask'] if encoder_out is not None else None, + incremental_state, + ) + inner_states.append(x) + + if self.normalize: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + x = F.linear(x, self.embed_tokens.weight) + else: + x = F.linear(x, self.embed_out) + + return x, {'attn': attn, 'inner_states': inner_states} + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions()) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: + self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + return self._future_mask[:dim, :dim] + + +class LightConvEncoderLayer(nn.Module): + """Encoder layer block. + + Args: + args (argparse.Namespace): parsed command-line arguments + kernel_size: kernel size of the convolution + """ + + def __init__(self, args, kernel_size=0): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.conv_dim = args.encoder_conv_dim + padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2) + + if args.encoder_glu: + self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) + self.act = nn.GLU() + else: + self.linear1 = Linear(self.embed_dim, self.conv_dim) + self.act = None + if args.encoder_conv_type == 'lightweight': + self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout) + elif args.encoder_conv_type == 'dynamic': + self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout) + else: + raise NotImplementedError + self.linear2 = Linear(self.conv_dim, self.embed_dim) + + self.dropout = args.dropout + self.relu_dropout = args.relu_dropout + self.input_dropout = args.input_dropout + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) + + def forward(self, x, encoder_padding_mask): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + + Returns: + encoded output of shape `(batch, src_len, embed_dim)` + """ + residual = x + x = self.maybe_layer_norm(0, x, before=True) + x = F.dropout(x, p=self.input_dropout, training=self.training) + x = self.linear1(x) + if self.act is not None: + x = self.act(x) + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0) + x = self.conv(x) + x = self.linear2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(0, x, after=True) + + residual = x + x = self.maybe_layer_norm(1, x, before=True) + x = F.relu(self.fc1(x)) + x = F.dropout(x, p=self.relu_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(1, x, after=True) + return x + + def maybe_layer_norm(self, i, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return self.layer_norms[i](x) + else: + return x + + def extra_repr(self): + return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( + self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before) + + +class LightConvDecoderLayer(nn.Module): + """Decoder layer block. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + kernel_size: kernel size of the convolution + """ + + def __init__(self, args, no_encoder_attn=False, kernel_size=0): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.conv_dim = args.decoder_conv_dim + if args.decoder_glu: + self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) + self.act = nn.GLU() + else: + self.linear1 = Linear(self.embed_dim, self.conv_dim) + self.act = None + if args.decoder_conv_type == 'lightweight': + self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout) + elif args.decoder_conv_type == 'dynamic': + self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout) + else: + raise NotImplementedError + self.linear2 = Linear(self.conv_dim, self.embed_dim) + + self.dropout = args.dropout + self.relu_dropout = args.relu_dropout + self.input_dropout = args.input_dropout + self.normalize_before = args.decoder_normalize_before + + self.conv_layer_norm = LayerNorm(self.embed_dim) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, args.decoder_attention_heads, + dropout=args.attention_dropout, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim) + self.need_attn = True + + def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, + prev_conv_state=None, prev_attn_state=None, conv_mask=None, + conv_padding_mask=None): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + + Returns: + encoded output of shape `(batch, src_len, embed_dim)` + """ + residual = x + x = self.maybe_layer_norm(self.conv_layer_norm, x, before=True) + if prev_conv_state is not None: + if incremental_state is None: + incremental_state = {} + self.conv._set_input_buffer(incremental_state, prev_conv_state) + x = F.dropout(x, p=self.input_dropout, training=self.training) + x = self.linear1(x) + if self.act is not None: + x = self.act(x) + x = self.conv(x, incremental_state=incremental_state) + x = self.linear2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True) + + attn = None + if self.encoder_attn is not None: + residual = x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) + if prev_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = F.relu(self.fc1(x)) + x = F.dropout(x, p=self.relu_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return x, attn + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + def extra_repr(self): + return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( + self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LayerNorm(embedding_dim): + m = nn.LayerNorm(embedding_dim) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): + if learned: + m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + else: + m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1) + return m + + +@register_model_architecture('lightconv_lm', 'lightconv_lm') +def base_lm_architecture(args): + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) + args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) + + args.character_embeddings = getattr(args, 'character_embeddings', False) + + args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + + # The model training is not stable without this + args.decoder_normalize_before = True + + args.adaptive_input = getattr(args, 'adaptive_input', False) + args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) + args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None) + + args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) + args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) + + args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) + if len(args.decoder_kernel_size_list) == 1: + args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers + + +@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw') +def lightconv_lm_gbw(args): + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.dropout = getattr(args, 'dropout', 0.1) + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + base_lm_architecture(args) + + +@register_model_architecture('lightconv', 'lightconv') +def base_architecture(args): + args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048) + args.encoder_layers = getattr(args, 'encoder_layers', 7) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) + args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) + args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) + args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) + args.attention_dropout = getattr(args, 'attention_dropout', 0.) + args.relu_dropout = getattr(args, 'relu_dropout', 0.) + args.dropout = getattr(args, 'dropout', 0.1) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) + args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) + args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) + + args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + + args.encoder_conv_dim = getattr(args, 'encoder_conv_dim', args.encoder_embed_dim) + args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim) + + args.encoder_kernel_size_list = getattr(args, 'encoder_kernel_size_list', [3, 7, 15, 31, 31, 31, 31]) + args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) + if len(args.encoder_kernel_size_list) == 1: + args.encoder_kernel_size_list = args.encoder_kernel_size_list * args.encoder_layers + if len(args.decoder_kernel_size_list) == 1: + args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers + assert len(args.encoder_kernel_size_list) == args.encoder_layers, "encoder_kernel_size_list doesn't match encoder_layers" + assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers" + args.encoder_glu = getattr(args, 'encoder_glu', True) + args.decoder_glu = getattr(args, 'decoder_glu', True) + args.input_dropout = getattr(args, 'input_dropout', 0.1) + args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout) + + +@register_model_architecture('lightconv', 'lightconv_iwslt_de_en') +def lightconv_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) + args.encoder_layers = getattr(args, 'encoder_layers', 7) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) + args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.weight_dropout = getattr(args, 'weight_dropout', 0.1) + args.encoder_glu = getattr(args, 'encoder_glu', False) + args.decoder_glu = getattr(args, 'decoder_glu', False) + args.input_dropout = getattr(args, 'input_dropout', 0.0) + base_architecture(args) + + +@register_model_architecture('lightconv', 'lightconv_wmt_en_de') +def lightconv_wmt_en_de(args): + base_architecture(args) + + +@register_model_architecture('lightconv', 'lightconv_wmt_en_de_big') +def lightconv_wmt_en_de_big(args): + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + args.dropout = getattr(args, 'dropout', 0.3) + base_architecture(args) + + +@register_model_architecture('lightconv', 'lightconv_wmt_en_fr_big') +def lightconv_wmt_en_fr_big(args): + args.dropout = getattr(args, 'dropout', 0.1) + lightconv_wmt_en_de_big(args) + + +@register_model_architecture('lightconv', 'lightconv_wmt_zh_en_big') +def lightconv_wmt_zh_en_big(args): + args.dropout = getattr(args, 'dropout', 0.2) + args.attention_dropout = getattr(args, 'attention_dropout', 0.2) + args.weight_dropout = getattr(args, 'weight_dropout', 0.2) + lightconv_wmt_en_de_big(args) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 4db8e4ffe..be0fa5518 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -11,13 +11,16 @@ from .beamable_mm import BeamableMM from .character_token_embedder import CharacterTokenEmbedder from .conv_tbc import ConvTBC from .downsampled_multihead_attention import DownsampledMultiHeadAttention +from .dynamic_convolution import DynamicConv1dTBC from .grad_multiply import GradMultiply from .highway import Highway from .learned_positional_embedding import LearnedPositionalEmbedding +from .lightweight_convolution import LightweightConv1dTBC from .linearized_convolution import LinearizedConvolution from .multihead_attention import MultiheadAttention from .scalar_bias import ScalarBias from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding +from .unfold1d import unfold1d __all__ = [ 'AdaptiveInput', @@ -26,11 +29,14 @@ __all__ = [ 'CharacterTokenEmbedder', 'ConvTBC', 'DownsampledMultiHeadAttention', + 'DynamicConv1dTBC', 'GradMultiply', 'Highway', 'LearnedPositionalEmbedding', + 'LightweightConv1dTBC', 'LinearizedConvolution', 'MultiheadAttention', 'ScalarBias', 'SinusoidalPositionalEmbedding', + 'unfold1d', ] diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py new file mode 100644 index 000000000..7b3e921de --- /dev/null +++ b/fairseq/modules/dynamic_convolution.py @@ -0,0 +1,227 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.modules import unfold1d + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +class DynamicConv1dTBC(nn.Module): + '''Dynamic lightweight convolution taking T x B x C inputs + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) + bias: use bias + conv_bias: bias of the convolution + query_size: specified when feeding a different input as the query + in_proj: project the input and generate the filter together + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + ''' + def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, + weight_dropout=0., weight_softmax=False, + renorm_padding=False, bias=False, conv_bias=False, + query_size=None, in_proj=False): + super().__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout = weight_dropout + self.weight_softmax = weight_softmax + self.renorm_padding = renorm_padding + + if in_proj: + self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1) + else: + self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + @property + def in_proj(self): + return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size + + def reset_parameters(self): + self.weight_linear.reset_parameters() + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + query: use the specified query to predict the conv filters + ''' + unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None or not self.in_proj + + if query is None: + query = x + + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def _forward_unfolded(self, x, incremental_state, query): + '''The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) + else: + weight = self.weight_linear(query).view(T*B*H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) + x_unfold = x_unfold.view(T*B*H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K-1: + weight = weight.narrow(1, K-T, T) + K, padding_l = T, T-1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T*B*H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2):] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + '''Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + ''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) + else: + weight = self.weight_linear(query).view(T*B*H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B*H, K).transpose(0, 1) + + x = x.view(T, B*H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) + weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False) + else: + P = self.padding_l + # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K-1: + weight = weight.narrow(2, K-T, T) + K, P = T, T-1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) + weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, 'input_buffer') + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + + def extra_repr(self): + s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format( + self.input_size, self.kernel_size, self.padding_l, + self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding, + self.in_proj, + ) + + if self.query_size != self.input_size: + s += ', query_size={}'.format(self.query_size) + if self.weight_dropout > 0.: + s += ', weight_dropout={}'.format(self.weight_dropout) + return s diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py new file mode 100644 index 000000000..d19c36931 --- /dev/null +++ b/fairseq/modules/lightweight_convolution.py @@ -0,0 +1,233 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.modules import unfold1d + + +class LightweightConv1d(nn.Module): + '''Lightweight Convolution assuming the input is BxCxT + This is just an example that explains LightConv clearer than the TBC version. + We don't use this module in the model. + + Args: + input_size: # of channels of the input and output + kernel_size: convolution channels + padding: padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_softmax: normalize the weight with softmax before the convolution + Shape: + Input: BxCxT, i.e. (batch_size, input_size, timesteps) + Output: BxCxT, i.e. (batch_size, input_size, timesteps) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + ''' + + def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1, + weight_softmax=False, bias=False, weight_dropout=0.): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.num_heads = num_heads + self.padding = padding + self.weight_softmax = weight_softmax + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.weight_dropout = weight_dropout + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.) + + def forward(self, input): + ''' + input size: B x C x T + output size: B x C x T + ''' + B, C, T = input.size() + H = self.num_heads + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + + weight = F.dropout(weight, self.weight_dropout, training=self.training) + # Merge every C/H entries into the batch dimension (C = self.input_size) + # B x C x T -> (B * C/H) x H x T + # One can also expand the weight to C x 1 x K by a factor of C/H + # and do not reshape the input instead, which is slow though + input = input.view(-1, H, T) + output = F.conv1d(input, weight, padding=self.padding, groups=self.num_heads) + output = output.view(B, C, T) + if self.bias is not None: + output = output + self.bias.view(1, -1, 1) + + return output + + +class LightweightConv1dTBC(nn.Module): + '''Lightweight Convolution assuming the input is TxBxC + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + bias: use bias + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + ''' + def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, + weight_dropout=0., weight_softmax=False, bias=False): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout = weight_dropout + self.weight_softmax = weight_softmax + + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.) + + def forward(self, x, incremental_state=None, unfold=False): + '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + ''' + unfold = unfold or (incremental_state is not None) + + if unfold: + output = self._forward_unfolded(x, incremental_state) + else: + output = self._forward_expanded(x, incremental_state) + + if self.bias is not None: + output = output + self.bias.view(1, 1, -1) + return output + + def _forward_unfolded(self, x, incremental_state): + '''The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) + x_unfold = x_unfold.view(T*B*H, R, -1) + else: + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) + x_unfold = x_unfold.view(T*B*H, R, K) + + if self.weight_softmax: + weight = F.softmax(weight.float(), dim=1).type_as(weight) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2):] + K = weight.size(1) + + weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) + + weight = F.dropout(weight, self.weight_dropout, training=self.training) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_state): + '''Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + ''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if self.weight_softmax: + weight = F.softmax(weight.float(), dim=1).type_as(weight) + weight = weight.view(1, H, K).expand(T*B, H, K).contiguous() + weight = weight.view(T, B*H, K).transpose(0, 1) + + x = x.view(T, B*H, R).transpose(0, 1) + P = self.padding_l + if K > T and P == K-1: + weight = weight.narrow(2, K-T, T) + K, P = T, T-1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) + weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) + weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training) + + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, 'input_buffer') + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + + def extra_repr(self): + s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}'.format( + self.input_size, self.kernel_size, self.padding_l, + self.num_heads, self.weight_softmax, self.bias is not None + ) + if self.weight_dropout > 0.: + s += ', weight_dropout={}'.format(self.weight_dropout) + return s diff --git a/fairseq/modules/unfold1d.py b/fairseq/modules/unfold1d.py new file mode 100644 index 000000000..fde3b4f03 --- /dev/null +++ b/fairseq/modules/unfold1d.py @@ -0,0 +1,19 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch.nn.functional as F + + +def unfold1d(x, kernel_size, padding_l, pad_value=0): + '''unfold T x B x C to T x B x C x K''' + if kernel_size > 1: + T, B, C = x.size() + x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) + x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) + else: + x = x.unsqueeze(3) + return x diff --git a/scripts/compound_split_bleu.sh b/scripts/compound_split_bleu.sh new file mode 100644 index 000000000..958c5d126 --- /dev/null +++ b/scripts/compound_split_bleu.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +if [ $# -ne 1 ]; then + echo "usage: $0 GENERATE_PY_OUTPUT" + exit 1 +fi + +GEN=$1 + +SYS=$GEN.sys +REF=$GEN.ref + +if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then + echo "not done generating" + exit +fi + +grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS +grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF +python score.py --sys $SYS --ref $REF diff --git a/scripts/sacrebleu_pregen.sh b/scripts/sacrebleu_pregen.sh new file mode 100755 index 000000000..2599b94d6 --- /dev/null +++ b/scripts/sacrebleu_pregen.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +if [ $# -ne 4 ]; then + echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" + exit 1 +fi + +TESTSET=$1 +SRCLANG=$2 +TGTLANG=$3 + +GEN=$4 + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +SCRIPTS=mosesdecoder/scripts +DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl + +grep ^H $GEN \ +| sed 's/^H\-//' \ +| sort -n -k 1 \ +| cut -f 3 \ +| perl $DETOKENIZER -l $TGTLANG \ +| sed "s/ - /-/g" \ +> $GEN.sorted.detok + +sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 406876e2c..6168ebc96 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -138,6 +138,28 @@ class TestTranslation(unittest.TestCase): train_translation_model(data_dir, 'transformer_iwslt_de_en') generate_main(data_dir) + def test_lightconv(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_lightconv') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'lightconv_iwslt_de_en', [ + '--encoder-conv-type', 'lightweight', + '--decoder-conv-type', 'lightweight', + ]) + generate_main(data_dir) + + def test_dynamicconv(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_dynamicconv') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'lightconv_iwslt_de_en', [ + '--encoder-conv-type', 'dynamic', + '--decoder-conv-type', 'dynamic', + ]) + generate_main(data_dir) + class TestStories(unittest.TestCase): diff --git a/tests/test_train.py b/tests/test_train.py index cfffc3ed5..d86db8d1c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -62,6 +62,7 @@ class TestLoadCheckpoint(unittest.TestCase): 'os.makedirs': MagicMock(), 'os.path.join': MagicMock(), 'os.path.isfile': MagicMock(return_value=True), + 'os.path.isabs': MagicMock(return_value=False), } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] diff --git a/train.py b/train.py index b4c3aa368..d5e42cd6c 100644 --- a/train.py +++ b/train.py @@ -328,7 +328,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): def load_checkpoint(args, trainer, epoch_itr): """Load a checkpoint and replay dataloader to match.""" os.makedirs(args.save_dir, exist_ok=True) - checkpoint_path = os.path.join(args.save_dir, args.restore_file) + if os.path.isabs(args.restore_file): + checkpoint_path = args.restore_file + else: + checkpoint_path = os.path.join(args.save_dir, args.restore_file) if os.path.isfile(checkpoint_path): extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler, eval(args.optimizer_overrides)) @@ -344,6 +347,8 @@ def load_checkpoint(args, trainer, epoch_itr): if 'best' in extra_state: save_checkpoint.best = extra_state['best'] return True + else: + print('| no existing checkpoint found {}'.format(checkpoint_path)) return False