Initial commit

This commit is contained in:
Sergey Edunov 2017-09-14 17:22:43 -07:00
commit e734b0fa58
46 changed files with 4773 additions and 0 deletions

104
.gitignore vendored Normal file
View File

@ -0,0 +1,104 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Checkpoints
checkpoints
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/

30
LICENSE Normal file
View File

@ -0,0 +1,30 @@
BSD License
For fairseq software
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

33
PATENTS Normal file
View File

@ -0,0 +1,33 @@
Additional Grant of Patent Rights Version 2
"Software" means the fairseq software distributed by Facebook, Inc.
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
(subject to the termination provision below) license under any Necessary
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
transfer the Software. For avoidance of doubt, no license is granted under
Facebooks rights in any patent claims that are infringed by (i) modifications
to the Software made by you or any third party or (ii) the Software in
combination with any software or other technology.
The license granted hereunder will terminate, automatically and without notice,
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
directly or indirectly, or take a direct financial interest in, any Patent
Assertion: (i) against Facebook or any of its subsidiaries or corporate
affiliates, (ii) against any party if such Patent Assertion arises in whole or
in part from any software, technology, product or service of Facebook or any of
its subsidiaries or corporate affiliates, or (iii) against any party relating
to the Software. Notwithstanding the foregoing, if Facebook or any of its
subsidiaries or corporate affiliates files a lawsuit alleging patent
infringement against you in the first instance, and you respond by filing a
patent infringement counterclaim in that lawsuit against that party that is
unrelated to the Software, the license granted hereunder will not terminate
under section (i) of this paragraph due to such counterclaim.
A "Necessary Claim" is a claim of a patent owned by Facebook that is
necessarily infringed by the Software standing alone.
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
or contributory infringement or inducement to infringe any patent, including a
cross-claim or counterclaim.

191
README.md Normal file
View File

