Add code for "Pay Less Attention with Lightweight and Dynamic Convolutions" (#473)

Summary:
Changelog:
- `e330f56`: Add code for the "Pay Less Attention with Lightweight and Dynamic Convolutions" paper
- `5e3b98c`: Add scripts for computing tokenized BLEU with compound splitting and sacrebleu
- update READMEs
- misc fixes
Pull Request resolved: https://github.com/pytorch/fairseq/pull/473

Differential Revision: D13819717

Pulled By: myleott

fbshipit-source-id: f2dc12ea89a436b950cafec3593ed1b04af808e9
This commit is contained in:
Myle Ott 2019-01-25 15:35:01 -08:00 committed by Facebook Github Bot
parent bc8ae44932
commit b41c74dc5b
21 changed files with 1832 additions and 114 deletions

View File

@ -8,14 +8,16 @@ of various sequence-to-sequence models, including:
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- **_New_** [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX)
- **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- **_New_** [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)
- **_New_** [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381)
- [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381)
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
@ -27,7 +29,7 @@ Fairseq features:
- fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark
We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark
translation and language modeling datasets.
![Model](fairseq.gif)
@ -55,73 +57,27 @@ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
for getting started, training new models and extending fairseq with new model
types and tasks.
# Pre-trained Models
# Pre-trained models and examples
We provide the following pre-trained models and pre-processed, binarized test sets:
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
as well as example training and evaluation commands.
### Translation
- [Translation](examples/translation/README.md): convolutional and transformer models are available
- [Language Modeling](examples/language_model/README.md): convolutional models are available
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-de.newstest2014.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt17.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
### Language models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wiki103_test_lm.tar.bz2)
### Stories
Description | Dataset | Model | Test set(s)
---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/stories_test.tar.bz2)
### Usage
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```
$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
$ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
# Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
We also have more detailed READMEs to reproduce results from specific papers:
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
# Join the fairseq community
* Facebook page: https://www.facebook.com/groups/fairseq.users
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
# Citation
If you use the code in your paper, then please cite it as:
```
@inproceedings{gehring2017convs2s,
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
title = "{Convolutional Sequence to Sequence Learning}",
booktitle = {Proc. of ICML},
year = 2017,
}
```
# License
fairseq(-py) is BSD-licensed.
The license applies to the pre-trained models as well.

View File

@ -5,7 +5,7 @@ First, download a pre-trained model along with its vocabularies:
.. code-block:: console
> curl https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
This model uses a `Byte Pair Encoding (BPE)
vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply

View File

@ -26,7 +26,7 @@ of the data that is already tokenized into characters and split into separate
train, valid and test sets.
Download and extract the data from here:
`tutorial_names.tar.gz <https://s3.amazonaws.com/fairseq-py/data/tutorial_names.tar.gz>`_
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
Once extracted, let's preprocess the data using the :ref:`preprocess.py`
command-line tool to create the dictionaries. While this tool is primarily

View File

@ -0,0 +1,19 @@
# Understanding Back-Translation at Scale (Edunov et al., 2018)
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
## Citation
```bibtex
@inproceedings{edunov2018backtranslation,
title = {Understanding Back-Translation at Scale},
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```

View File

@ -0,0 +1,26 @@
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
## Example usage
See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103
using the `fconv_lm_dauphin_wikitext103` model architecture.
## Citation
```bibtex
@inproceedings{dauphin2017language,
title={Language Modeling with Gated Convolutional Networks},
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
pages={933--941},
year={2017},
organization={JMLR}
}
```

View File

@ -0,0 +1,25 @@
# Convolutional Sequence to Sequence Learning (Gehring et al., 2017)
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
## Example usage
See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and
WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures.
## Citation
```bibtex
@inproceedings{gehring2017convs2s,
title = {Convolutional Sequence to Sequence Learning},
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
booktitle = {Proc. of ICML},
year = 2017,
}
```

View File

@ -1,8 +1,17 @@
Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit
# Neural Language Modeling
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
## Example usage
These scripts provide an example of pre-processing data for the Language Modeling task.
# prepare-wikitext-103.sh
### prepare-wikitext-103.sh
Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):

View File

