fairseq/examples/language_model
Sosuke Kobayashi 920b85d4bd Minor update of README.md of language model example (#1063)
Summary:
With this white space, the command might fail.
```
fairseq-preprocess: error: unrecognized arguments:
zsh: command not found: --destdir
```
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1063

Differential Revision: D17072516

Pulled By: myleott

fbshipit-source-id: 68bb9d05b40b215b18aceac2bff3f5ec1ef2f537
2019-08-27 08:40:54 -07:00
..
conv_lm Update READMEs 2019-08-14 08:28:36 -07:00
transformer_lm Update READMEs 2019-08-14 08:28:36 -07:00
prepare-wikitext-103.sh 0.6.1 -> 0.6.2 (#577) 2019-03-15 10:27:01 -07:00
README.md Minor update of README.md of language model example (#1063) 2019-08-27 08:40:54 -07:00

Neural Language Modeling

Pre-trained models

Model Description Dataset Download
transformer_lm.gbw.adaptive_huge Adaptive Inputs
(Baevski and Auli, 2018)
1026M params
Google Billion Words download (.tar.bz2)
transformer_lm.wiki103.adaptive Adaptive Inputs
(Baevski and Auli, 2018)
247M params
WikiText-103 download (.tar.bz2)
transformer_lm.wmt19.en English LM
(Ng et al., 2019)
WMT News Crawl download (.tar.gz)
transformer_lm.wmt19.de German LM
(Ng et al., 2019)
WMT News Crawl download (.tar.gz)
transformer_lm.wmt19.ru Russian LM
(Ng et al., 2019)
WMT News Crawl download (.tar.gz)

Example usage

To sample from a language model using PyTorch Hub:

import torch

# List available models
torch.hub.list('pytorch/fairseq')  # [..., 'transformer_lm.wmt19.en', ...]

# Load an English LM trained on WMT'19 News Crawl data
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')

# Sample from the language model
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
# "Barack Obama is coming to Sydney and New Zealand (...)"

# The same interface can be used with custom models as well
from fairseq.models.transformer_lm import TransformerLanguageModel
custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe')
custom_lm.sample('Barack Obama', beam=5)
# "Barack Obama (...)"

Training a transformer language model with the CLI tools

1) Preprocess the data

First download and prepare the WikiText-103 dataset:

cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..

Next preprocess/binarize the data:

TEXT=examples/language_model/wikitext-103
fairseq-preprocess \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

2) Train a language model

Next we'll train a transformer language model using adaptive inputs:

fairseq-train --task language_modeling \
    data-bin/wikitext-103 \
    --save-dir checkpoints/transformer_wikitext-103 \
    --arch transformer_lm_wiki103 \
    --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
    --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
    --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
    --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d

If the above command runs out of memory, try reducing --max-tokens (max number of tokens per batch) or --tokens-per-sample (max sequence length). You can also increase --update-freq to accumulate gradients and simulate training on more GPUs.

3) Evaluate

fairseq-eval-lm data-bin/wikitext-103 \
    --path checkpoints/transformer_wiki103/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 \
    --context-window 2560 --softmax-batch 1024

Convolutional language models

Please see the convolutional LM README for instructions to train convolutional language models.