@ -0,0 +1,191 @@
# Introduction
FAIR Sequence-to-Sequence Toolkit (PyTorch)
This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122). The toolkit features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation.
![Model](fairseq.gif)
# Citation
If you use the code in your paper, then please cite it as:
```
@inproceedings{gehring2017convs2s,
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
title = "{Convolutional Sequence to Sequence Learning}",
booktitle = {Proc. of ICML},
year = 2017,
}
```
# Requirements and Installation
* A computer running macOS or Linux
* For training new models, you'll also need a NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
* A [PyTorch installation](http://pytorch.org/)
Currently fairseq-py requires PyTorch from the GitHub repository. There are multiple ways of installing it.
We suggest using [Miniconda3](https://conda.io/miniconda.html) and the following instructions.
* Install Miniconda3 from https://conda.io/miniconda.html create and activate python 3 environment.
```
conda install gcc numpy cudnn nccl
conda install magma-cuda80 -c soumith
pip install cmake
pip install cffi
git clone https://github.com/pytorch/pytorch.git
cd pytorch
git reset --hard a03e5cb40938b6b3f3e6dbddf9cff8afdff72d1b
git submodule update --init
pip install -r requirements.txt
NO_DISTRIBUTED=1 python setup.py install
```
Install fairseq by cloning the GitHub repository and by running
```
pip install -r requirements.txt
python setup.py build
python setup.py develop
```
The following command-line tools are available:
* `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data
* `python train.py`: Train a new model on one or multiple GPUs
* `python generate.py`: Translate pre-processed data with a trained model
* `python generate.py -i`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations
# Quick Start
## Evaluating Pre-trained Models [TO BE ADAPTED]
First, download a pre-trained model along with its vocabularies:
```
$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf -
```
This model uses a [Byte Pair Encoding (BPE) vocabulary](https://arxiv.org/abs/1508.07909), so we'll have to apply the encoding to the source text before it can be translated.
This can be done with the [apply_bpe.py](https://github.com/rsennrich/subword-nmt/blob/master/apply_bpe.py) script using the `wmt14.en-fr.fconv-cuda/bpecodes` file.
`@@` is used as a continuation marker and the original text can be easily recovered with e.g. `sed s/@@ //g` or by passing the `--remove-bpe` flag to `generate.py`.
Prior to BPE, input text needs to be tokenized using `tokenizer.perl` from [mosesdecoder](https://github.com/moses-smt/mosesdecoder).
Let's use `python generate.py -i` to generate translations.
Here, we use a beam size of 5:
```
$ MODEL_DIR=wmt14.en-fr.fconv-py
$ python generate.py -i \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| model fconv_wmt_en_fr
| loaded checkpoint /private/home/edunov/wmt14.en-fr.fconv-py/model.pt (epoch 37)
> Why is it rare to discover new marine mam@@ mal species ?
S Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.08662842959165573 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 10 8 8 8 11 12
```
This generation script produces four types of outputs: a line prefixed with *S* shows the supplied source sentence after applying the vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *A* is the attention maxima for each word in the hypothesis, including the end-of-sentence marker which is omitted from the text.
Check [below](#pre-trained-models) for a full list of pre-trained models available.
## Training a New Model
### Data Pre-processing
The fairseq source distribution contains an example pre-processing script for
the IWSLT 2014 German-English corpus.
Pre-process and binarize the data as follows:
```
$ cd data/
$ bash prepare-iwslt14.sh
$ cd ..
$ TEXT=data/iwslt14.tokenized.de-en
$ python preprocess.py --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--thresholdtgt 3 --thresholdsrc 3 --destdir data-bin/iwslt14.tokenized.de-en
```
This will write binarized data that can be used for model training to `data-bin/iwslt14.tokenized.de-en`.
### Training
Use `python train.py` to train a new model.
Here a few example settings that work well for the IWSLT 2014 dataset:
```
$ mkdir -p trainings/fconv
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--encoder-layers "[(256, 3)] * 4" --decoder-layers "[(256, 3)] * 3" \
--encoder-embed-dim 256 --decoder-embed-dim 256 --save-dir trainings/fconv
```
By default, `python train.py` will use all available GPUs on your machine.
Use the [CUDA_VISIBLE_DEVICES](http://acceleware.com/blog/cudavisibledevices-masking-gpus) environment variable to select specific GPUs and/or to change the number of GPU devices that will be used.
Also note that the batch size is specified in terms of the maximum number of tokens per batch (`--max-tokens`).
You may need to use a smaller value depending on the available GPU memory on your system.
### Generation
Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python generate.py -i` **(for raw text)**:
```
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path trainings/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
| data-bin/iwslt14.tokenized.de-en test 6750 examples
| model fconv
| loaded checkpoint trainings/fconv/checkpoint_best.pt
S-721 danke .
T-721 thank you .
...
```
To generate translations with only a CPU, use the `--cpu` flag.
BPE continuation markers can be removed with the `--remove-bpe` flag.
# Pre-trained Models
We provide the following pre-trained fully convolutional sequence-to-sequence models:
* [wmt14.en-fr.fconv-py.tar.bz2](https://s3.amazonaws.com/faiseq-py/models/wmt14.en-fr.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) including vocabularies
* [wmt14.en-de.fconv-py.tar.bz2](https://s3.amazonaws.com/faiseq-py/models/wmt14.en-de.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) including vocabularies
In addition, we provide pre-processed and binarized test sets for the models above:
* [wmt14.en-fr.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-French
* [wmt14.en-fr.ntst1213.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.ntst1213.tar.bz2): newstest2012 and newstest2013 test sets for WMT14 English-French
* [wmt14.en-de.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-de.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-German
Generation with the binarized test sets can be run in batch mode as follows, e.g. for English-French on a GTX-1080ti:
```
$ curl https://s3.amazonaws.com/faiseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
$ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s)
| Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%)
| BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448)
# Word-level BLEU scoring:
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
TODO: update scores
BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194)
```
# Join the fairseq community
* Facebook page: https://www.facebook.com/groups/fairseq.users
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
# License
fairseq is BSD-licensed.
The license applies to the pre-trained models as well.
We also provide an additional patent grant.

93
data/prepare-iwslt14.sh Normal file
View File

@ -0,0 +1,93 @@
#!/usr/bin/env bash
#
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz"
GZ=de-en.tgz
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=de
tgt=en
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
echo "Downloading data from ${URL}..."
cd $orig
wget "$URL"
if [ -f $GZ ]; then
echo "Data successfully downloaded."
else
echo "Data not successfully downloaded."
exit
fi
tar zxvf $GZ
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
f=train.tags.$lang.$l
tok=train.tags.$lang.tok.$l
cat $orig/$lang/$f | \
grep -v '<url>' | \
grep -v '<talkid>' | \
grep -v '<keywords>' | \
sed -e 's/<title>//g' | \
sed -e 's/<\/title>//g' | \
sed -e 's/<description>//g' | \
sed -e 's/<\/description>//g' | \
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
echo "pre-processing valid/test data..."
for l in $src $tgt; do
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
fname=${o##*/}
f=$tmp/${fname%.*}
echo $o $f
grep '<seg id' $o | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\/\'/g" | \
perl $TOKENIZER -threads 8 -l $l | \
perl $LC > $f
echo ""
done
done
echo "creating train, valid, test..."
for l in $src $tgt; do
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/valid.$l
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/train.$l
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
$tmp/IWSLT14.TED.tst2010.de-en.$l \
$tmp/IWSLT14.TED.tst2011.de-en.$l \
$tmp/IWSLT14.TED.tst2012.de-en.$l \
> $prep/test.$l
done

BIN
fairseq.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 MiB

9
fairseq/__init__.py Normal file
View File

@ -0,0 +1,9 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .multiprocessing_pdb import pdb

106
fairseq/bleu.py Normal file
View File

@ -0,0 +1,106 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import ctypes
import math
import torch
try:
from fairseq import libbleu
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n')
raise e
C = ctypes.cdll.LoadLibrary(libbleu.__file__)
class BleuStat(ctypes.Structure):
_fields_ = [
('reflen', ctypes.c_size_t),
('predlen', ctypes.c_size_t),
('match1', ctypes.c_size_t),
('count1', ctypes.c_size_t),
('match2', ctypes.c_size_t),
('count2', ctypes.c_size_t),
('match3', ctypes.c_size_t),
('count3', ctypes.c_size_t),
('match4', ctypes.c_size_t),
('count4', ctypes.c_size_t),
]
class Scorer(object):
def __init__(self, pad, eos, unk):
self.stat = BleuStat()
self.pad = pad
self.eos = eos
self.unk = unk
self.reset()
def reset(self, one_init=False):
if one_init:
C.bleu_one_init(ctypes.byref(self.stat))
else:
C.bleu_zero_init(ctypes.byref(self.stat))
def add(self, ref, pred):
if not isinstance(ref, torch.IntTensor):
raise TypeError('ref must be a torch.IntTensor (got {})'
.format(type(ref)))
if not isinstance(pred, torch.IntTensor):
raise TypeError('pred must be a torch.IntTensor(got {})'
.format(type(pred)))
assert self.unk > 0, 'unknown token index must be >0'
rref = ref.clone()
rref.apply_(lambda x: x if x != self.unk else -x)
rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),
ctypes.c_size_t(pred.size(0)),
ctypes.c_void_p(pred.data_ptr()),
ctypes.c_int(self.pad),
ctypes.c_int(self.eos))
def score(self, order=4):
psum = sum(math.log(p) if p > 0 else float('-Inf')
for p in self.precision()[:order])
return self.brevity() * math.exp(psum / order) * 100
def precision(self):
def ratio(a, b):
return a / b if b > 0 else 0
return [
ratio(self.stat.match1, self.stat.count1),
ratio(self.stat.match2, self.stat.count2),
ratio(self.stat.match3, self.stat.count3),
ratio(self.stat.match4, self.stat.count4),
]
def brevity(self):
r = self.stat.reflen / self.stat.predlen
return min(1, math.exp(1 - r))
def result_string(self, order=4):
assert order <= 4, "BLEU scores for order > 4 aren't supported"
fmt = 'BLEU{} = {:2.2f}, {:2.1f}'
for i in range(1, order):
fmt += '/{:2.1f}'
fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})'
bleup = [p * 100 for p in self.precision()[:order]]
return fmt.format(order, self.score(order=order), *bleup,
self.brevity(), self.stat.reflen/self.stat.predlen,
self.stat.predlen, self.stat.reflen)

View File

@ -0,0 +1,132 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <map>
#include <array>
#include <cstring>
#include <cstdio>
typedef struct
{
size_t reflen;
size_t predlen;
size_t match1;
size_t count1;
size_t match2;
size_t count2;
size_t match3;
size_t count3;
size_t match4;
size_t count4;
} bleu_stat;
// left trim (remove pad)
void bleu_ltrim(size_t* len, int** sent, int pad) {
size_t start = 0;
while(start < *len) {
if (*(*sent + start) != pad) { break; }
start++;
}
*sent += start;
*len -= start;
}
// right trim remove (eos)
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
size_t end = *len - 1;
while (end > 0) {
if (*(*sent + end) != eos && *(*sent + end) != pad) { break; }
end--;
}
*len = end + 1;
}
// left and right trim
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
bleu_ltrim(len, sent, pad);
bleu_rtrim(len, sent, pad, eos);
}
size_t bleu_hash(int len, int* data) {
size_t h = 14695981039346656037ul;
size_t prime = 0x100000001b3;
char* b = (char*) data;
size_t blen = sizeof(int) * len;
while (blen-- > 0) {
h ^= *b++;
h *= prime;
}
return h;
}
void bleu_addngram(
size_t *ntotal, size_t *nmatch, size_t n,
size_t reflen, int* ref, size_t predlen, int* pred) {
if (predlen < n) { return; }
predlen = predlen - n + 1;
(*ntotal) += predlen;
if (reflen < n) { return; }
reflen = reflen - n + 1;
std::map<size_t, size_t> count;
while (predlen > 0) {
size_t w = bleu_hash(n, pred++);
count[w]++;
predlen--;
}
while (reflen > 0) {
size_t w = bleu_hash(n, ref++);
if (count[w] > 0) {
(*nmatch)++;
count[w] -=1;
}
reflen--;
}
}
extern "C" {
void bleu_zero_init(bleu_stat* stat) {
std::memset(stat, 0, sizeof(bleu_stat));
}
void bleu_one_init(bleu_stat* stat) {
bleu_zero_init(stat);
stat->count1 = 1;
stat->count2 = 1;
stat->count3 = 1;
stat->count4 = 1;
stat->match1 = 1;
stat->match2 = 1;
stat->match3 = 1;
stat->match4 = 1;
}
void bleu_add(
bleu_stat* stat,
size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) {
bleu_trim(&reflen, &ref, pad, eos);
bleu_trim(&predlen, &pred, pad, eos);
stat->reflen += reflen;
stat->predlen += predlen;
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
}
}

View File

@ -0,0 +1,37 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <Python.h>
static PyMethodDef method_def[] = {
{NULL, NULL, 0, NULL}
};
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"libbleu", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_def
};
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_libbleu()
#else
PyMODINIT_FUNC PyInit_libbleu()
#endif
{
PyObject *m = PyModule_Create(&module_def);
if (!m) {
return NULL;
}
return m;
}

View File

@ -0,0 +1,130 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <string.h>
#include <stdexcept>
#include <ATen/ATen.h>
using at::Tensor;
extern THCState* state;
at::Type& getDataType(const char* dtype) {
if (strcmp(dtype, "torch.cuda.FloatTensor") == 0) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}
}
inline at::Tensor t(at::Type& type, void* i) {
return type.unsafeTensorFromTH(i, true);
}
extern "C" void TemporalConvolutionTBC_forward(
const char* dtype,
void* _input,
void* _output,
void* _weight,
void* _bias)
{
auto& type = getDataType(dtype);
Tensor input = t(type, _input);
Tensor output = t(type, _output);
Tensor weight = t(type, _weight);
Tensor bias = t(type, _bias);
auto input_size = input.sizes();
auto output_size = output.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
// input * weights + bias -> output_features
output.copy_(bias.expand(output.sizes()));
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if (t > 0) {
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
at::addmm_out(1, O, 1, I, W, O);
}
}
}
extern "C" void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight)
{
auto& type = getDataType(dtype);
Tensor dOutput = t(type, _dOutput);
Tensor dInput = t(type, _dInput);
Tensor dWeight = t(type, _dWeight);
Tensor dBias = t(type, _dBias);
Tensor input = t(type, _input);
Tensor weight = t(type, _weight);
auto input_size = input.sizes();
auto output_size = dOutput.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// dOutput * T(weight) -> dInput
if (t > 0) {
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
at::addmm_out(1, dI, 1, dO, weight[k].t(), dI);
}
}
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// T(input) * dOutput -> dWeight
if (t > 0) {
auto dW = dWeight[k];
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
at::addmm_out(1, dW, 1, I, dO, dW);
}
}
auto tmp = dOutput.sum(0, false);
at::sum_out(tmp, 0, dBias);
}

View File

@ -0,0 +1,23 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
void TemporalConvolutionTBC_forward(
const char* dtype,
void* input,
void* output,
void* weight,
void* bias);
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight);

View File

@ -0,0 +1,16 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .cross_entropy import CrossEntropyCriterion
from .fairseq_criterion import FairseqCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
__all__ = [
'CrossEntropyCriterion',
'LabelSmoothedCrossEntropyCriterion',
]

View File

@ -0,0 +1,31 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, padding_idx):
super().__init__()
self.padding_idx = padding_idx
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return loss / self.denom
def aggregate(self, losses):
return sum(losses) / math.log(2)

View File

@ -0,0 +1,31 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def prepare(self, samples):
"""Prepare criterion for DataParallel training."""
raise NotImplementedError
def forward(self, net_output, sample):
"""Compute the loss for the given sample and network output."""
raise NotImplementedError
def aggregate(self, losses):
"""Aggregate losses from DataParallel training.
Takes a list of losses as input (as returned by forward) and
aggregates them into the total loss for the mini-batch.
"""
raise NotImplementedError

View File

@ -0,0 +1,62 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.autograd.variable import Variable
import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
norm = grad_input.size(-1)
if weights is not None:
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
if padding_idx is not None:
norm -= 1 if weights is None else weights[padding_idx]
grad_input.select(grad_input.dim() - 1, padding_idx).fill_(0)
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
return input.new([grad_input.view(-1).dot(input.view(-1))])
@staticmethod
def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, eps, padding_idx=None, weights=None):
super().__init__()
self.eps = eps
self.padding_idx = padding_idx
self.weights = weights
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return loss / self.denom
def aggregate(self, losses):
return sum(losses) / math.log(2)

313
fairseq/data.py Normal file
View File

@ -0,0 +1,313 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import contextlib
import itertools
import numpy as np
import os
import torch
import torch.utils.data
from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
def find_language_pair(files):
for filename in files:
parts = filename.split('.')
if parts[0] == 'train' and parts[-1] == 'idx':
return parts[1].split('-')
def train_file_exists(src, dst):
filename = 'train.{0}-{1}.{0}.idx'.format(src, dst)
return os.path.exists(os.path.join(path, filename))
if src is None and dst is None:
# find language pair automatically
src, dst = find_language_pair(os.listdir(path))
elif train_file_exists(src, dst):
# check for src-dst langcode
pass
elif train_file_exists(dst, src):
# check for dst-src langcode
src, dst = dst, src
else:
raise ValueError('training file not found for {}-{}'.format(src, dst))
dataset = load(path, src, dst)
return dataset
def load(path, src, dst):
"""Loads the train, valid, and test sets from the specified folder."""
langcode = '{}-{}'.format(src, dst)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
src_dict = Dictionary.load(fmt_path('dict.{}.txt', src))
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in ['train', 'valid', 'test']:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
if not IndexedInMemoryDataset.exists(src_path):
break
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)),
padding_value=dataset.src_dict.pad(),
eos=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024):
dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst,
max_tokens=max_tokens, epoch=epoch,
sample=sample_without_replacement,
max_positions=max_positions)
elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions))
else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions))
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=PaddingCollater(self.src_dict.pad()),
batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
res = []
idx = 0
for i, sample in enumerate(it):
if i < offset:
continue
res.append(sample)
if len(res) >= ngpus:
yield (i, res)
res = []
idx = i + 1
if len(res) > 0:
yield (idx, res)
class PaddingCollater(object):
def __init__(self, padding_value=1):
self.padding_value = padding_value
def __call__(self, samples):
def merge(key, pad_begin):
return self.merge_with_pad([s[key] for s in samples], pad_begin)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('input_tokens', pad_begin=True),
'input_positions': merge('input_positions', pad_begin=True),
'target': merge('target', pad_begin=True),
'src_tokens': merge('src_tokens', pad_begin=False),
'src_positions': merge('src_positions', pad_begin=False),
'ntokens': ntokens,
}
def merge_with_pad(self, values, pad_begin):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(self.padding_value)
for i, v in enumerate(values):
if pad_begin:
res[i][size-len(v):].copy_(v)
else:
res[i][:len(v)].copy_(v)
return res
class LanguagePairDataset(object):
def __init__(self, src, dst, padding_value=1, eos=2):
self.src = src
self.dst = dst
self.padding_value = padding_value
self.eos = eos
def __getitem__(self, i):
src = self.src[i].long() - 1
target = self.dst[i].long() - 1
input = target.new(target.size())
input[0] = self.eos
input[1:].copy_(target[:-1])
return {
'id': i,
'input_tokens': input,
'input_positions': self.make_positions(input),
'target': target,
'src_tokens': src,
'src_positions': self.make_positions(src),
}
def make_positions(self, x):
start = self.padding_value + 1
return torch.arange(start, start + len(x)).type_as(x)
def __len__(self):
return len(self.src)
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
assert dst is None or isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
sizes = src.sizes
indices = np.argsort(sizes, kind='mergesort')
if dst is not None:
sizes = np.maximum(sizes, dst.sizes)
batch = []
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == batch_size:
return True
if sizes[batch[0]] != sizes[next_idx]:
return True
if num_tokens >= max_tokens:
return True
return False
cur_max_size = 0
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or \
(dst is not None and dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2:
raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch
batch = []
cur_max_size = 0
batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx])
if len(batch) > 0:
yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_positions=1024):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
def make_batches():
batch = []
sample_len = 0
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions - 2 or \
dst.sizes[idx] > max_positions - 2:
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0 and (len(batch) + 1) * sample_len > max_tokens:
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
batches = list(make_batches())
np.random.shuffle(batches)
if sample:
offset = (epoch - 1) * sample
while offset > len(batches):
np.random.shuffle(batches)
offset -= len(batches)
result = batches[offset:(offset + sample)]
while len(result) < sample:
np.random.shuffle(batches)
result += batches[:(sample - len(result))]
assert len(result) == sample, \
"batch length is not correct {}".format(len(result))
batches = result
else:
for i in range(epoch - 1):
np.random.shuffle(batches)
return batches
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)

117
fairseq/dictionary.py Normal file
View File

@ -0,0 +1,117 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols)
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(self, tensor):
if torch.is_tensor(tensor) and tensor.dim() == 2:
sentences = [self.string(line) for line in tensor]
return '\n'.join(sentences)
eos = self.eos()
return ' '.join([self[i] for i in tensor if i != eos])
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def finalize(self):
"""Sort symbols by frequency in descending order, ignoring special ones."""
self.count, self.symbols = zip(
*sorted(zip(self.count, self.symbols),
key=(lambda x: math.inf if self.indices[x[1]] < self.nspecial else x[0]),
reverse=True)
)
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@staticmethod
def load(f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
if isinstance(f, str):
with open(f, 'r') as fd:
return Dictionary.load(fd)
d = Dictionary()
for line in f.readlines():
idx = line.rfind(' ')
word = line[:idx]
count = int(line[idx+1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f, threshold=3, nwords=-1):
"""Stores dictionary into a text file"""
if isinstance(f, str):
with open(f, 'w') as fd:
return self.save(fd, threshold, nwords)
cnt = 0
for i, t in enumerate(zip(self.symbols, self.count)):
if i >= self.nspecial and t[1] >= threshold \
and (nwords < 0 or cnt < nwords):
print('{} {}'.format(t[0], t[1]), file=f)
cnt += 1

143
fairseq/indexed_dataset.py Normal file
View File

@ -0,0 +1,143 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import numpy as np
import os
import struct
import torch
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
class IndexedDataset(object):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path):
with open(path + '.idx', 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
self.data_file = open(path + '.bin', 'rb', buffering=0)
def __del__(self):
self.data_file.close()
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return torch.from_numpy(a)
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path + '.idx')
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(path + '.bin', 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
def __del__(self):
pass
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a)
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
# +1 for Lua compatibility
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype),
self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1,
len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()

74
fairseq/meters.py Normal file
View File

@ -0,0 +1,74 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class TimeMeter(object):
"""Computes the average occurence of some event per second"""
def __init__(self):
self.reset()
def reset(self):
self.start = time.time()
self.n = 0
def update(self, val=1):
self.n += val
@property
def avg(self):
delta = time.time() - self.start
return self.n / delta
@property
def elapsed_time(self):
return time.time() - self.start
class StopwatchMeter(object):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self):
self.reset()
def start(self):
self.start_time = time.time()
def stop(self, n=1):
if self.start_time is not None:
delta = time.time() - self.start_time
self.sum += delta
self.n += n
self.start_time = None
def reset(self):
self.sum = 0
self.n = 0
self.start_time = None
@property
def avg(self):
return self.sum / self.n

View File

@ -0,0 +1,14 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .fconv import *
__all__ = [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de',
'fconv_wmt_en_fr',
]

485
fairseq/models/fconv.py Normal file
View File

@ -0,0 +1,485 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import BeamableMM, LinearizedConvolution
class FConvModel(nn.Module):
def __init__(self, encoder, decoder, padding_idx=1):
super(FConvModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self.padding_idx = padding_idx
self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions):
encoder_out = self.encoder(src_tokens, src_positions)
decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def make_generation_fast_(self, beam_size, use_beamable_mm=False):
"""Optimize model for faster generation.
Optimizations include:
- remove WeightNorm
- (optionally) use BeamableMM in attention layers
The optimized model should not be used again for training.
Note: this can be combined with incremental inference in the Decoder for
even faster generation.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def remove_weight_norm(m):
try:
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(remove_weight_norm)
# use BeamableMM in attention layers
if use_beamable_mm:
self.decoder._use_beamable_mm(beam_size)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
class Encoder(nn.Module):
"""Convolutional encoder"""
def __init__(self, num_embeddings, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, padding_idx=1):
super(Encoder, self).__init__()
self.dropout = dropout
self.num_attention_layers = None
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions:
pad = (kernel_size - 1) // 2
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, padding=pad,
dropout=dropout))
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, tokens, positions):
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x
# project to size of convolution
x = self.fc1(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# temporal convolutions
for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x)
x = F.glu(x, dim=-1)
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of embedding
x = self.fc2(x)
# scale gradients (this only affects backward, not forward)
x = grad_multiply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5)
return x, y
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None):
super(AttentionLayer, self).__init__()
# projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size
self.out_projection = Linear(embed_dim, conv_channels)
self.bmm = bmm if bmm is not None else torch.bmm
def forward(self, x, target_embedding, encoder_out):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
x = x.view(sz)
attn_scores = x
x = self.bmm(x, encoder_out[1])
# scale attention output
s = encoder_out[1].size(1)
x = x * (s * math.sqrt(1.0 / s))
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
class Decoder(nn.Module):
"""Convolutional decoder"""
def __init__(self, num_embeddings, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, padding_idx=1):
super(Decoder, self).__init__()
self.dropout = dropout
in_channels = convolutions[0][0]
if isinstance(attention, bool):
# expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions)
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
pad = kernel_size - 1
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=pad, dropout=dropout))
self.attention.append(AttentionLayer(out_channels, embed_dim)
if attention[i] else None)
in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
self._is_inference_incremental = False
def forward(self, tokens, positions, encoder_out):
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# temporal convolutions
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x)
x = conv.remove_future_timesteps(x)
x = F.glu(x)
# attention
if attention is not None:
x = x.transpose(1, 0)
x, _ = attention(x, target_embedding, (encoder_a, encoder_b))
x = x.transpose(1, 0)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x
def context_size(self):
"""Maximum number of input elements each output element depends on"""
context = 1
for conv in self.convolutions:
context += conv.kernel_size[0] - 1
return context
def incremental_inference(self):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out = model.decoder(tokens[:, :step], positions[:, :step],
encoder_out)
probs = F.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder._start_incremental_inference()
def __exit__(self, *args):
self.decoder._stop_incremental_inference()
return IncrementalInference(self)
def _start_incremental_inference(self):
assert not self._is_inference_incremental, \
'already performing incremental inference'
self._is_inference_incremental = True
# save original forward and convolution layers
self._orig_forward = self.forward
self._orig_conv = self.convolutions
# switch to incremental forward
self.forward = self._incremental_forward
# start a fresh sequence
self.clear_incremental_state()
def _stop_incremental_inference(self):
# restore original forward and convolution layers
self.forward = self._orig_forward
self.convolutions = self._orig_conv
self._is_inference_incremental = False
def _incremental_forward(self, tokens, positions, encoder_out):
assert self._is_inference_incremental
# setup initial state
if self.prev_state is None:
# transpose encoder output once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
self.prev_state = {
'encoder_out': (encoder_a, encoder_b),
}
# load previous state
encoder_a, encoder_b = self.prev_state['encoder_out']
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
x = conv.incremental_forward(x)
x = F.glu(x)
# attention
if attention is not None:
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
avg_attn_scores += attn_scores
# residual
x = (x + residual) * math.sqrt(0.5)
# project back to size of vocabulary
x = self.fc2(x)
x = self.fc3(x)
return x, avg_attn_scores
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
"""
if self._is_inference_incremental:
self.prev_state = None
for conv in self.convolutions:
conv.clear_buffer()
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_inference_incremental:
for conv in self.convolutions:
conv.reorder_buffer(new_order)
def _use_beamable_mm(self, beam_size):
"""Replace torch.bmm with BeamableMM in attention layers."""
beamable_mm = BeamableMM(beam_size)
for attn in self.attention:
attn.bmm = beamable_mm
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
return m
def Linear(in_features, out_features, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer"""
from fairseq.modules import ConvTBC
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m, dim=2)
def grad_multiply(x, scale):
return GradMultiply.apply(x, scale)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
def fconv_iwslt_de_en(dataset, dropout, **kwargs):
encoder_convs = [(256, 3)] * 4
decoder_convs = [(256, 3)] * 3
return fconv(dataset, dropout, 256, encoder_convs, 256, decoder_convs, **kwargs)
def fconv_wmt_en_ro(dataset, dropout, **kwargs):
convs = [(512, 3)] * 20
return fconv(dataset, dropout, 512, convs, 512, convs, **kwargs)
def fconv_wmt_en_de(dataset, dropout, **kwargs):
convs = [(512, 3)] * 9 # first 10 layers have 512 units
convs += [(1024, 3)] * 4 # next 3 layers have 768 units
convs += [(2048, 1)] * 2 # final 2 layers are 1x1
return fconv(dataset, dropout, 768, convs, 768, convs,
decoder_out_embed_dim=512,
**kwargs)
def fconv_wmt_en_fr(dataset, dropout, **kwargs):
convs = [(512, 3)] * 6 # first 5 layers have 512 units
convs += [(768, 3)] * 4 # next 4 layers have 768 units
convs += [(1024, 3)] * 3 # next 4 layers have 1024 units
convs += [(2048, 1)] * 1 # next 1 layer is 1x1
convs += [(4096, 1)] * 1 # final 1 layer is 1x1
return fconv(dataset, dropout, 768, convs, 768, convs,
decoder_out_embed_dim=512,
**kwargs)
def fconv(dataset, dropout, encoder_embed_dim, encoder_convolutions,
decoder_embed_dim, decoder_convolutions, attention=True,
decoder_out_embed_dim=256, max_positions=1024):
padding_idx = dataset.dst_dict.pad()
encoder = Encoder(
len(dataset.src_dict),
embed_dim=encoder_embed_dim,
convolutions=encoder_convolutions,
dropout=dropout,
padding_idx=padding_idx,
max_positions=max_positions)
decoder = Decoder(
len(dataset.dst_dict),
embed_dim=decoder_embed_dim,
convolutions=decoder_convolutions,
out_embed_dim=decoder_out_embed_dim,
attention=attention,
dropout=dropout,
padding_idx=padding_idx,
max_positions=max_positions)
return FConvModel(encoder, decoder, padding_idx)

View File

@ -0,0 +1,15 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .beamable_mm import *
from .linearized_convolution import *
from .conv_tbc import ConvTBC
__all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC',
]

View File

@ -0,0 +1,47 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import torch.nn as nn
class BeamableMM(nn.Module):
"""This module provides an optimized MM for beam decoding with attention.
It leverage the fact that the source-side of the input is replicated beam
times and the target-side of the input is of width one. This layer speeds up
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
def __init__(self, beam_size):
super(BeamableMM, self).__init__()
self.beam_size = beam_size
def forward(self, input1, input2):
if (
not self.training and # test mode
self.beam_size > 0 and # beam size is set
input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update
):
bsz, beam = input1.size(0), self.beam_size
# bsz x 1 x nhu --> bsz/beam x beam x nhu
input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
# bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
# use non batched operation if bsz = beam
if input1.size(0) == 1:
output = torch.mm(input1[0, :, :], input2[0, :, :])
else:
output = input1.bmm(input2)
return output.view(bsz, 1, -1)
else:
return input1.bmm(input2)

105
fairseq/modules/conv_tbc.py Normal file
View File

@ -0,0 +1,105 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable, Function
from torch.nn.modules.utils import _single
try:
from fairseq import temporal_convolution_tbc
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing temporal_convolution_tbc, run `python setup.py install`\n')
raise e
class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel)
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _single(kernel_size)
self.stride = _single(stride)
self.padding = _single(padding)
assert self.stride == (1,)
self.weight = torch.nn.Parameter(torch.Tensor(
self.kernel_size[0], in_channels, out_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
return ConvTBCFunction.apply(
input.contiguous(), self.weight, self.bias, self.padding[0])
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', padding={padding}')
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class ConvTBCFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, pad):
input_size = input.size()
weight_size = weight.size()
kernel_size = weight_size[0]
output = input.new(
input_size[0] - kernel_size + 1 + pad * 2,
input_size[1],
weight_size[2])
ctx.input_size = input_size
ctx.weight_size = weight_size
ctx.save_for_backward(input, weight)
temporal_convolution_tbc.TemporalConvolutionTBC_forward(
input.type().encode('utf-8'),
input,
output,
weight,
bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = grad_output.data.contiguous()
grad_input = grad_output.new(ctx.input_size).zero_()
grad_weight = grad_output.new(ctx.weight_size).zero_()
grad_bias = grad_output.new(ctx.weight_size[2])
temporal_convolution_tbc.TemporalConvolutionTBC_backward(
input.type().encode('utf-8'),
grad_output,
grad_input,
grad_weight,
grad_bias,
input,
weight)
grad_input = Variable(grad_input, volatile=True)
grad_weight = Variable(grad_weight, volatile=True)
grad_bias = Variable(grad_bias, volatile=True)
return grad_input, grad_weight, grad_bias, None
def conv_tbc(input, weight, bias=None, stride=1, padding=0):
return ConvTBCFunction.apply(
input.contiguous(), weight, bias, padding[0])

View File

@ -0,0 +1,84 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import torch.nn.functional as F
from .conv_tbc import ConvTBC
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
This module replaces convolutions with linear layers as appropriate
and supports optimizations for incremental inference.
"""
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.clear_buffer()
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def incremental_forward(self, input):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and
accepts a single frame as input. If the input order changes
between time steps, call reorder_buffer. To apply to fresh
inputs, call clear_buffer.
"""
if self.training:
raise RuntimeError('LinearizedConvolution only supports inference')
# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)
# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
if self.input_buffer is None:
self.input_buffer = input.new(bsz, kw, input.size(2))
self.input_buffer.zero_()
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.autograd.Variable(self.input_buffer, volatile=True)
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def clear_buffer(self):
self.input_buffer = None
def reorder_buffer(self, new_order):
if self.input_buffer is not None:
self.input_buffer = self.input_buffer.index_select(0, new_order)
def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight
def _clear_linearized_weight(self, *args):
self._linearized_weight = None

View File

@ -0,0 +1,167 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import os
import signal
import threading
from torch import multiprocessing
class MultiprocessingEventLoop(object):
"""Start a multiprocessing event loop."""
def __init__(self, device_ids=None, multiprocessing_method='spawn'):
super().__init__()
self.device_ids = tuple(device_ids)
self.num_replicas = len(device_ids)
self.rank = None
self._mp = multiprocessing.get_context(multiprocessing_method)
self._start_error_handler()
self._start_multiprocessing()
def call_async(self, rank, action, **kwargs):
"""Asynchronously call a function in each child process.
Call a function named `action` on the rank'th process and return
a Future with the result.
"""
def result_generator():
yield self.return_pipes[rank].recv()
assert not self.return_pipes[rank].poll(), \
'return pipe must be consumed before calling another function'
self.input_pipes[rank].send((action, kwargs))
return Future(result_generator())
def stop(self, interrupt_children=False):
"""Stop multiprocessing."""
for rank in range(self.num_replicas):
self.input_pipes[rank].close()
self.return_pipes[rank].close()
if interrupt_children:
# send KeyboardInterrupt to children
os.kill(self.procs[rank].pid, signal.SIGINT)
else:
self.procs[rank].join()
self.error_queue.put((None, None)) # poison pill
def _start_error_handler(self):
"""Error handler to catch exceptions in child processes."""
# create a thread to listen for errors in the child processes
self.error_queue = self._mp.SimpleQueue()
error_thread = threading.Thread(target=self._error_listener,
daemon=True)
error_thread.start()
# create signal handler that executes in the main process/thread and
# handles errors from child processes
signal.signal(signal.SIGUSR1, self._signal_handler)
def _error_listener(self):
"""A thread that listens for errors in the child processes.
Errors are handled in a signal handler in the main thread.
"""
(rank, original_trace) = self.error_queue.get()
if rank is None: # poison pill, return
return
# requeue error and switch to main thread for handling the error
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def _signal_handler(self, signal, frame):
"""Signal handler that handles errors from child processes.
This signal handler executes in the main/process thread.
"""
self.stop(interrupt_children=True)
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
def _start_multiprocessing(self):
"""Create child processes to run async event loop.
Each process reads input from a Pipe, performs some computation,
and returns its output to another Pipe.
"""
# create child processes
input_pipes = []
return_pipes = []
procs = []
for rank, id in enumerate(self.device_ids):
recv_input_pipe, send_input_pipe = self._mp.Pipe(duplex=False)
recv_return_pipe, send_return_pipe = self._mp.Pipe(duplex=False)
proc = self._mp.Process(
target=self._process_event_loop,
args=(rank, id, recv_input_pipe, send_return_pipe),
daemon=True)
proc.start()
input_pipes.append(send_input_pipe)
return_pipes.append(recv_return_pipe)
procs.append(proc)
self.input_pipes = input_pipes
self.return_pipes = return_pipes
self.procs = procs
def _process_event_loop(self, rank, device_id, input_pipe, return_pipe):
"""Event loop that runs in each child process.
Event loop:
- take an action from the input pipe
- call the corresponding function in this process
- put the return value in the return pipe
Any exceptions are put in the error queue.
"""
self.rank = rank
try:
# event loop
while True:
action, kwargs = input_pipe.recv()
action_fn = getattr(self, action)
return_pipe.send(action_fn(rank, device_id, **kwargs))
except EOFError:
# input pipe was closed, do nothing
pass
except KeyboardInterrupt:
# killed by parent, do nothing
pass
except Exception:
# propagate exception from child to parent process, keeping
# original traceback
import traceback
self.error_queue.put((rank, traceback.format_exc()))
finally:
# cleanup pipes
input_pipe.close()
return_pipe.close()
class Future(object):
"""A wrapper around a Python generator, with syntactic sugar."""
def __init__(self, generator):
self.generator = generator
def gen(self):
return next(self.generator)
@staticmethod
def gen_list(gens):
return [g.gen() for g in gens]
@staticmethod
def gen_tuple_list(gens):
list = [g.gen() for g in gens]
return zip(*list)

View File

@ -0,0 +1,40 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import multiprocessing
import os
import pdb
import sys
class MultiprocessingPdb(pdb.Pdb):
"""A Pdb wrapper that works in a multiprocessing environment.
Usage: `from fairseq import pdb; pdb.set_trace()`
"""
_stdin_fd = sys.stdin.fileno()
_stdin = None
_stdin_lock = multiprocessing.Lock()
def __init__(self):
pdb.Pdb.__init__(self, nosigint=True)
def _cmdloop(self):
stdin_bak = sys.stdin
with self._stdin_lock:
try:
if not self._stdin:
self._stdin = os.fdopen(self._stdin_fd)
sys.stdin = self._stdin
self.cmdloop()
finally:
sys.stdin = stdin_bak
pdb = MultiprocessingPdb()

View File

@ -0,0 +1,260 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Train a network on multiple GPUs using multiprocessing.
"""
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq.criterions import FairseqCriterion
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Main class for multi-GPU training.
Each GPU has a full copy of the model and is assigned to its own Python
process. Gradients are accumulated with all-reduce and all model replicas
are updated synchronously after each batch.
The methods in this class are divided into synchronous functions, which
prepare and dispatch the input to each process, and asynchronous functions
(prefixed with `_async_`), which run on each process in parallel.
"""
def __init__(self, args, model, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count()))
super().__init__(device_ids, multiprocessing_method)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory()
nccl_uid = nccl.get_unique_id()
Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model,
nccl_uid=nccl_uid)
for rank in range(self.num_replicas)
])
def _async_init(self, rank, device_id, args, model, nccl_uid):
"""Initialize child processes."""
self.args = args
# set torch.seed in this process
torch.manual_seed(args.seed)
# set CUDA device
torch.cuda.set_device(device_id)
# initialize NCCL
nccl.initialize(self.num_replicas, nccl_uid, device_id)
# copy model to current device
self.model = model.cuda()
# initialize optimizer
self.optimizer = NAG(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.flat_grads = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
def _build_lr_scheduler(self):
if self.args.force_anneal > 0:
def anneal(e):
if e < self.args.force_anneal:
return 1
else:
return self.args.lrshrink ** (e + 1 - self.args.force_anneal)
lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None
else:
# decay the LR by 0.1 every time the validation loss plateaus
lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=0)
return lr_scheduler
def get_model(self):
"""Get one of the model replicas."""
# just return the first model, since all replicas are the same
return self.call_async(0, '_async_get_model').gen()
def _async_get_model(self, rank, device_id):
return self.model
def save_checkpoint(self, args, epoch, batch_offset, val_loss=None):
"""Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', args=args, epoch=epoch,
batch_offset=batch_offset, val_loss=val_loss).gen()
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss):
utils.save_checkpoint(args, epoch, batch_offset, self.model,
self.optimizer, self.lr_scheduler, val_loss)
def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process."""
results = Future.gen_list([
self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas)
])
epoch, batch_offset = results[0]
return epoch, batch_offset
def _async_load_checkpoint(self, rank, device_id, filename):
return utils.load_checkpoint(filename, self.model, self.optimizer,
self.lr_scheduler, cuda_device=device_id)
def train_step(self, samples, criterion):
"""Do forward, backward and gradient step in parallel."""
assert isinstance(criterion, FairseqCriterion)
# scatter sample across GPUs
samples, data_events = self._scatter_samples(samples)
criterion.prepare(samples)
# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', sample=samples[rank],
criterion=criterion, data_event=event)
for rank, event in enumerate(data_events)
]
# aggregate losses and gradient norms
losses, grad_norms = Future.gen_tuple_list(losses)
loss = criterion.aggregate(losses)
return loss, grad_norms[0]
def _async_train_step(self, rank, device_id, sample, criterion, data_event):
data_event.wait()
self.model.train()
# zero grads even if net_input is None, since we will all-reduce them
self.optimizer.zero_grad()
# calculate loss and grads
loss = 0
if sample is not None:
net_output = self.model(**sample['net_input'])
loss_ = criterion(net_output, sample)
loss_.backward()
loss = loss_.data[0]
# flatten grads into a contiguous block of memory
if self.flat_grads is None:
self.flat_grads = self._flatten_grads_(self.model)
# all-reduce grads
nccl.all_reduce(self.flat_grads)
# clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
# take an optimization step
self.optimizer.step()
return loss, grad_norm
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
flat_grads = next(model.parameters()).data.new(num_params)
offset = 0
for p in model.parameters():
grad = p.grad.data
numel, sz = grad.numel(), grad.size()
flat_grads[offset:offset+numel] = grad.view(-1)
grad.set_(flat_grads[offset:offset+numel])
grad.resize_(sz) # preserve original shape
offset += numel
return flat_grads
def _clip_grads_(self, flat_grads, clipv):
norm = flat_grads.norm()
if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv
flat_grads.div_(coef)
return norm
def valid_step(self, samples, criterion):
"""Do forward pass in parallel."""
# scatter sample across GPUs
samples, data_events = self._scatter_samples(samples, volatile=True)
criterion.prepare(samples)
# forward pass
losses = [
self.call_async(rank, '_async_valid_step', sample=samples[rank],
criterion=criterion, data_event=event)
for rank, event in enumerate(data_events)
]
# aggregate losses
loss = criterion.aggregate(Future.gen_list(losses))
return loss
def _async_valid_step(self, rank, device_id, sample, criterion, data_event):
if sample is None:
return 0
data_event.wait()
self.model.eval()
net_output = self.model(**sample['net_input'])
loss = criterion(net_output, sample)
return loss.data[0]
def get_lr(self):
"""Get the current learning rate."""
return self.call_async(0, '_async_get_lr').gen()
def _async_get_lr(self, rank, device_id):
return self.optimizer.param_groups[0]['lr']
def lr_step(self, val_loss=None, epoch=None):
"""Adjust the learning rate depending on the validation loss."""
lr = Future.gen_list([
self.call_async(rank, '_async_lr_step', val_loss=val_loss, epoch=epoch)
for rank in range(self.num_replicas)
])
return lr[0]
def _async_lr_step(self, rank, device_id, epoch, val_loss):
# update the learning rate
if self.args.force_anneal > 0:
self.lr_scheduler.step(epoch)
else:
self.lr_scheduler.step(val_loss, epoch)
return self.optimizer.param_groups[0]['lr']
def _scatter_samples(self, samples, volatile=False):
"""Split and distribute a sample across GPUs."""
res = [utils.prepare_sample(sample, volatile=volatile,
cuda_device=device_id)
for sample, device_id in zip(samples, self.device_ids)]
# Pad with None until its size is equal to the number of replicas.
res = res + [None]*(self.num_replicas - len(samples))
# Synchronize GPU devices after data is sent to prevent
# race conditions.
events = []
for d in self.device_ids:
with torch.cuda.device(d):
event = torch.cuda.Event(interprocess=True)
event.record()
events.append(event)
return res, events

52
fairseq/nag.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from torch.optim.optimizer import Optimizer, required
class NAG(Optimizer):
def __init__(self, params, lr=required, momentum=0, weight_decay=0):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
lr = group['lr']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.clone().zero_()
buf = param_state['momentum_buffer']
p.data.add_(momentum * momentum, buf)
p.data.add_(-(1 + momentum) * lr, d_p)
buf.mul_(momentum).add_(-lr, d_p)
return loss

165
fairseq/nccl.py Normal file
View File

@ -0,0 +1,165 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
A modified version of torch.cuda.nccl.all_reduce for launching kernels on each
GPU separately.
"""
import ctypes
import warnings
lib = None
_uid = None
_rank = None
_num_devices = None
_comm = None
__all__ = ['all_reduce', 'initialize', 'get_unique_id']
def _libnccl():
global lib
if lib is None:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
lib.ncclGetErrorString.restype = ctypes.c_char_p
else:
lib = None
return lib
def is_available(tensors):
devices = set()
for tensor in tensors:
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
if _libnccl() is None:
warnings.warn('NCCL library not found. Check your LD_LIBRARY_PATH')
return False
return True
_communicators = {}
# ncclDataType_t
ncclChar = 0
ncclInt = 1
ncclHalf = 2
ncclFloat = 3
ncclDouble = 4
ncclInt64 = 5
ncclUint64 = 6
# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3
nccl_types = {
'torch.cuda.ByteTensor': ncclChar,
'torch.cuda.CharTensor': ncclChar,
'torch.cuda.IntTensor': ncclInt,
'torch.cuda.HalfTensor': ncclHalf,
'torch.cuda.FloatTensor': ncclFloat,
'torch.cuda.DoubleTensor': ncclDouble,
'torch.cuda.LongTensor': ncclInt64,
}
class NcclError(RuntimeError):
def __init__(self, status):
self.status = status
msg = '{0} ({1})'.format(lib.ncclGetErrorString(status), status)
super(NcclError, self).__init__(msg)
class NcclComm(ctypes.c_void_p):
def __del__(self):
lib.ncclCommDestroy(self)
class NcclUniqueId(ctypes.Structure):
_fields_ = [
('internal', ctypes.c_uint8 * 128)
]
def check_error(status):
if status != 0:
raise NcclError(status)
_uids = []
def get_unique_id():
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
uid = NcclUniqueId()
check_error(lib.ncclGetUniqueId(ctypes.byref(uid)))
_uids.append(uid) # Don't allow UIDs to be collected
return uid
def initialize(num_devices, uid, rank):
global _num_devices, _uid, _rank
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
_num_devices = num_devices
if rank != 0:
_uid = NcclUniqueId.from_buffer_copy(uid)
else:
_uid = uid
_rank = rank
def communicator():
global _comm
if _uid is None:
raise RuntimeError('NCCL not initialized')
if _comm is None:
comm = ctypes.c_void_p()
check_error(lib.ncclCommInitRank(
ctypes.byref(comm),
ctypes.c_int(_num_devices),
_uid,
ctypes.c_int(_rank)))
_comm = comm
return _comm
def all_reduce(input, output=None, op=SUM, stream=None):
comm = communicator()
if output is None:
output = input
if stream is not None:
stream = stream.cuda_stream
data_type = nccl_types[input.type()]
check_error(lib.ncclAllReduce(
ctypes.c_void_p(input.data_ptr()),
ctypes.c_void_p(output.data_ptr()),
ctypes.c_size_t(input.numel()),
data_type,
op,
comm,
ctypes.c_void_p(stream)))
return output

132
fairseq/options.py Normal file
View File

@ -0,0 +1,132 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
from fairseq import models
def get_parser(desc):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
return parser
def add_dataset_args(parser):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('data', metavar='DIR',
help='path to data directory')
group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence')
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR',
help='initial learning rate')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
help='minimum learning rate')
group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N',
help='force annealing at specified epoch')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--lrshrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lrshrink)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly with replacement from the'
' dataset')
return group
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=-1,
help='checkpoint every this many batches')
group.add_argument('--no-save', action='store_true',
help='don\'t save models and checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
return group
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--remove-bpe', action='store_true',
help='remove BPE tokens before scoring')
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement')
return group
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH',
choices=models.__all__,
help='model architecture ({})'.format(', '.join(models.__all__)))
group.add_argument('--encoder-embed-dim', default=512, type=int, metavar='N',
help='encoder embedding dimension')
group.add_argument('--encoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
group.add_argument('--decoder-embed-dim', default=512, type=int, metavar='N',
help='decoder embedding dimension')
group.add_argument('--decoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
group.add_argument('--decoder-attention', default='True', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
group.add_argument('--decoder-out-embed-dim', default=256, type=int, metavar='N',
help='decoder output embedding dimension')
group.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
group.add_argument('--label-smoothing', default=0, type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
return group

55
fairseq/progress_bar.py Normal file
View File

@ -0,0 +1,55 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Progress bar wrapper around tqdm which handles non-tty outputs
"""
import sys
from tqdm import tqdm
class progress_bar(tqdm):
enabled = sys.stderr.isatty()
print_interval = 1000
def __new__(cls, *args, **kwargs):
if cls.enabled:
return tqdm(*args, **kwargs)
else:
return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(tqdm):
def __init__(self, print_interval, *args, **kwargs):
super(simple_progress_bar, self).__init__(*args, **kwargs)
self.print_interval = print_interval
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if i > 0 and i % self.print_interval == 0:
msg = '{} {:5d} / {:d} {}\n'.format(self.desc, i, size, self.postfix)
sys.stdout.write(msg)
sys.stdout.flush()
@classmethod
def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout
fp.write(s)
fp.write(end)
fp.flush()
@staticmethod
def status_printer(file):
def print_status(s):
pass
return print_status

View File

@ -0,0 +1,345 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from contextlib import ExitStack
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from fairseq import utils
class SequenceGenerator(object):
def __init__(self, models, dst_dict, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
stop_early: Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores.
normalize_scores: Normalize scores by the length of the output.
"""
self.models = models
self.dict = dst_dict
self.pad = dst_dict.pad()
self.eos = dst_dict.eos()
self.vocab_size = len(dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = maxlen
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
def cuda(self):
for model in self.models:
model.cuda()
self.positions = self.positions.cuda()
return self
def generate_batched_itr(self, data_itr, maxlen_a=0, maxlen_b=200,
cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations.
"""
def lstrip_pad(tensor):
return tensor[tensor.eq(self.pad).sum():]
for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'],
maxlen=(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id']):
src = input['src_tokens'].data[i, :]
# remove padding from ref, which appears at the beginning
ref = lstrip_pad(s['target'].data[i, :])
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
"""Generate a batch of translations."""
with ExitStack() as stack:
for model in self.models:
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, src_positions, beam_size, maxlen)
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
bsz = src_tokens.size(0)
beam_size = beam_size if beam_size is not None else self.beam_size
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
encoder_outs = []
for model in self.models:
model.eval()
model.decoder.clear_incremental_state() # start a fresh sequence
# compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions)
encoder_out = self._expand_encoder_out(encoder_out, beam_size)
encoder_outs.append(encoder_out)
# initialize buffers
scores = encoder_outs[0][0].data.new(bsz * beam_size).fill_(0)
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
align = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(-1)
align_buf = align.clone()
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{'idx': None, 'score': float('Inf')} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens):
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
bbsz = sent*beam_size
best_unfinalized_score = scores[bbsz:bbsz+beam_size].max()
if self.normalize_scores:
best_unfinalized_score /= maxlen
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, scores):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
scores: A vector of the same size as bbsz_idx containing scores
for each hypothesis
"""
assert bbsz_idx.numel() == scores.numel()
norm_scores = scores/math.pow(step+1, self.len_penalty) if self.normalize_scores else scores
sents_seen = set()
for idx, score in zip(bbsz_idx.cpu(), norm_scores.cpu()):
sent = idx // beam_size
sents_seen.add(sent)
def get_hypo():
hypo = tokens[idx, 1:step+2].clone()
hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone()
return {
'tokens': hypo,
'score': score,
'alignment': alignment,
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif score > worst_finalized[sent]['score']:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]['idx']
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
worst_finalized[sent] = {
'score': s['score'],
'idx': idx,
}
# return number of hypotheses finished this step
num_finished = 0
for sent in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent):
finished[sent] = True
num_finished += 1
return num_finished
reorder_state = None
for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
model.decoder.reorder_incremental_state(reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
else:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1))
# record alignment to source tokens, based on attention
_ignore_scores = buffer('_ignore_scores', type_of=scores)
avg_attn_scores.topk(1, out=(_ignore_scores, align[:, step+1].unsqueeze(1)))
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams')
probs.view(bsz, -1).topk(cand_size, out=(cand_scores, cand_indices))
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx')
cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx)
if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores)
cand_scores.masked_select(eos_mask, out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets,
out=active_mask)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
active_mask.topk(beam_size, 1, largest=False, out=(_ignore, active_hypos))
active_bbsz_idx = buffer('active_bbsz_idx')
cand_bbsz_idx.gather(1, active_hypos, out=active_bbsz_idx)
active_scores = cand_scores.gather(1, active_hypos,
out=scores.view(bsz, beam_size))
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# finalize all active hypotheses once we hit maxlen
# finalize_hypos will take care of adding the EOS markers
if step == maxlen:
num_remaining_sent -= finalize_hypos(step, active_bbsz_idx, active_scores)
assert num_remaining_sent == 0
break
# copy tokens for active hypotheses
torch.index_select(tokens[:, :step+1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1])
cand_indices.gather(1, active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
# copy attention/alignment for active hypotheses
torch.index_select(align[:, :step+2], dim=0, index=active_bbsz_idx,
out=align_buf[:, :step+2])
# swap buffers
old_tokens = tokens
tokens = tokens_buf
tokens_buf = old_tokens
old_align = align
align = align_buf
align_buf = old_align
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(bsz):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
def _decode(self, tokens, encoder_outs):
length = tokens.size(1)
# repeat the first length positions to fill batch
positions = self.positions[:length].view(1, length)
# wrap in Variables
tokens = Variable(tokens, volatile=True)
positions = Variable(positions, volatile=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, positions, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None:
avg_probs = probs
avg_attn = attn
else:
avg_probs.add_(probs)
avg_attn.add_(attn)
avg_probs.div_(len(self.models))
avg_probs.log_()
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _expand_encoder_out(self, encoder_out, beam_size):
res = []
for tensor in encoder_out:
res.append(
# repeat beam_size times along second dimension
tensor.repeat(1, beam_size, *[1 for i in range(tensor.dim()-2)]) \
# then collapse into [bsz*beam, ...original dims...]
.view(-1, *tensor.size()[1:])
)
return tuple(res)

77
fairseq/tokenizer.py Normal file
View File

@ -0,0 +1,77 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import re
import torch
from fairseq import dictionary
def tokenize_line(line):
line = re.sub(r"\t", "", line)
line = re.sub(r"^\s+", "", line)
line = re.sub(r"\s+$", "", line)
line = re.sub(r"\s+", " ", line)
return line.split()
class Tokenizer:
@staticmethod
def build_dictionary(filename, tokenize=tokenize_line):
dict = dictionary.Dictionary()
Tokenizer.add_file_to_dictionary(filename, dict, tokenize)
dict.finalize()
return dict
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f:
for line in f.readlines():
for word in tokenize(line):
dict.add_symbol(word)
dict.add_symbol(dict.eos_word)
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line):
nseq, ntok, nunk = 0, 0, 0
replaced = {}
with open(filename, 'r') as f:
for line in f.readlines():
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
nseq = nseq + 1
for i in range(0, len(words)):
word = words[i]
idx = dict.index(word)
if idx == dict.unk_index and word != dict.unk_word:
nunk = nunk + 1
if word in replaced:
replaced[word] = replaced[word] + 1
else:
replaced[word] = 1
ids[i] = idx
ids[nwords] = dict.eos_index
consumer(ids)
ntok = ntok + len(ids)
return {'nseq': nseq, 'nunk': nunk, 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True):
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
for i in range(0, len(words)):
if add_if_not_exist:
ids[i] = dict.add_symbol(words[i])
else:
ids[i] = dict.index(words[i])
ids[nwords] = dict.eos_index
return ids

138
fairseq/utils.py Normal file
View File

@ -0,0 +1,138 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import os
import torch
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, data, models
def build_model(args, dataset):
if args.arch == 'fconv':
encoder_layers = eval(args.encoder_layers)
decoder_layers = eval(args.decoder_layers)
decoder_attention = eval(args.decoder_attention)
model = models.fconv(
dataset, args.dropout, args.encoder_embed_dim, encoder_layers,
args.decoder_embed_dim, decoder_layers, decoder_attention,
decoder_out_embed_dim=args.decoder_out_embed_dim,
max_positions=args.max_positions)
else:
model = models.__dict__[args.arch](dataset, args.dropout,
max_positions=args.max_positions)
return model
def build_criterion(args, dataset):
padding_idx = dataset.dst_dict.pad()
if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
else:
return criterions.CrossEntropyCriterion(padding_idx)
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except:
if i == 3:
raise
def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, val_loss=None):
state_dict = {
'args': args,
'epoch': epoch,
'batch_offset': batch_offset,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_loss': lr_scheduler.best,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
torch_persistent_save(state_dict, epoch_filename)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
torch_persistent_save(state_dict, best_filename)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
torch_persistent_save(state_dict, last_filename)
def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename):
return 1, 0
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
lr_scheduler.best = state['best_loss']
epoch = state['epoch'] + 1
batch_offset = state['batch_offset']
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else ''
print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch, gpu_str))
return epoch, batch_offset
def load_ensemble_for_inference(models, data_path):
# load model architectures and weights
states = []
for model in models:
if not os.path.exists(model):
raise IOError('Model file not found: ' + model)
states.append(
torch.load(model, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
# load dataset
args = states[0]['args']
dataset = data.load(data_path, args.source_lang, args.target_lang)
# build models
models = []
for state in states:
model = build_model(args, dataset)
model.load_state_dict(state['model'])
models.append(model)
return models, dataset
def prepare_sample(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class."""
def make_variable(tensor):
if cuda_device is not None and torch.cuda.is_available():
tensor = tensor.cuda(async=True, device=cuda_device)
return Variable(tensor, volatile=volatile)
return {
'id': sample['id'],
'ntokens': sample['ntokens'],
'target': make_variable(sample['target']),
'net_input': {
key: make_variable(sample[key])
for key in ['src_tokens', 'src_positions', 'input_tokens', 'input_positions']
},
}

176
generate.py Normal file
View File

@ -0,0 +1,176 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import bleu, options, utils, tokenizer
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Generation')
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('-i', '--interactive', action='store_true',
help='generate translations in interactive mode')
dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
options.add_generation_args(parser)
args = parser.parse_args()
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu
# Load model and dataset
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize model for generation
for model in models:
model.make_generation_fast_(args.beam, not args.no_beamable_mm)
# Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen)
align_dict = {}
if args.unk_replace_dict != '':
assert args.interactive, "Unkown words replacing requires access to original source and is only" \
"supported in interactive mode"
with open(args.unk_replace_dict, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
def replace_unk(hypo_str, align_str, src, unk):
hypo_tokens = hypo_str.split()
src_tokens = tokenizer.tokenize_line(src)
align_idx = [int(i) for i in align_str.split()]
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[align_idx[i]]
if src_token in align_dict:
hypo_tokens[i] = align_dict[src_token]
else:
hypo_tokens[i] = src_token
return ' '.join(hypo_tokens)
if use_cuda:
translator.cuda()
bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos):
id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
print('S{}\t{}'.format(id_str, src_str))
if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None:
print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True)))
for hypo in hypos:
hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict))
print('H{}\t{}\t{}'.format(
id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
if args.interactive:
for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
start = dataset.src_dict.pad() + 1
positions = torch.arange(start, start + len(tokens)).type_as(tokens)
if use_cuda:
positions = positions.cuda()
tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else:
non_bpe_dict = {}
def maybe_remove_bpe_and_reindex(tokens):
"""Helper for removing BPE symbols from a tensor of indices.
If BPE removal is enabled, the returned tensor is reindexed
using a new dictionary that is created on-the-fly."""
if not args.remove_bpe:
return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0
return torch.IntTensor([
non_bpe_dict.setdefault(w, len(non_bpe_dict))
for w in to_sentence(dataset.dst_dict, tokens, bpe_symbol).split(' ')
])
# Generate and compute BLEU score
scorer = bleu.Scorer(
dataset.dst_dict.pad() if not args.remove_bpe else -1,
dataset.dst_dict.eos() if not args.remove_bpe else -1,
dataset.dst_dict.unk())
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions=args.max_positions)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
for id, src, ref, hypos in translations:
ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe_and_reindex(ref), maybe_remove_bpe_and_reindex(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)))
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def to_token(dict, i, runk):
return runk if i == dict.unk() else dict[i]
def unk_symbol(dict, ref_unk=False):
return '<{}>'.format(dict.unk_word) if ref_unk else dict.unk_word
def to_sentence(dict, tokens, bpe_symbol=None, ref_unk=False):
if torch.is_tensor(tokens) and tokens.dim() == 2:
sentences = [to_sentence(dict, token) for token in tokens]
return '\n'.join(sentences)
eos = dict.eos()
runk = unk_symbol(dict, ref_unk=ref_unk)
sent = ' '.join([to_token(dict, i, runk) for i in tokens if i != eos])
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
if __name__ == '__main__':
main()

119
preprocess.py Normal file
View File

@ -0,0 +1,119 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import os
from itertools import zip_longest
from fairseq import dictionary, indexed_dataset
from fairseq.tokenizer import Tokenizer
def main():
parser = argparse.ArgumentParser(
description='Data pre-processing: Create dictionary and store data in binary format')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default='train', help='target language')
parser.add_argument('--validpref', metavar='FP', default='valid', help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default='test', help='comma separated, test language prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
args = parser.parse_args()
print(args)
os.makedirs(args.destdir, exist_ok=True)
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
def make_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
ds = indexed_dataset.IndexedDatasetBuilder(
'{}/{}.{}-{}.{}.bin'.format(args.destdir, output_prefix, args.source_lang,
args.target_lang, lang)
)
def consumer(tensor):
ds.add_item(tensor)
input_file = '{}.{}'.format(input_prefix, lang)
res = Tokenizer.binarize(input_file, dict, consumer)
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'],
100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize('{}/{}.{}-{}.{}.idx'.format(
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
make_dataset(args.trainpref, 'train', args.source_lang)
make_dataset(args.trainpref, 'train', args.target_lang)
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang)
make_dataset(validpref, outprefix, args.target_lang)
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang)
make_dataset(testpref, outprefix, args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)))
freq_map = {}
with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file:
with open(tgt_file_name, 'r') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split('-')), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(
args.source_lang, args.target_lang)), 'w') as f:
for k, v in align_dict.items():
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)
if __name__ == '__main__':
main()

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
numpy
torch
tqdm

58
score.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import os
import sys
from fairseq import bleu, dictionary, tokenizer
def main():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', default='', help='references')
parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
args = parser.parse_args()
print(args)
assert args.sys == '-' or os.path.exists(args.sys), \
"System output file {} does not exist".format(args.sys)
assert os.path.exists(args.ref), \
"Reference file {} does not exist".format(args.ref)
dict = dictionary.Dictionary()
def readlines(fd):
for line in fd.readlines():
if args.ignore_case:
yield line.lower()
yield line
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
if args.sys == '-':
score(sys.stdin)
else:
with open(args.sys, 'r') as f:
score(f)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,99 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Use this script in order to build symmetric alignments for your translation
dataset.
This script depends on fast_align and mosesdecoder tools. You will need to
build those before running the script.
fast_align:
github: http://github.com/clab/fast_align
instructions: follow the instructions in README.md
mosesdecoder:
github: http://github.com/moses-smt/mosesdecoder
instructions: http://www.statmt.org/moses/?n=Development.GetStarted
The script produces the following files under --output_dir:
text.joined - concatenation of lines from the source_file and the
target_file.
align.forward - forward pass of fast_align.
align.backward - backward pass of fast_align.
aligned.sym_heuristic - symmetrized alignment.
"""
import argparse
import os
from itertools import zip_longest
def main():
parser = argparse.ArgumentParser(description='symmetric alignment builer')
parser.add_argument('--fast_align_dir',
help='path to fast_align build directory')
parser.add_argument('--mosesdecoder_dir',
help='path to mosesdecoder root directory')
parser.add_argument('--sym_heuristic',
help='heuristic to use for symmetrization',
default='grow-diag-final-and')
parser.add_argument('--source_file',
help='path to a file with sentences '
'in the source language')
parser.add_argument('--target_file',
help='path to a file with sentences '
'in the target language')
parser.add_argument('--output_dir',
help='output directory')
args = parser.parse_args()
fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal')
sym_fast_align_bin = os.path.join(
args.mosesdecoder_dir, 'scripts', 'ems',
'support', 'symmetrize-fast-align.perl')
# create joined file
joined_file = os.path.join(args.output_dir, 'text.joined')
with open(args.source_file, 'r') as src, open(args.target_file, 'r') as tgt:
with open(joined_file, 'w') as joined:
for s, t in zip_longest(src, tgt):
print('{} ||| {}'.format(s.strip(), t.strip()), file=joined)
bwd_align_file = os.path.join(args.output_dir, 'align.backward')
# run forward alignment
fwd_align_file = os.path.join(args.output_dir, 'align.forward')
fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format(
FASTALIGN=fast_align_bin,
JOINED=joined_file,
FWD=fwd_align_file)
assert os.system(fwd_fast_align_cmd) == 0
# run backward alignment
bwd_align_file = os.path.join(args.output_dir, 'align.backward')
bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format(
FASTALIGN=fast_align_bin,
JOINED=joined_file,
BWD=bwd_align_file)
assert os.system(bwd_fast_align_cmd) == 0
# run symmetrization
sym_out_file = os.path.join(args.output_dir, 'aligned')
sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format(
SYMFASTALIGN=sym_fast_align_bin,
FWD=fwd_align_file,
BWD=bwd_align_file,
SRC=args.source_file,
TGT=args.target_file,
OUT=sym_out_file,
HEURISTIC=args.sym_heuristic,
SYMAL=symal_bin
)
assert os.system(sym_cmd) == 0
if __name__ == '__main__':
main()

View File

@ -0,0 +1,36 @@
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_dictionary.lua <dict.th7>
require 'fairseq'
require 'torch'
require 'paths'
if #arg < 1 then
print('usage: convert_dictionary.lua <dict.th7>')
os.exit(1)
end
if not paths.filep(arg[1]) then
print('error: file does not exit: ' .. arg[1])
os.exit(1)
end
dict = torch.load(arg[1])
dst = paths.basename(arg[1]):gsub('.th7', '.txt')
assert(dst:match('.txt$'))
f = io.open(dst, 'w')
for idx, symbol in ipairs(dict.index_to_symbol) do
if idx > dict.cutoff then
break
end
f:write(symbol)
f:write(' ')
f:write(dict.index_to_freq[idx])
f:write('\n')
end
f:close()

110
scripts/convert_model.lua Normal file
View File

@ -0,0 +1,110 @@
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_model.lua <model_epoch1.th7>
require 'torch'
local fairseq = require 'fairseq'
model = torch.load(arg[1])
function find_weight_norm(container, module)
for _, wn in ipairs(container:listModules()) do
if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
return wn
end
end
end
function push_state(dict, key, module)
if torch.type(module) == 'nn.Linear' then
local wn = find_weight_norm(model.module, module)
assert(wn)
dict[key .. '.weight_v'] = wn.v:float()
dict[key .. '.weight_g'] = wn.g:float()
elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
local wn = find_weight_norm(model.module, module)
assert(wn)
local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
dict[key .. '.weight_v'] = v
dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
else
dict[key .. '.weight'] = module.weight:float()
end
if module.bias then
dict[key .. '.bias'] = module.bias:float()
end
end
encoder_dict = {}
decoder_dict = {}
combined_dict = {}
function encoder_state(encoder)
luts = encoder:findModules('nn.LookupTable')
push_state(encoder_dict, 'embed_tokens', luts[1])
push_state(encoder_dict, 'embed_positions', luts[2])
fcs = encoder:findModules('nn.Linear')
assert(#fcs >= 2)
local nInputPlane = fcs[1].weight:size(1)
push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
if nInputPlane ~= module.weight:size(3) / 2 then
push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
end
nInputPlane = module.weight:size(3) / 2
end
assert(#fcs == 0)
end
function decoder_state(decoder)
luts = decoder:findModules('nn.LookupTable')
push_state(decoder_dict, 'embed_tokens', luts[1])
push_state(decoder_dict, 'embed_positions', luts[2])
fcs = decoder:findModules('nn.Linear')
local nInputPlane = fcs[1].weight:size(1)
push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
push_state(decoder_dict, 'fc3', fcs[#fcs])
table.remove(fcs, #fcs)
table.remove(fcs, #fcs)
for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
if nInputPlane ~= module.weight:size(3) / 2 then
push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
end
nInputPlane = module.weight:size(3) / 2
local prefix = 'attention.' .. tostring(i - 1)
push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
end
assert(#fcs == 0)
end
_encoder = model.module.modules[2]
_decoder = model.module.modules[3]
encoder_state(_encoder)
decoder_state(_decoder)
for k, v in pairs(encoder_dict) do
combined_dict['encoder.' .. k] = v
end
for k, v in pairs(decoder_dict) do
combined_dict['decoder.' .. k] = v
end
torch.save('state_dict.t7', combined_dict)

71
setup.py Normal file
View File

@ -0,0 +1,71 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py
import sys
from torch.utils.ffi import create_extension
if sys.version_info < (3,):
sys.exit('Sorry, Python3 is required for fairseq.')
with open('README.md') as f:
readme = f.read()
with open('LICENSE') as f:
license = f.read()
with open('requirements.txt') as f:
reqs = f.read()
bleu = Extension(
'fairseq.libbleu',
sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=['-std=c++11'],
)
conv_tbc = create_extension(
'fairseq.temporal_convolution_tbc',
relative_to='fairseq',
headers=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h'],
sources=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp'],
define_macros=[('WITH_CUDA', None)],
with_cuda=True,
extra_compile_args=['-std=c++11'],
)
class build_py_hook(build_py):
def run(self):
conv_tbc.build()
build_py.run(self)
setup(
name='fairseq',
version='0.1.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme,
license=license,
install_requires=reqs.strip().split('\n'),
packages=find_packages(),
ext_modules=[bleu],
# build and install PyTorch extensions
package_data={
'fairseq': ['temporal_convolution_tbc/*.so'],
},
include_package_data=True,
cmdclass={
'build_py': build_py_hook,
},
)

View File

@ -0,0 +1,35 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from torch.autograd import Variable, gradcheck
torch.set_default_tensor_type('torch.DoubleTensor')
class TestLabelSmoothing(unittest.TestCase):
def test_label_smoothing(self):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))
weights = torch.ones(5)
weights[2] = 0
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, weights), (input, target)))
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, None), (input, target)))
if __name__ == '__main__':
unittest.main()

210
train.py Normal file
View File

@ -0,0 +1,210 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import collections
import os
import torch
import math
from fairseq import bleu, data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Trainer')
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)')
dataset_args.add_argument('--test-subset', default='test', metavar='SPLIT',
help='comma separated list ofdata subset '
'to use for testing (train, valid, test)')
options.add_optimization_args(parser)
options.add_checkpoint_args(parser)
options.add_model_args(parser)
args = parser.parse_args()
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
progress_bar.print_interval = args.log_interval
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
torch.manual_seed(args.seed)
# Load dataset
dataset = data.load_with_check(args.data, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in dataset.splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model
print('| model {}'.format(args.arch))
model = utils.build_model(args, dataset)
criterion = utils.build_criterion(args, dataset)
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model)
# Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
# Train until the learning rate gets too small
val_loss = None
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, criterion, dataset, subset, num_gpus)
if k == 0:
if not args.no_save:
# save checkpoint
trainer.save_checkpoint(args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch)
epoch += 1
batch_offset = 0
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
# Generate on test set and compute BLEU score
for beam in [1, 5, 10, 20]:
for subset in args.test_subset.split(','):
scorer = score_test(args, trainer.get_model(), dataset, subset, beam,
cuda_device=(0 if num_gpus > 0 else None))
print('| Test on {} with beam={}: {}'.format(subset, beam, scorer.result_string()))
# Stop multiprocessing
trainer.stop()
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
"""Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement)
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
clip_meter = AverageMeter() # % of updates clipped
gnorm_meter = AverageMeter() # gradient norm
desc = '| epoch {:03d}'.format(epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss, grad_norm = trainer.train_step(sample, criterion)
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, ntokens)
bsz_meter.update(src_size)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if grad_norm > args.clip_norm else 0)
gnorm_meter.update(grad_norm)
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
('wps', '{:5d}'.format(round(wps_meter.avg))),
('wpb', '{:5d}'.format(round(wpb_meter.avg))),
('bsz', '{:5d}'.format(round(bsz_meter.avg))),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
]))
if i == 0:
# ignore the first mini-batch in words-per-second calculation
wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
t.write(fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
round(wps_meter.elapsed_time),
round(wps_meter.avg),
round(wpb_meter.avg),
round(bsz_meter.avg),
lr, clip_meter.avg * 100,
gnorm_meter.avg))
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None,
max_tokens=args.max_tokens,
max_positions=args.max_positions)
loss_meter = AverageMeter()
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample, criterion)
loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg))
val_loss = loss_meter.avg
t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'
.format(val_loss, math.pow(2, val_loss)))
# update and return the learning rate
return val_loss
def score_test(args, model, dataset, subset, beam, cuda_device):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator = SequenceGenerator([model], dataset.dst_dict, beam_size=beam)
if torch.cuda.is_available():
translator.cuda()
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(subset, batch_size=4, max_positions=args.max_positions)
for _, _, ref, hypos in translator.generate_batched_itr(itr, cuda_device=cuda_device):
scorer.add(ref.int().cpu(), hypos[0]['tokens'].int().cpu())
return scorer
if __name__ == '__main__':
main()