@ -0,0 +1,116 @@
# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://openreview.net/pdf?id=SkVhlh09tX)
## Citation:
```bibtex
@inproceedings{wu2018pay,
title = {Pay Less Attention with Lightweight and Dynamic Convolutions},
author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli},
booktitle = {International Conference on Learning Representations},
year = {2019},
url = {https://openreview.net/forum?id=SkVhlh09tX},
}
```
## Translation
### Pre-trained models
For some datasets we release models without GLUs which are faster at inference.
Description | Dataset | Model | Test set(s)
---|---|---|---
LightConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.bz2) | IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
DynamicConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.bz2) | IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
LightConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
DynamicConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
LightConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
DynamicConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
LightConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
DynamicConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
LightConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.bz2) | newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
DynamicConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.bz2) | newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
### Preprocessing the training datasets
Please follow the instructions in [`examples/translation/README.md`](../translation/README.md) to preprocess the data.
### Training and evaluation options:
To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`.
For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv.
For best BLEU results, lenpen may need to be manually tuned.
### IWSLT14 De-En
Training and evaluating DynamicConv (without GLU) on a GPU:
```sh
# Training
SAVE="save/dynamic_conv_iwslt"
mkdir -p $SAVE
CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
--clip-norm 0 --optimizer adam --lr 0.0005 \
--source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \
--log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler inverse_sqrt \
--ddp-backend=no_c10d \
--max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \
-a lightconv_iwslt_de_en --save-dir $SAVE \
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 0 --decoder-glu 0
python scripts/average_checkpoints.py --inputs $SAVE \
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet
```
### WMT16 En-De
Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine scheduler on one machine with 8 V100 GPUs:
```sh
# Training
SAVE="save/dynamic_conv_wmt16en2de"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 train.py \
data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
--clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
--t-mult 1 --lr-period-updates 20000 \
--arch lightconv_wmt_en_de_big --save-dir $SAVE \
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
```
### WMT14 En-Fr
Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one machine with 8 V100 GPUs:
```sh
# Training
SAVE="save/dynamic_conv_wmt14en2fr"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 train.py \
data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
--clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
--t-mult 1 --lr-period-updates 70000 \
--arch lightconv_wmt_en_fr_big --save-dir $SAVE \
--dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test
```

View File

@ -0,0 +1,62 @@
# Scaling Neural Machine Translation (Ott et al., 2018)
This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
## Training a new model on WMT'16 En-De
Please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
Then:
1. Extract the WMT'16 En-De data:
```
$ TEXT=wmt16_en_de_bpe32k
$ mkdir $TEXT
$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT
```
2. Preprocess the dataset with a joined dictionary:
```
$ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary
```
3. Train a model:
```
$ python train.py data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0005 --min-lr 1e-09 \
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU.
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 16` to simulate training on 8*16=128 GPUs
- increase the learning rate; 0.001 works well for big batches
## Citation
```bibtex
@inproceedings{ott2018scaling,
title = {Scaling Neural Machine Translation},
author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael},
booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)},
year = 2018,
}
```

View File

@ -1,18 +1,28 @@
FAIR Sequence-to-Sequence Toolkit for Story Generation
# Hierarchical Neural Story Generation (Fan et al., 2018)
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
## Dataset
The dataset can be downloaded like this:
```
cd examples/stories
curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvzf -
curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -
```
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
Example usage:
## Example usage
```
# Preprocess the dataset:
# Note that the dataset release is the full data, but the paper models the first 1000 words of each story
@ -44,3 +54,13 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
```
## Citation
```bibtex
@inproceedings{fan2018hierarchical,
title = {Hierarchical Neural Story Generation},
author = {Fan, Angela and Lewis, Mike and Dauphin, Yann},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```

View File

@ -1,10 +1,41 @@
# Example usage for Neural Machine Translation
# Neural Machine Translation
These scripts provide an example of pre-processing data for the NMT task
and instructions for how to replicate the results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
## Example usage
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```
$ mkdir -p data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
$ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
# Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
## Preprocessing
These scripts provide an example of pre-processing data for the NMT task.
### prepare-iwslt14.sh
Provides an example of pre-processing for IWSLT'14 German to English translation task: ["Report on the 11th IWSLT evaluation campaign" by Cettolo et al.](http://workshop2014.iwslt.org/downloads/proceeding.pdf)
@ -64,9 +95,10 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
### prepare-wmt14en2de.sh
Provides an example of pre-processing for the WMT'14 English to German translation task. By default it will produce a dataset that was modeled after ["Attention Is All You Need" by Vaswani et al.](https://arxiv.org/abs/1706.03762) that includes news-commentary-v12 data.
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
By default it will produce a dataset that was modeled after ["Attention Is All You Need" (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with news-commentary-v12 data from WMT'17.
To use only data available in WMT'14 or to replicate results obtained in the original paper ["Convolutional Sequence to Sequence Learning" by Gehring et al.](https://arxiv.org/abs/1705.03122) run it with --icml17 instead:
To use only data available in WMT'14 or to replicate results obtained in the original ["Convolutional Sequence to Sequence Learning" (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
```
$ bash prepare-wmt14en2de.sh --icml17
@ -131,42 +163,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
```
## Replicating results from "Scaling Neural Machine Translation"
To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187),
please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
1. Extract the WMT'16 En-De data:
```
$ TEXT=wmt16_en_de_bpe32k
$ mkdir $TEXT
$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT
```
2. Preprocess the dataset with a joined dictionary:
```
$ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary
```
3. Train a model:
```
$ python train.py data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0005 --min-lr 1e-09 \
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU.
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 16` to simulate training on 8*16=128 GPUs
- increase the learning rate; 0.001 works well for big batches

931
fairseq/models/lightconv.py Normal file
View File

@ -0,0 +1,931 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options
from fairseq import utils
from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, DynamicConv1dTBC, LightweightConv1dTBC
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model_architecture,
)
@register_model('lightconv')
class LightConvModel(FairseqModel):
"""
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
<https://openreview.net/pdf?id=SkVhlh09tX>`_.
To use LightConv please set --encoder-conv-type lightweight --decoder-conv-type lightweight
To use DynamicConv please set --encoder-conv-type dynamic --decoder-conv-type dynamic
Args:
encoder (LightConvEncoder): the encoder
decoder (LightConvDecoder): the decoder
The LightConv model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.lightconv_parser
:prog:
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--input-dropout', type=float, metavar='D',
help='dropout probability of the inputs')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-conv-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads or LightConv/DynamicConv heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-conv-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads or LightConv/DynamicConv heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
"""LightConv and DynamicConv arguments"""
parser.add_argument('--encoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int),
help='list of kernel size (default: "[3,7,15,31,31,31,31]")')
parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int),
help='list of kernel size (default: "[3,7,15,31,31,31]")')
parser.add_argument('--encoder-glu', type=options.eval_bool,
help='glu after in proj')
parser.add_argument('--decoder-glu', type=options.eval_bool,
help='glu after in proj')
parser.add_argument('--encoder-conv-type', default='dynamic', type=str,
choices=['dynamic', 'lightweight'],
help='type of convolution')
parser.add_argument('--decoder-conv-type', default='dynamic', type=str,
choices=['dynamic', 'lightweight'],
help='type of convolution')
parser.add_argument('--weight-softmax', default=True, type=options.eval_bool)
parser.add_argument('--weight-dropout', type=float, metavar='D',
help='dropout probability for conv weights')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens)
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
return LightConvModel(encoder, decoder)
@register_model('lightconv_lm')
class LightConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--input-dropout', type=float, metavar='D',
help='dropout probability of the inputs')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension')
parser.add_argument('--decoder-input-dim', type=int, metavar='N',
help='decoder input dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads or LightConv/DynamicConv heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--character-embeddings', default=False, action='store_true',
help='if set, uses character embedding convolutions to produce token embeddings')
parser.add_argument('--character-filters', type=str, metavar='LIST',
default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
help='size of character embeddings')
parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4,
help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', default=False, action='store_true',
help='if set, uses adaptive input')
parser.add_argument('--adaptive-input-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--adaptive-input-cutoff', metavar='EXPR',
help='comma separated list of adaptive input cutoff points.')
parser.add_argument('--tie-adaptive-weights', action='store_true',
help='if set, ties the weights of adaptive softmax and adaptive input')
parser.add_argument('--tie-adaptive-proj', action='store_true',
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
"""LightConv and DynamicConv arguments"""
parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int),
help='list of kernel size (default: "[3,7,15,31,31,31]")')
parser.add_argument('--decoder-glu', type=options.eval_bool,
help='glu after in proj')
parser.add_argument('--decoder-conv-type', default='dynamic', type=str,
choices=['dynamic', 'lightweight'],
help='type of convolution')
parser.add_argument('--weight-softmax', default=True, type=options.eval_bool)
parser.add_argument('--weight-dropout', type=float, metavar='D',
help='dropout probability for conv weights')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample
if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
args.character_embedding_dim,
args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int))
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
if args.tie_adaptive_weights:
assert args.adaptive_input
assert args.adaptive_input_factor == args.adaptive_softmax_factor
assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format(
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = LightConvDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return LightConvLanguageModel(decoder)
class LightConvEncoder(FairseqEncoder):
"""
LightConv encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`LightConvEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary)
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i])
for i in range(args.encoder_layers)
])
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if self.normalize:
x = self.layer_norm(x)
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = \
encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
class LightConvDecoder(FairseqIncrementalDecoder):
"""
LightConv decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`LightConvDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
LightConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i])
for i in range(args.decoder_layers)
])
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states = [x]
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
)
inner_states.append(x)
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, {'attn': attn, 'inner_states': inner_states}
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
class LightConvEncoderLayer(nn.Module):
"""Encoder layer block.
Args:
args (argparse.Namespace): parsed command-line arguments
kernel_size: kernel size of the convolution
"""
def __init__(self, args, kernel_size=0):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.conv_dim = args.encoder_conv_dim
padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2)
if args.encoder_glu:
self.linear1 = Linear(self.embed_dim, 2*self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.encoder_conv_type == 'lightweight':
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
elif args.encoder_conv_type == 'dynamic':
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.input_dropout = args.input_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)])
def forward(self, x, encoder_padding_mask):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x = F.dropout(x, p=self.input_dropout, training=self.training)
x = self.linear1(x)
if self.act is not None:
x = self.act(x)
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0)
x = self.conv(x)
x = self.linear2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
return x
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
def extra_repr(self):
return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format(
self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before)
class LightConvDecoderLayer(nn.Module):
"""Decoder layer block.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
kernel_size: kernel size of the convolution
"""
def __init__(self, args, no_encoder_attn=False, kernel_size=0):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.conv_dim = args.decoder_conv_dim
if args.decoder_glu:
self.linear1 = Linear(self.embed_dim, 2*self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.decoder_conv_type == 'lightweight':
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
elif args.decoder_conv_type == 'dynamic':
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.input_dropout = args.input_dropout
self.normalize_before = args.decoder_normalize_before
self.conv_layer_norm = LayerNorm(self.embed_dim)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
prev_conv_state=None, prev_attn_state=None, conv_mask=None,
conv_padding_mask=None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.conv_layer_norm, x, before=True)
if prev_conv_state is not None:
if incremental_state is None:
incremental_state = {}
self.conv._set_input_buffer(incremental_state, prev_conv_state)
x = F.dropout(x, p=self.input_dropout, training=self.training)
x = self.linear1(x)
if self.act is not None:
x = self.act(x)
x = self.conv(x, incremental_state=incremental_state)
x = self.linear2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True)
attn = None
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x, attn
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def extra_repr(self):
return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format(
self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
return m
@register_model_architecture('lightconv_lm', 'lightconv_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
# The model training is not stable without this
args.decoder_normalize_before = True
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None)
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
if len(args.decoder_kernel_size_list) == 1:
args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
def lightconv_lm_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)
@register_model_architecture('lightconv', 'lightconv')
def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 7)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
args.encoder_conv_dim = getattr(args, 'encoder_conv_dim', args.encoder_embed_dim)
args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim)
args.encoder_kernel_size_list = getattr(args, 'encoder_kernel_size_list', [3, 7, 15, 31, 31, 31, 31])
args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
if len(args.encoder_kernel_size_list) == 1:
args.encoder_kernel_size_list = args.encoder_kernel_size_list * args.encoder_layers
if len(args.decoder_kernel_size_list) == 1:
args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
assert len(args.encoder_kernel_size_list) == args.encoder_layers, "encoder_kernel_size_list doesn't match encoder_layers"
assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers"
args.encoder_glu = getattr(args, 'encoder_glu', True)
args.decoder_glu = getattr(args, 'decoder_glu', True)
args.input_dropout = getattr(args, 'input_dropout', 0.1)
args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout)
@register_model_architecture('lightconv', 'lightconv_iwslt_de_en')
def lightconv_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 7)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.weight_dropout = getattr(args, 'weight_dropout', 0.1)
args.encoder_glu = getattr(args, 'encoder_glu', False)
args.decoder_glu = getattr(args, 'decoder_glu', False)
args.input_dropout = getattr(args, 'input_dropout', 0.0)
base_architecture(args)
@register_model_architecture('lightconv', 'lightconv_wmt_en_de')
def lightconv_wmt_en_de(args):
base_architecture(args)
@register_model_architecture('lightconv', 'lightconv_wmt_en_de_big')
def lightconv_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('lightconv', 'lightconv_wmt_en_fr_big')
def lightconv_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
lightconv_wmt_en_de_big(args)
@register_model_architecture('lightconv', 'lightconv_wmt_zh_en_big')
def lightconv_wmt_zh_en_big(args):
args.dropout = getattr(args, 'dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.2)
args.weight_dropout = getattr(args, 'weight_dropout', 0.2)
lightconv_wmt_en_de_big(args)

View File

@ -11,13 +11,16 @@ from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC
from .grad_multiply import GradMultiply
from .highway import Highway
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .unfold1d import unfold1d
__all__ = [
'AdaptiveInput',
@ -26,11 +29,14 @@ __all__ = [
'CharacterTokenEmbedder',
'ConvTBC',
'DownsampledMultiHeadAttention',
'DynamicConv1dTBC',
'GradMultiply',
'Highway',
'LearnedPositionalEmbedding',
'LightweightConv1dTBC',
'LinearizedConvolution',
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',
'unfold1d',
]

View File

@ -0,0 +1,227 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import unfold1d
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.)
return m
class DynamicConv1dTBC(nn.Module):
'''Dynamic lightweight convolution taking T x B x C inputs
Args:
input_size: # of channels of the input
kernel_size: convolution channels
padding_l: padding to the left when using "same" padding
num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
weight_dropout: the drop rate of the DropConnect to drop the weight
weight_softmax: normalize the weight with softmax before the convolution
renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1)
bias: use bias
conv_bias: bias of the convolution
query_size: specified when feeding a different input as the query
in_proj: project the input and generate the filter together
Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size)
Output: TxBxC, i.e. (timesteps, batch_size, input_size)
Attributes:
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
'''
def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=0., weight_softmax=False,
renorm_padding=False, bias=False, conv_bias=False,
query_size=None, in_proj=False):
super().__init__()
self.input_size = input_size
self.query_size = input_size if query_size is None else query_size
self.kernel_size = kernel_size
self.padding_l = padding_l
self.num_heads = num_heads
self.weight_dropout = weight_dropout
self.weight_softmax = weight_softmax
self.renorm_padding = renorm_padding
if in_proj:
self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1)
else:
self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias)
if conv_bias:
self.conv_bias = nn.Parameter(torch.Tensor(input_size))
else:
self.conv_bias = None
self.reset_parameters()
@property
def in_proj(self):
return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size
def reset_parameters(self):
self.weight_linear.reset_parameters()
if self.conv_bias is not None:
nn.init.constant_(self.conv_bias, 0.)
def forward(self, x, incremental_state=None, query=None, unfold=None):
'''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
args:
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
incremental_state: A dict to keep the state
unfold: unfold the input or not. If not, we use the matrix trick instead
query: use the specified query to predict the conv filters
'''
unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory
unfold = unfold or (incremental_state is not None)
assert query is None or not self.in_proj
if query is None:
query = x
if unfold:
output = self._forward_unfolded(x, incremental_state, query)
else:
output = self._forward_expanded(x, incremental_state, query)
if self.conv_bias is not None:
output = output + self.conv_bias.view(1, 1, -1)
return output
def _forward_unfolded(self, x, incremental_state, query):
'''The conventional implementation of convolutions.
Unfolding the input by having a window shifting to the right.'''
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
if self.in_proj:
proj = self.weight_linear(x)
x = proj.narrow(2, 0, self.input_size).contiguous()
weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1)
else:
weight = self.weight_linear(query).view(T*B*H, -1)
# renorm_padding is only implemented in _forward_expanded
assert not self.renorm_padding or incremental_state is not None
if incremental_state is not None:
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
x_unfold = x_unfold.view(T*B*H, R, -1)
else:
padding_l = self.padding_l
if K > T and padding_l == K-1:
weight = weight.narrow(1, K-T, T)
K, padding_l = T, T-1
# unfold the input: T x B x C --> T' x B x C x K
x_unfold = unfold1d(x, K, padding_l, 0)
x_unfold = x_unfold.view(T*B*H, R, K)
if self.weight_softmax and not self.renorm_padding:
weight = F.softmax(weight, dim=1)
weight = weight.narrow(1, 0, K)
if incremental_state is not None:
weight = weight[:, -x_unfold.size(2):]
K = weight.size(1)
if self.weight_softmax and self.renorm_padding:
weight = F.softmax(weight, dim=1)
weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False)
output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1
output = output.view(T, B, C)
return output
def _forward_expanded(self, x, incremental_stat, query):
'''Turn the convolution filters into band matrices and do matrix multiplication.
This is faster when the sequence is short, but less memory efficient.
This is not used in the decoder during inference.
'''
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
if self.in_proj:
proj = self.weight_linear(x)
x = proj.narrow(2, 0, self.input_size).contiguous()
weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1)
else:
weight = self.weight_linear(query).view(T*B*H, -1)
if not self.renorm_padding:
if self.weight_softmax:
weight = F.softmax(weight, dim=1)
weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False)
weight = weight.narrow(1, 0, K).contiguous()
weight = weight.view(T, B*H, K).transpose(0, 1)
x = x.view(T, B*H, R).transpose(0, 1)
if self.weight_softmax and self.renorm_padding:
# turn the convolution filters into band matrices
weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf'))
weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
weight_expanded = weight_expanded.narrow(2, self.padding_l, T)
# normalize the weight over valid positions like self-attention
weight_expanded = F.softmax(weight_expanded, dim=2)
weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False)
else:
P = self.padding_l
# For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length
if K > T and P == K-1:
weight = weight.narrow(2, K-T, T)
K, P = T, T-1
# turn the convolution filters into band matrices
weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C)
return output
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def extra_repr(self):
s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format(
self.input_size, self.kernel_size, self.padding_l,
self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding,
self.in_proj,
)
if self.query_size != self.input_size:
s += ', query_size={}'.format(self.query_size)
if self.weight_dropout > 0.:
s += ', weight_dropout={}'.format(self.weight_dropout)
return s

