Add fairseq to PyPI (#495)

Summary:
- fairseq can now be installed via pip: `pip install fairseq`
- command-line tools are globally accessible: `fairseq-preprocess`, `fairseq-train`, `fairseq-generate`, etc.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/495

Differential Revision: D14017761

Pulled By: myleott

fbshipit-source-id: 10c9f6634a3056074eac2f33324b4f1f404d4235
This commit is contained in:
Myle Ott 2019-02-08 22:00:46 -08:00 committed by Facebook Github Bot
parent cea0e4b9ea
commit fbd4cef9a5
30 changed files with 143 additions and 136 deletions

View File

@ -45,10 +45,18 @@ Please follow the instructions here: https://github.com/pytorch/pytorch#installa
If you use Docker make sure to increase the shared memory size either with
`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
After PyTorch is installed, you can install fairseq with:
After PyTorch is installed, you can install fairseq with `pip`:
```
pip install -r requirements.txt
python setup.py build develop
pip install fairseq
```
**Installing from source**
To install fairseq from source and develop locally:
```
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable .
```
# Getting Started

View File

@ -5,81 +5,81 @@ Command-line Tools
Fairseq provides several command-line tools for training and evaluating models:
- :ref:`preprocess.py`: Data pre-processing: build vocabularies and binarize training data
- :ref:`train.py`: Train a new model on one or multiple GPUs
- :ref:`generate.py`: Translate pre-processed data with a trained model
- :ref:`interactive.py`: Translate raw text with a trained model
- :ref:`score.py`: BLEU scoring of generated translations against reference translations
- :ref:`eval_lm.py`: Language model evaluation
- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
- :ref:`fairseq-interactive`: Translate raw text with a trained model
- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
- :ref:`fairseq-eval-lm`: Language model evaluation
.. _preprocess.py:
.. _fairseq-preprocess:
preprocess.py
~~~~~~~~~~~~~
fairseq-preprocess
~~~~~~~~~~~~~~~~~~
.. automodule:: preprocess
.. argparse::
:module: preprocess
:func: get_parser
:prog: preprocess.py
:module: fairseq.options
:func: get_preprocessing_parser
:prog: fairseq-preprocess
.. _train.py:
.. _fairseq-train:
train.py
~~~~~~~~
fairseq-train
~~~~~~~~~~~~~
.. automodule:: train
.. argparse::
:module: fairseq.options
:func: get_training_parser
:prog: train.py
:prog: fairseq-train
.. _generate.py:
.. _fairseq-generate:
generate.py
~~~~~~~~~~~
fairseq-generate
~~~~~~~~~~~~~~~~
.. automodule:: generate
.. argparse::
:module: fairseq.options
:func: get_generation_parser
:prog: generate.py
:prog: fairseq-generate
.. _interactive.py:
.. _fairseq-interactive:
interactive.py
~~~~~~~~~~~~~~
fairseq-interactive
~~~~~~~~~~~~~~~~~~~
.. automodule:: interactive
.. argparse::
:module: fairseq.options
:func: get_interactive_generation_parser
:prog: interactive.py
:prog: fairseq-interactive
.. _score.py:
.. _fairseq-score:
score.py
~~~~~~~~
fairseq-score
~~~~~~~~~~~~~
.. automodule:: score
.. argparse::
:module: score
:module: fairseq_cli.score
:func: get_parser
:prog: score.py
:prog: fairseq-score
.. _eval_lm.py:
.. _fairseq-eval-lm:
eval_lm.py
~~~~~~~~~~
fairseq-eval-lm
~~~~~~~~~~~~~~~
.. automodule:: eval_lm
.. argparse::
:module: fairseq.options
:func: get_eval_lm_parser
:prog: eval_lm.py
:prog: fairseq-eval-lm

View File

@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
# built documents.
#
# The short X.Y version.
version = '0.6.0'
version = '0.6.1'
# The full version, including alpha/beta/rc tags.
release = '0.6.0'
release = '0.6.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.

View File

@ -46,8 +46,6 @@ Dictionary
Iterators
---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator
:members:
.. autoclass:: fairseq.data.EpochBatchIterator

View File

@ -15,17 +15,17 @@ done with the
script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
used as a continuation marker and the original text can be easily
recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
flag to :ref:`generate.py`. Prior to BPE, input text needs to be tokenized
flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
using ``tokenizer.perl`` from
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
Let's use :ref:`interactive.py` to generate translations
Let's use :ref:`fairseq-interactive` to generate translations
interactively. Here, we use a beam size of 5:
.. code-block:: console
> MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \
> fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
@ -66,7 +66,7 @@ datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
> bash prepare-iwslt14.sh
> cd ../..
> TEXT=examples/translation/iwslt14.tokenized.de-en
> python preprocess.py --source-lang de --target-lang en \
> fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en
@ -76,17 +76,17 @@ This will write binarized data that can be used for model training to
Training
--------
Use :ref:`train.py` to train a new model. Here a few example settings that work
Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
well for the IWSLT 2014 dataset:
.. code-block:: console
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
By default, :ref:`train.py` will use all available GPUs on your machine. Use the
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
change the number of GPU devices that will be used.
@ -98,12 +98,12 @@ Generation
----------
Once your model is trained, you can generate translations using
:ref:`generate.py` **(for binarized data)** or
:ref:`interactive.py` **(for raw text)**:
:ref:`fairseq-generate` **(for binarized data)** or
:ref:`fairseq-interactive` **(for raw text)**:
.. code-block:: console
> python generate.py data-bin/iwslt14.tokenized.de-en \
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5
| [de] dictionary: 35475 types
@ -136,7 +136,7 @@ to training on 8 GPUs:
.. code-block:: console
> CUDA_VISIBLE_DEVICES=0 python train.py --update-freq 8 (...)
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
Training with half precision floating point (FP16)
--------------------------------------------------
@ -152,7 +152,7 @@ Fairseq supports FP16 training with the ``--fp16`` flag:
.. code-block:: console
> python train.py --fp16 (...)
> fairseq-train --fp16 (...)
Lazily loading large training datasets
--------------------------------------
@ -178,7 +178,7 @@ replacing ``node_rank=0`` with ``node_rank=1`` on the second node:
> python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
--master_port=1234 \
train.py data-bin/wmt16_en_de_bpe32k \
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \

View File

@ -29,6 +29,6 @@ epoch boundaries via :func:`step`.
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:

View File

@ -49,7 +49,10 @@ new plug-ins.
**Loading plug-ins from another directory**
New plug-ins can be defined in a custom module stored in the user system. In order to import the module, and make the plugin available to *fairseq*, the command line supports the ``--user-dir`` flag that can be used to specify a custom location for additional modules to load into *fairseq*.
New plug-ins can be defined in a custom module stored in the user system. In
order to import the module, and make the plugin available to *fairseq*, the
command line supports the ``--user-dir`` flag that can be used to specify a
custom location for additional modules to load into *fairseq*.
For example, assuming this directory tree::
@ -65,6 +68,6 @@ with ``__init__.py``::
def transformer_mmt_big(args):
transformer_vaswani_wmt_en_de_big(args)
it is possible to invoke the ``train.py`` script with the new architecture with::
it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
python3 train.py ... --user-dir /home/user/my-module -a my_transformer --task translation
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation

View File

@ -28,7 +28,7 @@ train, valid and test sets.
Download and extract the data from here:
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
Once extracted, let's preprocess the data using the :ref:`preprocess.py`
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
command-line tool to create the dictionaries. While this tool is primarily
intended for sequence-to-sequence problems, we're able to reuse it here by
treating the label as a "target" sequence of length 1. We'll also output the
@ -37,7 +37,7 @@ enhance readability:
.. code-block:: console
> python preprocess.py \
> fairseq-preprocess \
--trainpref names/train --validpref names/valid --testpref names/test \
--source-lang input --target-lang label \
--destdir names-bin --output-format raw
@ -324,7 +324,7 @@ following contents::
4. Training the Model
---------------------
Now we're ready to train the model. We can use the existing :ref:`train.py`
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
command-line tool for this, making sure to specify our new Task (``--task
simple_classification``) and Model architecture (``--arch
pytorch_tutorial_rnn``):
@ -332,11 +332,11 @@ pytorch_tutorial_rnn``):
.. note::
You can also configure the dimensionality of the hidden state by passing the
``--hidden-dim`` argument to :ref:`train.py`.
``--hidden-dim`` argument to :ref:`fairseq-train`.
.. code-block:: console
> python train.py names-bin \
> fairseq-train names-bin \
--task simple_classification \
--arch pytorch_tutorial_rnn \
--optimizer adam --lr 0.001 --lr-shrink 0.5 \

View File

@ -341,7 +341,7 @@ function decorator. Thereafter this named architecture can be used with the
3. Training the Model
---------------------
Now we're ready to train the model. We can use the existing :ref:`train.py`
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
command-line tool for this, making sure to specify our new Model architecture
(``--arch tutorial_simple_lstm``).
@ -352,7 +352,7 @@ command-line tool for this, making sure to specify our new Model architecture
.. code-block:: console
> python train.py data-bin/iwslt14.tokenized.de-en \
> fairseq-train data-bin/iwslt14.tokenized.de-en \
--arch tutorial_simple_lstm \
--encoder-dropout 0.2 --decoder-dropout 0.2 \
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
@ -362,12 +362,12 @@ command-line tool for this, making sure to specify our new Model architecture
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
The model files should appear in the :file:`checkpoints/` directory. While this
model architecture is not very good, we can use the :ref:`generate.py` script to
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
generate translations and compute our BLEU score over the test set:
.. code-block:: console
> python generate.py data-bin/iwslt14.tokenized.de-en \
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
@ -498,7 +498,7 @@ Finally, we can rerun generation and observe the speedup:
# Before
> python generate.py data-bin/iwslt14.tokenized.de-en \
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
@ -508,7 +508,7 @@ Finally, we can rerun generation and observe the speedup:
# After
> python generate.py data-bin/iwslt14.tokenized.de-en \
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe

View File

@ -24,20 +24,20 @@ $ cd ../..
# Binarize the dataset:
$ TEXT=examples/language_model/wikitext-103
$ python preprocess.py --only-source \
$ fairseq-preprocess --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
$ python train.py --task language_modeling data-bin/wikitext-103 \
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
# Evaluate:
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
```

View File

@ -45,7 +45,7 @@ Training and evaluating DynamicConv (without GLU) on a GPU:
# Training
SAVE="save/dynamic_conv_iwslt"
mkdir -p $SAVE
CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \
--clip-norm 0 --optimizer adam --lr 0.0005 \
--source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \
--log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \
@ -61,7 +61,7 @@ python scripts/average_checkpoints.py --inputs $SAVE \
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet
```
### WMT16 En-De
@ -70,7 +70,7 @@ Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine sched
# Training
SAVE="save/dynamic_conv_wmt16en2de"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 train.py \
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
@ -86,7 +86,7 @@ python -m torch.distributed.launch --nproc_per_node 8 train.py \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
```
@ -96,7 +96,7 @@ Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one mac
# Training
SAVE="save/dynamic_conv_wmt14en2fr"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 train.py \
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
@ -112,5 +112,5 @@ python -m torch.distributed.launch --nproc_per_node 8 train.py \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 python generate.py data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test
```

View File

@ -23,7 +23,7 @@ $ tar -xzvf wmt16_en_de.tar.gz -C $TEXT
2. Preprocess the dataset with a joined dictionary:
```
$ python preprocess.py --source-lang en --target-lang de \
$ fairseq-preprocess --source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
@ -34,7 +34,7 @@ $ python preprocess.py --source-lang en --target-lang de \
3. Train a model:
```
$ python train.py data-bin/wmt16_en_de_bpe32k \
$ fairseq-train data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \

View File

@ -39,12 +39,12 @@ $ o.write(line.strip() + "\n")
# Binarize the dataset:
$ export TEXT=examples/stories/writingPrompts
$ python preprocess.py --source-lang wp_source --target-lang wp_target \
$ fairseq-preprocess --source-lang wp_source --target-lang wp_target \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10
# Train the model:
$ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
$ fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
# Train a fusion model:
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
@ -52,7 +52,7 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-
# Generate:
# Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary.
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
$ fairseq-generate data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
```
## Citation

View File

@ -18,17 +18,17 @@ Generation with the binarized test sets can be run in batch mode as follows, e.g
$ mkdir -p data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
$ python generate.py data-bin/wmt14.en-fr.newstest2014 \
$ fairseq-generate data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
# Scoring with score.py:
# Compute BLEU score
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
$ fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
@ -48,20 +48,20 @@ $ cd ../..
# Binarize the dataset:
$ TEXT=examples/translation/iwslt14.tokenized.de-en
$ python preprocess.py --source-lang de --target-lang en \
$ fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/fconv
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 200 \
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
$ fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5 --remove-bpe
@ -73,7 +73,7 @@ To train transformer model on IWSLT'14 German to English:
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
-a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
--label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
--min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
@ -86,7 +86,7 @@ $ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
--num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
$ fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/transformer/model.pt \
--batch-size 128 --beam 5 --remove-bpe
@ -113,21 +113,21 @@ $ cd ../..
# Binarize the dataset:
$ TEXT=examples/translation/wmt14_en_de
$ python preprocess.py --source-lang en --target-lang de \
$ fairseq-preprocess --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0
# Train the model:
# If it runs out of memory, try to set --max-tokens 1500 instead
$ mkdir -p checkpoints/fconv_wmt_en_de
$ python train.py data-bin/wmt14_en_de \
$ fairseq-train data-bin/wmt14_en_de \
--lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de
# Generate:
$ python generate.py data-bin/wmt14_en_de \
$ fairseq-generate data-bin/wmt14_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe
```
@ -145,21 +145,21 @@ $ cd ../..
# Binarize the dataset:
$ TEXT=examples/translation/wmt14_en_fr
$ python preprocess.py --source-lang en --target-lang fr \
$ fairseq-preprocess --source-lang en --target-lang fr \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0
# Train the model:
# If it runs out of memory, try to set --max-tokens 1000 instead
$ mkdir -p checkpoints/fconv_wmt_en_fr
$ python train.py data-bin/wmt14_en_fr \
$ fairseq-train data-bin/wmt14_en_fr \
--lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_fr --save-dir checkpoints/fconv_wmt_en_fr
# Generate:
$ python generate.py data-bin/fconv_wmt_en_fr \
$ fairseq-generate data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
```

View File

@ -8,7 +8,7 @@
from .multiprocessing_pdb import pdb
__all__ = ['pdb']
__version__ = '0.6.0'
__version__ = '0.6.1'
import fairseq.criterions
import fairseq.models

View File

@ -13,7 +13,7 @@ try:
from fairseq import libbleu
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py build develop`\n')
sys.stderr.write('ERROR: missing libbleu.so. run `pip install --editable .`\n')
raise e

View File

@ -41,9 +41,9 @@ class LanguageModelingTask(FairseqTask):
.. note::
The language modeling task is compatible with :mod:`train.py <train>`,
:mod:`generate.py <generate>`, :mod:`interactive.py <interactive>` and
:mod:`eval_lm.py <eval_lm>`.
The language modeling task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate`, :mod:`fairseq-interactive` and
:mod:`fairseq-eval-lm`.
The language modeling task provides the following additional command-line
arguments:

View File

@ -33,8 +33,8 @@ class TranslationTask(FairseqTask):
.. note::
The translation task is compatible with :mod:`train.py <train>`,
:mod:`generate.py <generate>` and :mod:`interactive.py <interactive>`.
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:

0
fairseq_cli/__init__.py Normal file
View File

1
fairseq_cli/eval_lm.py Symbolic link
View File

@ -0,0 +1 @@
../eval_lm.py

1
fairseq_cli/generate.py Symbolic link
View File

@ -0,0 +1 @@
../generate.py

1
fairseq_cli/interactive.py Symbolic link
View File

@ -0,0 +1 @@
../interactive.py

1
fairseq_cli/preprocess.py Symbolic link
View File

@ -0,0 +1 @@
../preprocess.py

1
fairseq_cli/score.py Symbolic link
View File

@ -0,0 +1 @@
../score.py

1
fairseq_cli/setup.py Symbolic link
View File

@ -0,0 +1 @@
../setup.py

1
fairseq_cli/train.py Symbolic link
View File

@ -0,0 +1 @@
../train.py

View File

@ -1,4 +0,0 @@
cffi
numpy
torch
tqdm

View File

@ -17,4 +17,4 @@ fi
grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
python score.py --sys $SYS --ref $REF
fairseq-score --sys $SYS --ref $REF

View File

@ -16,12 +16,6 @@ if sys.version_info < (3,):
with open('README.md') as f:
readme = f.read()
with open('LICENSE') as f:
license = f.read()
with open('requirements.txt') as f:
reqs = f.read()
bleu = Extension(
'fairseq.libbleu',
@ -35,22 +29,33 @@ bleu = Extension(
setup(
name='fairseq',
version='0.6.0',
version='0.6.1',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
url='https://github.com/pytorch/fairseq',
classifiers=[
'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License',
'Programming Language :: Python :: 3.6',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
long_description=readme,
license=license,
install_requires=reqs.strip().split('\n'),
packages=find_packages(),
install_requires=[
'cffi',
'numpy',
'torch',
'tqdm',
],
packages=find_packages(exclude=['scripts', 'tests']),
ext_modules=[bleu],
test_suite='tests',
entry_points={
'console_scripts': [
'fairseq-eval-lm = eval_lm:cli_main',
'fairseq-generate = generate:cli_main',
'fairseq-interactive = interactive:cli_main',
'fairseq-preprocess = preprocess:cli_main',
'fairseq-train = train:cli_main',
'fairseq-score = score:main',
'fairseq-eval-lm = fairseq_cli.eval_lm:cli_main',
'fairseq-generate = fairseq_cli.generate:cli_main',
'fairseq-interactive = fairseq_cli.interactive:cli_main',
'fairseq-preprocess = fairseq_cli.preprocess:cli_main',
'fairseq-train = fairseq_cli.train:cli_main',
'fairseq-score = fairseq_cli.score:main',
],
},
)

View File

@ -395,18 +395,8 @@ def cli_main():
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id
print(
'''| NOTE: you may get better performance with:
python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
'''.format(
ngpu=args.distributed_world_size,
no_c10d=(
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
else ''
),
)
)
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),