Multilingual training example (#527)

Summary:
* Add example for multilingual translation on IWSLT'17
* Match dataset ordering for multilingual_translation and translation
* Fix bug with LegacyDistributedDataParallel when calling forward of sub-modules
Pull Request resolved: https://github.com/pytorch/fairseq/pull/527

Differential Revision: D14218372

Pulled By: myleott

fbshipit-source-id: 2e3fe24aa39476bcc5c9af68ef9a40192db34a3b
This commit is contained in:
Myle Ott 2019-02-25 18:40:37 -08:00 committed by Facebook Github Bot
parent 44d27e645b
commit 00493490ba
10 changed files with 388 additions and 23 deletions

View File

@ -92,7 +92,6 @@ $ fairseq-generate data-bin/iwslt14.tokenized.de-en \
```
### prepare-wmt14en2de.sh
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
@ -163,3 +162,71 @@ $ fairseq-generate data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
```
## Multilingual Translation
We also support training multilingual translation models. In this example we'll
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
languages and use interactive.py and sacrebleu for scoring the test set.
```
# First install sacrebleu and sentencepiece
$ pip install sacrebleu sentencepiece
# Then download and preprocess the data
$ cd examples/translation/
$ bash prepare-iwslt17-multilingual.sh
$ cd ../..
# Binarize the de-en dataset
$ TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
$ fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \
--joined-dictionary \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Binarize the fr-en dataset
# NOTE: it's important to reuse the en dictionary from the previous step
$ fairseq-preprocess --source-lang fr --target-lang en \
--trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \
--joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Train a multilingual transformer model
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
# 8 fwd/bwd passes to simulate training on 8 GPUs
$ mkdir -p checkpoints/multilingual_transformer
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
--max-epoch 50 \
--ddp-backend=no_c10d \
--task multilingual_translation --lang-pairs de-en,fr-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)'
--lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy
--dropout 0.3 --weight-decay 0.0001 \
--save-dir checkpoints/multilingual_transformer \
--max-tokens 4000 \
--update-freq 8
# Generate and score the test set with sacrebleu
$ SRC=de
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
```

View File

@ -0,0 +1,126 @@
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
SRCS=(
"de"
"fr"
)
TGT=en
ROOT=$(dirname "$0")
SCRIPTS=$ROOT/../../scripts
SPM_TRAIN=$SCRIPTS/spm_train.py
SPM_ENCODE=$SCRIPTS/spm_encode.py
BPESIZE=16384
ORIG=$ROOT/iwslt17_orig
DATA=$ROOT/iwslt17.de_fr.en.bpe16k
mkdir -p "$ORIG" "$DATA"
TRAIN_MINLEN=1 # remove sentences with <1 BPE token
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens
URLS=(
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz"
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz"
)
ARCHIVES=(
"de-en.tgz"
"fr-en.tgz"
)
VALID_SETS=(
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en"
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en"
)
# download and extract data
for ((i=0;i<${#URLS[@]};++i)); do
ARCHIVE=$ORIG/${ARCHIVES[i]}
if [ -f "$ARCHIVE" ]; then
echo "$ARCHIVE already exists, skipping download"
else
URL=${URLS[i]}
wget -P "$ORIG" "$URL"
if [ -f "$ARCHIVE" ]; then
echo "$URL successfully downloaded."
else
echo "$URL not successfully downloaded."
exit 1
fi
fi
FILE=${ARCHIVE: -4}
if [ -e "$FILE" ]; then
echo "$FILE already exists, skipping extraction"
else
tar -C "$ORIG" -xzvf "$ARCHIVE"
fi
done
echo "pre-processing train data..."
for SRC in "${SRCS[@]}"; do
for LANG in "${SRC}" "${TGT}"; do
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \
| grep -v '<url>' \
| grep -v '<talkid>' \
| grep -v '<keywords>' \
| grep -v '<speaker>' \
| grep -v '<reviewer' \
| grep -v '<translator' \
| grep -v '<doc' \
| grep -v '</doc>' \
| sed -e 's/<title>//g' \
| sed -e 's/<\/title>//g' \
| sed -e 's/<description>//g' \
| sed -e 's/<\/description>//g' \
| sed 's/^\s*//g' \
| sed 's/\s*$//g' \
> "$DATA/train.${SRC}-${TGT}.${LANG}"
done
done
echo "pre-processing valid data..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=${VALID_SETS[i]}
for FILE in ${VALID_SET[@]}; do
for LANG in "$SRC" "$TGT"; do
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \
| sed -e 's/<seg id="[0-9]*">\s*//g' \
| sed -e 's/\s*<\/seg>\s*//g' \
| sed -e "s/\/\'/g" \
> "$DATA/valid.${SRC}-${TGT}.${LANG}"
done
done
done
# learn BPE with sentencepiece
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",")
echo "learning joint BPE over ${TRAIN_FILES}..."
python "$SPM_TRAIN" \
--input=$TRAIN_FILES \
--model_prefix=$DATA/sentencepiece.bpe \
--vocab_size=$BPESIZE \
--character_coverage=1.0 \
--model_type=bpe
# encode train/valid/test
echo "encoding train/valid with learned BPE..."
for SRC in "${SRCS[@]}"; do
for LANG in "$SRC" "$TGT"; do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs "$DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT}" \
--outputs "$DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT}" \
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs "$DATA/valid.${SRC}-${TGT}.${SRC} $DATA/valid.${SRC}-${TGT}.${TGT}" \
--outputs "$DATA/valid.bpe.${SRC}-${TGT}.${SRC} $DATA/valid.bpe.${SRC}-${TGT}.${TGT}"
done
done

View File

@ -39,12 +39,11 @@ class RoundRobinZipDatasets(FairseqDataset):
self.longest_dataset = dataset
self.longest_dataset_key = key
self._ordered_indices = OrderedDict([
(key, dataset.ordered_indices())
for key, dataset in datasets.items()
])
self._ordered_indices = None
def _map_index(self, key, index):
assert self._ordered_indices is not None, \
'Must call RoundRobinZipDatasets.ordered_indices() first'
return self._ordered_indices[key][index % len(self.datasets[key])]
def __getitem__(self, index):
@ -102,6 +101,14 @@ class RoundRobinZipDatasets(FairseqDataset):
def ordered_indices(self):
"""Ordered indices for batching."""
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying dataset directly.
self._ordered_indices = OrderedDict([
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
])
return np.arange(len(self))
@property

View File

@ -75,7 +75,6 @@ class LegacyDistributedDataParallel(nn.Module):
self._register_grad_hook()
def forward(self, *inputs, **kwargs):
self.need_reduction = True
return self.module(*inputs, **kwargs)
def _register_grad_hook(self):
@ -166,6 +165,7 @@ class LegacyDistributedDataParallel(nn.Module):
for p in self.module.parameters():
def allreduce_hook(*unused):
self.need_reduction = True
Variable._execution_engine.queue_callback(reduction_fn)
if p.requires_grad:

View File

@ -226,7 +226,6 @@ class FairseqTask(object):
- logging outputs to display while training
"""
model.train()
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0

View File

@ -50,6 +50,7 @@ class Trainer(object):
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._prev_grad_norm = None
self._wrapped_model = None
self.init_meters(args)
@ -215,12 +216,15 @@ class Trainer(object):
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms],
))
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
))
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
assert all(norm == prev_norms[0] for norm in prev_norms), \
'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
@ -246,6 +250,7 @@ class Trainer(object):
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
self._prev_grad_norm = grad_norm
# take an optimization step
self.optimizer.step()

View File

@ -56,37 +56,37 @@ def main(args):
padding_factor=args.padding_factor,
)
if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
raise FileExistsError(dict_path(args.source_lang))
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
raise FileExistsError(dict_path(args.target_lang))
if args.joined_dictionary:
assert (
not args.srcdict or not args.tgtdict
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
assert not args.srcdict or not args.tgtdict, \
"cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(
{train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
if target:
if args.tgtdict:
tgt_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
else:
tgt_dict = None

45
scripts/spm_decode.py Normal file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import sentencepiece as spm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="sentencepiece model to use for decoding")
parser.add_argument("--input", required=True, help="input file to decode")
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
if args.input_format == "piece":
def decode(l):
return "".join(sp.DecodePieces(l))
elif args.input_format == "id":
def decode(l):
return "".join(sp.DecodeIds(l))
else:
raise NotImplementedError
def tok2int(tok):
# remap reference-side <unk> (represented as <<unk>>) to 0
return int(tok) if tok != "<<unk>>" else 0
with open(args.input, "r", encoding="utf-8") as h:
for line in h:
print(decode(list(map(tok2int, line.rstrip().split()))))
if __name__ == "__main__":
main()

99
scripts/spm_encode.py Normal file
View File

@ -0,0 +1,99 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import contextlib
import sys
import sentencepiece as spm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="sentencepiece model to use for encoding")
parser.add_argument("--inputs", nargs="+", default=['-'],
help="input files to filter/encode")
parser.add_argument("--outputs", nargs="+", default=['-'],
help="path to save encoded outputs")
parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
parser.add_argument("--min-len", type=int, metavar="N",
help="filter sentence pairs with fewer than N tokens")
parser.add_argument("--max-len", type=int, metavar="N",
help="filter sentence pairs with more than N tokens")
args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match"
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
if args.output_format == "piece":
def encode(l):
return sp.EncodeAsPieces(l)
elif args.output_format == "id":
def encode(l):
return list(map(str, sp.EncodeAsIds(l)))
else:
raise NotImplementedError
if args.min_len is not None or args.max_len is not None:
def valid(line):
return (
(args.min_len is None or len(line) >= args.min_len)
and (args.max_len is None or len(line) <= args.max_len)
)
else:
def valid(lines):
return True
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8")) \
if input != "-" else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8")) \
if output != "-" else sys.stdout
for output in args.outputs
]
stats = {
"num_empty": 0,
"num_filtered": 0,
}
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
for i, lines in enumerate(zip(*inputs), start=1):
enc_lines = list(map(encode_line, lines))
if not any(enc_line is None for enc_line in enc_lines):
for enc_line, output_h in zip(enc_lines, outputs):
print(" ".join(enc_line), file=output_h)
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
if __name__ == "__main__":
main()

17
scripts/spm_train.py Normal file
View File

@ -0,0 +1,17 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import shlex
import sys
import sentencepiece as spm
if __name__ == "__main__":
spm.SentencePieceTrainer.Train(" ".join(map(shlex.quote, sys.argv[1:])))