View File

@ -0,0 +1,233 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import unfold1d
class LightweightConv1d(nn.Module):
'''Lightweight Convolution assuming the input is BxCxT
This is just an example that explains LightConv clearer than the TBC version.
We don't use this module in the model.
Args:
input_size: # of channels of the input and output
kernel_size: convolution channels
padding: padding
num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
weight_softmax: normalize the weight with softmax before the convolution
Shape:
Input: BxCxT, i.e. (batch_size, input_size, timesteps)
Output: BxCxT, i.e. (batch_size, input_size, timesteps)
Attributes:
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
'''
def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1,
weight_softmax=False, bias=False, weight_dropout=0.):
super().__init__()
self.input_size = input_size
self.kernel_size = kernel_size
self.num_heads = num_heads
self.padding = padding
self.weight_softmax = weight_softmax
self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(input_size))
else:
self.bias = None
self.weight_dropout = weight_dropout
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.constant_(self.bias, 0.)
def forward(self, input):
'''
input size: B x C x T
output size: B x C x T
'''
B, C, T = input.size()
H = self.num_heads
weight = self.weight
if self.weight_softmax:
weight = F.softmax(weight, dim=-1)
weight = F.dropout(weight, self.weight_dropout, training=self.training)
# Merge every C/H entries into the batch dimension (C = self.input_size)
# B x C x T -> (B * C/H) x H x T
# One can also expand the weight to C x 1 x K by a factor of C/H
# and do not reshape the input instead, which is slow though
input = input.view(-1, H, T)
output = F.conv1d(input, weight, padding=self.padding, groups=self.num_heads)
output = output.view(B, C, T)
if self.bias is not None:
output = output + self.bias.view(1, -1, 1)
return output
class LightweightConv1dTBC(nn.Module):
'''Lightweight Convolution assuming the input is TxBxC
Args:
input_size: # of channels of the input
kernel_size: convolution channels
padding_l: padding to the left when using "same" padding
num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
weight_dropout: the drop rate of the DropConnect to drop the weight
weight_softmax: normalize the weight with softmax before the convolution
bias: use bias
Shape:
Input: TxBxC, i.e. (timesteps, batch_size, input_size)
Output: TxBxC, i.e. (timesteps, batch_size, input_size)
Attributes:
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
'''
def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=0., weight_softmax=False, bias=False):
super().__init__()
self.input_size = input_size
self.kernel_size = kernel_size
self.padding_l = padding_l
self.num_heads = num_heads
self.weight_dropout = weight_dropout
self.weight_softmax = weight_softmax
self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(input_size))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.constant_(self.bias, 0.)
def forward(self, x, incremental_state=None, unfold=False):
'''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
args:
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
incremental_state: A dict to keep the state
unfold: unfold the input or not. If not, we use the matrix trick instead
'''
unfold = unfold or (incremental_state is not None)
if unfold:
output = self._forward_unfolded(x, incremental_state)
else:
output = self._forward_expanded(x, incremental_state)
if self.bias is not None:
output = output + self.bias.view(1, 1, -1)
return output
def _forward_unfolded(self, x, incremental_state):
'''The conventional implementation of convolutions.
Unfolding the input by having a window shifting to the right.'''
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
weight = self.weight.view(H, K)
if incremental_state is not None:
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
x_unfold = x_unfold.view(T*B*H, R, -1)
else:
# unfold the input: T x B x C --> T' x B x C x K
x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0)
x_unfold = x_unfold.view(T*B*H, R, K)
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
if incremental_state is not None:
weight = weight[:, -x_unfold.size(2):]
K = weight.size(1)
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)
weight = F.dropout(weight, self.weight_dropout, training=self.training)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
output = output.view(T, B, C)
return output
def _forward_expanded(self, x, incremental_state):
'''Turn the convolution filters into band matrices and do matrix multiplication.
This is faster when the sequence is short, but less memory efficient.
This is not used in the decoder during inference.
'''
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
weight = self.weight.view(H, K)
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous()
weight = weight.view(T, B*H, K).transpose(0, 1)
x = x.view(T, B*H, R).transpose(0, 1)
P = self.padding_l
if K > T and P == K-1:
weight = weight.narrow(2, K-T, T)
K, P = T, T-1
# turn the convolution filters into band matrices
weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T)
weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training)
output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C)
return output
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def extra_repr(self):
s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}'.format(
self.input_size, self.kernel_size, self.padding_l,
self.num_heads, self.weight_softmax, self.bias is not None
)
if self.weight_dropout > 0.:
s += ', weight_dropout={}'.format(self.weight_dropout)
return s

View File

@ -0,0 +1,19 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn.functional as F
def unfold1d(x, kernel_size, padding_l, pad_value=0):
'''unfold T x B x C to T x B x C x K'''
if kernel_size > 1:
T, B, C = x.size()
x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value)
x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C))
else:
x = x.unsqueeze(3)
return x

View File

@ -0,0 +1,20 @@
#!/bin/bash
if [ $# -ne 1 ]; then
echo "usage: $0 GENERATE_PY_OUTPUT"
exit 1
fi
GEN=$1
SYS=$GEN.sys
REF=$GEN.ref
if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
echo "not done generating"
exit
fi
grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
python score.py --sys $SYS --ref $REF

28
scripts/sacrebleu_pregen.sh Executable file
View File

@ -0,0 +1,28 @@
#!/bin/bash
if [ $# -ne 4 ]; then
echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
exit 1
fi
TESTSET=$1
SRCLANG=$2
TGTLANG=$3
GEN=$4
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
grep ^H $GEN \
| sed 's/^H\-//' \
| sort -n -k 1 \
| cut -f 3 \
| perl $DETOKENIZER -l $TGTLANG \
| sed "s/ - /-/g" \
> $GEN.sorted.detok
sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok

View File

@ -138,6 +138,28 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'transformer_iwslt_de_en')
generate_main(data_dir)
def test_lightconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lightconv_iwslt_de_en', [
'--encoder-conv-type', 'lightweight',
'--decoder-conv-type', 'lightweight',
])
generate_main(data_dir)
def test_dynamicconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_dynamicconv') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lightconv_iwslt_de_en', [
'--encoder-conv-type', 'dynamic',
'--decoder-conv-type', 'dynamic',
])
generate_main(data_dir)
class TestStories(unittest.TestCase):

View File

@ -62,6 +62,7 @@ class TestLoadCheckpoint(unittest.TestCase):
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
'os.path.isfile': MagicMock(return_value=True),
'os.path.isabs': MagicMock(return_value=False),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]

View File

@ -328,7 +328,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
@ -344,6 +347,8 @@ def load_checkpoint(args, trainer, epoch_itr):
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
return True
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False